[to #42339763] move pydataset into maas_lib

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8974892
This commit is contained in:
yingda.chen
2022-06-09 10:14:48 +08:00
parent e3b8ec3bf1
commit 0d840d519c
8 changed files with 11 additions and 12 deletions

View File

@@ -101,7 +101,7 @@ import cv2
import os.path as osp
from maas_lib.pipelines import pipeline
from maas_lib.utils.constant import Tasks
from pydatasets import PyDataset
from maas_lib.pydatasets import PyDataset
# 使用图像url构建PyDataset此处也可通过 input_location = '/dir/to/images' 来使用本地文件夹
input_location = [

View File

@@ -2,14 +2,14 @@
import os.path as osp
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Tuple, Union
from typing import Any, Dict, Generator, List, Union
from maas_hub.snapshot_download import snapshot_download
from pydatasets import PyDataset
from maas_lib.models import Model
from maas_lib.pipelines import util
from maas_lib.preprocessors import Preprocessor
from maas_lib.pydatasets import PyDataset
from maas_lib.utils.config import Config
from .util import is_model_name

View File

@@ -11,7 +11,7 @@ logger = get_logger()
class PyDataset:
_hf_ds = None # holds the underlying HuggingFace Dataset
"""A PyDataset backed by hugging face datasets."""
"""A PyDataset backed by hugging face Dataset."""
def __init__(self, hf_ds: Dataset):
self._hf_ds = hf_ds
@@ -52,7 +52,7 @@ class PyDataset:
Mapping[str, Union[str,
Sequence[str]]]]] = None
) -> 'PyDataset':
"""Load a pydataset from the MaaS Hub, Hugging Face Hub, urls, or a local dataset.
"""Load a PyDataset from the MaaS Hub, Hugging Face Hub, urls, or a local dataset.
Args:
path (str): Path or name of the dataset.
@@ -64,7 +64,7 @@ class PyDataset:
split (str, optional): Which split of the data to load.
Returns:
pydataset (obj:`PyDataset`): PyDataset object for a certain dataset.
PyDataset (obj:`PyDataset`): PyDataset object for a certain dataset.
"""
if isinstance(path, str):
dataset = load_dataset(

View File

@@ -5,10 +5,10 @@ import tempfile
import unittest
import cv2
from pydatasets import PyDataset
from maas_lib.fileio import File
from maas_lib.pipelines import pipeline, util
from maas_lib.pydatasets import PyDataset
from maas_lib.utils.constant import Tasks

View File

@@ -1,17 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import unittest
import zipfile
from pathlib import Path
from pydatasets import PyDataset
from maas_lib.fileio import File
from maas_lib.models import Model
from maas_lib.models.nlp import BertForSequenceClassification
from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util
from maas_lib.preprocessors import SequenceClassificationPreprocessor
from maas_lib.pydatasets import PyDataset
from maas_lib.utils.constant import Tasks

View File

@@ -1,13 +1,14 @@
import unittest
import datasets as hfdata
from pydatasets import PyDataset
from maas_lib.pydatasets import PyDataset
class PyDatasetTest(unittest.TestCase):
def setUp(self):
# ds1 initiazed from in memory json
# ds1 initialized from in memory json
self.json_data = {
'dummy': [{
'a': i,