unfinished

This commit is contained in:
雨泓
2022-06-22 20:13:41 +08:00
parent 31c774936b
commit 63695d6743
5 changed files with 16 additions and 16 deletions

View File

@@ -1,10 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .audio.tts.am import SambertNetHifi16k
from .audio.tts.vocoder import Hifigan16k
# from .audio.tts.am import SambertNetHifi16k
# from .audio.tts.vocoder import Hifigan16k
from .base import Model
from .builder import MODELS, build_model
from .multi_model import OfaForImageCaptioning
# from .multi_model import OfaForImageCaptioning
from .nlp import (
BertForSequenceClassification,
SbertForNLI,
@@ -13,5 +13,5 @@ from .nlp import (
SbertForZeroShotClassification,
StructBertForMaskedLM,
VecoForMaskedLM,
StructBertForTokenClassification,
SbertForTokenClassification,
)

View File

@@ -24,10 +24,10 @@ class SbertForTokenClassification(Model):
"""
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir
from sofa import SbertConfig, SbertForTokenClassification
self.model = SbertForTokenClassification.from_pretrained(
import sofa
self.model = sofa.SbertForTokenClassification.from_pretrained(
self.model_dir)
self.config = SbertConfig.from_pretrained(self.model_dir)
self.config = sofa.SbertConfig.from_pretrained(self.model_dir)
def forward(self, input: Dict[str,
Any]) -> Dict[str, Union[str, np.ndarray]]:
@@ -46,7 +46,7 @@ class SbertForTokenClassification(Model):
}
"""
input_ids = torch.tensor(input['input_ids']).unsqueeze(0)
return self.model(input_ids)
return {**self.model(input_ids), 'text': input['text']}
def postprocess(self, input: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
logits = input["logits"]

View File

@@ -25,15 +25,15 @@ DEFAULT_MODEL_FOR_PIPELINE = {
(Pipelines.sentence_similarity,
'damo/nlp_structbert_sentence-similarity_chinese-base'),
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'),
Tasks.nli: ('nlp_structbert_nli_chinese-base',
Tasks.nli: (Pipelines.nli,
'damo/nlp_structbert_nli_chinese-base'),
Tasks.sentiment_classification:
('sbert-sentiment-classification',
(Pipelines.sentiment_classification,
'damo/nlp_structbert_sentiment-classification_chinese-base'),
Tasks.text_classification: ('bert-sentiment-analysis',
'damo/bert-base-sst2'),
Tasks.zero_shot_classification:
('bert-zero-shot-classification',
(Pipelines.zero_shot_classification,
'damo/nlp_structbert_zero-shot-classification_chinese-base'),
Tasks.image_matting: (Pipelines.image_matting,
'damo/cv_unet_image-matting'),
@@ -48,7 +48,7 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_unet_person-image-cartoon_compound-models'),
Tasks.ocr_detection: (Pipelines.ocr_detection,
'damo/cv_resnet18_ocr-detection-line-level_damo'),
Tasks.fill_mask: ('veco', 'damo/nlp_veco_fill-mask_large')
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask_large')
}

View File

@@ -4,7 +4,7 @@ import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import StructBertForTokenClassification
from modelscope.models.nlp import SbertForTokenClassification
from modelscope.pipelines import WordSegmentationPipeline, pipeline
from modelscope.preprocessors import TokenClassifcationPreprocessor
from modelscope.utils.constant import Tasks
@@ -19,7 +19,7 @@ class WordSegmentationTest(unittest.TestCase):
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
tokenizer = TokenClassifcationPreprocessor(cache_path)
model = StructBertForTokenClassification(
model = SbertForTokenClassification(
cache_path, tokenizer=tokenizer)
pipeline1 = WordSegmentationPipeline(model, preprocessor=tokenizer)
pipeline2 = pipeline(

View File

@@ -4,7 +4,7 @@ import unittest
from maas_hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import BertForZeroShotClassification
from modelscope.models.nlp import SbertForZeroShotClassification
from modelscope.pipelines import ZeroShotClassificationPipeline, pipeline
from modelscope.preprocessors import ZeroShotClassificationPreprocessor
from modelscope.utils.constant import Tasks
@@ -19,7 +19,7 @@ class ZeroShotClassificationTest(unittest.TestCase):
def test_run_from_local(self):
cache_path = snapshot_download(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(cache_path)
model = BertForZeroShotClassification(cache_path, tokenizer=tokenizer)
model = SbertForZeroShotClassification(cache_path, tokenizer=tokenizer)
pipeline1 = ZeroShotClassificationPipeline(
model, preprocessor=tokenizer)
pipeline2 = pipeline(