mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #43112771] requirements check and lazy import support
This commit is contained in:
committed by
zhangzhicheng.zzc
parent
69777a97df
commit
d55525bfb6
@@ -4,6 +4,7 @@ from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
|
||||
DEFAULT_MODELSCOPE_GROUP,
|
||||
MODEL_ID_SEPARATOR,
|
||||
MODELSCOPE_URL_SCHEME)
|
||||
from modelscope.utils.utils import get_default_cache_dir
|
||||
|
||||
|
||||
def model_id_to_group_owner_name(model_id):
|
||||
@@ -21,8 +22,7 @@ def get_cache_dir():
|
||||
cache dir precedence:
|
||||
function parameter > enviroment > ~/.cache/modelscope/hub
|
||||
"""
|
||||
default_cache_dir = os.path.expanduser(
|
||||
os.path.join('~/.cache', 'modelscope'))
|
||||
default_cache_dir = get_default_cache_dir()
|
||||
return os.getenv('MODELSCOPE_CACHE', os.path.join(default_cache_dir,
|
||||
'hub'))
|
||||
|
||||
|
||||
@@ -1,8 +1,36 @@
|
||||
from .base import Metric
|
||||
from .builder import METRICS, build_metric, task_default_metrics
|
||||
from .image_color_enhance_metric import ImageColorEnhanceMetric
|
||||
from .image_denoise_metric import ImageDenoiseMetric
|
||||
from .image_instance_segmentation_metric import \
|
||||
ImageInstanceSegmentationCOCOMetric
|
||||
from .sequence_classification_metric import SequenceClassificationMetric
|
||||
from .text_generation_metric import TextGenerationMetric
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import Metric
|
||||
from .builder import METRICS, build_metric, task_default_metrics
|
||||
from .image_color_enhance_metric import ImageColorEnhanceMetric
|
||||
from .image_denoise_metric import ImageDenoiseMetric
|
||||
from .image_instance_segmentation_metric import \
|
||||
ImageInstanceSegmentationCOCOMetric
|
||||
from .sequence_classification_metric import SequenceClassificationMetric
|
||||
from .text_generation_metric import TextGenerationMetric
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'base': ['Metric'],
|
||||
'builder': ['METRICS', 'build_metric', 'task_default_metrics'],
|
||||
'image_color_enhance_metric': ['ImageColorEnhanceMetric'],
|
||||
'image_denoise_metric': ['ImageDenoiseMetric'],
|
||||
'image_instance_segmentation_metric':
|
||||
['ImageInstanceSegmentationCOCOMetric'],
|
||||
'sequence_classification_metric': ['SequenceClassificationMetric'],
|
||||
'text_generation_metric': ['TextGenerationMetric'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1,39 +1,12 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.error import (AUDIO_IMPORT_ERROR,
|
||||
TENSORFLOW_IMPORT_WARNING)
|
||||
from .base import Model
|
||||
from .builder import MODELS, build_model
|
||||
from modelscope.utils.import_utils import is_torch_available
|
||||
from . import audio, cv, multi_modal, nlp
|
||||
from .base import Head, Model
|
||||
from .builder import BACKBONES, HEADS, MODELS, build_model
|
||||
|
||||
try:
|
||||
from .audio.ans.frcrn import FRCRNModel
|
||||
from .audio.asr import GenericAutomaticSpeechRecognition
|
||||
from .audio.kws import GenericKeyWordSpotting
|
||||
from .audio.tts import SambertHifigan
|
||||
except ModuleNotFoundError as e:
|
||||
print(AUDIO_IMPORT_ERROR.format(e))
|
||||
|
||||
try:
|
||||
from .nlp.csanmt_for_translation import CsanmtForTranslation
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'tensorflow'":
|
||||
print(TENSORFLOW_IMPORT_WARNING.format('CsanmtForTranslation'))
|
||||
else:
|
||||
raise ModuleNotFoundError(e)
|
||||
|
||||
try:
|
||||
from .multi_modal import OfaForImageCaptioning
|
||||
from .cv import NAFNetForImageDenoise
|
||||
from .nlp import (BertForMaskedLM, BertForSequenceClassification,
|
||||
SbertForNLI, SbertForSentenceSimilarity,
|
||||
SbertForSentimentClassification,
|
||||
SbertForTokenClassification,
|
||||
SbertForZeroShotClassification, SpaceForDialogIntent,
|
||||
SpaceForDialogModeling, SpaceForDialogStateTracking,
|
||||
StructBertForMaskedLM, VecoForMaskedLM)
|
||||
from .nlp.backbones import SbertModel
|
||||
from .nlp.heads import SequenceClassificationHead
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'pytorch'":
|
||||
pass
|
||||
else:
|
||||
raise ModuleNotFoundError(e)
|
||||
if is_torch_available():
|
||||
from .base import TorchModel, TorchHead
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from . import ans, asr, kws, tts
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .frcrn import FRCRNModel
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'frcrn': ['FRCRNModel'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1 +1,23 @@
|
||||
from .generic_automatic_speech_recognition import * # noqa F403
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .generic_automatic_speech_recognition import GenericAutomaticSpeechRecognition
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'generic_automatic_speech_recognition':
|
||||
['GenericAutomaticSpeechRecognition'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1 +1,22 @@
|
||||
from .generic_key_word_spotting import * # noqa F403
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .generic_key_word_spotting import GenericKeyWordSpotting
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'generic_key_word_spotting': ['GenericKeyWordSpotting'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1 +1,22 @@
|
||||
from .sambert_hifi import * # noqa F403
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .sambert_hifi import SambertHifigan
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'sambert_hifi': ['SambertHifigan'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.torch_utils import create_device
|
||||
from .base_model import Model
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .image_color_enhance.image_color_enhance import ImageColorEnhance
|
||||
from .image_denoise.nafnet_for_image_denoise import * # noqa F403
|
||||
from . import (action_recognition, animal_recognition, cartoon,
|
||||
cmdssl_video_embedding, face_generation, image_color_enhance,
|
||||
image_colorization, image_denoise, image_instance_segmentation,
|
||||
super_resolution, virual_tryon)
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .models import BaseVideoModel
|
||||
from .tada_convnext import TadaConvNeXt
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'models': ['BaseVideoModel'],
|
||||
'tada_convnext': ['TadaConvNeXt'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .resnet import ResNet, Bottleneck
|
||||
from .splat import SplAtConv2d
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'resnet': ['ResNet', 'Bottleneck'],
|
||||
'splat': ['SplAtConv2d']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .facelib.facer import FaceAna
|
||||
from .mtcnn_pytorch.src.align_trans import (get_reference_facial_points,
|
||||
warp_and_crop_face)
|
||||
from .utils import (get_f5p, padTo16x, resize_size)
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'facelib.facer': ['FaceAna'],
|
||||
'mtcnn_pytorch.src.align_trans':
|
||||
['get_reference_facial_points', 'warp_and_crop_face'],
|
||||
'utils': ['get_f5p', 'padTo16x', 'resize_size']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1,3 +1,26 @@
|
||||
from .c3d import C3D
|
||||
from .resnet2p1d import resnet26_2p1d
|
||||
from .resnet3d import resnet26_3d
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .c3d import C3D
|
||||
from .resnet2p1d import resnet26_2p1d
|
||||
from .resnet3d import resnet26_3d
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'c3d': ['C3D'],
|
||||
'resnet2p1d': ['resnet26_2p1d'],
|
||||
'resnet3d': ['resnet26_3d']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .stylegan2 import Generator
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'stylegan2': ['Generator'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .image_color_enhance import ImageColorEnhance
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'image_color_enhance': ['ImageColorEnhance'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .unet import DynamicUnetWide, DynamicUnetDeep
|
||||
from .utils import NormType
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'unet': ['DynamicUnetWide', 'DynamicUnetDeep'],
|
||||
'utils': ['NormType']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .nafnet_for_image_denoise import NAFNetForImageDenoise
|
||||
|
||||
else:
|
||||
_import_structure = {'nafnet_for_image_denoise': ['NAFNetForImageDenoise']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1,2 +1,27 @@
|
||||
from .cascade_mask_rcnn_swin import CascadeMaskRCNNSwin
|
||||
from .model import CascadeMaskRCNNSwinModel
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cascade_mask_rcnn_swin import CascadeMaskRCNNSwin
|
||||
from .model import CascadeMaskRCNNSwinModel
|
||||
from .postprocess_utils import get_img_ins_seg_result
|
||||
from .datasets import ImageInstanceSegmentationCocoDataset
|
||||
else:
|
||||
_import_structure = {
|
||||
'cascade_mask_rcnn_swin': ['CascadeMaskRCNNSwin'],
|
||||
'model': ['CascadeMaskRCNNSwinModel'],
|
||||
'postprocess_utils': ['get_img_ins_seg_result'],
|
||||
'datasets': ['ImageInstanceSegmentationCocoDataset']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1 +1,22 @@
|
||||
from .swin_transformer import SwinTransformer
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .swin_transformer import SwinTransformer
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'swin_transformer': ['SwinTransformer'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .rrdbnet_arch import RRDBNet
|
||||
|
||||
else:
|
||||
_import_structure = {'rrdbnet_arch': ['RRDBNet']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .sdafnet import SDAFNet_Tryon
|
||||
|
||||
else:
|
||||
_import_structure = {'sdafnet': ['SDAFNet_Tryon']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1,8 +1,35 @@
|
||||
from .clip.clip_model import CLIPForMultiModalEmbedding
|
||||
from .gemm.gemm_model import GEMMForMultiModalEmbedding
|
||||
from .imagen.imagen_model import ImagenForTextToImageSynthesis
|
||||
from .mmr.models.clip_for_multi_model_video_embedding import \
|
||||
VideoCLIPForMultiModalEmbedding
|
||||
from .mplug_for_visual_question_answering import \
|
||||
MPlugForVisualQuestionAnswering
|
||||
from .ofa_for_image_captioning_model import OfaForImageCaptioning
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .clip import CLIPForMultiModalEmbedding
|
||||
from .gemm import GEMMForMultiModalEmbedding
|
||||
from .imagen import ImagenForTextToImageSynthesis
|
||||
from .mmr import VideoCLIPForMultiModalEmbedding
|
||||
from .mplug_for_visual_question_answering import \
|
||||
MPlugForVisualQuestionAnswering
|
||||
from .ofa_for_image_captioning_model import OfaForImageCaptioning
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'clip': ['CLIPForMultiModalEmbedding'],
|
||||
'imagen': ['ImagenForTextToImageSynthesis'],
|
||||
'gemm': ['GEMMForMultiModalEmbedding'],
|
||||
'mmr': ['VideoCLIPForMultiModalEmbedding'],
|
||||
'mplug_for_visual_question_answering':
|
||||
['MPlugForVisualQuestionAnswering'],
|
||||
'ofa_for_image_captioning_model': ['OfaForImageCaptioning']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .clip_model import CLIPForMultiModalEmbedding
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .gemm_model import GEMMForMultiModalEmbedding
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .imagen_model import ImagenForTextToImageSynthesis
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .models import VideoCLIPForMultiModalEmbedding
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .clip_for_mm_video_embedding import VideoCLIPForMultiModalEmbedding
|
||||
|
||||
@@ -11,13 +11,11 @@ from PIL import Image
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.multi_modal.mmr.dataloaders.rawvideo_util import \
|
||||
RawVideoExtractor
|
||||
from modelscope.models.multi_modal.mmr.models.modeling import CLIP4Clip
|
||||
from modelscope.models.multi_modal.mmr.models.tokenization_clip import \
|
||||
SimpleTokenizer as ClipTokenizer
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from ..dataloaders.rawvideo_util import RawVideoExtractor
|
||||
from .modeling import CLIP4Clip
|
||||
from .tokenization_clip import SimpleTokenizer as ClipTokenizer
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -1,25 +1,59 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from modelscope.utils.error import TENSORFLOW_IMPORT_WARNING
|
||||
from .backbones import * # noqa F403
|
||||
from .bert_for_sequence_classification import * # noqa F403
|
||||
from .heads import * # noqa F403
|
||||
from .masked_language import * # noqa F403
|
||||
from .nncrf_for_named_entity_recognition import * # noqa F403
|
||||
from .palm_for_text_generation import * # noqa F403
|
||||
from .sbert_for_nli import * # noqa F403
|
||||
from .sbert_for_sentence_similarity import * # noqa F403
|
||||
from .sbert_for_sentiment_classification import * # noqa F403
|
||||
from .sbert_for_token_classification import * # noqa F403
|
||||
from .sbert_for_zero_shot_classification import * # noqa F403
|
||||
from .sequence_classification import * # noqa F403
|
||||
from .space_for_dialog_intent_prediction import * # noqa F403
|
||||
from .space_for_dialog_modeling import * # noqa F403
|
||||
from .space_for_dialog_state_tracking import * # noqa F403
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
try:
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .backbones import (SbertModel, SpaceGenerator, SpaceModelBase)
|
||||
from .heads import SequenceClassificationHead
|
||||
from .bert_for_sequence_classification import BertForSequenceClassification
|
||||
from .csanmt_for_translation import CsanmtForTranslation
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'tensorflow'":
|
||||
print(TENSORFLOW_IMPORT_WARNING.format('CsanmtForTranslation'))
|
||||
else:
|
||||
raise ModuleNotFoundError(e)
|
||||
from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM,
|
||||
BertForMaskedLM)
|
||||
from .nncrf_for_named_entity_recognition import TransformerCRFForNamedEntityRecognition
|
||||
from .palm_for_text_generation import PalmForTextGeneration
|
||||
from .sbert_for_nli import SbertForNLI
|
||||
from .sbert_for_sentence_similarity import SbertForSentenceSimilarity
|
||||
from .sbert_for_sentiment_classification import SbertForSentimentClassification
|
||||
from .sbert_for_token_classification import SbertForTokenClassification
|
||||
from .sbert_for_zero_shot_classification import SbertForZeroShotClassification
|
||||
from .sequence_classification import SequenceClassificationModel
|
||||
from .space_for_dialog_intent_prediction import SpaceForDialogIntent
|
||||
from .space_for_dialog_modeling import SpaceForDialogModeling
|
||||
from .space_for_dialog_state_tracking import SpaceForDialogStateTracking
|
||||
from .task_model import SingleBackboneTaskModelBase
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'backbones': ['SbertModel', 'SpaceGenerator', 'SpaceModelBase'],
|
||||
'heads': ['SequenceClassificationHead'],
|
||||
'csanmt_for_translation': ['CsanmtForTranslation'],
|
||||
'bert_for_sequence_classification': ['BertForSequenceClassification'],
|
||||
'masked_language':
|
||||
['StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM'],
|
||||
'nncrf_for_named_entity_recognition':
|
||||
['TransformerCRFForNamedEntityRecognition'],
|
||||
'palm_for_text_generation': ['PalmForTextGeneration'],
|
||||
'sbert_for_nli': ['SbertForNLI'],
|
||||
'sbert_for_sentence_similarity': ['SbertForSentenceSimilarity'],
|
||||
'sbert_for_sentiment_classification':
|
||||
['SbertForSentimentClassification'],
|
||||
'sbert_for_token_classification': ['SbertForTokenClassification'],
|
||||
'sbert_for_zero_shot_classification':
|
||||
['SbertForZeroShotClassification'],
|
||||
'sequence_classification': ['SequenceClassificationModel'],
|
||||
'space_for_dialog_intent_prediction': ['SpaceForDialogIntent'],
|
||||
'space_for_dialog_modeling': ['SpaceForDialogModeling'],
|
||||
'space_for_dialog_state_tracking': ['SpaceForDialogStateTracking'],
|
||||
'task_model': ['SingleBackboneTaskModelBase'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1,4 +1,23 @@
|
||||
from .space import SpaceGenerator, SpaceModelBase
|
||||
from .structbert import SbertModel
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
__all__ = ['SbertModel', 'SpaceGenerator', 'SpaceModelBase']
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .space import SpaceGenerator, SpaceModelBase
|
||||
from .structbert import SbertModel
|
||||
else:
|
||||
_import_structure = {
|
||||
'space': ['SpaceGenerator', 'SpaceModelBase'],
|
||||
'structbert': ['SbertModel']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
from .sequence_classification_head import SequenceClassificationHead
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
__all__ = ['SequenceClassificationHead']
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .sequence_classification_head import SequenceClassificationHead
|
||||
else:
|
||||
_import_structure = {
|
||||
'sequence_classification_head': ['SequenceClassificationHead']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -6,11 +6,10 @@ from typing import Any, Dict
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model, Tensor
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.preprocessors.space.fields.intent_field import \
|
||||
IntentBPETextField
|
||||
from modelscope.models.nlp.backbones import SpaceGenerator, SpaceModelBase
|
||||
from modelscope.preprocessors.space import IntentBPETextField
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .backbones import SpaceGenerator, SpaceModelBase
|
||||
|
||||
__all__ = ['SpaceForDialogIntent']
|
||||
|
||||
|
||||
@@ -6,11 +6,10 @@ from typing import Any, Dict, Optional
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model, Tensor
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.preprocessors.space.fields.gen_field import \
|
||||
MultiWOZBPETextField
|
||||
from modelscope.models.nlp.backbones import SpaceGenerator, SpaceModelBase
|
||||
from modelscope.preprocessors.space import MultiWOZBPETextField
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .backbones import SpaceGenerator, SpaceModelBase
|
||||
|
||||
__all__ = ['SpaceForDialogModeling']
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .image_denoise_dataset import PairedImageDataset
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'image_denoise_dataset': ['PairedImageDataset'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
from modelscope.utils.error import AUDIO_IMPORT_ERROR
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
from . import audio, cv, multi_modal, nlp
|
||||
from .base import Pipeline
|
||||
from .builder import pipeline
|
||||
from .cv import * # noqa F403
|
||||
from .multi_modal import * # noqa F403
|
||||
from .nlp import * # noqa F403
|
||||
|
||||
try:
|
||||
from .audio import LinearAECPipeline
|
||||
from .audio.ans_pipeline import ANSPipeline
|
||||
except ModuleNotFoundError as e:
|
||||
print(AUDIO_IMPORT_ERROR.format(e))
|
||||
|
||||
@@ -1,21 +1,30 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.error import TENSORFLOW_IMPORT_ERROR
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
try:
|
||||
if TYPE_CHECKING:
|
||||
from .ans_pipeline import ANSPipeline
|
||||
from .asr_inference_pipeline import AutomaticSpeechRecognitionPipeline
|
||||
from .kws_kwsbp_pipeline import * # noqa F403
|
||||
from .kws_kwsbp_pipeline import KeyWordSpottingKwsbpPipeline
|
||||
from .linear_aec_pipeline import LinearAECPipeline
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'torch'":
|
||||
pass
|
||||
else:
|
||||
raise ModuleNotFoundError(e)
|
||||
from .text_to_speech_pipeline import TextToSpeechSambertHifiganPipeline
|
||||
|
||||
try:
|
||||
from .text_to_speech_pipeline import * # noqa F403
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'tensorflow'":
|
||||
print(TENSORFLOW_IMPORT_ERROR.format('tts'))
|
||||
else:
|
||||
raise ModuleNotFoundError(e)
|
||||
else:
|
||||
_import_structure = {
|
||||
'ans_pipeline': ['ANSPipeline'],
|
||||
'asr_inference_pipeline': ['AutomaticSpeechRecognitionPipeline'],
|
||||
'kws_kwsbp_pipeline': ['KeyWordSpottingKwsbpPipeline'],
|
||||
'linear_aec_pipeline': ['LinearAECPipeline'],
|
||||
'text_to_speech_pipeline': ['TextToSpeechSambertHifiganPipeline'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.torch_utils import create_device
|
||||
|
||||
|
||||
def audio_norm(x):
|
||||
|
||||
@@ -12,7 +12,7 @@ from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors.audio import LinearAECAndFbank
|
||||
from modelscope.preprocessors import LinearAECAndFbank
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
||||
Input = Union[str, tuple, MsDataset, 'PIL.Image.Image', 'numpy.ndarray']
|
||||
Input = Union[str, tuple, MsDataset, 'Image.Image', 'numpy.ndarray']
|
||||
InputModel = Union[str, Model]
|
||||
|
||||
logger = get_logger()
|
||||
@@ -233,7 +233,7 @@ class Pipeline(ABC):
|
||||
|
||||
"""
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from modelscope.preprocessors.space.dst_processors import InputFeatures
|
||||
from modelscope.preprocessors import InputFeatures
|
||||
if isinstance(data, dict) or isinstance(data, Mapping):
|
||||
return type(data)(
|
||||
{k: self._collate_fn(v)
|
||||
|
||||
@@ -1,33 +1,48 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.error import TENSORFLOW_IMPORT_ERROR
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
try:
|
||||
if TYPE_CHECKING:
|
||||
from .action_recognition_pipeline import ActionRecognitionPipeline
|
||||
from .animal_recog_pipeline import AnimalRecogPipeline
|
||||
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline
|
||||
from .face_image_generation_pipeline import FaceImageGenerationPipeline
|
||||
from .image_cartoon_pipeline import ImageCartoonPipeline
|
||||
from .image_denoise_pipeline import ImageDenoisePipeline
|
||||
from .image_color_enhance_pipeline import ImageColorEnhancePipeline
|
||||
from .virtual_tryon_pipeline import VirtualTryonPipeline
|
||||
from .image_colorization_pipeline import ImageColorizationPipeline
|
||||
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
|
||||
from .face_image_generation_pipeline import FaceImageGenerationPipeline
|
||||
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'torch'":
|
||||
pass
|
||||
else:
|
||||
raise ModuleNotFoundError(e)
|
||||
|
||||
try:
|
||||
from .image_cartoon_pipeline import ImageCartoonPipeline
|
||||
from .image_matting_pipeline import ImageMattingPipeline
|
||||
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
|
||||
from .style_transfer_pipeline import StyleTransferPipeline
|
||||
from .ocr_detection_pipeline import OCRDetectionPipeline
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'tensorflow'":
|
||||
print(
|
||||
TENSORFLOW_IMPORT_ERROR.format(
|
||||
'image-cartoon image-matting ocr-detection style-transfer'))
|
||||
else:
|
||||
raise ModuleNotFoundError(e)
|
||||
from .virtual_tryon_pipeline import VirtualTryonPipeline
|
||||
else:
|
||||
_import_structure = {
|
||||
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
|
||||
'animal_recog_pipeline': ['AnimalRecogPipeline'],
|
||||
'cmdssl_video_embedding_pipleline': ['CMDSSLVideoEmbeddingPipeline'],
|
||||
'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'],
|
||||
'virtual_tryon_pipeline': ['VirtualTryonPipeline'],
|
||||
'image_colorization_pipeline': ['ImageColorizationPipeline'],
|
||||
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'],
|
||||
'image_denoise_pipeline': ['ImageDenoisePipeline'],
|
||||
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'],
|
||||
'image_cartoon_pipeline': ['ImageCartoonPipeline'],
|
||||
'image_matting_pipeline': ['ImageMattingPipeline'],
|
||||
'style_transfer_pipeline': ['StyleTransferPipeline'],
|
||||
'ocr_detection_pipeline': ['OCRDetectionPipeline'],
|
||||
'image_instance_segmentation_pipeline':
|
||||
['ImageInstanceSegmentationPipeline'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -5,11 +5,11 @@ from typing import Any, Dict
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.action_recognition.models import BaseVideoModel
|
||||
from modelscope.models.cv.action_recognition import BaseVideoModel
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors.video import ReadVideoData
|
||||
from modelscope.preprocessors import ReadVideoData
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -9,11 +9,11 @@ from torchvision import transforms
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.animal_recognition import resnet
|
||||
from modelscope.models.cv.animal_recognition import Bottleneck, ResNet
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage, load_image
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -34,8 +34,8 @@ class AnimalRecogPipeline(Pipeline):
|
||||
import torch
|
||||
|
||||
def resnest101(**kwargs):
|
||||
model = resnet.ResNet(
|
||||
resnet.Bottleneck, [3, 4, 23, 3],
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3],
|
||||
radix=2,
|
||||
groups=1,
|
||||
bottleneck_width=64,
|
||||
|
||||
@@ -7,8 +7,7 @@ import torchvision.transforms.functional as TF
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.cmdssl_video_embedding.resnet2p1d import \
|
||||
resnet26_2p1d
|
||||
from modelscope.models.cv.cmdssl_video_embedding import resnet26_2p1d
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
|
||||
@@ -7,7 +7,7 @@ import PIL
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.face_generation import stylegan2
|
||||
from modelscope.models.cv.face_generation import Generator
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
@@ -36,7 +36,7 @@ class FaceImageGenerationPipeline(Pipeline):
|
||||
self.channel_multiplier = 2
|
||||
self.truncation = 0.7
|
||||
self.truncation_mean = 4096
|
||||
self.generator = stylegan2.Generator(
|
||||
self.generator = Generator(
|
||||
self.size,
|
||||
self.latent,
|
||||
self.n_mlp,
|
||||
|
||||
@@ -6,10 +6,10 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.cartoon.facelib.facer import FaceAna
|
||||
from modelscope.models.cv.cartoon.mtcnn_pytorch.src.align_trans import (
|
||||
get_reference_facial_points, warp_and_crop_face)
|
||||
from modelscope.models.cv.cartoon.utils import get_f5p, padTo16x, resize_size
|
||||
from modelscope.models.cv.cartoon import (FaceAna, get_f5p,
|
||||
get_reference_facial_points,
|
||||
padTo16x, resize_size,
|
||||
warp_and_crop_face)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
|
||||
@@ -5,8 +5,7 @@ from torchvision import transforms
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.cv.image_color_enhance.image_color_enhance import \
|
||||
ImageColorEnhance
|
||||
from modelscope.models.cv.image_color_enhance import ImageColorEnhance
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
|
||||
@@ -7,8 +7,8 @@ import torch
|
||||
from torchvision import models, transforms
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.image_colorization import unet
|
||||
from modelscope.models.cv.image_colorization.utils import NormType
|
||||
from modelscope.models.cv.image_colorization import (DynamicUnetDeep,
|
||||
DynamicUnetWide, NormType)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
@@ -46,7 +46,7 @@ class ImageColorizationPipeline(Pipeline):
|
||||
if self.model_type == 'stable':
|
||||
body = models.resnet101(pretrained=True)
|
||||
body = torch.nn.Sequential(*list(body.children())[:self.cut])
|
||||
self.model = unet.DynamicUnetWide(
|
||||
self.model = DynamicUnetWide(
|
||||
body,
|
||||
n_classes=3,
|
||||
blur=True,
|
||||
@@ -61,7 +61,7 @@ class ImageColorizationPipeline(Pipeline):
|
||||
else:
|
||||
body = models.resnet34(pretrained=True)
|
||||
body = torch.nn.Sequential(*list(body.children())[:cut])
|
||||
model = unet.DynamicUnetDeep(
|
||||
model = DynamicUnetDeep(
|
||||
body,
|
||||
n_classes=3,
|
||||
blur=True,
|
||||
@@ -84,7 +84,7 @@ class ImageColorizationPipeline(Pipeline):
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
if isinstance(input, str):
|
||||
img = load_image(input).convert('LA').convert('RGB')
|
||||
elif isinstance(input, PIL.Image.Image):
|
||||
elif isinstance(input, Image.Image):
|
||||
img = input.convert('LA').convert('RGB')
|
||||
elif isinstance(input, np.ndarray):
|
||||
if len(input.shape) == 2:
|
||||
|
||||
@@ -5,7 +5,7 @@ from torchvision import transforms
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.cv import NAFNetForImageDenoise
|
||||
from modelscope.models.cv.image_denoise import NAFNetForImageDenoise
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
|
||||
@@ -7,10 +7,8 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.image_instance_segmentation.model import \
|
||||
CascadeMaskRCNNSwinModel
|
||||
from modelscope.models.cv.image_instance_segmentation.postprocess_utils import \
|
||||
get_img_ins_seg_result
|
||||
from modelscope.models.cv.image_instance_segmentation import (
|
||||
CascadeMaskRCNNSwinModel, get_img_ins_seg_result)
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import (ImageInstanceSegmentationPreprocessor,
|
||||
|
||||
@@ -6,11 +6,11 @@ import PIL
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.super_resolution import rrdbnet_arch
|
||||
from modelscope.models.cv.super_resolution import RRDBNet
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage, load_image
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -32,7 +32,7 @@ class ImageSuperResolutionPipeline(Pipeline):
|
||||
self.num_feat = 64
|
||||
self.num_block = 23
|
||||
self.scale = 4
|
||||
self.sr_model = rrdbnet_arch.RRDBNet(
|
||||
self.sr_model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=self.num_feat,
|
||||
|
||||
@@ -12,7 +12,9 @@ from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils
|
||||
from .ocr_utils import (SegLinkDetector, cal_width, combine_segments_python,
|
||||
decode_segments_links_python, nms_python,
|
||||
rboxes_to_polygons)
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
tf = tf.compat.v1
|
||||
@@ -66,7 +68,7 @@ class OCRDetectionPipeline(Pipeline):
|
||||
0.997, global_step)
|
||||
|
||||
# detector
|
||||
detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector()
|
||||
detector = SegLinkDetector()
|
||||
all_maps = detector.build_model(
|
||||
self.input_images, is_training=False)
|
||||
|
||||
@@ -90,7 +92,7 @@ class OCRDetectionPipeline(Pipeline):
|
||||
|
||||
# decode segments and links
|
||||
image_size = tf.shape(self.input_images)[1:3]
|
||||
segments, group_indices, segment_counts, _ = ops.decode_segments_links_python(
|
||||
segments, group_indices, segment_counts, _ = decode_segments_links_python(
|
||||
image_size,
|
||||
all_nodes,
|
||||
all_links,
|
||||
@@ -98,7 +100,7 @@ class OCRDetectionPipeline(Pipeline):
|
||||
anchor_sizes=list(detector.anchor_sizes))
|
||||
|
||||
# combine segments
|
||||
combined_rboxes, combined_counts = ops.combine_segments_python(
|
||||
combined_rboxes, combined_counts = combine_segments_python(
|
||||
segments, group_indices, segment_counts)
|
||||
self.output['combined_rboxes'] = combined_rboxes
|
||||
self.output['combined_counts'] = combined_counts
|
||||
@@ -145,7 +147,7 @@ class OCRDetectionPipeline(Pipeline):
|
||||
# convert rboxes to polygons and find its coordinates on the original image
|
||||
orig_h, orig_w = inputs['orig_size']
|
||||
resize_h, resize_w = inputs['resize_size']
|
||||
polygons = utils.rboxes_to_polygons(rboxes)
|
||||
polygons = rboxes_to_polygons(rboxes)
|
||||
scale_y = float(orig_h) / float(resize_h)
|
||||
scale_x = float(orig_w) / float(resize_w)
|
||||
|
||||
@@ -157,8 +159,8 @@ class OCRDetectionPipeline(Pipeline):
|
||||
polygons = np.round(polygons).astype(np.int32)
|
||||
|
||||
# nms
|
||||
dt_n9 = [o + [utils.cal_width(o)] for o in polygons.tolist()]
|
||||
dt_nms = utils.nms_python(dt_n9)
|
||||
dt_n9 = [o + [cal_width(o)] for o in polygons.tolist()]
|
||||
dt_nms = nms_python(dt_n9)
|
||||
dt_polygons = np.array([o[:8] for o in dt_nms])
|
||||
|
||||
result = {OutputKeys.POLYGONS: dt_polygons}
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model_resnet_mutex_v4_linewithchar import SegLinkDetector
|
||||
from .ops import decode_segments_links_python, combine_segments_python
|
||||
from .utils import rboxes_to_polygons, cal_width, nms_python
|
||||
else:
|
||||
_import_structure = {
|
||||
'model_resnet_mutex_v4_linewithchar': ['SegLinkDetector'],
|
||||
'ops': ['decode_segments_links_python', 'combine_segments_python'],
|
||||
'utils': ['rboxes_to_polygons', 'cal_width', 'nms_python']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from PIL import Image
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.virual_tryon.sdafnet import SDAFNet_Tryon
|
||||
from modelscope.models.cv.virual_tryon import SDAFNet_Tryon
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
|
||||
@@ -1,14 +1,36 @@
|
||||
try:
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .generative_multi_modal_embedding_pipeline import GEMMMultiModalEmbeddingPipeline
|
||||
from .image_captioning_pipeline import ImageCaptionPipeline
|
||||
from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline
|
||||
from .generative_multi_modal_embedding_pipeline import GEMMMultiModalEmbeddingPipeline
|
||||
from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline
|
||||
from .video_multi_modal_embedding_pipeline import \
|
||||
VideoMultiModalEmbeddingPipeline
|
||||
from .visual_question_answering_pipeline import \
|
||||
VisualQuestionAnsweringPipeline
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'torch'":
|
||||
pass
|
||||
else:
|
||||
raise ModuleNotFoundError(e)
|
||||
from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'image_captioning_pipeline': ['ImageCaptionPipeline'],
|
||||
'multi_modal_embedding_pipeline': ['MultiModalEmbeddingPipeline'],
|
||||
'text_to_image_synthesis_pipeline': ['TextToImageSynthesisPipeline'],
|
||||
'visual_question_answering_pipeline':
|
||||
['VisualQuestionAnsweringPipeline'],
|
||||
'video_multi_modal_embedding_pipeline':
|
||||
['VideoMultiModalEmbeddingPipeline'],
|
||||
'generative_multi_modal_embedding_pipeline':
|
||||
['GEMMMultiModalEmbeddingPipeline']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1,30 +1,50 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.error import TENSORFLOW_IMPORT_WARNING
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
try:
|
||||
from .translation_pipeline import * # noqa F403
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'tensorflow'":
|
||||
print(TENSORFLOW_IMPORT_WARNING.format('translation'))
|
||||
else:
|
||||
raise ModuleNotFoundError(e)
|
||||
if TYPE_CHECKING:
|
||||
from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline
|
||||
from .dialog_modeling_pipeline import DialogModelingPipeline
|
||||
from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline
|
||||
from .fill_mask_pipeline import FillMaskPipeline
|
||||
from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline
|
||||
from .nli_pipeline import NLIPipeline
|
||||
from .sentence_similarity_pipeline import SentenceSimilarityPipeline
|
||||
from .sentiment_classification_pipeline import SentimentClassificationPipeline
|
||||
from .sequence_classification_pipeline import SequenceClassificationPipeline
|
||||
from .text_generation_pipeline import TextGenerationPipeline
|
||||
from .translation_pipeline import TranslationPipeline
|
||||
from .word_segmentation_pipeline import WordSegmentationPipeline
|
||||
from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline
|
||||
|
||||
try:
|
||||
from .dialog_intent_prediction_pipeline import * # noqa F403
|
||||
from .dialog_modeling_pipeline import * # noqa F403
|
||||
from .dialog_state_tracking_pipeline import * # noqa F403
|
||||
from .fill_mask_pipeline import * # noqa F403
|
||||
from .named_entity_recognition_pipeline import * # noqa F403
|
||||
from .nli_pipeline import * # noqa F403
|
||||
from .sentence_similarity_pipeline import * # noqa F403
|
||||
from .sentiment_classification_pipeline import * # noqa F403
|
||||
from .sequence_classification_pipeline import * # noqa F403
|
||||
from .text_generation_pipeline import * # noqa F403
|
||||
from .word_segmentation_pipeline import * # noqa F403
|
||||
from .zero_shot_classification_pipeline import * # noqa F403
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'torch'":
|
||||
pass
|
||||
else:
|
||||
raise ModuleNotFoundError(e)
|
||||
else:
|
||||
_import_structure = {
|
||||
'dialog_intent_prediction_pipeline':
|
||||
['DialogIntentPredictionPipeline'],
|
||||
'dialog_modeling_pipeline': ['DialogModelingPipeline'],
|
||||
'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'],
|
||||
'fill_mask_pipeline': ['FillMaskPipeline'],
|
||||
'nli_pipeline': ['NLIPipeline'],
|
||||
'sentence_similarity_pipeline': ['SentenceSimilarityPipeline'],
|
||||
'sentiment_classification_pipeline':
|
||||
['SentimentClassificationPipeline'],
|
||||
'sequence_classification_pipeline': ['SequenceClassificationPipeline'],
|
||||
'text_generation_pipeline': ['TextGenerationPipeline'],
|
||||
'word_segmentation_pipeline': ['WordSegmentationPipeline'],
|
||||
'zero_shot_classification_pipeline':
|
||||
['ZeroShotClassificationPipeline'],
|
||||
'named_entity_recognition_pipeline':
|
||||
['NamedEntityRecognitionPipeline'],
|
||||
'translation_pipeline': ['TranslationPipeline'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model, SpaceForDialogStateTracking
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import SpaceForDialogStateTracking
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
|
||||
@@ -1,29 +1,67 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.error import AUDIO_IMPORT_ERROR, TENSORFLOW_IMPORT_ERROR
|
||||
from .asr import WavToScp
|
||||
from .base import Preprocessor
|
||||
from .builder import PREPROCESSORS, build_preprocessor
|
||||
from .common import Compose
|
||||
from .image import LoadImage, load_image
|
||||
from .kws import WavToLists
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
try:
|
||||
if TYPE_CHECKING:
|
||||
from .base import Preprocessor
|
||||
from .builder import PREPROCESSORS, build_preprocessor
|
||||
from .common import Compose
|
||||
from .asr import WavToScp
|
||||
from .audio import LinearAECAndFbank
|
||||
except ModuleNotFoundError as e:
|
||||
print(AUDIO_IMPORT_ERROR.format(e))
|
||||
from .image import (LoadImage, load_image,
|
||||
ImageColorEnhanceFinetunePreprocessor,
|
||||
ImageInstanceSegmentationPreprocessor,
|
||||
ImageDenoisePreprocessor)
|
||||
from .kws import WavToLists
|
||||
from .multi_modal import (OfaImageCaptionPreprocessor,
|
||||
MPlugVisualQuestionAnsweringPreprocessor)
|
||||
from .nlp import (Tokenize, SequenceClassificationPreprocessor,
|
||||
TextGenerationPreprocessor,
|
||||
TokenClassificationPreprocessor, NLIPreprocessor,
|
||||
SentimentClassificationPreprocessor,
|
||||
SentenceSimilarityPreprocessor, FillMaskPreprocessor,
|
||||
ZeroShotClassificationPreprocessor, NERPreprocessor)
|
||||
from .space import (DialogIntentPredictionPreprocessor,
|
||||
DialogModelingPreprocessor,
|
||||
DialogStateTrackingPreprocessor)
|
||||
from .video import ReadVideoData
|
||||
|
||||
try:
|
||||
from .multi_modal import * # noqa F403
|
||||
from .nlp import * # noqa F403
|
||||
from .space.dialog_intent_prediction_preprocessor import * # noqa F403
|
||||
from .space.dialog_modeling_preprocessor import * # noqa F403
|
||||
from .space.dialog_state_tracking_preprocessor import * # noqa F403
|
||||
from .image import ImageColorEnhanceFinetunePreprocessor
|
||||
from .image import ImageInstanceSegmentationPreprocessor
|
||||
from .image import ImageDenoisePreprocessor
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'tensorflow'":
|
||||
print(TENSORFLOW_IMPORT_ERROR.format('tts'))
|
||||
else:
|
||||
raise ModuleNotFoundError(e)
|
||||
else:
|
||||
_import_structure = {
|
||||
'base': ['Preprocessor'],
|
||||
'builder': ['PREPROCESSORS', 'build_preprocessor'],
|
||||
'common': ['Compose'],
|
||||
'asr': ['WavToScp'],
|
||||
'video': ['ReadVideoData'],
|
||||
'image': [
|
||||
'LoadImage', 'load_image', 'ImageColorEnhanceFinetunePreprocessor',
|
||||
'ImageInstanceSegmentationPreprocessor', 'ImageDenoisePreprocessor'
|
||||
],
|
||||
'kws': ['WavToLists'],
|
||||
'multi_modal': [
|
||||
'OfaImageCaptionPreprocessor',
|
||||
'MPlugVisualQuestionAnsweringPreprocessor'
|
||||
],
|
||||
'nlp': [
|
||||
'Tokenize', 'SequenceClassificationPreprocessor',
|
||||
'TextGenerationPreprocessor', 'TokenClassificationPreprocessor',
|
||||
'NLIPreprocessor', 'SentimentClassificationPreprocessor',
|
||||
'SentenceSimilarityPreprocessor', 'FillMaskPreprocessor',
|
||||
'ZeroShotClassificationPreprocessor', 'NERPreprocessor'
|
||||
],
|
||||
'space': [
|
||||
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor',
|
||||
'DialogStateTrackingPreprocessor', 'InputFeatures'
|
||||
],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .dialog_intent_prediction_preprocessor import \
|
||||
DialogIntentPredictionPreprocessor
|
||||
from .dialog_modeling_preprocessor import DialogModelingPreprocessor
|
||||
from .dialog_state_tracking_preprocessor import DialogStateTrackingPreprocessor
|
||||
from .dst_processors import InputFeatures
|
||||
from .fields import MultiWOZBPETextField, IntentBPETextField
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'dialog_intent_prediction_preprocessor':
|
||||
['DialogIntentPredictionPreprocessor'],
|
||||
'dialog_modeling_preprocessor': ['DialogModelingPreprocessor'],
|
||||
'dialog_state_tracking_preprocessor':
|
||||
['DialogStateTrackingPreprocessor'],
|
||||
'dst_processors': ['InputFeatures'],
|
||||
'fields': ['MultiWOZBPETextField', 'IntentBPETextField']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
from .gen_field import MultiWOZBPETextField
|
||||
from .intent_field import IntentBPETextField
|
||||
|
||||
@@ -1,3 +1,25 @@
|
||||
from .base import TaskDataset
|
||||
from .builder import build_task_dataset
|
||||
from .torch_base_dataset import TorchTaskDataset
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule, is_torch_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import TaskDataset
|
||||
from .builder import TASK_DATASETS, build_task_dataset
|
||||
from .torch_base_dataset import TorchTaskDataset
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'base': ['TaskDataset'],
|
||||
'builder': ['TASK_DATASETS', 'build_task_dataset'],
|
||||
'torch_base_dataset': ['TorchTaskDataset'],
|
||||
}
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
598
modelscope/utils/ast_utils.py
Normal file
598
modelscope/utils/ast_utils.py
Normal file
@@ -0,0 +1,598 @@
|
||||
import ast
|
||||
import contextlib
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
from functools import reduce
|
||||
from typing import Generator, Union
|
||||
|
||||
import gast
|
||||
import json
|
||||
|
||||
from modelscope import __version__
|
||||
from modelscope.fileio.file import LocalStorage
|
||||
from modelscope.metainfo import (Heads, Metrics, Models, Pipelines,
|
||||
Preprocessors, TaskModels, Trainers)
|
||||
from modelscope.utils.constant import Fields, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.registry import default_group
|
||||
from modelscope.utils.utils import get_default_cache_dir
|
||||
|
||||
logger = get_logger()
|
||||
storage = LocalStorage()
|
||||
|
||||
# get the path of package 'modelscope'
|
||||
MODELSCOPE_PATH = '/'.join(os.path.dirname(__file__).split('/')[:-1])
|
||||
REGISTER_MODULE = 'register_module'
|
||||
IGNORED_PACKAGES = ['modelscope', '.']
|
||||
SCAN_SUB_FOLDERS = [
|
||||
'models', 'metrics', 'pipelines', 'preprocessors', 'task_datasets'
|
||||
]
|
||||
INDEXER_FILE = 'ast_indexer'
|
||||
DECORATOR_KEY = 'decorators'
|
||||
FROM_IMPORT_KEY = 'from_imports'
|
||||
IMPORT_KEY = 'imports'
|
||||
FILE_NAME_KEY = 'filepath'
|
||||
VERSION_KEY = 'version'
|
||||
MD5_KEY = 'md5'
|
||||
INDEX_KEY = 'index'
|
||||
REQUIREMENT_KEY = 'requirements'
|
||||
MODULE_KEY = 'module'
|
||||
|
||||
|
||||
class AstScaning(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.result_import = dict()
|
||||
self.result_from_import = dict()
|
||||
self.result_decorator = []
|
||||
|
||||
def _is_sub_node(self, node: object) -> bool:
|
||||
return isinstance(node,
|
||||
ast.AST) and not isinstance(node, ast.expr_context)
|
||||
|
||||
def _is_leaf(self, node: ast.AST) -> bool:
|
||||
for field in node._fields:
|
||||
attr = getattr(node, field)
|
||||
if self._is_sub_node(attr):
|
||||
return False
|
||||
elif isinstance(attr, (list, tuple)):
|
||||
for val in attr:
|
||||
if self._is_sub_node(val):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _fields(self, n: ast.AST, show_offsets: bool = True) -> tuple:
|
||||
if show_offsets:
|
||||
return n._attributes + n._fields
|
||||
else:
|
||||
return n._fields
|
||||
|
||||
def _leaf(self, node: ast.AST, show_offsets: bool = True) -> str:
|
||||
output = dict()
|
||||
local_print = list()
|
||||
if isinstance(node, ast.AST):
|
||||
local_dict = dict()
|
||||
for field in self._fields(node, show_offsets=show_offsets):
|
||||
field_output, field_prints = self._leaf(
|
||||
getattr(node, field), show_offsets=show_offsets)
|
||||
local_dict[field] = field_output
|
||||
local_print.append('{}={}'.format(field, field_prints))
|
||||
|
||||
prints = '{}({})'.format(
|
||||
type(node).__name__,
|
||||
', '.join(local_print),
|
||||
)
|
||||
output[type(node).__name__] = local_dict
|
||||
return output, prints
|
||||
elif isinstance(node, list):
|
||||
if '_fields' not in node:
|
||||
return node, repr(node)
|
||||
for item in node:
|
||||
item_output, item_prints = self._leaf(
|
||||
getattr(node, item), show_offsets=show_offsets)
|
||||
local_print.append(item_prints)
|
||||
return node, '[{}]'.format(', '.join(local_print), )
|
||||
else:
|
||||
return node, repr(node)
|
||||
|
||||
def _refresh(self):
|
||||
self.result_import = dict()
|
||||
self.result_from_import = dict()
|
||||
self.result_decorator = []
|
||||
|
||||
def scan_ast(self, node: Union[ast.AST, None, str]):
|
||||
self._setup_global()
|
||||
self.scan_import(node, indent=' ', show_offsets=False)
|
||||
|
||||
def scan_import(
|
||||
self,
|
||||
node: Union[ast.AST, None, str],
|
||||
indent: Union[str, int] = ' ',
|
||||
show_offsets: bool = True,
|
||||
_indent: int = 0,
|
||||
parent_node_name: str = '',
|
||||
) -> tuple:
|
||||
if node is None:
|
||||
return node, repr(node)
|
||||
elif self._is_leaf(node):
|
||||
return self._leaf(node, show_offsets=show_offsets)
|
||||
else:
|
||||
if isinstance(indent, int):
|
||||
indent_s = indent * ' '
|
||||
else:
|
||||
indent_s = indent
|
||||
|
||||
class state:
|
||||
indent = _indent
|
||||
|
||||
@contextlib.contextmanager
|
||||
def indented() -> Generator[None, None, None]:
|
||||
state.indent += 1
|
||||
yield
|
||||
state.indent -= 1
|
||||
|
||||
def indentstr() -> str:
|
||||
return state.indent * indent_s
|
||||
|
||||
def _scan_import(el: Union[ast.AST, None, str],
|
||||
_indent: int = 0,
|
||||
parent_node_name: str = '') -> str:
|
||||
return self.scan_import(
|
||||
el,
|
||||
indent=indent,
|
||||
show_offsets=show_offsets,
|
||||
_indent=_indent,
|
||||
parent_node_name=parent_node_name)
|
||||
|
||||
out = type(node).__name__ + '(\n'
|
||||
outputs = dict()
|
||||
# add relative path expression
|
||||
if type(node).__name__ == 'ImportFrom':
|
||||
level = getattr(node, 'level')
|
||||
if level >= 1:
|
||||
path_level = ''.join(['.'] * level)
|
||||
setattr(node, 'level', 0)
|
||||
module_name = getattr(node, 'module')
|
||||
if module_name is None:
|
||||
setattr(node, 'module', path_level)
|
||||
else:
|
||||
setattr(node, 'module', path_level + module_name)
|
||||
with indented():
|
||||
for field in self._fields(node, show_offsets=show_offsets):
|
||||
attr = getattr(node, field)
|
||||
if attr == []:
|
||||
representation = '[]'
|
||||
outputs[field] = []
|
||||
elif (isinstance(attr, list) and len(attr) == 1
|
||||
and isinstance(attr[0], ast.AST)
|
||||
and self._is_leaf(attr[0])):
|
||||
local_out, local_print = _scan_import(attr[0])
|
||||
representation = f'[{local_print}]'
|
||||
outputs[field] = local_out
|
||||
|
||||
elif isinstance(attr, list):
|
||||
representation = '[\n'
|
||||
el_dict = dict()
|
||||
with indented():
|
||||
for el in attr:
|
||||
local_out, local_print = _scan_import(
|
||||
el, state.indent,
|
||||
type(el).__name__)
|
||||
representation += '{}{},\n'.format(
|
||||
indentstr(),
|
||||
local_print,
|
||||
)
|
||||
name = type(el).__name__
|
||||
if (name == 'Import' or name == 'ImportFrom'
|
||||
or parent_node_name == 'ImportFrom'):
|
||||
if name not in el_dict:
|
||||
el_dict[name] = []
|
||||
el_dict[name].append(local_out)
|
||||
representation += indentstr() + ']'
|
||||
outputs[field] = el_dict
|
||||
elif isinstance(attr, ast.AST):
|
||||
output, representation = _scan_import(
|
||||
attr, state.indent)
|
||||
outputs[field] = output
|
||||
else:
|
||||
representation = repr(attr)
|
||||
outputs[field] = attr
|
||||
|
||||
if (type(node).__name__ == 'Import'
|
||||
or type(node).__name__ == 'ImportFrom'):
|
||||
if type(node).__name__ == 'ImportFrom':
|
||||
if field == 'module':
|
||||
self.result_from_import[
|
||||
outputs[field]] = dict()
|
||||
if field == 'names':
|
||||
if isinstance(outputs[field]['alias'], list):
|
||||
item_name = []
|
||||
for item in outputs[field]['alias']:
|
||||
local_name = item['alias']['name']
|
||||
item_name.append(local_name)
|
||||
self.result_from_import[
|
||||
outputs['module']] = item_name
|
||||
else:
|
||||
local_name = outputs[field]['alias'][
|
||||
'name']
|
||||
self.result_from_import[
|
||||
outputs['module']] = [local_name]
|
||||
|
||||
if type(node).__name__ == 'Import':
|
||||
final_dict = outputs[field]['alias']
|
||||
self.result_import[outputs[field]['alias']
|
||||
['name']] = final_dict
|
||||
|
||||
if 'decorator_list' == field and attr != []:
|
||||
self.result_decorator.extend(attr)
|
||||
|
||||
out += f'{indentstr()}{field}={representation},\n'
|
||||
|
||||
out += indentstr() + ')'
|
||||
return {
|
||||
IMPORT_KEY: self.result_import,
|
||||
FROM_IMPORT_KEY: self.result_from_import,
|
||||
DECORATOR_KEY: self.result_decorator
|
||||
}, out
|
||||
|
||||
def _parse_decorator(self, node: ast.AST) -> tuple:
|
||||
|
||||
def _get_attribute_item(node: ast.AST) -> tuple:
|
||||
value, id, attr = None, None, None
|
||||
if type(node).__name__ == 'Attribute':
|
||||
value = getattr(node, 'value')
|
||||
id = getattr(value, 'id')
|
||||
attr = getattr(node, 'attr')
|
||||
if type(node).__name__ == 'Name':
|
||||
id = getattr(node, 'id')
|
||||
return id, attr
|
||||
|
||||
def _get_args_name(nodes: list) -> list:
|
||||
result = []
|
||||
for node in nodes:
|
||||
result.append(_get_attribute_item(node))
|
||||
return result
|
||||
|
||||
def _get_keyword_name(nodes: ast.AST) -> list:
|
||||
result = []
|
||||
for node in nodes:
|
||||
if type(node).__name__ == 'keyword':
|
||||
attribute_node = getattr(node, 'value')
|
||||
if type(attribute_node).__name__ == 'Str':
|
||||
result.append((attribute_node.s, None))
|
||||
else:
|
||||
result.append(_get_attribute_item(attribute_node))
|
||||
return result
|
||||
|
||||
functions = _get_attribute_item(node.func)
|
||||
args_list = _get_args_name(node.args)
|
||||
keyword_list = _get_keyword_name(node.keywords)
|
||||
return functions, args_list, keyword_list
|
||||
|
||||
def _get_registry_value(self, key_item):
|
||||
if key_item is None:
|
||||
return None
|
||||
if key_item == 'default_group':
|
||||
return default_group
|
||||
split_list = key_item.split('.')
|
||||
# in the case, the key_item is raw data, not registred
|
||||
if len(split_list) == 1:
|
||||
return key_item
|
||||
else:
|
||||
return getattr(eval(split_list[0]), split_list[1])
|
||||
|
||||
def _registry_indexer(self, parsed_input: tuple) -> tuple:
|
||||
"""format registry information to a tuple indexer
|
||||
|
||||
Return:
|
||||
tuple: (MODELS, Tasks.text-classification, Models.structbert)
|
||||
"""
|
||||
functions, args_list, keyword_list = parsed_input
|
||||
|
||||
# ignore decocators other than register_module
|
||||
if REGISTER_MODULE != functions[1]:
|
||||
return None
|
||||
output = [functions[0]]
|
||||
|
||||
if len(args_list) == 0 and len(keyword_list) == 0:
|
||||
args_list.append(None)
|
||||
if len(keyword_list) == 0 and len(args_list) == 1:
|
||||
args_list.append(None)
|
||||
|
||||
args_list.extend(keyword_list)
|
||||
|
||||
for item in args_list:
|
||||
# the case empty input
|
||||
if item is None:
|
||||
output.append(None)
|
||||
# the case (default_group)
|
||||
elif item[1] is None:
|
||||
output.append(item[0])
|
||||
else:
|
||||
output.append('.'.join(item))
|
||||
return (output[0], self._get_registry_value(output[1]),
|
||||
self._get_registry_value(output[2]))
|
||||
|
||||
def parse_decorators(self, nodes: list) -> list:
|
||||
"""parse the AST nodes of decorators object to registry indexer
|
||||
|
||||
Args:
|
||||
nodes (list): list of AST decorator nodes
|
||||
|
||||
Returns:
|
||||
list: list of registry indexer
|
||||
"""
|
||||
results = []
|
||||
for node in nodes:
|
||||
if type(node).__name__ != 'Call':
|
||||
continue
|
||||
parse_output = self._parse_decorator(node)
|
||||
index = self._registry_indexer(parse_output)
|
||||
if None is not index:
|
||||
results.append(index)
|
||||
return results
|
||||
|
||||
def generate_ast(self, file):
|
||||
self._refresh()
|
||||
with open(file, 'r') as code:
|
||||
data = code.readlines()
|
||||
data = ''.join(data)
|
||||
|
||||
node = gast.parse(data)
|
||||
output, _ = self.scan_import(node, indent=' ', show_offsets=False)
|
||||
output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY])
|
||||
return output
|
||||
|
||||
|
||||
class FilesAstScaning(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.astScaner = AstScaning()
|
||||
self.file_dirs = []
|
||||
|
||||
def _parse_import_path(self,
|
||||
import_package: str,
|
||||
current_path: str = None) -> str:
|
||||
"""
|
||||
Args:
|
||||
import_package (str): relative import or abs import
|
||||
current_path (str): path/to/current/file
|
||||
"""
|
||||
if import_package.startswith(IGNORED_PACKAGES[0]):
|
||||
return MODELSCOPE_PATH + '/' + '/'.join(
|
||||
import_package.split('.')[1:]) + '.py'
|
||||
elif import_package.startswith(IGNORED_PACKAGES[1]):
|
||||
current_path_list = current_path.split('/')
|
||||
import_package_list = import_package.split('.')
|
||||
level = 0
|
||||
for index, item in enumerate(import_package_list):
|
||||
if item != '':
|
||||
level = index
|
||||
break
|
||||
|
||||
abs_path_list = current_path_list[0:-level]
|
||||
abs_path_list.extend(import_package_list[index:])
|
||||
return '/' + '/'.join(abs_path_list) + '.py'
|
||||
else:
|
||||
return current_path
|
||||
|
||||
def _traversal_import(
|
||||
self,
|
||||
import_abs_path,
|
||||
):
|
||||
pass
|
||||
|
||||
def parse_import(self, scan_result: dict) -> list:
|
||||
"""parse import and from import dicts to a third party package list
|
||||
|
||||
Args:
|
||||
scan_result (dict): including the import and from import result
|
||||
|
||||
Returns:
|
||||
list: a list of package ignored 'modelscope' and relative path import
|
||||
"""
|
||||
output = []
|
||||
output.extend(list(scan_result[IMPORT_KEY].keys()))
|
||||
output.extend(list(scan_result[FROM_IMPORT_KEY].keys()))
|
||||
|
||||
# get the package name
|
||||
for index, item in enumerate(output):
|
||||
if '' == item.split('.')[0]:
|
||||
output[index] = '.'
|
||||
else:
|
||||
output[index] = item.split('.')[0]
|
||||
|
||||
ignored = set()
|
||||
for item in output:
|
||||
for ignored_package in IGNORED_PACKAGES:
|
||||
if item.startswith(ignored_package):
|
||||
ignored.add(item)
|
||||
return list(set(output) - set(ignored))
|
||||
|
||||
def traversal_files(self, path, check_sub_dir):
|
||||
self.file_dirs = []
|
||||
if check_sub_dir is None or len(check_sub_dir) == 0:
|
||||
self._traversal_files(path)
|
||||
|
||||
for item in check_sub_dir:
|
||||
sub_dir = os.path.join(path, item)
|
||||
if os.path.isdir(sub_dir):
|
||||
self._traversal_files(sub_dir)
|
||||
|
||||
def _traversal_files(self, path):
|
||||
dir_list = os.scandir(path)
|
||||
for item in dir_list:
|
||||
if item.name.startswith('__'):
|
||||
continue
|
||||
if item.is_dir():
|
||||
self._traversal_files(item.path)
|
||||
elif item.is_file() and item.name.endswith('.py'):
|
||||
self.file_dirs.append(item.path)
|
||||
|
||||
def _get_single_file_scan_result(self, file):
|
||||
output = self.astScaner.generate_ast(file)
|
||||
import_list = self.parse_import(output)
|
||||
return output[DECORATOR_KEY], import_list
|
||||
|
||||
def _inverted_index(self, forward_index):
|
||||
inverted_index = dict()
|
||||
for index in forward_index:
|
||||
for item in forward_index[index][DECORATOR_KEY]:
|
||||
inverted_index[item] = {
|
||||
FILE_NAME_KEY: index,
|
||||
IMPORT_KEY: forward_index[index][IMPORT_KEY],
|
||||
MODULE_KEY: forward_index[index][MODULE_KEY],
|
||||
}
|
||||
return inverted_index
|
||||
|
||||
def _module_import(self, forward_index):
|
||||
module_import = dict()
|
||||
for index, value_dict in forward_index.items():
|
||||
module_import[value_dict[MODULE_KEY]] = value_dict[IMPORT_KEY]
|
||||
return module_import
|
||||
|
||||
def get_files_scan_results(self,
|
||||
target_dir=MODELSCOPE_PATH,
|
||||
target_folders=SCAN_SUB_FOLDERS):
|
||||
"""the entry method of the ast scan method
|
||||
|
||||
Args:
|
||||
target_dir (str, optional): the absolute path of the target directory to be scaned. Defaults to None.
|
||||
target_folder (list, optional): the list of
|
||||
sub-folders to be scaned in the target folder.
|
||||
Defaults to SCAN_SUB_FOLDERS.
|
||||
|
||||
Returns:
|
||||
dict: indexer of registry
|
||||
"""
|
||||
|
||||
self.traversal_files(target_dir, target_folders)
|
||||
start = time.time()
|
||||
logger.info(
|
||||
f'AST-Scaning the path "{target_dir}" with the following sub folders {target_folders}'
|
||||
)
|
||||
|
||||
result = dict()
|
||||
for file in self.file_dirs:
|
||||
filepath = file[file.find('modelscope'):]
|
||||
module_name = filepath.replace(osp.sep, '.').replace('.py', '')
|
||||
decorator_list, import_list = self._get_single_file_scan_result(
|
||||
file)
|
||||
result[file] = {
|
||||
DECORATOR_KEY: decorator_list,
|
||||
IMPORT_KEY: import_list,
|
||||
MODULE_KEY: module_name
|
||||
}
|
||||
inverted_index_with_results = self._inverted_index(result)
|
||||
module_import = self._module_import(result)
|
||||
index = {
|
||||
INDEX_KEY: inverted_index_with_results,
|
||||
REQUIREMENT_KEY: module_import
|
||||
}
|
||||
logger.info(
|
||||
f'Scaning done! A number of {len(inverted_index_with_results)}'
|
||||
f' files indexed! Time consumed {time.time()-start}s')
|
||||
return index
|
||||
|
||||
def files_mtime_md5(self,
|
||||
target_path=MODELSCOPE_PATH,
|
||||
target_subfolder=SCAN_SUB_FOLDERS):
|
||||
self.file_dirs = []
|
||||
self.traversal_files(target_path, target_subfolder)
|
||||
files_mtime = []
|
||||
for item in self.file_dirs:
|
||||
files_mtime.append(os.path.getmtime(item))
|
||||
result_str = reduce(lambda x, y: str(x) + str(y), files_mtime, '')
|
||||
md5 = hashlib.md5(result_str.encode())
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
fileScaner = FilesAstScaning()
|
||||
|
||||
|
||||
def _save_index(index, file_path):
|
||||
# convert tuple key to str key
|
||||
index[INDEX_KEY] = {str(k): v for k, v in index[INDEX_KEY].items()}
|
||||
index[VERSION_KEY] = __version__
|
||||
index[MD5_KEY] = fileScaner.files_mtime_md5()
|
||||
json_index = json.dumps(index)
|
||||
storage.write(json_index.encode(), file_path)
|
||||
index[INDEX_KEY] = {
|
||||
ast.literal_eval(k): v
|
||||
for k, v in index[INDEX_KEY].items()
|
||||
}
|
||||
|
||||
|
||||
def _load_index(file_path):
|
||||
bytes_index = storage.read(file_path)
|
||||
wrapped_index = json.loads(bytes_index)
|
||||
# convert str key to tuple key
|
||||
wrapped_index[INDEX_KEY] = {
|
||||
ast.literal_eval(k): v
|
||||
for k, v in wrapped_index[INDEX_KEY].items()
|
||||
}
|
||||
return wrapped_index
|
||||
|
||||
|
||||
def load_index(force_rebuild=False):
|
||||
"""get the index from scan results or cache
|
||||
|
||||
Args:
|
||||
force_rebuild: If set true, rebuild and load index
|
||||
Returns:
|
||||
dict: the index information for all registred modules, including key:
|
||||
index, requirments, version and md5, the detail is shown below example:
|
||||
{
|
||||
'index': {
|
||||
('MODELS', 'nlp', 'bert'):{
|
||||
'filepath' : 'path/to/the/registered/model', 'imports':
|
||||
['os', 'torch', 'typeing'] 'module':
|
||||
'modelscope.models.nlp.bert'
|
||||
},
|
||||
...
|
||||
}, 'requirments': {
|
||||
'modelscope.models.nlp.bert': ['os', 'torch', 'typeing'],
|
||||
'modelscope.models.nlp.structbert': ['os', 'torch', 'typeing'],
|
||||
...
|
||||
}, 'version': '0.2.3', 'md5': '8616924970fe6bc119d1562832625612',
|
||||
}
|
||||
"""
|
||||
cache_dir = os.getenv('MODELSCOPE_CACHE', get_default_cache_dir())
|
||||
file_path = os.path.join(cache_dir, INDEXER_FILE)
|
||||
logger.info(f'Loading ast index from {file_path}')
|
||||
index = None
|
||||
if not force_rebuild and os.path.exists(file_path):
|
||||
wrapped_index = _load_index(file_path)
|
||||
md5 = fileScaner.files_mtime_md5()
|
||||
if (wrapped_index[VERSION_KEY] == __version__
|
||||
and wrapped_index[MD5_KEY] == md5):
|
||||
index = wrapped_index
|
||||
|
||||
if index is None:
|
||||
if force_rebuild:
|
||||
logger.info('Force rebuilding ast index')
|
||||
else:
|
||||
logger.info(
|
||||
f'No valid ast index found from {file_path}, rebuilding ast index!'
|
||||
)
|
||||
index = fileScaner.get_files_scan_results()
|
||||
_save_index(index, file_path)
|
||||
return index
|
||||
|
||||
|
||||
def check_import_module_avaliable(module_dicts: dict) -> list:
|
||||
missed_module = []
|
||||
for module in module_dicts.keys():
|
||||
loader = importlib.find_loader(module)
|
||||
if loader is None:
|
||||
missed_module.append(module)
|
||||
return missed_module
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
index = load_index()
|
||||
print(index)
|
||||
@@ -1,79 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.constant import Fields, Requirements
|
||||
from modelscope.utils.import_utils import requires
|
||||
|
||||
|
||||
def get_msg(field):
|
||||
msg = f'\n{field} requirements not installed, please execute ' \
|
||||
f'`pip install requirements/{field}.txt` or ' \
|
||||
f'`pip install modelscope[{field}]`'
|
||||
return msg
|
||||
|
||||
|
||||
class NLPModuleNotFoundError(ModuleNotFoundError):
|
||||
|
||||
def __init__(self, e: ModuleNotFoundError) -> None:
|
||||
e.msg += get_msg(Fields.nlp)
|
||||
super().__init__(e)
|
||||
|
||||
|
||||
class CVModuleNotFoundError(ModuleNotFoundError):
|
||||
|
||||
def __init__(self, e: ModuleNotFoundError) -> None:
|
||||
e.msg += get_msg(Fields.cv)
|
||||
super().__init__(e)
|
||||
|
||||
|
||||
class AudioModuleNotFoundError(ModuleNotFoundError):
|
||||
|
||||
def __init__(self, e: ModuleNotFoundError) -> None:
|
||||
e.msg += get_msg(Fields.audio)
|
||||
super().__init__(e)
|
||||
|
||||
|
||||
class MultiModalModuleNotFoundError(ModuleNotFoundError):
|
||||
|
||||
def __init__(self, e: ModuleNotFoundError) -> None:
|
||||
e.msg += get_msg(Fields.multi_modal)
|
||||
super().__init__(e)
|
||||
|
||||
|
||||
def check_nlp():
|
||||
try:
|
||||
requires('nlp models', (
|
||||
Requirements.torch,
|
||||
Requirements.tokenizers,
|
||||
))
|
||||
except ImportError as e:
|
||||
raise NLPModuleNotFoundError(e)
|
||||
|
||||
|
||||
def check_cv():
|
||||
try:
|
||||
requires('cv models', (
|
||||
Requirements.torch,
|
||||
Requirements.tokenizers,
|
||||
))
|
||||
except ImportError as e:
|
||||
raise CVModuleNotFoundError(e)
|
||||
|
||||
|
||||
def check_audio():
|
||||
try:
|
||||
requires('audio models', (
|
||||
Requirements.torch,
|
||||
Requirements.tf,
|
||||
))
|
||||
except ImportError as e:
|
||||
raise AudioModuleNotFoundError(e)
|
||||
|
||||
|
||||
def check_multi_modal():
|
||||
try:
|
||||
requires('multi-modal models', (
|
||||
Requirements.torch,
|
||||
Requirements.tokenizers,
|
||||
))
|
||||
except ImportError as e:
|
||||
raise MultiModalModuleNotFoundError(e)
|
||||
@@ -75,3 +75,19 @@ SCIPY_IMPORT_ERROR = """
|
||||
{0} requires the scipy library but it was not found in your environment. You can install it with pip:
|
||||
`pip install scipy`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
OPENCV_IMPORT_ERROR = """
|
||||
{0} requires the opencv library but it was not found in your environment. You can install it with pip:
|
||||
`pip install opencv-python`
|
||||
"""
|
||||
|
||||
PILLOW_IMPORT_ERROR = """
|
||||
{0} requires the Pillow library but it was not found in your environment. You can install it with pip:
|
||||
`pip install Pillow`
|
||||
"""
|
||||
|
||||
GENERAL_IMPORT_ERROR = """
|
||||
{0} requires the REQ library but it was not found in your environment. You can install it with pip:
|
||||
`pip install REQ`
|
||||
"""
|
||||
|
||||
@@ -2,11 +2,10 @@
|
||||
# Part of the implementation is borrowed from huggingface/transformers.
|
||||
import ast
|
||||
import functools
|
||||
import importlib.util
|
||||
import importlib
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import types
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
from importlib import import_module
|
||||
@@ -14,18 +13,15 @@ from itertools import chain
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
from packaging import version
|
||||
|
||||
from modelscope.utils.constant import Fields
|
||||
from modelscope.utils.error import (PROTOBUF_IMPORT_ERROR,
|
||||
PYTORCH_IMPORT_ERROR, SCIPY_IMPORT_ERROR,
|
||||
SENTENCEPIECE_IMPORT_ERROR,
|
||||
SKLEARN_IMPORT_ERROR,
|
||||
TENSORFLOW_IMPORT_ERROR, TIMM_IMPORT_ERROR,
|
||||
TOKENIZERS_IMPORT_ERROR)
|
||||
from modelscope.utils.ast_utils import (INDEX_KEY, MODULE_KEY, REQUIREMENT_KEY,
|
||||
load_index)
|
||||
from modelscope.utils.error import * # noqa
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
import importlib_metadata
|
||||
else:
|
||||
@@ -33,6 +29,8 @@ else:
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
AST_INDEX = None
|
||||
|
||||
|
||||
def import_modules_from_file(py_file: str):
|
||||
""" Import module from a certrain file
|
||||
@@ -250,18 +248,44 @@ def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
||||
def is_opencv_available():
|
||||
return importlib.util.find_spec('cv2') is not None
|
||||
|
||||
|
||||
def is_pillow_available():
|
||||
return importlib.util.find_spec('PIL.Image') is not None
|
||||
|
||||
|
||||
def is_package_available(pkg_name):
|
||||
return importlib.util.find_spec(pkg_name) is not None
|
||||
|
||||
|
||||
def is_espnet_available(pkg_name):
|
||||
return importlib.util.find_spec('espnet2') is not None \
|
||||
and importlib.util.find_spec('espnet')
|
||||
|
||||
|
||||
REQUIREMENTS_MAAPING = OrderedDict([
|
||||
('protobuf', (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
|
||||
('sentencepiece', (is_sentencepiece_available,
|
||||
SENTENCEPIECE_IMPORT_ERROR)),
|
||||
('sklearn', (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
|
||||
('tf', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
|
||||
('tensorflow', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
|
||||
('timm', (is_timm_available, TIMM_IMPORT_ERROR)),
|
||||
('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
|
||||
('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
||||
('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
||||
('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)),
|
||||
('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)),
|
||||
('espnet2', (is_espnet_available,
|
||||
GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))),
|
||||
('espnet', (is_espnet_available,
|
||||
GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))),
|
||||
])
|
||||
|
||||
SYSTEM_PACKAGE = set(['os', 'sys', 'typing'])
|
||||
|
||||
|
||||
def requires(obj, requirements):
|
||||
if not isinstance(requirements, (list, tuple)):
|
||||
@@ -271,7 +295,18 @@ def requires(obj, requirements):
|
||||
else:
|
||||
name = obj.__name__ if hasattr(obj,
|
||||
'__name__') else obj.__class__.__name__
|
||||
checks = (REQUIREMENTS_MAAPING[req] for req in requirements)
|
||||
checks = []
|
||||
for req in requirements:
|
||||
if req == '' or req in SYSTEM_PACKAGE:
|
||||
continue
|
||||
if req in REQUIREMENTS_MAAPING:
|
||||
check = REQUIREMENTS_MAAPING[req]
|
||||
else:
|
||||
check_fn = functools.partial(is_package_available, req)
|
||||
err_msg = GENERAL_IMPORT_ERROR.replace('REQ', req)
|
||||
check = (check_fn, err_msg)
|
||||
checks.append(check)
|
||||
|
||||
failed = [msg.format(name) for available, msg in checks if not available()]
|
||||
if failed:
|
||||
raise ImportError(''.join(failed))
|
||||
@@ -299,3 +334,99 @@ def tf_required(func):
|
||||
raise ImportError(f'Method `{func.__name__}` requires TF.')
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class LazyImportModule(ModuleType):
|
||||
AST_INDEX = None
|
||||
if AST_INDEX is None:
|
||||
AST_INDEX = load_index()
|
||||
|
||||
def __init__(self,
|
||||
name,
|
||||
module_file,
|
||||
import_structure,
|
||||
module_spec=None,
|
||||
extra_objects=None,
|
||||
try_to_pre_import=False):
|
||||
super().__init__(name)
|
||||
self._modules = set(import_structure.keys())
|
||||
self._class_to_module = {}
|
||||
for key, values in import_structure.items():
|
||||
for value in values:
|
||||
self._class_to_module[value] = key
|
||||
# Needed for autocompletion in an IDE
|
||||
self.__all__ = list(import_structure.keys()) + list(
|
||||
chain(*import_structure.values()))
|
||||
self.__file__ = module_file
|
||||
self.__spec__ = module_spec
|
||||
self.__path__ = [os.path.dirname(module_file)]
|
||||
self._objects = {} if extra_objects is None else extra_objects
|
||||
self._name = name
|
||||
self._import_structure = import_structure
|
||||
if try_to_pre_import:
|
||||
self._try_to_import()
|
||||
|
||||
def _try_to_import(self):
|
||||
for sub_module in self._class_to_module.keys():
|
||||
try:
|
||||
getattr(self, sub_module)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
f'pre load module {sub_module} error, please check {e}')
|
||||
|
||||
# Needed for autocompletion in an IDE
|
||||
def __dir__(self):
|
||||
result = super().__dir__()
|
||||
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
|
||||
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
|
||||
for attr in self.__all__:
|
||||
if attr not in result:
|
||||
result.append(attr)
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in self._objects:
|
||||
return self._objects[name]
|
||||
if name in self._modules:
|
||||
value = self._get_module(name)
|
||||
elif name in self._class_to_module.keys():
|
||||
module = self._get_module(self._class_to_module[name])
|
||||
value = getattr(module, name)
|
||||
else:
|
||||
raise AttributeError(
|
||||
f'module {self.__name__} has no attribute {name}')
|
||||
|
||||
setattr(self, name, value)
|
||||
return value
|
||||
|
||||
def _get_module(self, module_name: str):
|
||||
try:
|
||||
# check requirements before module import
|
||||
module_name_full = self.__name__ + '.' + module_name
|
||||
if module_name_full in LazyImportModule.AST_INDEX[REQUIREMENT_KEY]:
|
||||
requirements = LazyImportModule.AST_INDEX[REQUIREMENT_KEY][
|
||||
module_name_full]
|
||||
requires(module_name_full, requirements)
|
||||
return importlib.import_module('.' + module_name, self.__name__)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f'Failed to import {self.__name__}.{module_name} because of the following error '
|
||||
f'(look up to see its traceback):\n{e}') from e
|
||||
|
||||
def __reduce__(self):
|
||||
return self.__class__, (self._name, self.__file__,
|
||||
self._import_structure)
|
||||
|
||||
@staticmethod
|
||||
def import_module(signature):
|
||||
""" import a lazy import module using signature
|
||||
|
||||
Args:
|
||||
signature (tuple): a tuple of str, (registry_name, registry_group_name, module_name)
|
||||
"""
|
||||
if signature in LazyImportModule.AST_INDEX[INDEX_KEY]:
|
||||
mod_index = LazyImportModule.AST_INDEX[INDEX_KEY][signature]
|
||||
module_name = mod_index[MODULE_KEY]
|
||||
importlib.import_module(module_name)
|
||||
else:
|
||||
logger.warning(f'{signature} not found in ast index file')
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from modelscope.utils.import_utils import requires
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
TYPE_NAME = 'type'
|
||||
default_group = 'default'
|
||||
logger = get_logger()
|
||||
AST_INDEX = None
|
||||
|
||||
|
||||
class Registry(object):
|
||||
@@ -55,14 +56,10 @@ class Registry(object):
|
||||
def _register_module(self,
|
||||
group_key=default_group,
|
||||
module_name=None,
|
||||
module_cls=None,
|
||||
requirements=None):
|
||||
module_cls=None):
|
||||
assert isinstance(group_key,
|
||||
str), 'group_key is required and must be str'
|
||||
|
||||
if requirements is not None:
|
||||
requires(module_cls, requirements)
|
||||
|
||||
if group_key not in self._modules:
|
||||
self._modules[group_key] = dict()
|
||||
|
||||
@@ -81,8 +78,7 @@ class Registry(object):
|
||||
def register_module(self,
|
||||
group_key: str = default_group,
|
||||
module_name: str = None,
|
||||
module_cls: type = None,
|
||||
requirements: Union[List, Tuple] = None):
|
||||
module_cls: type = None):
|
||||
""" Register module
|
||||
|
||||
Example:
|
||||
@@ -106,7 +102,6 @@ class Registry(object):
|
||||
default group name is 'default'
|
||||
module_name: Module name
|
||||
module_cls: Module class object
|
||||
requirements: Module necessary requirements
|
||||
|
||||
"""
|
||||
if not (module_name is None or isinstance(module_name, str)):
|
||||
@@ -116,8 +111,7 @@ class Registry(object):
|
||||
self._register_module(
|
||||
group_key=group_key,
|
||||
module_name=module_name,
|
||||
module_cls=module_cls,
|
||||
requirements=requirements)
|
||||
module_cls=module_cls)
|
||||
return module_cls
|
||||
|
||||
# if module_cls is None, should return a decorator function
|
||||
@@ -125,8 +119,7 @@ class Registry(object):
|
||||
self._register_module(
|
||||
group_key=group_key,
|
||||
module_name=module_name,
|
||||
module_cls=module_cls,
|
||||
requirements=requirements)
|
||||
module_cls=module_cls)
|
||||
return module_cls
|
||||
|
||||
return _register
|
||||
@@ -178,6 +171,11 @@ def build_from_cfg(cfg,
|
||||
raise TypeError('default_args must be a dict or None, '
|
||||
f'but got {type(default_args)}')
|
||||
|
||||
# dynamic load installation reqruiements for this module
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
sig = (registry.name.upper(), group_key, cfg['type'])
|
||||
LazyImportModule.import_module(sig)
|
||||
|
||||
args = cfg.copy()
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
|
||||
|
||||
def if_func_receive_dict_inputs(func, inputs):
|
||||
@@ -26,3 +27,12 @@ def if_func_receive_dict_inputs(func, inputs):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def get_default_cache_dir():
|
||||
"""
|
||||
default base dir: '~/.cache/modelscope'
|
||||
"""
|
||||
default_cache_dir = os.path.expanduser(
|
||||
os.path.join('~/.cache', 'modelscope'))
|
||||
return default_cache_dir
|
||||
|
||||
@@ -9,6 +9,7 @@ librosa
|
||||
lxml
|
||||
matplotlib
|
||||
nara_wpe
|
||||
nltk
|
||||
numpy<=1.18
|
||||
# protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged.
|
||||
protobuf>3,<3.21.0
|
||||
|
||||
@@ -3,6 +3,7 @@ datasets
|
||||
easydict
|
||||
einops
|
||||
filelock>=3.3.0
|
||||
gast>=0.5.3
|
||||
numpy
|
||||
opencv-python
|
||||
Pillow>=6.2.0
|
||||
|
||||
@@ -21,5 +21,5 @@ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids
|
||||
[flake8]
|
||||
select = B,C,E,F,P,T4,W,B9
|
||||
max-line-length = 120
|
||||
ignore = F401,F821,W503
|
||||
ignore = F401,F405,F821,W503
|
||||
exclude = docs/src,*.pyi,.git
|
||||
|
||||
@@ -4,7 +4,7 @@ import unittest
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import Pipeline, pipeline
|
||||
@@ -54,7 +54,7 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
""" Provide default implementation based on preprocess_cfg and user can reimplement it
|
||||
|
||||
"""
|
||||
if not isinstance(input, PIL.Image.Image):
|
||||
if not isinstance(input, Image.Image):
|
||||
from modelscope.preprocessors import load_image
|
||||
data_dict = {'img': load_image(input), 'url': input}
|
||||
else:
|
||||
|
||||
@@ -3,7 +3,7 @@ import shutil
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.pipelines import TranslationPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ import unittest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import SpaceForDialogIntent
|
||||
from modelscope.pipelines import DialogIntentPredictionPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import DialogIntentPredictionPipeline
|
||||
from modelscope.preprocessors import DialogIntentPredictionPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -5,7 +5,8 @@ from typing import List
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import SpaceForDialogModeling
|
||||
from modelscope.pipelines import DialogModelingPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import DialogModelingPipeline
|
||||
from modelscope.preprocessors import DialogModelingPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -3,8 +3,10 @@ import unittest
|
||||
from typing import List
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model, SpaceForDialogStateTracking
|
||||
from modelscope.pipelines import DialogStateTrackingPipeline, pipeline
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import SpaceForDialogStateTracking
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import DialogStateTrackingPipeline
|
||||
from modelscope.preprocessors import DialogStateTrackingPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -5,7 +5,8 @@ from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import (BertForMaskedLM, StructBertForMaskedLM,
|
||||
VecoForMaskedLM)
|
||||
from modelscope.pipelines import FillMaskPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import FillMaskPipeline
|
||||
from modelscope.preprocessors import FillMaskPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -7,7 +7,8 @@ from PIL import Image
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import ImageDenoisePipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.cv import ImageDenoisePipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@@ -4,10 +4,11 @@ import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.cv.image_instance_segmentation.model import \
|
||||
CascadeMaskRCNNSwinModel
|
||||
from modelscope.models.cv.image_instance_segmentation import (
|
||||
CascadeMaskRCNNSwinModel, get_img_ins_seg_result)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import ImageInstanceSegmentationPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.cv import ImageInstanceSegmentationPipeline
|
||||
from modelscope.preprocessors import build_preprocessor
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Fields, ModelFile, Tasks
|
||||
|
||||
@@ -4,7 +4,8 @@ import unittest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import TransformerCRFForNamedEntityRecognition
|
||||
from modelscope.pipelines import NamedEntityRecognitionPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import NamedEntityRecognitionPipeline
|
||||
from modelscope.preprocessors import NERPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -4,7 +4,8 @@ import unittest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import SbertForNLI
|
||||
from modelscope.pipelines import NLIPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import NLIPipeline
|
||||
from modelscope.preprocessors import NLIPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -5,7 +5,8 @@ import unittest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import SbertForSentenceSimilarity
|
||||
from modelscope.pipelines import SentenceSimilarityPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import SentenceSimilarityPipeline
|
||||
from modelscope.preprocessors import SentenceSimilarityPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -5,7 +5,8 @@ from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import (SbertForSentimentClassification,
|
||||
SequenceClassificationModel)
|
||||
from modelscope.pipelines import SentimentClassificationPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import SentimentClassificationPipeline
|
||||
from modelscope.preprocessors import SentimentClassificationPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -4,7 +4,8 @@ import unittest
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import SequenceClassificationPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import SequenceClassificationPipeline
|
||||
from modelscope.preprocessors import SequenceClassificationPreprocessor
|
||||
from modelscope.utils.constant import Hubs, Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -4,7 +4,8 @@ import unittest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import PalmForTextGeneration
|
||||
from modelscope.pipelines import TextGenerationPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import TextGenerationPipeline
|
||||
from modelscope.preprocessors import TextGenerationPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -5,7 +5,8 @@ import unittest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering
|
||||
from modelscope.pipelines import VisualQuestionAnsweringPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.multi_modal import VisualQuestionAnsweringPipeline
|
||||
from modelscope.preprocessors import MPlugVisualQuestionAnsweringPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -5,7 +5,8 @@ import unittest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import SbertForTokenClassification
|
||||
from modelscope.pipelines import WordSegmentationPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import WordSegmentationPipeline
|
||||
from modelscope.preprocessors import TokenClassificationPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -4,7 +4,8 @@ import unittest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import SbertForZeroShotClassification
|
||||
from modelscope.pipelines import ZeroShotClassificationPipeline, pipeline
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import ZeroShotClassificationPipeline
|
||||
from modelscope.preprocessors import ZeroShotClassificationPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.preprocessors import load_image
|
||||
|
||||
@@ -11,7 +11,7 @@ class ImagePreprocessorTest(unittest.TestCase):
|
||||
|
||||
def test_load(self):
|
||||
img = load_image('data/test/images/image_matting.png')
|
||||
self.assertTrue(isinstance(img, PIL.Image.Image))
|
||||
self.assertTrue(isinstance(img, Image.Image))
|
||||
self.assertEqual(img.size, (948, 533))
|
||||
|
||||
|
||||
|
||||
@@ -11,8 +11,7 @@ import torch
|
||||
from torch.utils import data as data
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.cv.image_color_enhance.image_color_enhance import \
|
||||
ImageColorEnhance
|
||||
from modelscope.models.cv.image_color_enhance import ImageColorEnhance
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -5,9 +5,8 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import NAFNetForImageDenoise
|
||||
from modelscope.msdatasets.image_denoise_data.image_denoise_dataset import \
|
||||
PairedImageDataset
|
||||
from modelscope.models.cv.image_denoise import NAFNetForImageDenoise
|
||||
from modelscope.msdatasets.image_denoise_data import PairedImageDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
|
||||
@@ -7,10 +7,8 @@ import zipfile
|
||||
from functools import partial
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.cv.image_instance_segmentation import \
|
||||
CascadeMaskRCNNSwinModel
|
||||
from modelscope.models.cv.image_instance_segmentation.datasets import \
|
||||
ImageInstanceSegmentationCocoDataset
|
||||
from modelscope.models.cv.image_instance_segmentation import (
|
||||
CascadeMaskRCNNSwinModel, ImageInstanceSegmentationCocoDataset)
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
|
||||
97
tests/utils/test_ast.py
Normal file
97
tests/utils/test_ast.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import gast
|
||||
|
||||
from modelscope.utils.ast_utils import AstScaning, FilesAstScaning, load_index
|
||||
|
||||
MODELSCOPE_PATH = '/'.join(
|
||||
os.path.dirname(__file__).split('/')[:-2]) + '/modelscope'
|
||||
|
||||
|
||||
class AstScaningTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
self.test_file = os.path.join(self.tmp_dir, 'test.py')
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
|
||||
def test_ast_scaning_class(self):
|
||||
astScaner = AstScaning()
|
||||
pipeline_file = os.path.join(MODELSCOPE_PATH, 'pipelines', 'nlp',
|
||||
'sequence_classification_pipeline.py')
|
||||
output = astScaner.generate_ast(pipeline_file)
|
||||
self.assertTrue(output['imports'] is not None)
|
||||
self.assertTrue(output['from_imports'] is not None)
|
||||
self.assertTrue(output['decorators'] is not None)
|
||||
imports, from_imports, decorators = output['imports'], output[
|
||||
'from_imports'], output['decorators']
|
||||
self.assertIsInstance(imports, dict)
|
||||
self.assertIsInstance(from_imports, dict)
|
||||
self.assertIsInstance(decorators, list)
|
||||
self.assertListEqual(
|
||||
list(set(imports.keys()) - set(['typing', 'numpy'])), [])
|
||||
self.assertEqual(len(from_imports.keys()), 9)
|
||||
self.assertTrue(from_imports['modelscope.metainfo'] is not None)
|
||||
self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines'])
|
||||
self.assertEqual(
|
||||
decorators,
|
||||
[('PIPELINES', 'text-classification', 'sentiment-analysis')])
|
||||
|
||||
def test_files_scaning_method(self):
|
||||
fileScaner = FilesAstScaning()
|
||||
output = fileScaner.get_files_scan_results()
|
||||
self.assertTrue(output['index'] is not None)
|
||||
self.assertTrue(output['requirements'] is not None)
|
||||
index, requirements = output['index'], output['requirements']
|
||||
self.assertIsInstance(index, dict)
|
||||
self.assertIsInstance(requirements, dict)
|
||||
self.assertIsInstance(list(index.keys())[0], tuple)
|
||||
index_0 = list(index.keys())[0]
|
||||
self.assertIsInstance(index[index_0], dict)
|
||||
self.assertTrue(index[index_0]['imports'] is not None)
|
||||
self.assertIsInstance(index[index_0]['imports'], list)
|
||||
self.assertTrue(index[index_0]['module'] is not None)
|
||||
self.assertIsInstance(index[index_0]['module'], str)
|
||||
index_0 = list(requirements.keys())[0]
|
||||
self.assertIsInstance(requirements[index_0], list)
|
||||
|
||||
def test_file_mtime_md5_method(self):
|
||||
fileScaner = FilesAstScaning()
|
||||
# create first file
|
||||
with open(self.test_file, 'w', encoding='utf-8') as f:
|
||||
f.write('This is the new test!')
|
||||
|
||||
md5_1 = fileScaner.files_mtime_md5(self.tmp_dir, [])
|
||||
md5_2 = fileScaner.files_mtime_md5(self.tmp_dir, [])
|
||||
self.assertEqual(md5_1, md5_2)
|
||||
time.sleep(2)
|
||||
# case of revise
|
||||
with open(self.test_file, 'w', encoding='utf-8') as f:
|
||||
f.write('test again')
|
||||
md5_3 = fileScaner.files_mtime_md5(self.tmp_dir, [])
|
||||
self.assertNotEqual(md5_1, md5_3)
|
||||
|
||||
# case of create
|
||||
self.test_file_new = os.path.join(self.tmp_dir, 'test_1.py')
|
||||
time.sleep(2)
|
||||
with open(self.test_file_new, 'w', encoding='utf-8') as f:
|
||||
f.write('test again')
|
||||
md5_4 = fileScaner.files_mtime_md5(self.tmp_dir, [])
|
||||
self.assertNotEqual(md5_1, md5_4)
|
||||
self.assertNotEqual(md5_3, md5_4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,22 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
from typing import List, Union
|
||||
|
||||
from modelscope.utils.check_requirements import NLPModuleNotFoundError, get_msg
|
||||
from modelscope.utils.constant import Fields
|
||||
|
||||
|
||||
class ImportUtilsTest(unittest.TestCase):
|
||||
|
||||
def test_type_module_not_found(self):
|
||||
with self.assertRaises(NLPModuleNotFoundError) as ctx:
|
||||
try:
|
||||
import not_found
|
||||
except ModuleNotFoundError as e:
|
||||
raise NLPModuleNotFoundError(e)
|
||||
self.assertTrue(get_msg(Fields.nlp) in ctx.exception.msg.msg)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user