mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
[to #42322933] add conversational_text_to_sql pipeline
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9580066
This commit is contained in:
@@ -29,6 +29,7 @@ class Models(object):
|
||||
space_dst = 'space-dst'
|
||||
space_intent = 'space-intent'
|
||||
space_modeling = 'space-modeling'
|
||||
star = 'star'
|
||||
tcrf = 'transformer-crf'
|
||||
bart = 'bart'
|
||||
gpt3 = 'gpt3'
|
||||
@@ -123,6 +124,7 @@ class Pipelines(object):
|
||||
dialog_state_tracking = 'dialog-state-tracking'
|
||||
zero_shot_classification = 'zero-shot-classification'
|
||||
text_error_correction = 'text-error-correction'
|
||||
conversational_text_to_sql = 'conversational-text-to-sql'
|
||||
|
||||
# audio tasks
|
||||
sambert_hifigan_tts = 'sambert-hifigan-tts'
|
||||
@@ -201,6 +203,7 @@ class Preprocessors(object):
|
||||
text_error_correction = 'text-error-correction'
|
||||
word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor'
|
||||
fill_mask = 'fill-mask'
|
||||
conversational_text_to_sql = 'conversational-text-to-sql'
|
||||
|
||||
# audio preprocessor
|
||||
linear_aec_fbank = 'linear-aec-fbank'
|
||||
|
||||
@@ -17,12 +17,14 @@ if TYPE_CHECKING:
|
||||
from .space import SpaceForDialogIntent
|
||||
from .space import SpaceForDialogModeling
|
||||
from .space import SpaceForDialogStateTracking
|
||||
from .star_text_to_sql import StarForTextToSql
|
||||
from .task_models.task_model import SingleBackboneTaskModelBase
|
||||
from .bart_for_text_error_correction import BartForTextErrorCorrection
|
||||
from .gpt3 import GPT3ForTextGeneration
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'star_text_to_sql': ['StarForTextToSql'],
|
||||
'backbones': ['SbertModel'],
|
||||
'heads': ['SequenceClassificationHead'],
|
||||
'csanmt_for_translation': ['CsanmtForTranslation'],
|
||||
|
||||
68
modelscope/models/nlp/star_text_to_sql.py
Normal file
68
modelscope/models/nlp/star_text_to_sql.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from text2sql_lgesql.asdl.asdl import ASDLGrammar
|
||||
from text2sql_lgesql.asdl.transition_system import TransitionSystem
|
||||
from text2sql_lgesql.model.model_constructor import Text2SQL
|
||||
from text2sql_lgesql.utils.constants import GRAMMAR_FILEPATH
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model, Tensor
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
__all__ = ['StarForTextToSql']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.conversational_text_to_sql, module_name=Models.star)
|
||||
class StarForTextToSql(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the star model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.beam_size = 5
|
||||
self.config = kwargs.pop(
|
||||
'config',
|
||||
Config.from_file(
|
||||
os.path.join(self.model_dir, ModelFile.CONFIGURATION)))
|
||||
self.config.model.model_dir = model_dir
|
||||
self.grammar = ASDLGrammar.from_filepath(
|
||||
os.path.join(model_dir, 'sql_asdl_v2.txt'))
|
||||
self.trans = TransitionSystem.get_class_by_lang('sql')(self.grammar)
|
||||
self.arg = self.config.model
|
||||
self.device = 'cuda' if \
|
||||
('device' not in kwargs or kwargs['device'] == 'gpu') \
|
||||
and torch.cuda.is_available() else 'cpu'
|
||||
self.model = Text2SQL(self.arg, self.trans)
|
||||
check_point = torch.load(
|
||||
open(
|
||||
os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'rb'),
|
||||
map_location=self.device)
|
||||
self.model.load_state_dict(check_point['model'])
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""return the result by the model
|
||||
|
||||
Args:
|
||||
input (Dict[str, Tensor]): the preprocessed data
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: results
|
||||
Example:
|
||||
"""
|
||||
self.model.eval()
|
||||
hyps = self.model.parse(input['batch'], self.beam_size) #
|
||||
db = input['batch'].examples[0].db
|
||||
|
||||
predict = {'predict': hyps, 'db': db}
|
||||
return predict
|
||||
@@ -389,6 +389,12 @@ TASK_OUTPUTS = {
|
||||
# }
|
||||
Tasks.task_oriented_conversation: [OutputKeys.OUTPUT],
|
||||
|
||||
# conversational text-to-sql result for single sample
|
||||
# {
|
||||
# "text": "SELECT shop.Name FROM shop."
|
||||
# }
|
||||
Tasks.conversational_text_to_sql: [OutputKeys.TEXT],
|
||||
|
||||
# ============ audio tasks ===================
|
||||
# asr result for single sample
|
||||
# { "text": "每一天都要快乐喔"}
|
||||
|
||||
@@ -239,6 +239,7 @@ class Pipeline(ABC):
|
||||
"""
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from modelscope.preprocessors import InputFeatures
|
||||
from text2sql_lgesql.utils.batch import Batch
|
||||
if isinstance(data, dict) or isinstance(data, Mapping):
|
||||
return type(data)(
|
||||
{k: self._collate_fn(v)
|
||||
@@ -259,6 +260,8 @@ class Pipeline(ABC):
|
||||
return data
|
||||
elif isinstance(data, InputFeatures):
|
||||
return data
|
||||
elif isinstance(data, Batch):
|
||||
return data
|
||||
else:
|
||||
import mmcv
|
||||
if isinstance(data, mmcv.parallel.data_container.DataContainer):
|
||||
|
||||
@@ -50,6 +50,11 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/nlp_structbert_zero-shot-classification_chinese-base'),
|
||||
Tasks.task_oriented_conversation: (Pipelines.dialog_modeling,
|
||||
'damo/nlp_space_dialog-modeling'),
|
||||
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking,
|
||||
'damo/nlp_space_dialog-state-tracking'),
|
||||
Tasks.conversational_text_to_sql:
|
||||
(Pipelines.conversational_text_to_sql,
|
||||
'damo/nlp_star_conversational-text-to-sql'),
|
||||
Tasks.text_error_correction:
|
||||
(Pipelines.text_error_correction,
|
||||
'damo/nlp_bart_text-error-correction_chinese'),
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .conversational_text_to_sql_pipeline import ConversationalTextToSqlPipeline
|
||||
from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline
|
||||
from .dialog_modeling_pipeline import DialogModelingPipeline
|
||||
from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline
|
||||
@@ -22,6 +23,8 @@ if TYPE_CHECKING:
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'conversational_text_to_sql_pipeline':
|
||||
['ConversationalTextToSqlPipeline'],
|
||||
'dialog_intent_prediction_pipeline':
|
||||
['DialogIntentPredictionPipeline'],
|
||||
'dialog_modeling_pipeline': ['DialogModelingPipeline'],
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from text2sql_lgesql.utils.example import Example
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import StarForTextToSql
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import ConversationalTextToSqlPreprocessor
|
||||
from modelscope.preprocessors.star.fields.common_utils import SubPreprocessor
|
||||
from modelscope.preprocessors.star.fields.process_dataset import process_tables
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
__all__ = ['ConversationalTextToSqlPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.conversational_text_to_sql,
|
||||
module_name=Pipelines.conversational_text_to_sql)
|
||||
class ConversationalTextToSqlPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[StarForTextToSql, str],
|
||||
preprocessor: ConversationalTextToSqlPreprocessor = None,
|
||||
**kwargs):
|
||||
"""use `model` and `preprocessor` to create a conversational text-to-sql prediction pipeline
|
||||
|
||||
Args:
|
||||
model (StarForTextToSql): a model instance
|
||||
preprocessor (ConversationalTextToSqlPreprocessor):
|
||||
a preprocessor instance
|
||||
"""
|
||||
model = model if isinstance(
|
||||
model, StarForTextToSql) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = ConversationalTextToSqlPreprocessor(model.model_dir)
|
||||
|
||||
preprocessor.device = 'cuda' if \
|
||||
('device' not in kwargs or kwargs['device'] == 'gpu') \
|
||||
and torch.cuda.is_available() else 'cpu'
|
||||
use_device = True if preprocessor.device == 'cuda' else False
|
||||
preprocessor.processor = \
|
||||
SubPreprocessor(model_dir=model.model_dir,
|
||||
db_content=True,
|
||||
use_gpu=use_device)
|
||||
preprocessor.output_tables = \
|
||||
process_tables(preprocessor.processor,
|
||||
preprocessor.tables)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): _description_
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: the prediction results
|
||||
"""
|
||||
sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db'])
|
||||
result = {OutputKeys.TEXT: sql}
|
||||
return result
|
||||
@@ -27,6 +27,7 @@ if TYPE_CHECKING:
|
||||
DialogModelingPreprocessor,
|
||||
DialogStateTrackingPreprocessor)
|
||||
from .video import ReadVideoData
|
||||
from .star import ConversationalTextToSqlPreprocessor
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -55,6 +56,7 @@ else:
|
||||
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor',
|
||||
'DialogStateTrackingPreprocessor', 'InputFeatures'
|
||||
],
|
||||
'star': ['ConversationalTextToSqlPreprocessor'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
29
modelscope/preprocessors/star/__init__.py
Normal file
29
modelscope/preprocessors/star/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .conversational_text_to_sql_preprocessor import \
|
||||
ConversationalTextToSqlPreprocessor
|
||||
from .fields import MultiWOZBPETextField, IntentBPETextField
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'conversational_text_to_sql_preprocessor':
|
||||
['ConversationalTextToSqlPreprocessor'],
|
||||
'fields': [
|
||||
'get_label', 'SubPreprocessor', 'preprocess_dataset',
|
||||
'process_dataset'
|
||||
]
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,111 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import json
|
||||
import torch
|
||||
from text2sql_lgesql.preprocess.graph_utils import GraphProcessor
|
||||
from text2sql_lgesql.preprocess.process_graphs import process_dataset_graph
|
||||
from text2sql_lgesql.utils.batch import Batch
|
||||
from text2sql_lgesql.utils.example import Example
|
||||
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.preprocessors.base import Preprocessor
|
||||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||||
from modelscope.preprocessors.star.fields.preprocess_dataset import \
|
||||
preprocess_dataset
|
||||
from modelscope.preprocessors.star.fields.process_dataset import (
|
||||
process_dataset, process_tables)
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Fields, ModelFile
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
|
||||
__all__ = ['ConversationalTextToSqlPreprocessor']
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp, module_name=Preprocessors.conversational_text_to_sql)
|
||||
class ConversationalTextToSqlPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""preprocess the data via the vocab.txt from the `model_dir` path
|
||||
|
||||
Args:
|
||||
model_dir (str): model path
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.model_dir: str = model_dir
|
||||
|
||||
self.config = Config.from_file(
|
||||
os.path.join(self.model_dir, ModelFile.CONFIGURATION))
|
||||
self.device = 'cuda' if \
|
||||
('device' not in kwargs or kwargs['device'] == 'gpu') \
|
||||
and torch.cuda.is_available() else 'cpu'
|
||||
self.processor = None
|
||||
self.table_path = os.path.join(self.model_dir, 'tables.json')
|
||||
self.tables = json.load(open(self.table_path, 'r'))
|
||||
self.output_tables = None
|
||||
self.path_cache = []
|
||||
self.graph_processor = GraphProcessor()
|
||||
|
||||
Example.configuration(
|
||||
plm=self.config['model']['plm'],
|
||||
tables=self.output_tables,
|
||||
table_path=os.path.join(model_dir, 'tables.json'),
|
||||
model_dir=self.model_dir,
|
||||
db_dir=os.path.join(model_dir, 'db'))
|
||||
|
||||
@type_assert(object, dict)
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""process the raw input data
|
||||
|
||||
Args:
|
||||
data (dict):
|
||||
utterance: a sentence
|
||||
last_sql: predicted sql of last utterance
|
||||
Example:
|
||||
utterance: 'Which of these are hiring?'
|
||||
last_sql: ''
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: the preprocessed data
|
||||
"""
|
||||
# use local database
|
||||
if data['local_db_path'] is not None and data[
|
||||
'local_db_path'] not in self.path_cache:
|
||||
self.path_cache.append(data['local_db_path'])
|
||||
path = os.path.join(data['local_db_path'], 'tables.json')
|
||||
self.tables = json.load(open(path, 'r'))
|
||||
self.processor.db_dir = os.path.join(data['local_db_path'], 'db')
|
||||
self.output_tables = process_tables(self.processor, self.tables)
|
||||
Example.configuration(
|
||||
plm=self.config['model']['plm'],
|
||||
tables=self.output_tables,
|
||||
table_path=path,
|
||||
model_dir=self.model_dir,
|
||||
db_dir=self.processor.db_dir)
|
||||
|
||||
theresult, sql_label = \
|
||||
preprocess_dataset(
|
||||
self.processor,
|
||||
data,
|
||||
self.output_tables,
|
||||
data['database_id'],
|
||||
self.tables
|
||||
)
|
||||
output_dataset = process_dataset(self.model_dir, self.processor,
|
||||
theresult, self.output_tables)
|
||||
output_dataset = \
|
||||
process_dataset_graph(
|
||||
self.graph_processor,
|
||||
output_dataset,
|
||||
self.output_tables,
|
||||
method='lgesql'
|
||||
)
|
||||
dev_ex = Example(output_dataset[0],
|
||||
self.output_tables[data['database_id']], sql_label)
|
||||
current_batch = Batch.from_example_list([dev_ex],
|
||||
self.device,
|
||||
train=False)
|
||||
return {'batch': current_batch, 'db': data['database_id']}
|
||||
6
modelscope/preprocessors/star/fields/__init__.py
Normal file
6
modelscope/preprocessors/star/fields/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from modelscope.preprocessors.star.fields.common_utils import SubPreprocessor
|
||||
from modelscope.preprocessors.star.fields.parse import get_label
|
||||
from modelscope.preprocessors.star.fields.preprocess_dataset import \
|
||||
preprocess_dataset
|
||||
from modelscope.preprocessors.star.fields.process_dataset import \
|
||||
process_dataset
|
||||
471
modelscope/preprocessors/star/fields/common_utils.py
Normal file
471
modelscope/preprocessors/star/fields/common_utils.py
Normal file
@@ -0,0 +1,471 @@
|
||||
# Copyright (c) rhythmcao modified from https://github.com/rhythmcao/text2sql-lgesql.
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
from itertools import combinations, product
|
||||
|
||||
import nltk
|
||||
import numpy as np
|
||||
from text2sql_lgesql.utils.constants import MAX_RELATIVE_DIST
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
mwtokenizer = nltk.MWETokenizer(separator='')
|
||||
mwtokenizer.add_mwe(('[', 'CLS', ']'))
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def is_number(s):
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def quote_normalization(question):
|
||||
""" Normalize all usage of quotation marks into a separate \" """
|
||||
new_question, quotation_marks = [], [
|
||||
"'", '"', '`', '‘', '’', '“', '”', '``', "''", '‘‘', '’’'
|
||||
]
|
||||
for idx, tok in enumerate(question):
|
||||
if len(tok) > 2 and tok[0] in quotation_marks and tok[
|
||||
-1] in quotation_marks:
|
||||
new_question += ["\"", tok[1:-1], "\""]
|
||||
elif len(tok) > 2 and tok[0] in quotation_marks:
|
||||
new_question += ["\"", tok[1:]]
|
||||
elif len(tok) > 2 and tok[-1] in quotation_marks:
|
||||
new_question += [tok[:-1], "\""]
|
||||
elif tok in quotation_marks:
|
||||
new_question.append("\"")
|
||||
elif len(tok) == 2 and tok[0] in quotation_marks:
|
||||
# special case: the length of entity value is 1
|
||||
if idx + 1 < len(question) and question[idx
|
||||
+ 1] in quotation_marks:
|
||||
new_question += ["\"", tok[1]]
|
||||
else:
|
||||
new_question.append(tok)
|
||||
else:
|
||||
new_question.append(tok)
|
||||
return new_question
|
||||
|
||||
|
||||
class SubPreprocessor():
|
||||
|
||||
def __init__(self, model_dir, use_gpu=False, db_content=True):
|
||||
super(SubPreprocessor, self).__init__()
|
||||
self.model_dir = model_dir
|
||||
self.db_dir = os.path.join(model_dir, 'db')
|
||||
self.db_content = db_content
|
||||
|
||||
from nltk import data
|
||||
from nltk.corpus import stopwords
|
||||
data.path.append(os.path.join(self.model_dir, 'nltk_data'))
|
||||
self.stopwords = stopwords.words('english')
|
||||
|
||||
import stanza
|
||||
from stanza.resources import common
|
||||
from stanza.pipeline import core
|
||||
self.nlp = stanza.Pipeline(
|
||||
'en',
|
||||
use_gpu=use_gpu,
|
||||
dir=self.model_dir,
|
||||
processors='tokenize,pos,lemma',
|
||||
tokenize_pretokenized=True,
|
||||
download_method=core.DownloadMethod.REUSE_RESOURCES)
|
||||
self.nlp1 = stanza.Pipeline(
|
||||
'en',
|
||||
use_gpu=use_gpu,
|
||||
dir=self.model_dir,
|
||||
processors='tokenize,pos,lemma',
|
||||
download_method=core.DownloadMethod.REUSE_RESOURCES)
|
||||
|
||||
def pipeline(self, entry: dict, db: dict, verbose: bool = False):
|
||||
""" db should be preprocessed """
|
||||
entry = self.preprocess_question(entry, db, verbose=verbose)
|
||||
entry = self.schema_linking(entry, db, verbose=verbose)
|
||||
entry = self.extract_subgraph(entry, db, verbose=verbose)
|
||||
return entry
|
||||
|
||||
def preprocess_database(self, db: dict, verbose: bool = False):
|
||||
table_toks, table_names = [], []
|
||||
for tab in db['table_names']:
|
||||
doc = self.nlp1(tab)
|
||||
tab = [w.lemma.lower() for s in doc.sentences for w in s.words]
|
||||
table_toks.append(tab)
|
||||
table_names.append(' '.join(tab))
|
||||
db['processed_table_toks'], db[
|
||||
'processed_table_names'] = table_toks, table_names
|
||||
column_toks, column_names = [], []
|
||||
for _, c in db['column_names']:
|
||||
doc = self.nlp1(c)
|
||||
c = [w.lemma.lower() for s in doc.sentences for w in s.words]
|
||||
column_toks.append(c)
|
||||
column_names.append(' '.join(c))
|
||||
db['processed_column_toks'], db[
|
||||
'processed_column_names'] = column_toks, column_names
|
||||
column2table = list(map(lambda x: x[0], db['column_names']))
|
||||
table2columns = [[] for _ in range(len(table_names))]
|
||||
for col_id, col in enumerate(db['column_names']):
|
||||
if col_id == 0:
|
||||
continue
|
||||
table2columns[col[0]].append(col_id)
|
||||
db['column2table'], db['table2columns'] = column2table, table2columns
|
||||
|
||||
t_num, c_num, dtype = len(db['table_names']), len(
|
||||
db['column_names']), '<U100'
|
||||
|
||||
tab_mat = np.array([['table-table-generic'] * t_num
|
||||
for _ in range(t_num)],
|
||||
dtype=dtype)
|
||||
table_fks = set(
|
||||
map(lambda pair: (column2table[pair[0]], column2table[pair[1]]),
|
||||
db['foreign_keys']))
|
||||
for (tab1, tab2) in table_fks:
|
||||
if (tab2, tab1) in table_fks:
|
||||
tab_mat[tab1, tab2], tab_mat[
|
||||
tab2, tab1] = 'table-table-fkb', 'table-table-fkb'
|
||||
else:
|
||||
tab_mat[tab1, tab2], tab_mat[
|
||||
tab2, tab1] = 'table-table-fk', 'table-table-fkr'
|
||||
tab_mat[list(range(t_num)),
|
||||
list(range(t_num))] = 'table-table-identity'
|
||||
|
||||
col_mat = np.array([['column-column-generic'] * c_num
|
||||
for _ in range(c_num)],
|
||||
dtype=dtype)
|
||||
for i in range(t_num):
|
||||
col_ids = [idx for idx, t in enumerate(column2table) if t == i]
|
||||
col1, col2 = list(zip(*list(product(col_ids, col_ids))))
|
||||
col_mat[col1, col2] = 'column-column-sametable'
|
||||
col_mat[list(range(c_num)),
|
||||
list(range(c_num))] = 'column-column-identity'
|
||||
if len(db['foreign_keys']) > 0:
|
||||
col1, col2 = list(zip(*db['foreign_keys']))
|
||||
col_mat[col1, col2], col_mat[
|
||||
col2, col1] = 'column-column-fk', 'column-column-fkr'
|
||||
col_mat[0, list(range(c_num))] = '*-column-generic'
|
||||
col_mat[list(range(c_num)), 0] = 'column-*-generic'
|
||||
col_mat[0, 0] = '*-*-identity'
|
||||
|
||||
# relations between tables and columns, t_num*c_num and c_num*t_num
|
||||
tab_col_mat = np.array([['table-column-generic'] * c_num
|
||||
for _ in range(t_num)],
|
||||
dtype=dtype)
|
||||
col_tab_mat = np.array([['column-table-generic'] * t_num
|
||||
for _ in range(c_num)],
|
||||
dtype=dtype)
|
||||
cols, tabs = list(
|
||||
zip(*list(map(lambda x: (x, column2table[x]), range(1, c_num)))))
|
||||
col_tab_mat[cols, tabs], tab_col_mat[
|
||||
tabs, cols] = 'column-table-has', 'table-column-has'
|
||||
if len(db['primary_keys']) > 0:
|
||||
cols, tabs = list(
|
||||
zip(*list(
|
||||
map(lambda x: (x, column2table[x]), db['primary_keys']))))
|
||||
col_tab_mat[cols, tabs], tab_col_mat[
|
||||
tabs, cols] = 'column-table-pk', 'table-column-pk'
|
||||
col_tab_mat[0, list(range(t_num))] = '*-table-generic'
|
||||
tab_col_mat[list(range(t_num)), 0] = 'table-*-generic'
|
||||
|
||||
relations = \
|
||||
np.concatenate([
|
||||
np.concatenate([tab_mat, tab_col_mat], axis=1),
|
||||
np.concatenate([col_tab_mat, col_mat], axis=1)
|
||||
], axis=0)
|
||||
db['relations'] = relations.tolist()
|
||||
|
||||
if verbose:
|
||||
print('Tables:', ', '.join(db['table_names']))
|
||||
print('Lemmatized:', ', '.join(table_names))
|
||||
print('Columns:',
|
||||
', '.join(list(map(lambda x: x[1], db['column_names']))))
|
||||
print('Lemmatized:', ', '.join(column_names), '\n')
|
||||
return db
|
||||
|
||||
def preprocess_question(self,
|
||||
entry: dict,
|
||||
db: dict,
|
||||
verbose: bool = False):
|
||||
""" Tokenize, lemmatize, lowercase question"""
|
||||
# stanza tokenize, lemmatize and POS tag
|
||||
question = ' '.join(quote_normalization(entry['question_toks']))
|
||||
|
||||
from nltk import data
|
||||
data.path.append(os.path.join(self.model_dir, 'nltk_data'))
|
||||
question = nltk.word_tokenize(question)
|
||||
question = mwtokenizer.tokenize(question)
|
||||
|
||||
doc = self.nlp([question])
|
||||
raw_toks = [w.text.lower() for s in doc.sentences for w in s.words]
|
||||
toks = [w.lemma.lower() for s in doc.sentences for w in s.words]
|
||||
pos_tags = [w.xpos for s in doc.sentences for w in s.words]
|
||||
|
||||
entry['raw_question_toks'] = raw_toks
|
||||
entry['processed_question_toks'] = toks
|
||||
entry['pos_tags'] = pos_tags
|
||||
|
||||
q_num, dtype = len(toks), '<U100'
|
||||
if q_num <= MAX_RELATIVE_DIST + 1:
|
||||
dist_vec = [
|
||||
'question-question-dist'
|
||||
+ str(i) if i != 0 else 'question-question-identity'
|
||||
for i in range(-MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1)
|
||||
]
|
||||
starting = MAX_RELATIVE_DIST
|
||||
else:
|
||||
dist_vec = ['question-question-generic'] \
|
||||
* (q_num - MAX_RELATIVE_DIST - 1) + \
|
||||
[
|
||||
'question-question-dist' + str(i)
|
||||
if i != 0 else 'question-question-identity'
|
||||
for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1,
|
||||
1)]\
|
||||
+ ['question-question-generic'] \
|
||||
* (q_num - MAX_RELATIVE_DIST - 1)
|
||||
starting = q_num - 1
|
||||
list_data = \
|
||||
[dist_vec[starting - i:starting - i + q_num] for i in range(q_num)]
|
||||
q_mat = \
|
||||
np.array(
|
||||
list_data,
|
||||
dtype=dtype
|
||||
)
|
||||
entry['relations'] = q_mat.tolist()
|
||||
|
||||
if verbose:
|
||||
print('Question:', entry['question'])
|
||||
print('Tokenized:', ' '.join(entry['raw_question_toks']))
|
||||
print('Lemmatized:', ' '.join(entry['processed_question_toks']))
|
||||
print('Pos tags:', ' '.join(entry['pos_tags']), '\n')
|
||||
return entry
|
||||
|
||||
def extract_subgraph(self, entry: dict, db: dict, verbose: bool = False):
|
||||
used_schema = {'table': set(), 'column': set()}
|
||||
entry['used_tables'] = sorted(list(used_schema['table']))
|
||||
entry['used_columns'] = sorted(list(used_schema['column']))
|
||||
|
||||
if verbose:
|
||||
print('Used tables:', entry['used_tables'])
|
||||
print('Used columns:', entry['used_columns'], '\n')
|
||||
return entry
|
||||
|
||||
def extract_subgraph_from_sql(self, sql: dict, used_schema: dict):
|
||||
select_items = sql['select'][1]
|
||||
# select clause
|
||||
for _, val_unit in select_items:
|
||||
if val_unit[0] == 0:
|
||||
col_unit = val_unit[1]
|
||||
used_schema['column'].add(col_unit[1])
|
||||
else:
|
||||
col_unit1, col_unit2 = val_unit[1:]
|
||||
used_schema['column'].add(col_unit1[1])
|
||||
used_schema['column'].add(col_unit2[1])
|
||||
# from clause conds
|
||||
table_units = sql['from']['table_units']
|
||||
for _, t in table_units:
|
||||
if type(t) == dict:
|
||||
used_schema = self.extract_subgraph_from_sql(t, used_schema)
|
||||
else:
|
||||
used_schema['table'].add(t)
|
||||
# from, where and having conds
|
||||
used_schema = self.extract_subgraph_from_conds(sql['from']['conds'],
|
||||
used_schema)
|
||||
used_schema = self.extract_subgraph_from_conds(sql['where'],
|
||||
used_schema)
|
||||
used_schema = self.extract_subgraph_from_conds(sql['having'],
|
||||
used_schema)
|
||||
# groupBy and orderBy clause
|
||||
groupBy = sql['groupBy']
|
||||
for col_unit in groupBy:
|
||||
used_schema['column'].add(col_unit[1])
|
||||
orderBy = sql['orderBy']
|
||||
if len(orderBy) > 0:
|
||||
orderBy = orderBy[1]
|
||||
for val_unit in orderBy:
|
||||
if val_unit[0] == 0:
|
||||
col_unit = val_unit[1]
|
||||
used_schema['column'].add(col_unit[1])
|
||||
else:
|
||||
col_unit1, col_unit2 = val_unit[1:]
|
||||
used_schema['column'].add(col_unit1[1])
|
||||
used_schema['column'].add(col_unit2[1])
|
||||
# union, intersect and except clause
|
||||
if sql['intersect']:
|
||||
used_schema = self.extract_subgraph_from_sql(
|
||||
sql['intersect'], used_schema)
|
||||
if sql['union']:
|
||||
used_schema = self.extract_subgraph_from_sql(
|
||||
sql['union'], used_schema)
|
||||
if sql['except']:
|
||||
used_schema = self.extract_subgraph_from_sql(
|
||||
sql['except'], used_schema)
|
||||
return used_schema
|
||||
|
||||
def extract_subgraph_from_conds(self, conds: list, used_schema: dict):
|
||||
if len(conds) == 0:
|
||||
return used_schema
|
||||
for cond in conds:
|
||||
if cond in ['and', 'or']:
|
||||
continue
|
||||
val_unit, val1, val2 = cond[2:]
|
||||
if val_unit[0] == 0:
|
||||
col_unit = val_unit[1]
|
||||
used_schema['column'].add(col_unit[1])
|
||||
else:
|
||||
col_unit1, col_unit2 = val_unit[1:]
|
||||
used_schema['column'].add(col_unit1[1])
|
||||
used_schema['column'].add(col_unit2[1])
|
||||
if type(val1) == list:
|
||||
used_schema['column'].add(val1[1])
|
||||
elif type(val1) == dict:
|
||||
used_schema = self.extract_subgraph_from_sql(val1, used_schema)
|
||||
if type(val2) == list:
|
||||
used_schema['column'].add(val1[1])
|
||||
elif type(val2) == dict:
|
||||
used_schema = self.extract_subgraph_from_sql(val2, used_schema)
|
||||
return used_schema
|
||||
|
||||
def schema_linking(self, entry: dict, db: dict, verbose: bool = False):
|
||||
raw_question_toks, question_toks = entry['raw_question_toks'], entry[
|
||||
'processed_question_toks']
|
||||
table_toks, column_toks = db['processed_table_toks'], db[
|
||||
'processed_column_toks']
|
||||
table_names, column_names = db['processed_table_names'], db[
|
||||
'processed_column_names']
|
||||
q_num, t_num, c_num, dtype = len(question_toks), len(table_toks), len(
|
||||
column_toks), '<U100'
|
||||
|
||||
# relations between questions and tables, q_num*t_num and t_num*q_num
|
||||
table_matched_pairs = {'partial': [], 'exact': []}
|
||||
q_tab_mat = np.array([['question-table-nomatch'] * t_num
|
||||
for _ in range(q_num)],
|
||||
dtype=dtype)
|
||||
tab_q_mat = np.array([['table-question-nomatch'] * q_num
|
||||
for _ in range(t_num)],
|
||||
dtype=dtype)
|
||||
max_len = max([len(t) for t in table_toks])
|
||||
index_pairs = list(
|
||||
filter(lambda x: x[1] - x[0] <= max_len,
|
||||
combinations(range(q_num + 1), 2)))
|
||||
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
|
||||
for i, j in index_pairs:
|
||||
phrase = ' '.join(question_toks[i:j])
|
||||
if phrase in self.stopwords:
|
||||
continue
|
||||
for idx, name in enumerate(table_names):
|
||||
if phrase == name:
|
||||
q_tab_mat[range(i, j), idx] = 'question-table-exactmatch'
|
||||
tab_q_mat[idx, range(i, j)] = 'table-question-exactmatch'
|
||||
if verbose:
|
||||
table_matched_pairs['exact'].append(
|
||||
str((name, idx, phrase, i, j)))
|
||||
elif (j - i == 1
|
||||
and phrase in name.split()) or (j - i > 1
|
||||
and phrase in name):
|
||||
q_tab_mat[range(i, j), idx] = 'question-table-partialmatch'
|
||||
tab_q_mat[idx, range(i, j)] = 'table-question-partialmatch'
|
||||
if verbose:
|
||||
table_matched_pairs['partial'].append(
|
||||
str((name, idx, phrase, i, j)))
|
||||
|
||||
# relations between questions and columns
|
||||
column_matched_pairs = {'partial': [], 'exact': [], 'value': []}
|
||||
q_col_mat = np.array([['question-column-nomatch'] * c_num
|
||||
for _ in range(q_num)],
|
||||
dtype=dtype)
|
||||
col_q_mat = np.array([['column-question-nomatch'] * q_num
|
||||
for _ in range(c_num)],
|
||||
dtype=dtype)
|
||||
max_len = max([len(c) for c in column_toks])
|
||||
index_pairs = list(
|
||||
filter(lambda x: x[1] - x[0] <= max_len,
|
||||
combinations(range(q_num + 1), 2)))
|
||||
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
|
||||
for i, j in index_pairs:
|
||||
phrase = ' '.join(question_toks[i:j])
|
||||
if phrase in self.stopwords:
|
||||
continue
|
||||
for idx, name in enumerate(column_names):
|
||||
if phrase == name:
|
||||
q_col_mat[range(i, j), idx] = 'question-column-exactmatch'
|
||||
col_q_mat[idx, range(i, j)] = 'column-question-exactmatch'
|
||||
if verbose:
|
||||
column_matched_pairs['exact'].append(
|
||||
str((name, idx, phrase, i, j)))
|
||||
elif (j - i == 1
|
||||
and phrase in name.split()) or (j - i > 1
|
||||
and phrase in name):
|
||||
q_col_mat[range(i, j),
|
||||
idx] = 'question-column-partialmatch'
|
||||
col_q_mat[idx,
|
||||
range(i, j)] = 'column-question-partialmatch'
|
||||
if verbose:
|
||||
column_matched_pairs['partial'].append(
|
||||
str((name, idx, phrase, i, j)))
|
||||
if self.db_content:
|
||||
db_file = os.path.join(self.db_dir, db['db_id'],
|
||||
db['db_id'] + '.sqlite')
|
||||
if not os.path.exists(db_file):
|
||||
raise ValueError('[ERROR]: database file %s not found ...' %
|
||||
(db_file))
|
||||
conn = sqlite3.connect(db_file)
|
||||
conn.text_factory = lambda b: b.decode(errors='ignore')
|
||||
conn.execute('pragma foreign_keys=ON')
|
||||
for i, (tab_id,
|
||||
col_name) in enumerate(db['column_names_original']):
|
||||
if i == 0 or 'id' in column_toks[
|
||||
i]: # ignore * and special token 'id'
|
||||
continue
|
||||
tab_name = db['table_names_original'][tab_id]
|
||||
try:
|
||||
cursor = conn.execute("SELECT DISTINCT \"%s\" FROM \"%s\";"
|
||||
% (col_name, tab_name))
|
||||
cell_values = cursor.fetchall()
|
||||
cell_values = [str(each[0]) for each in cell_values]
|
||||
cell_values = [[str(float(each))] if is_number(each) else
|
||||
each.lower().split()
|
||||
for each in cell_values]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
for j, word in enumerate(raw_question_toks):
|
||||
word = str(float(word)) if is_number(word) else word
|
||||
for c in cell_values:
|
||||
if word in c and 'nomatch' in q_col_mat[
|
||||
j, i] and word not in self.stopwords:
|
||||
q_col_mat[j, i] = 'question-column-valuematch'
|
||||
col_q_mat[i, j] = 'column-question-valuematch'
|
||||
if verbose:
|
||||
column_matched_pairs['value'].append(
|
||||
str((column_names[i], i, word, j, j + 1)))
|
||||
break
|
||||
conn.close()
|
||||
|
||||
q_col_mat[:, 0] = 'question-*-generic'
|
||||
col_q_mat[0] = '*-question-generic'
|
||||
q_schema = np.concatenate([q_tab_mat, q_col_mat], axis=1)
|
||||
schema_q = np.concatenate([tab_q_mat, col_q_mat], axis=0)
|
||||
entry['schema_linking'] = (q_schema.tolist(), schema_q.tolist())
|
||||
|
||||
if verbose:
|
||||
print('Question:', ' '.join(question_toks))
|
||||
print('Table matched: (table name, column id, \
|
||||
question span, start id, end id)')
|
||||
print(
|
||||
'Exact match:', ', '.join(table_matched_pairs['exact'])
|
||||
if table_matched_pairs['exact'] else 'empty')
|
||||
print(
|
||||
'Partial match:', ', '.join(table_matched_pairs['partial'])
|
||||
if table_matched_pairs['partial'] else 'empty')
|
||||
print('Column matched: (column name, column id, \
|
||||
question span, start id, end id)')
|
||||
print(
|
||||
'Exact match:', ', '.join(column_matched_pairs['exact'])
|
||||
if column_matched_pairs['exact'] else 'empty')
|
||||
print(
|
||||
'Partial match:', ', '.join(column_matched_pairs['partial'])
|
||||
if column_matched_pairs['partial'] else 'empty')
|
||||
print(
|
||||
'Value match:', ', '.join(column_matched_pairs['value'])
|
||||
if column_matched_pairs['value'] else 'empty', '\n')
|
||||
return entry
|
||||
333
modelscope/preprocessors/star/fields/parse.py
Normal file
333
modelscope/preprocessors/star/fields/parse.py
Normal file
@@ -0,0 +1,333 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
CLAUSE_KEYWORDS = ('SELECT', 'FROM', 'WHERE', 'GROUP', 'ORDER', 'LIMIT',
|
||||
'INTERSECT', 'UNION', 'EXCEPT')
|
||||
JOIN_KEYWORDS = ('JOIN', 'ON', 'AS')
|
||||
|
||||
WHERE_OPS = ('NOT_IN', 'BETWEEN', '=', '>', '<', '>=', '<=', '!=', 'IN',
|
||||
'LIKE', 'IS', 'EXISTS')
|
||||
UNIT_OPS = ('NONE', '-', '+', '*', '/')
|
||||
AGG_OPS = ('', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG')
|
||||
TABLE_TYPE = {
|
||||
'sql': 'sql',
|
||||
'table_unit': 'table_unit',
|
||||
}
|
||||
COND_OPS = ('AND', 'OR')
|
||||
SQL_OPS = ('INTERSECT', 'UNION', 'EXCEPT')
|
||||
ORDER_OPS = ('DESC', 'ASC')
|
||||
|
||||
|
||||
def get_select_labels(select, slot, cur_nest):
|
||||
for item in select[1]:
|
||||
if AGG_OPS[item[0]] != '':
|
||||
if slot[item[1][1][1]] == '':
|
||||
slot[item[1][1][1]] += (cur_nest + ' ' + AGG_OPS[item[0]])
|
||||
else:
|
||||
slot[item[1][1][1]] += (' ' + cur_nest + ' '
|
||||
+ AGG_OPS[item[0]])
|
||||
else:
|
||||
if slot[item[1][1][1]] == '':
|
||||
slot[item[1][1][1]] += (cur_nest)
|
||||
else:
|
||||
slot[item[1][1][1]] += (' ' + cur_nest)
|
||||
return slot
|
||||
|
||||
|
||||
def get_groupby_labels(groupby, slot, cur_nest):
|
||||
for item in groupby:
|
||||
if slot[item[1]] == '':
|
||||
slot[item[1]] += (cur_nest)
|
||||
else:
|
||||
slot[item[1]] += (' ' + cur_nest)
|
||||
return slot
|
||||
|
||||
|
||||
def get_orderby_labels(orderby, limit, slot, cur_nest):
|
||||
if limit is None:
|
||||
thelimit = ''
|
||||
else:
|
||||
thelimit = ' LIMIT'
|
||||
for item in orderby[1]:
|
||||
if AGG_OPS[item[1][0]] != '':
|
||||
agg = ' ' + AGG_OPS[item[1][0]] + ' '
|
||||
else:
|
||||
agg = ' '
|
||||
if slot[item[1][1]] == '':
|
||||
slot[item[1][1]] += (
|
||||
cur_nest + agg + orderby[0].upper() + thelimit)
|
||||
else:
|
||||
slot[item[1][1]] += (' ' + cur_nest + agg + orderby[0].upper()
|
||||
+ thelimit)
|
||||
|
||||
return slot
|
||||
|
||||
|
||||
def get_intersect_labels(intersect, slot, cur_nest):
|
||||
if isinstance(intersect, dict):
|
||||
if cur_nest != '':
|
||||
slot = get_labels(intersect, slot, cur_nest)
|
||||
else:
|
||||
slot = get_labels(intersect, slot, 'INTERSECT')
|
||||
else:
|
||||
return slot
|
||||
return slot
|
||||
|
||||
|
||||
def get_except_labels(texcept, slot, cur_nest):
|
||||
if isinstance(texcept, dict):
|
||||
if cur_nest != '':
|
||||
slot = get_labels(texcept, slot, cur_nest)
|
||||
else:
|
||||
slot = get_labels(texcept, slot, 'EXCEPT')
|
||||
else:
|
||||
return slot
|
||||
return slot
|
||||
|
||||
|
||||
def get_union_labels(union, slot, cur_nest):
|
||||
if isinstance(union, dict):
|
||||
if cur_nest != '':
|
||||
slot = get_labels(union, slot, cur_nest)
|
||||
else:
|
||||
slot = get_labels(union, slot, 'UNION')
|
||||
else:
|
||||
return slot
|
||||
return slot
|
||||
|
||||
|
||||
def get_from_labels(tfrom, slot, cur_nest):
|
||||
if tfrom['table_units'][0][0] == 'sql':
|
||||
slot = get_labels(tfrom['table_units'][0][1], slot, 'OP_SEL')
|
||||
else:
|
||||
return slot
|
||||
return slot
|
||||
|
||||
|
||||
def get_having_labels(having, slot, cur_nest):
|
||||
if len(having) == 1:
|
||||
item = having[0]
|
||||
if item[0] is True:
|
||||
neg = ' NOT'
|
||||
else:
|
||||
neg = ''
|
||||
if isinstance(item[3], dict):
|
||||
if AGG_OPS[item[2][1][0]] != '':
|
||||
agg = ' ' + AGG_OPS[item[2][1][0]]
|
||||
else:
|
||||
agg = ''
|
||||
if slot[item[2][1][1]] == '':
|
||||
slot[item[2][1][1]] += (
|
||||
cur_nest + agg + neg + ' ' + WHERE_OPS[item[1]])
|
||||
else:
|
||||
slot[item[2][1][1]] += (' ' + cur_nest + agg + neg + ' '
|
||||
+ WHERE_OPS[item[1]])
|
||||
slot = get_labels(item[3], slot, 'OP_SEL')
|
||||
else:
|
||||
if AGG_OPS[item[2][1][0]] != '':
|
||||
agg = ' ' + AGG_OPS[item[2][1][0]] + ' '
|
||||
else:
|
||||
agg = ' '
|
||||
if slot[item[2][1][1]] == '':
|
||||
slot[item[2][1][1]] += (cur_nest + agg + WHERE_OPS[item[1]])
|
||||
else:
|
||||
slot[item[2][1][1]] += (' ' + cur_nest + agg
|
||||
+ WHERE_OPS[item[1]])
|
||||
else:
|
||||
for index, item in enumerate(having):
|
||||
if item[0] is True:
|
||||
neg = ' NOT'
|
||||
else:
|
||||
neg = ''
|
||||
if (index + 1 < len(having) and having[index + 1]) == 'or' or (
|
||||
index - 1 >= 0 and having[index - 1] == 'or'):
|
||||
if AGG_OPS[item[2][1][0]] != '':
|
||||
agg = ' ' + AGG_OPS[item[2][1][0]]
|
||||
else:
|
||||
agg = ''
|
||||
if isinstance(item[3], dict):
|
||||
if slot[item[2][1][1]] == '':
|
||||
slot[item[2][1][1]] += (
|
||||
cur_nest + agg + neg + ' ' + WHERE_OPS[item[1]])
|
||||
else:
|
||||
slot[item[2][1][1]] += (' ' + cur_nest + agg + neg
|
||||
+ ' ' + WHERE_OPS[item[1]])
|
||||
slot = get_labels(item[3], slot, 'OP_SEL')
|
||||
else:
|
||||
if AGG_OPS[item[2][1][0]] != '':
|
||||
agg = ' ' + AGG_OPS[item[2][1][0]] + ' '
|
||||
else:
|
||||
agg = ' '
|
||||
if slot[item[2][1][1]] == '':
|
||||
slot[item[2][1][1]] += (
|
||||
cur_nest + ' OR' + agg + WHERE_OPS[item[1]])
|
||||
else:
|
||||
slot[item[2][1][1]] += (' ' + cur_nest + ' OR' + agg
|
||||
+ WHERE_OPS[item[1]])
|
||||
elif item == 'and' or item == 'or':
|
||||
continue
|
||||
else:
|
||||
if isinstance(item[3], dict):
|
||||
if slot[item[2][1][1]] == '':
|
||||
slot[item[2][1][1]] += (
|
||||
cur_nest + neg + ' ' + WHERE_OPS[item[1]])
|
||||
else:
|
||||
slot[item[2][1][1]] += (' ' + cur_nest + neg + ' '
|
||||
+ WHERE_OPS[item[1]])
|
||||
slot = get_labels(item[3], slot, 'OP_SEL')
|
||||
else:
|
||||
if AGG_OPS[item[2][1][0]] != '':
|
||||
agg = ' ' + AGG_OPS[item[2][1][0]] + ' '
|
||||
else:
|
||||
agg = ' '
|
||||
if slot[item[2][1][1]] == '':
|
||||
slot[item[2][1][1]] += (
|
||||
cur_nest + agg + WHERE_OPS[item[1]])
|
||||
else:
|
||||
slot[item[2][1][1]] += (' ' + cur_nest + agg
|
||||
+ WHERE_OPS[item[1]])
|
||||
return slot
|
||||
|
||||
|
||||
def get_where_labels(where, slot, cur_nest):
|
||||
if len(where) == 1:
|
||||
item = where[0]
|
||||
if item[0] is True:
|
||||
neg = ' NOT'
|
||||
else:
|
||||
neg = ''
|
||||
if isinstance(item[3], dict):
|
||||
if slot[item[2][1][1]] == '':
|
||||
slot[item[2][1][1]] += (
|
||||
cur_nest + neg + ' ' + WHERE_OPS[item[1]])
|
||||
else:
|
||||
slot[item[2][1][1]] += (' ' + cur_nest + neg + ' '
|
||||
+ WHERE_OPS[item[1]])
|
||||
slot = get_labels(item[3], slot, 'OP_SEL')
|
||||
else:
|
||||
if slot[item[2][1][1]] == '':
|
||||
slot[item[2][1][1]] += (
|
||||
cur_nest + neg + ' ' + WHERE_OPS[item[1]])
|
||||
else:
|
||||
slot[item[2][1][1]] += (' ' + cur_nest + neg + ' '
|
||||
+ WHERE_OPS[item[1]])
|
||||
else:
|
||||
for index, item in enumerate(where):
|
||||
if item[0] is True:
|
||||
neg = ' NOT'
|
||||
else:
|
||||
neg = ''
|
||||
if (index + 1 < len(where) and where[index + 1]) == 'or' or (
|
||||
index - 1 >= 0 and where[index - 1] == 'or'):
|
||||
if isinstance(item[3], dict):
|
||||
if slot[item[2][1][1]] == '':
|
||||
slot[item[2][1][1]] += (
|
||||
cur_nest + neg + ' ' + WHERE_OPS[item[1]])
|
||||
else:
|
||||
slot[item[2][1][1]] += (' ' + cur_nest + neg + ' '
|
||||
+ WHERE_OPS[item[1]])
|
||||
slot = get_labels(item[3], slot, 'OP_SEL')
|
||||
else:
|
||||
if slot[item[2][1][1]] == '':
|
||||
slot[item[2][1][1]] += (
|
||||
cur_nest + ' OR' + neg + ' ' + WHERE_OPS[item[1]])
|
||||
else:
|
||||
slot[item[2][1][1]] += (' ' + cur_nest + ' OR' + neg
|
||||
+ ' ' + WHERE_OPS[item[1]])
|
||||
elif item == 'and' or item == 'or':
|
||||
continue
|
||||
else:
|
||||
if isinstance(item[3], dict):
|
||||
if slot[item[2][1][1]] == '':
|
||||
slot[item[2][1][1]] += (
|
||||
cur_nest + neg + ' ' + WHERE_OPS[item[1]])
|
||||
else:
|
||||
slot[item[2][1][1]] += (' ' + cur_nest + neg + ' '
|
||||
+ WHERE_OPS[item[1]])
|
||||
slot = get_labels(item[3], slot, 'OP_SEL')
|
||||
else:
|
||||
if slot[item[2][1][1]] == '':
|
||||
slot[item[2][1][1]] += (
|
||||
cur_nest + neg + ' ' + WHERE_OPS[item[1]])
|
||||
else:
|
||||
slot[item[2][1][1]] += (' ' + cur_nest + neg + ' '
|
||||
+ WHERE_OPS[item[1]])
|
||||
return slot
|
||||
|
||||
|
||||
def get_labels(sql_struct, slot, cur_nest):
|
||||
|
||||
if len(sql_struct['select']) > 0:
|
||||
if cur_nest != '':
|
||||
slot = get_select_labels(sql_struct['select'], slot,
|
||||
cur_nest + ' SELECT')
|
||||
else:
|
||||
slot = get_select_labels(sql_struct['select'], slot, 'SELECT')
|
||||
|
||||
if sql_struct['from']:
|
||||
if cur_nest != '':
|
||||
slot = get_from_labels(sql_struct['from'], slot, 'FROM')
|
||||
else:
|
||||
slot = get_from_labels(sql_struct['from'], slot, 'FROM')
|
||||
|
||||
if len(sql_struct['where']) > 0:
|
||||
if cur_nest != '':
|
||||
slot = get_where_labels(sql_struct['where'], slot,
|
||||
cur_nest + ' WHERE')
|
||||
else:
|
||||
slot = get_where_labels(sql_struct['where'], slot, 'WHERE')
|
||||
|
||||
if len(sql_struct['groupBy']) > 0:
|
||||
if cur_nest != '':
|
||||
slot = get_groupby_labels(sql_struct['groupBy'], slot,
|
||||
cur_nest + ' GROUP_BY')
|
||||
else:
|
||||
slot = get_groupby_labels(sql_struct['groupBy'], slot, 'GROUP_BY')
|
||||
|
||||
if len(sql_struct['having']) > 0:
|
||||
if cur_nest != '':
|
||||
slot = get_having_labels(sql_struct['having'], slot,
|
||||
cur_nest + ' HAVING')
|
||||
else:
|
||||
slot = get_having_labels(sql_struct['having'], slot, 'HAVING')
|
||||
|
||||
if len(sql_struct['orderBy']) > 0:
|
||||
if cur_nest != '':
|
||||
slot = get_orderby_labels(sql_struct['orderBy'],
|
||||
sql_struct['limit'], slot,
|
||||
cur_nest + ' ORDER_BY')
|
||||
else:
|
||||
slot = get_orderby_labels(sql_struct['orderBy'],
|
||||
sql_struct['limit'], slot, 'ORDER_BY')
|
||||
|
||||
if sql_struct['intersect']:
|
||||
if cur_nest != '':
|
||||
slot = get_intersect_labels(sql_struct['intersect'], slot,
|
||||
cur_nest + ' INTERSECT')
|
||||
else:
|
||||
slot = get_intersect_labels(sql_struct['intersect'], slot,
|
||||
'INTERSECT')
|
||||
|
||||
if sql_struct['except']:
|
||||
if cur_nest != '':
|
||||
slot = get_except_labels(sql_struct['except'], slot,
|
||||
cur_nest + ' EXCEPT')
|
||||
else:
|
||||
slot = get_except_labels(sql_struct['except'], slot, 'EXCEPT')
|
||||
|
||||
if sql_struct['union']:
|
||||
if cur_nest != '':
|
||||
slot = get_union_labels(sql_struct['union'], slot,
|
||||
cur_nest + ' UNION')
|
||||
else:
|
||||
slot = get_union_labels(sql_struct['union'], slot, 'UNION')
|
||||
return slot
|
||||
|
||||
|
||||
def get_label(sql, column_len):
|
||||
thelabel = []
|
||||
slot = {}
|
||||
for idx in range(column_len):
|
||||
slot[idx] = ''
|
||||
for value in get_labels(sql, slot, '').values():
|
||||
thelabel.append(value)
|
||||
return thelabel
|
||||
37
modelscope/preprocessors/star/fields/preprocess_dataset.py
Normal file
37
modelscope/preprocessors/star/fields/preprocess_dataset.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from text2sql_lgesql.preprocess.parse_raw_json import Schema, get_schemas
|
||||
from text2sql_lgesql.process_sql import get_sql
|
||||
|
||||
from modelscope.preprocessors.star.fields.parse import get_label
|
||||
|
||||
|
||||
def preprocess_dataset(processor, dataset, output_tables, database_id, tables):
|
||||
|
||||
schemas, db_names, thetables = get_schemas(tables)
|
||||
intables = output_tables[database_id]
|
||||
schema = schemas[database_id]
|
||||
table = thetables[database_id]
|
||||
sql_label = []
|
||||
if len(dataset['history']) == 0 or dataset['last_sql'] == '':
|
||||
sql_label = [''] * len(intables['column_names'])
|
||||
else:
|
||||
schema = Schema(schema, table)
|
||||
try:
|
||||
sql_label = get_sql(schema, dataset['last_sql'])
|
||||
except Exception:
|
||||
sql_label = [''] * len(intables['column_names'])
|
||||
sql_label = get_label(sql_label, len(table['column_names_original']))
|
||||
theone = {'db_id': database_id}
|
||||
theone['query'] = ''
|
||||
theone['query_toks_no_value'] = []
|
||||
theone['sql'] = {}
|
||||
if len(dataset['history']) != 0:
|
||||
theone['question'] = dataset['utterance'] + ' [CLS] ' + ' [CLS] '.join(
|
||||
dataset['history'][::-1][:4])
|
||||
theone['question_toks'] = theone['question'].split()
|
||||
else:
|
||||
theone['question'] = dataset['utterance']
|
||||
theone['question_toks'] = dataset['utterance'].split()
|
||||
|
||||
return [theone], sql_label
|
||||
64
modelscope/preprocessors/star/fields/process_dataset.py
Normal file
64
modelscope/preprocessors/star/fields/process_dataset.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) rhythmcao modified from https://github.com/rhythmcao/text2sql-lgesql.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
|
||||
import json
|
||||
from text2sql_lgesql.asdl.asdl import ASDLGrammar
|
||||
from text2sql_lgesql.asdl.transition_system import TransitionSystem
|
||||
|
||||
from modelscope.preprocessors.star.fields.common_utils import SubPreprocessor
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
def process_example(processor, entry, db, trans, verbose=False):
|
||||
# preprocess raw tokens, schema linking and subgraph extraction
|
||||
entry = processor.pipeline(entry, db, verbose=verbose)
|
||||
# generate target output actions
|
||||
entry['ast'] = []
|
||||
entry['actions'] = []
|
||||
return entry
|
||||
|
||||
|
||||
def process_tables(processor, tables_list, output_path=None, verbose=False):
|
||||
tables = {}
|
||||
for each in tables_list:
|
||||
if verbose:
|
||||
print('*************** Processing database %s **************' %
|
||||
(each['db_id']))
|
||||
tables[each['db_id']] = processor.preprocess_database(
|
||||
each, verbose=verbose)
|
||||
print('In total, process %d databases .' % (len(tables)))
|
||||
if output_path is not None:
|
||||
pickle.dump(tables, open(output_path, 'wb'))
|
||||
return tables
|
||||
|
||||
|
||||
def process_dataset(model_dir,
|
||||
processor,
|
||||
dataset,
|
||||
tables,
|
||||
output_path=None,
|
||||
skip_large=False,
|
||||
verbose=False):
|
||||
grammar = ASDLGrammar.from_filepath(
|
||||
os.path.join(model_dir, 'sql_asdl_v2.txt'))
|
||||
trans = TransitionSystem.get_class_by_lang('sql')(grammar)
|
||||
processed_dataset = []
|
||||
for idx, entry in enumerate(dataset):
|
||||
if skip_large and len(tables[entry['db_id']]['column_names']) > 100:
|
||||
continue
|
||||
if verbose:
|
||||
print('*************** Processing %d-th sample **************' %
|
||||
(idx))
|
||||
entry = process_example(
|
||||
processor, entry, tables[entry['db_id']], trans, verbose=verbose)
|
||||
processed_dataset.append(entry)
|
||||
if output_path is not None:
|
||||
# serialize preprocessed dataset
|
||||
pickle.dump(processed_dataset, open(output_path, 'wb'))
|
||||
return processed_dataset
|
||||
@@ -89,6 +89,7 @@ class NLPTasks(object):
|
||||
zero_shot_classification = 'zero-shot-classification'
|
||||
backbone = 'backbone'
|
||||
text_error_correction = 'text-error-correction'
|
||||
conversational_text_to_sql = 'conversational-text-to-sql'
|
||||
|
||||
|
||||
class AudioTasks(object):
|
||||
|
||||
@@ -6,5 +6,6 @@ pai-easynlp
|
||||
rouge_score<=0.0.4
|
||||
seqeval
|
||||
spacy>=2.3.5
|
||||
text2sql_lgesql
|
||||
tokenizers
|
||||
transformers>=4.12.0
|
||||
|
||||
97
tests/pipelines/test_conversational_text_to_sql.py
Normal file
97
tests/pipelines/test_conversational_text_to_sql.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import StarForTextToSql
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import ConversationalTextToSqlPipeline
|
||||
from modelscope.preprocessors import ConversationalTextToSqlPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class ConversationalTextToSql(unittest.TestCase):
|
||||
model_id = 'damo/nlp_star_conversational-text-to-sql'
|
||||
test_case = {
|
||||
'database_id':
|
||||
'employee_hire_evaluation',
|
||||
'local_db_path':
|
||||
None,
|
||||
'utterance': [
|
||||
"I'd like to see Shop names.", 'Which of these are hiring?',
|
||||
'Which shop is hiring the highest number of employees? | do you want the name of the shop ? | Yes'
|
||||
]
|
||||
}
|
||||
|
||||
def tracking_and_print_results(
|
||||
self, pipelines: List[ConversationalTextToSqlPipeline]):
|
||||
for my_pipeline in pipelines:
|
||||
last_sql, history = '', []
|
||||
for item in self.test_case['utterance']:
|
||||
case = {
|
||||
'utterance': item,
|
||||
'history': history,
|
||||
'last_sql': last_sql,
|
||||
'database_id': self.test_case['database_id'],
|
||||
'local_db_path': self.test_case['local_db_path']
|
||||
}
|
||||
results = my_pipeline(case)
|
||||
print({'question': item})
|
||||
print(results)
|
||||
last_sql = results['text']
|
||||
history.append(item)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
preprocessor = ConversationalTextToSqlPreprocessor(
|
||||
model_dir=cache_path,
|
||||
database_id=self.test_case['database_id'],
|
||||
db_content=True)
|
||||
model = StarForTextToSql(
|
||||
model_dir=cache_path, config=preprocessor.config)
|
||||
|
||||
pipelines = [
|
||||
ConversationalTextToSqlPipeline(
|
||||
model=model, preprocessor=preprocessor),
|
||||
pipeline(
|
||||
task=Tasks.conversational_text_to_sql,
|
||||
model=model,
|
||||
preprocessor=preprocessor)
|
||||
]
|
||||
self.tracking_and_print_results(pipelines)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
preprocessor = ConversationalTextToSqlPreprocessor(
|
||||
model_dir=model.model_dir)
|
||||
|
||||
pipelines = [
|
||||
ConversationalTextToSqlPipeline(
|
||||
model=model, preprocessor=preprocessor),
|
||||
pipeline(
|
||||
task=Tasks.conversational_text_to_sql,
|
||||
model=model,
|
||||
preprocessor=preprocessor)
|
||||
]
|
||||
self.tracking_and_print_results(pipelines)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_name(self):
|
||||
pipelines = [
|
||||
pipeline(
|
||||
task=Tasks.conversational_text_to_sql, model=self.model_id)
|
||||
]
|
||||
self.tracking_and_print_results(pipelines)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_default_model(self):
|
||||
pipelines = [pipeline(task=Tasks.conversational_text_to_sql)]
|
||||
self.tracking_and_print_results(pipelines)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user