Files
modelscope/tests/pipelines/test_addr_mgeo.py
xingjun.wang 48c0d2a9af add 1.6
2023-05-22 10:53:18 +08:00

122 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForSequenceClassification
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import TextClassificationPipeline
from modelscope.preprocessors import TextClassificationTransformersPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool
from modelscope.utils.test_utils import test_level
class MGeoTest(unittest.TestCase):
multi_modal_inputs = {
'source_sentence': ['杭州余杭东方未来学校附近世纪华联商场(金家渡北苑店)'],
'first_sequence_gis': [[
[
13159, 13295, 13136, 13157, 13158, 13291, 13294, 74505, 74713,
75387, 75389, 75411
],
[3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4],
[3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], # noqa: E126
[[1254, 1474, 1255, 1476], [1253, 1473, 1256, 1476],
[1247, 1473, 1255, 1480], [1252, 1475, 1253, 1476],
[1253, 1475, 1253, 1476], [1252, 1471, 1254, 1475],
[1254, 1473, 1256, 1475], [1238, 1427, 1339, 1490],
[1238, 1427, 1339, 1490], [1252, 1474, 1255, 1476],
[1252, 1474, 1255, 1476], [1249, 1472, 1255, 1479]],
[[24, 23, 15, 23], [24, 28, 15, 18], [31, 24, 22, 22],
[43, 13, 37, 13], [43, 6, 35, 6], [31, 32, 22, 14],
[19, 30, 9, 16], [24, 30, 15, 16], [24, 30, 15, 16],
[29, 24, 20, 22], [28, 25, 19, 21], [31, 26, 22, 20]],
'120.08802231437534,30.343853313981505'
]],
'sentences_to_compare': [
'良渚街道金家渡北苑42号世纪华联超市(金家渡北苑店)', '金家渡路金家渡中苑南区70幢金家渡中苑70幢',
'金家渡路140-142号附近家家福足道(金家渡店)'
],
'second_sequence_gis':
[[[13083, 13081, 13084, 13085, 13131, 13134, 13136, 13147, 13148],
[3, 3, 3, 3, 3, 3, 3, 3, 3], [3, 4, 4, 4, 4, 4, 4, 4, 4],
[[1248, 1477, 1250, 1479], [1248, 1475, 1250, 1476],
[1247, 1478, 1249, 1481], [1249, 1479, 1249, 1480],
[1249, 1476, 1250, 1476], [1250, 1474, 1252, 1478],
[1247, 1473, 1255, 1480], [1250, 1478, 1251, 1479],
[1249, 1478, 1250, 1481]],
[[30, 26, 21, 20], [32, 43, 23, 43], [33, 23, 23, 23],
[31, 13, 22, 13], [25, 43, 16, 43], [20, 33, 10, 33],
[26, 29, 17, 17], [18, 21, 8, 21], [26, 23, 17, 23]],
'120.08075205680345,30.34697777462197'],
[[13291, 13159, 13295, 74713, 75387, 75389, 75411],
[3, 3, 3, 4, 4, 4, 4], [3, 4, 4, 4, 4, 4, 4],
[[1252, 1471, 1254, 1475], [1254, 1474, 1255, 1476],
[1253, 1473, 1256, 1476], [1238, 1427, 1339, 1490],
[1252, 1474, 1255, 1476], [1252, 1474, 1255, 1476],
[1249, 1472, 1255, 1479]],
[[28, 28, 19, 18], [22, 16, 12, 16], [23, 24, 13, 22],
[24, 30, 15, 16], [27, 20, 18, 20], [27, 21, 18, 21],
[30, 24, 21, 22]], '120.0872539617001,30.342783672056953'],
[[13291, 13290, 13294, 13295, 13298], [3, 3, 3, 3, 3],
[3, 4, 4, 4, 4],
[[1252, 1471, 1254, 1475], [1253, 1469, 1255, 1472],
[1254, 1473, 1256, 1475], [1253, 1473, 1256, 1476],
[1255, 1467, 1258, 1472]],
[[32, 25, 23, 21], [26, 33, 17, 33], [21, 19, 11, 19],
[25, 21, 16, 21], [21, 33, 11,
33]], '120.08839673752281,30.34156156893651']]
}
single_modal_inputs = {
'source_sentence': ['杭州余杭东方未来学校附近世纪华联商场(金家渡北苑店)'],
'sentences_to_compare': [
'良渚街道金家渡北苑42号世纪华联超市(金家渡北苑店)', '金家渡路金家渡中苑南区70幢金家渡中苑70幢',
'金家渡路140-142号附近家家福足道(金家渡店)'
]
}
pipe_input = [
[
Tasks.text_ranking,
'damo/mgeo_geographic_textual_similarity_rerank_chinese_base',
multi_modal_inputs
],
[
Tasks.text_ranking,
'damo/mgeo_geographic_textual_similarity_rerank_chinese_base',
single_modal_inputs
],
[
Tasks.token_classification,
'damo/mgeo_geographic_elements_tagging_chinese_base',
'浙江省杭州市余杭区阿里巴巴西溪园区'
],
[
Tasks.token_classification,
'damo/mgeo_geographic_composition_analysis_chinese_base',
'浙江省杭州市余杭区阿里巴巴西溪园区'
],
[
Tasks.token_classification,
'damo/mgeo_geographic_where_what_cut_chinese_base',
'浙江省杭州市余杭区阿里巴巴西溪园区'
],
[
Tasks.sentence_similarity,
'damo/mgeo_geographic_entity_alignment_chinese_base',
('后湖金桥大道绿色新都116—120栋116号诺雅广告', '金桥大道46号宏宇·绿色新都120幢')
],
]
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
for task, model, inputs in self.pipe_input:
pipeline_ins = pipeline(task=task, model=model)
print(pipeline_ins(input=inputs))
if __name__ == '__main__':
unittest.main()