mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-17 00:37:43 +01:00
202 lines
8.3 KiB
Python
202 lines
8.3 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
import os
|
||
import unittest
|
||
from threading import Thread
|
||
from typing import List
|
||
|
||
import json
|
||
from transformers import BertTokenizer
|
||
|
||
from modelscope.hub.snapshot_download import snapshot_download
|
||
from modelscope.models import Model
|
||
from modelscope.outputs import OutputKeys
|
||
from modelscope.pipelines import pipeline
|
||
from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline
|
||
from modelscope.preprocessors import TableQuestionAnsweringPreprocessor
|
||
from modelscope.preprocessors.nlp.space_T_cn.fields.database import Database
|
||
from modelscope.utils.constant import ModelFile, Tasks
|
||
from modelscope.utils.test_utils import test_level
|
||
|
||
|
||
def tableqa_tracking_and_print_results_with_history(
|
||
pipelines: List[TableQuestionAnsweringPipeline]):
|
||
test_case = {
|
||
'utterance': [
|
||
'有哪些风险类型?',
|
||
'风险类型有多少种?',
|
||
'珠江流域的小型水库的库容总量是多少?',
|
||
'那平均值是多少?',
|
||
'那水库的名称呢?',
|
||
'换成中型的呢?',
|
||
]
|
||
}
|
||
for p in pipelines:
|
||
historical_queries = None
|
||
for question in test_case['utterance']:
|
||
output_dict = p({
|
||
'question': question,
|
||
'history_sql': historical_queries
|
||
})[OutputKeys.OUTPUT]
|
||
print('question', question)
|
||
print('sql text:', output_dict[OutputKeys.SQL_STRING])
|
||
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
|
||
print('query result:', output_dict[OutputKeys.QUERY_RESULT])
|
||
print('json dumps', json.dumps(output_dict, ensure_ascii=False))
|
||
print()
|
||
historical_queries = output_dict[OutputKeys.HISTORY]
|
||
|
||
|
||
def tableqa_tracking_and_print_results_without_history(
|
||
pipelines: List[TableQuestionAnsweringPipeline]):
|
||
test_case = {
|
||
'utterance': [['列出油耗大于8但是功率低于200的名称和价格', 'car'],
|
||
['油耗低于5的suv有哪些?', 'car'], ['上个月收益超过3的有几个基金?', 'fund'],
|
||
['净值不等于1的基金平均月收益率和年收益率是多少?', 'fund'],
|
||
['计算机或者成绩优秀的同学有哪些?学号是多少?', 'student'],
|
||
['本部博士生中平均身高是多少?', 'student'],
|
||
['长江流域和珠江流域的水库库容总量是多少?', 'reservoir'],
|
||
['今天星期几?', 'reservoir']]
|
||
}
|
||
for p in pipelines:
|
||
for question, table_id in test_case['utterance']:
|
||
output_dict = p({
|
||
'question': question,
|
||
'table_id': table_id
|
||
})[OutputKeys.OUTPUT]
|
||
print('question', question)
|
||
print('sql text:', output_dict[OutputKeys.SQL_STRING])
|
||
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
|
||
print('query result:', output_dict[OutputKeys.QUERY_RESULT])
|
||
print('json dumps', json.dumps(output_dict, ensure_ascii=False))
|
||
print()
|
||
|
||
|
||
def tableqa_tracking_and_print_results_with_tableid(
|
||
pipelines: List[TableQuestionAnsweringPipeline]):
|
||
test_case = {
|
||
'utterance': [
|
||
['有哪些风险类型?', 'fund', False],
|
||
['风险类型有多少种?', 'fund', True],
|
||
['珠江流域的小型水库的库容总量是多少?', 'reservoir', False],
|
||
['那平均值是多少?', 'reservoir', True],
|
||
['那水库的名称呢?', 'reservoir', True],
|
||
['换成中型的呢?', 'reservoir', True],
|
||
['近7年来车辆的销量趋势?', 'car_sales', False],
|
||
['近7年来车辆的销量月环比是多少呢?', 'car_sales', True],
|
||
],
|
||
}
|
||
for p in pipelines:
|
||
historical_queries = None
|
||
for question, table_id, use_history in test_case['utterance']:
|
||
output_dict = p({
|
||
'question':
|
||
question,
|
||
'table_id':
|
||
table_id,
|
||
'history_sql':
|
||
historical_queries if use_history else None
|
||
})[OutputKeys.OUTPUT]
|
||
print('question', question)
|
||
print('sql text:', output_dict[OutputKeys.SQL_STRING])
|
||
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
|
||
print('query result:', output_dict[OutputKeys.QUERY_RESULT])
|
||
print('json dumps', json.dumps(output_dict, ensure_ascii=False))
|
||
print()
|
||
historical_queries = output_dict[OutputKeys.HISTORY]
|
||
|
||
|
||
class TableQuestionAnswering(unittest.TestCase):
|
||
|
||
def setUp(self) -> None:
|
||
self.task = Tasks.table_question_answering
|
||
self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn'
|
||
|
||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||
def test_run_by_direct_model_download(self):
|
||
cache_path = snapshot_download(self.model_id)
|
||
preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path)
|
||
pipelines = [
|
||
pipeline(
|
||
Tasks.table_question_answering,
|
||
model=cache_path,
|
||
preprocessor=preprocessor)
|
||
]
|
||
tableqa_tracking_and_print_results_with_history(pipelines)
|
||
|
||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||
def test_run_by_direct_model_download_with_multithreads(self):
|
||
cache_path = snapshot_download(self.model_id)
|
||
pl = pipeline(Tasks.table_question_answering, model=cache_path)
|
||
|
||
def print_func(pl, i):
|
||
result = pl({
|
||
'question': '上个月收益从低到高排前七的基金的名称和风险等级是什么',
|
||
'table_id': 'fund',
|
||
'history_sql': None
|
||
})
|
||
print(i, result[OutputKeys.OUTPUT][OutputKeys.SQL_QUERY],
|
||
result[OutputKeys.OUTPUT][OutputKeys.QUERY_RESULT],
|
||
json.dumps(result))
|
||
|
||
procs = []
|
||
for i in range(5):
|
||
proc = Thread(target=print_func, args=(pl, i))
|
||
procs.append(proc)
|
||
proc.start()
|
||
for proc in procs:
|
||
proc.join()
|
||
|
||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||
def test_run_with_model_from_modelhub(self):
|
||
model = Model.from_pretrained(self.model_id)
|
||
self.tokenizer = BertTokenizer(
|
||
os.path.join(model.model_dir, ModelFile.VOCAB_FILE))
|
||
db = Database(
|
||
tokenizer=self.tokenizer,
|
||
table_file_path=[
|
||
os.path.join(model.model_dir, 'databases', fname)
|
||
for fname in os.listdir(
|
||
os.path.join(model.model_dir, 'databases'))
|
||
],
|
||
syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'),
|
||
is_use_sqlite=True)
|
||
preprocessor = TableQuestionAnsweringPreprocessor(
|
||
model_dir=model.model_dir, db=db)
|
||
pipelines = [
|
||
pipeline(
|
||
Tasks.table_question_answering,
|
||
model=model,
|
||
preprocessor=preprocessor,
|
||
db=db)
|
||
]
|
||
tableqa_tracking_and_print_results_with_tableid(pipelines)
|
||
|
||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||
def test_run_with_model_from_modelhub_with_other_classes(self):
|
||
model = Model.from_pretrained(self.model_id)
|
||
self.tokenizer = BertTokenizer(
|
||
os.path.join(model.model_dir, ModelFile.VOCAB_FILE))
|
||
db = Database(
|
||
tokenizer=self.tokenizer,
|
||
table_file_path=[
|
||
os.path.join(model.model_dir, 'databases', fname)
|
||
for fname in os.listdir(
|
||
os.path.join(model.model_dir, 'databases'))
|
||
],
|
||
syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'),
|
||
is_use_sqlite=True)
|
||
preprocessor = TableQuestionAnsweringPreprocessor(
|
||
model_dir=model.model_dir, db=db)
|
||
pipelines = [
|
||
pipeline(
|
||
Tasks.table_question_answering,
|
||
model=model,
|
||
preprocessor=preprocessor,
|
||
db=db)
|
||
]
|
||
tableqa_tracking_and_print_results_without_history(pipelines)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
unittest.main()
|