mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 13:15:06 +02:00
add PyDataset support
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8868644
This commit is contained in:
@@ -3,8 +3,9 @@
|
||||
import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from multiprocessing.sharedctypes import Value
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, Generator, List, Tuple, Union
|
||||
|
||||
from ali_maas_datasets import PyDataset
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from maas_lib.models import Model
|
||||
@@ -14,7 +15,7 @@ from maas_lib.utils.constant import CONFIGFILE
|
||||
from .util import is_model_name
|
||||
|
||||
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
||||
Input = Union[str, 'PIL.Image.Image', 'numpy.ndarray']
|
||||
Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
|
||||
|
||||
output_keys = [
|
||||
] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key
|
||||
@@ -59,8 +60,8 @@ class Pipeline(ABC):
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
def __call__(self, input: Union[Input, List[Input]], *args,
|
||||
**post_kwargs) -> Dict[str, Any]:
|
||||
# model provider should leave it as it is
|
||||
**post_kwargs) -> Union[Dict[str, Any], Generator]:
|
||||
# moodel provider should leave it as it is
|
||||
# maas library developer will handle this function
|
||||
|
||||
# simple showcase, need to support iterator type for both tensorflow and pytorch
|
||||
@@ -69,10 +70,18 @@ class Pipeline(ABC):
|
||||
output = []
|
||||
for ele in input:
|
||||
output.append(self._process_single(ele, *args, **post_kwargs))
|
||||
|
||||
elif isinstance(input, PyDataset):
|
||||
return self._process_iterator(input, *args, **post_kwargs)
|
||||
|
||||
else:
|
||||
output = self._process_single(input, *args, **post_kwargs)
|
||||
return output
|
||||
|
||||
def _process_iterator(self, input: Input, *args, **post_kwargs):
|
||||
for ele in input:
|
||||
yield self._process_single(ele, *args, **post_kwargs)
|
||||
|
||||
def _process_single(self, input: Input, *args,
|
||||
**post_kwargs) -> Dict[str, Any]:
|
||||
out = self.preprocess(input)
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/maas_lib-0.1.1-py3-none-any.whl
|
||||
https://maashub.oss-cn-hangzhou.aliyuncs.com/releases/maas_hub-0.1.0.dev0-py2.py3-none-any.whl
|
||||
https://mit-dataset.oss-cn-beijing.aliyuncs.com/release/ali_maas_datasets-0.0.1.dev0-py3-none-any.whl
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
addict
|
||||
https://maashub.oss-cn-hangzhou.aliyuncs.com/releases/maas_hub-0.1.0.dev0-py2.py3-none-any.whl
|
||||
https://mit-dataset.oss-cn-beijing.aliyuncs.com/release/ali_maas_datasets-0.0.1.dev0-py3-none-any.whl
|
||||
numpy
|
||||
opencv-python-headless
|
||||
Pillow
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Tuple, Union
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
from ali_maas_datasets import PyDataset
|
||||
|
||||
from maas_lib.fileio import File
|
||||
from maas_lib.pipelines import pipeline
|
||||
@@ -30,6 +31,25 @@ class ImageMattingTest(unittest.TestCase):
|
||||
)
|
||||
cv2.imwrite('result.png', result['output_png'])
|
||||
|
||||
def test_dataset(self):
|
||||
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \
|
||||
'.com/data/test/maas/image_matting/matting_person.pb'
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model_file = osp.join(tmp_dir, 'matting_person.pb')
|
||||
with open(model_file, 'wb') as ofile:
|
||||
ofile.write(File.read(model_path))
|
||||
img_matting = pipeline(Tasks.image_matting, model=tmp_dir)
|
||||
# dataset = PyDataset.load('/dir/to/images', target='image')
|
||||
# yapf: disable
|
||||
dataset = PyDataset.load([
|
||||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png'
|
||||
],
|
||||
target='image')
|
||||
result = img_matting(dataset)
|
||||
for i, r in enumerate(result):
|
||||
cv2.imwrite(f'/path/to/result/{i}.png', r['output_png'])
|
||||
print('end')
|
||||
|
||||
def test_run_modelhub(self):
|
||||
img_matting = pipeline(
|
||||
Tasks.image_matting, model='damo/image-matting-person')
|
||||
|
||||
@@ -4,6 +4,8 @@ import unittest
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
from ali_maas_datasets import PyDataset
|
||||
|
||||
from maas_lib.fileio import File
|
||||
from maas_lib.models import Model
|
||||
from maas_lib.models.nlp import SequenceClassificationModel
|
||||
@@ -58,6 +60,33 @@ class SequenceClassificationTest(unittest.TestCase):
|
||||
task='text-classification', model=model, preprocessor=preprocessor)
|
||||
self.predict(pipeline_ins)
|
||||
|
||||
def test_dataset(self):
|
||||
model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \
|
||||
'/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip'
|
||||
cache_path_str = r'.cache/easynlp/bert-base-sst2.zip'
|
||||
cache_path = Path(cache_path_str)
|
||||
|
||||
if not cache_path.exists():
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_path.touch(exist_ok=True)
|
||||
with cache_path.open('wb') as ofile:
|
||||
ofile.write(File.read(model_url))
|
||||
|
||||
with zipfile.ZipFile(cache_path_str, 'r') as zipf:
|
||||
zipf.extractall(cache_path.parent)
|
||||
path = r'.cache/easynlp/bert-base-sst2'
|
||||
model = SequenceClassificationModel(path)
|
||||
preprocessor = SequenceClassificationPreprocessor(
|
||||
path, first_sequence='sentence', second_sequence=None)
|
||||
text_classification = pipeline(
|
||||
'text-classification', model=model, preprocessor=preprocessor)
|
||||
dataset = PyDataset.load('glue', name='sst2', target='sentence')
|
||||
result = text_classification(dataset)
|
||||
for i, r in enumerate(result):
|
||||
if i > 10:
|
||||
break
|
||||
print(r)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user