From 4b29f28a04158acc679e702dff388e57740be735 Mon Sep 17 00:00:00 2001 From: "fubang.zfb" Date: Tue, 14 Feb 2023 08:36:46 +0000 Subject: [PATCH] siamese uie supports string-type schema input Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11655620 --- modelscope/pipelines/nlp/siamese_uie_pipeline.py | 5 ++++- tests/pipelines/test_siamese_uie.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modelscope/pipelines/nlp/siamese_uie_pipeline.py b/modelscope/pipelines/nlp/siamese_uie_pipeline.py index 1fc677ca..4637607c 100644 --- a/modelscope/pipelines/nlp/siamese_uie_pipeline.py +++ b/modelscope/pipelines/nlp/siamese_uie_pipeline.py @@ -8,6 +8,7 @@ from math import ceil from time import time from typing import Any, Dict, Generator, List, Mapping, Optional, Union +import json import torch from scipy.special import softmax from torch.cuda.amp import autocast @@ -89,7 +90,7 @@ class SiameseUiePipeline(Pipeline): """ Args: input(str): sentence to extract - schema: (dict) schema of uie task + schema: (dict or str) schema of uie task Default Returns: List[List]: predicted info list i.e. [[{'type': '人物', 'span': '谷口清太郎', 'offset': [18, 23]}], @@ -111,6 +112,8 @@ class SiameseUiePipeline(Pipeline): # sanitize the parameters text = input schema = kwargs.pop('schema') + if type(schema) == str: + schema = json.loads(schema) output_all_prefix = kwargs.pop('output_all_prefix', False) tokenized_text = self.preprocessor([text])[0] pred_info_list = [] diff --git a/tests/pipelines/test_siamese_uie.py b/tests/pipelines/test_siamese_uie.py index 0f450a1d..9097813c 100644 --- a/tests/pipelines/test_siamese_uie.py +++ b/tests/pipelines/test_siamese_uie.py @@ -1,6 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import unittest +import json + from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model from modelscope.models.nlp import SiameseUieModel @@ -38,7 +40,8 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck): print( f'sentence: {self.sentence}\n' - f'pipeline1:{pipeline1(input=self.sentence, schema=self.schema)}') + f'pipeline1:{pipeline1(input=self.sentence, schema=json.dumps(self.schema, ensure_ascii=False))}' + ) print(f'sentence: {self.sentence}\n' f'pipeline2: {pipeline2(self.sentence, schema=self.schema)}')