Files
modelscope/tests/pipelines/test_table_question_answering.py

178 lines
6.9 KiB
Python
Raw Normal View History

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from typing import List
import json
from transformers import BertTokenizer
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline
from modelscope.preprocessors import TableQuestionAnsweringPreprocessor
from modelscope.preprocessors.space_T_cn.fields.database import Database
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level
def tableqa_tracking_and_print_results_with_history(
pipelines: List[TableQuestionAnsweringPipeline]):
test_case = {
'utterance': [
'有哪些风险类型?',
'风险类型有多少种?',
'珠江流域的小(2)型水库的库容总量是多少?',
'那平均值是多少?',
'那水库的名称呢?',
'换成中型的呢?',
'枣庄营业厅的电话',
'那地址呢?',
'枣庄营业厅的电话和地址',
]
}
for p in pipelines:
historical_queries = None
for question in test_case['utterance']:
output_dict = p({
'question': question,
'history_sql': historical_queries
})[OutputKeys.OUTPUT]
print('question', question)
print('sql text:', output_dict[OutputKeys.SQL_STRING])
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
print('query result:', output_dict[OutputKeys.QUERT_RESULT])
print('json dumps', json.dumps(output_dict, ensure_ascii=False))
print()
historical_queries = output_dict[OutputKeys.HISTORY]
def tableqa_tracking_and_print_results_without_history(
pipelines: List[TableQuestionAnsweringPipeline]):
test_case = {
'utterance': [
'有哪些风险类型?',
'风险类型有多少种?',
'珠江流域的小(2)型水库的库容总量是多少?',
'枣庄营业厅的电话',
'枣庄营业厅的电话和地址',
]
}
for p in pipelines:
for question in test_case['utterance']:
output_dict = p({'question': question})[OutputKeys.OUTPUT]
print('question', question)
print('sql text:', output_dict[OutputKeys.SQL_STRING])
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
print('query result:', output_dict[OutputKeys.QUERT_RESULT])
print('json dumps', json.dumps(output_dict, ensure_ascii=False))
print()
def tableqa_tracking_and_print_results_with_tableid(
pipelines: List[TableQuestionAnsweringPipeline]):
test_case = {
'utterance': [
['有哪些风险类型?', 'fund'],
['风险类型有多少种?', 'reservoir'],
['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'],
['那平均值是多少?', 'reservoir'],
['那水库的名称呢?', 'reservoir'],
['换成中型的呢?', 'reservoir'],
['枣庄营业厅的电话', 'business'],
['那地址呢?', 'business'],
['枣庄营业厅的电话和地址', 'business'],
],
}
for p in pipelines:
historical_queries = None
for question, table_id in test_case['utterance']:
output_dict = p({
'question': question,
'table_id': table_id,
'history_sql': historical_queries
})[OutputKeys.OUTPUT]
print('question', question)
print('sql text:', output_dict[OutputKeys.SQL_STRING])
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
print('query result:', output_dict[OutputKeys.QUERT_RESULT])
print('json dumps', json.dumps(output_dict, ensure_ascii=False))
print()
historical_queries = output_dict[OutputKeys.HISTORY]
class TableQuestionAnswering(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.table_question_answering
self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn'
model_id = 'damo/nlp_convai_text2sql_pretrain_cn'
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path)
pipelines = [
pipeline(
Tasks.table_question_answering,
model=cache_path,
preprocessor=preprocessor)
]
tableqa_tracking_and_print_results_with_history(pipelines)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
self.tokenizer = BertTokenizer(
os.path.join(model.model_dir, ModelFile.VOCAB_FILE))
db = Database(
tokenizer=self.tokenizer,
table_file_path=[
os.path.join(model.model_dir, 'databases', fname)
for fname in os.listdir(
os.path.join(model.model_dir, 'databases'))
],
syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'),
is_use_sqlite=False)
preprocessor = TableQuestionAnsweringPreprocessor(
model_dir=model.model_dir, db=db)
pipelines = [
pipeline(
Tasks.table_question_answering,
model=model,
preprocessor=preprocessor,
db=db)
]
tableqa_tracking_and_print_results_with_tableid(pipelines)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub_with_other_classes(self):
model = Model.from_pretrained(self.model_id)
self.tokenizer = BertTokenizer(
os.path.join(model.model_dir, ModelFile.VOCAB_FILE))
db = Database(
tokenizer=self.tokenizer,
table_file_path=[
os.path.join(model.model_dir, 'databases', fname)
for fname in os.listdir(
os.path.join(model.model_dir, 'databases'))
],
syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'),
is_use_sqlite=True)
preprocessor = TableQuestionAnsweringPreprocessor(
model_dir=model.model_dir, db=db)
pipelines = [
pipeline(
Tasks.table_question_answering,
model=model,
preprocessor=preprocessor,
db=db)
]
tableqa_tracking_and_print_results_without_history(pipelines)
if __name__ == '__main__':
unittest.main()