siamese uie supports string-type schema input

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11655620
This commit is contained in:
fubang.zfb
2023-02-14 08:36:46 +00:00
committed by wenmeng.zwm
parent 2b8236e6d0
commit 4b29f28a04
2 changed files with 8 additions and 2 deletions

View File

@@ -8,6 +8,7 @@ from math import ceil
from time import time
from typing import Any, Dict, Generator, List, Mapping, Optional, Union
import json
import torch
from scipy.special import softmax
from torch.cuda.amp import autocast
@@ -89,7 +90,7 @@ class SiameseUiePipeline(Pipeline):
"""
Args:
input(str): sentence to extract
schema: (dict) schema of uie task
schema: (dict or str) schema of uie task
Default Returns:
List[List]: predicted info list i.e.
[[{'type': '人物', 'span': '谷口清太郎', 'offset': [18, 23]}],
@@ -111,6 +112,8 @@ class SiameseUiePipeline(Pipeline):
# sanitize the parameters
text = input
schema = kwargs.pop('schema')
if type(schema) == str:
schema = json.loads(schema)
output_all_prefix = kwargs.pop('output_all_prefix', False)
tokenized_text = self.preprocessor([text])[0]
pred_info_list = []

View File

@@ -1,6 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import json
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SiameseUieModel
@@ -38,7 +40,8 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck):
print(
f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence, schema=self.schema)}')
f'pipeline1:{pipeline1(input=self.sentence, schema=json.dumps(self.schema, ensure_ascii=False))}'
)
print(f'sentence: {self.sentence}\n'
f'pipeline2: {pipeline2(self.sentence, schema=self.schema)}')