mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-18 17:27:43 +01:00
122 lines
5.5 KiB
Python
122 lines
5.5 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
import unittest
|
||
|
||
from modelscope.hub.snapshot_download import snapshot_download
|
||
from modelscope.models import Model
|
||
from modelscope.models.nlp import SbertForSequenceClassification
|
||
from modelscope.pipelines import pipeline
|
||
from modelscope.pipelines.nlp import TextClassificationPipeline
|
||
from modelscope.preprocessors import TextClassificationTransformersPreprocessor
|
||
from modelscope.utils.constant import Tasks
|
||
from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool
|
||
from modelscope.utils.test_utils import test_level
|
||
|
||
|
||
class MGeoTest(unittest.TestCase):
|
||
|
||
multi_modal_inputs = {
|
||
'source_sentence': ['杭州余杭东方未来学校附近世纪华联商场(金家渡北苑店)'],
|
||
'first_sequence_gis': [[
|
||
[
|
||
13159, 13295, 13136, 13157, 13158, 13291, 13294, 74505, 74713,
|
||
75387, 75389, 75411
|
||
],
|
||
[3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4],
|
||
[3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], # noqa: E126
|
||
[[1254, 1474, 1255, 1476], [1253, 1473, 1256, 1476],
|
||
[1247, 1473, 1255, 1480], [1252, 1475, 1253, 1476],
|
||
[1253, 1475, 1253, 1476], [1252, 1471, 1254, 1475],
|
||
[1254, 1473, 1256, 1475], [1238, 1427, 1339, 1490],
|
||
[1238, 1427, 1339, 1490], [1252, 1474, 1255, 1476],
|
||
[1252, 1474, 1255, 1476], [1249, 1472, 1255, 1479]],
|
||
[[24, 23, 15, 23], [24, 28, 15, 18], [31, 24, 22, 22],
|
||
[43, 13, 37, 13], [43, 6, 35, 6], [31, 32, 22, 14],
|
||
[19, 30, 9, 16], [24, 30, 15, 16], [24, 30, 15, 16],
|
||
[29, 24, 20, 22], [28, 25, 19, 21], [31, 26, 22, 20]],
|
||
'120.08802231437534,30.343853313981505'
|
||
]],
|
||
'sentences_to_compare': [
|
||
'良渚街道金家渡北苑42号世纪华联超市(金家渡北苑店)', '金家渡路金家渡中苑南区70幢金家渡中苑70幢',
|
||
'金家渡路140-142号附近家家福足道(金家渡店)'
|
||
],
|
||
'second_sequence_gis':
|
||
[[[13083, 13081, 13084, 13085, 13131, 13134, 13136, 13147, 13148],
|
||
[3, 3, 3, 3, 3, 3, 3, 3, 3], [3, 4, 4, 4, 4, 4, 4, 4, 4],
|
||
[[1248, 1477, 1250, 1479], [1248, 1475, 1250, 1476],
|
||
[1247, 1478, 1249, 1481], [1249, 1479, 1249, 1480],
|
||
[1249, 1476, 1250, 1476], [1250, 1474, 1252, 1478],
|
||
[1247, 1473, 1255, 1480], [1250, 1478, 1251, 1479],
|
||
[1249, 1478, 1250, 1481]],
|
||
[[30, 26, 21, 20], [32, 43, 23, 43], [33, 23, 23, 23],
|
||
[31, 13, 22, 13], [25, 43, 16, 43], [20, 33, 10, 33],
|
||
[26, 29, 17, 17], [18, 21, 8, 21], [26, 23, 17, 23]],
|
||
'120.08075205680345,30.34697777462197'],
|
||
[[13291, 13159, 13295, 74713, 75387, 75389, 75411],
|
||
[3, 3, 3, 4, 4, 4, 4], [3, 4, 4, 4, 4, 4, 4],
|
||
[[1252, 1471, 1254, 1475], [1254, 1474, 1255, 1476],
|
||
[1253, 1473, 1256, 1476], [1238, 1427, 1339, 1490],
|
||
[1252, 1474, 1255, 1476], [1252, 1474, 1255, 1476],
|
||
[1249, 1472, 1255, 1479]],
|
||
[[28, 28, 19, 18], [22, 16, 12, 16], [23, 24, 13, 22],
|
||
[24, 30, 15, 16], [27, 20, 18, 20], [27, 21, 18, 21],
|
||
[30, 24, 21, 22]], '120.0872539617001,30.342783672056953'],
|
||
[[13291, 13290, 13294, 13295, 13298], [3, 3, 3, 3, 3],
|
||
[3, 4, 4, 4, 4],
|
||
[[1252, 1471, 1254, 1475], [1253, 1469, 1255, 1472],
|
||
[1254, 1473, 1256, 1475], [1253, 1473, 1256, 1476],
|
||
[1255, 1467, 1258, 1472]],
|
||
[[32, 25, 23, 21], [26, 33, 17, 33], [21, 19, 11, 19],
|
||
[25, 21, 16, 21], [21, 33, 11,
|
||
33]], '120.08839673752281,30.34156156893651']]
|
||
}
|
||
single_modal_inputs = {
|
||
'source_sentence': ['杭州余杭东方未来学校附近世纪华联商场(金家渡北苑店)'],
|
||
'sentences_to_compare': [
|
||
'良渚街道金家渡北苑42号世纪华联超市(金家渡北苑店)', '金家渡路金家渡中苑南区70幢金家渡中苑70幢',
|
||
'金家渡路140-142号附近家家福足道(金家渡店)'
|
||
]
|
||
}
|
||
|
||
pipe_input = [
|
||
[
|
||
Tasks.text_ranking,
|
||
'damo/mgeo_geographic_textual_similarity_rerank_chinese_base',
|
||
multi_modal_inputs
|
||
],
|
||
[
|
||
Tasks.text_ranking,
|
||
'damo/mgeo_geographic_textual_similarity_rerank_chinese_base',
|
||
single_modal_inputs
|
||
],
|
||
[
|
||
Tasks.token_classification,
|
||
'damo/mgeo_geographic_elements_tagging_chinese_base',
|
||
'浙江省杭州市余杭区阿里巴巴西溪园区'
|
||
],
|
||
[
|
||
Tasks.token_classification,
|
||
'damo/mgeo_geographic_composition_analysis_chinese_base',
|
||
'浙江省杭州市余杭区阿里巴巴西溪园区'
|
||
],
|
||
[
|
||
Tasks.token_classification,
|
||
'damo/mgeo_geographic_where_what_cut_chinese_base',
|
||
'浙江省杭州市余杭区阿里巴巴西溪园区'
|
||
],
|
||
[
|
||
Tasks.sentence_similarity,
|
||
'damo/mgeo_geographic_entity_alignment_chinese_base',
|
||
('后湖金桥大道绿色新都116—120栋116号(诺雅广告)', '金桥大道46号宏宇·绿色新都120幢')
|
||
],
|
||
]
|
||
|
||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||
def test_run_with_model_name(self):
|
||
for task, model, inputs in self.pipe_input:
|
||
pipeline_ins = pipeline(task=task, model=model)
|
||
print(pipeline_ins(input=inputs))
|
||
|
||
|
||
if __name__ == '__main__':
|
||
unittest.main()
|