mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
Merge remote-tracking branch 'origin/master' into ofa/finetune
This commit is contained in:
@@ -12,6 +12,7 @@ from http.cookiejar import CookieJar
|
||||
from os.path import expanduser
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import attrs
|
||||
import requests
|
||||
|
||||
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
|
||||
@@ -21,9 +22,14 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
|
||||
API_RESPONSE_FIELD_USERNAME,
|
||||
DEFAULT_CREDENTIALS_PATH, Licenses,
|
||||
ModelVisibility)
|
||||
from modelscope.hub.deploy import (DeleteServiceParameters,
|
||||
DeployServiceParameters,
|
||||
GetServiceParameters, ListServiceParameters,
|
||||
ServiceParameters, ServiceResourceConfig,
|
||||
Vendor)
|
||||
from modelscope.hub.errors import (InvalidParameter, NotExistError,
|
||||
NotLoginException, RequestError,
|
||||
datahub_raise_on_error,
|
||||
NotLoginException, NotSupportError,
|
||||
RequestError, datahub_raise_on_error,
|
||||
handle_http_post_error,
|
||||
handle_http_response, is_ok, raise_on_error)
|
||||
from modelscope.hub.git import GitCommandWrapper
|
||||
@@ -306,6 +312,169 @@ class HubApi:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
def deploy_model(self, model_id: str, revision: str, instance_name: str,
|
||||
resource: ServiceResourceConfig,
|
||||
provider: ServiceParameters):
|
||||
"""Deploy model to cloud, current we only support PAI EAS, this is asynchronous
|
||||
call , please check instance status through the console or query the instance status.
|
||||
At the same time, this call may take a long time.
|
||||
|
||||
Args:
|
||||
model_id (str): The deployed model id
|
||||
revision (str): The model revision
|
||||
instance_name (str): The deployed model instance name.
|
||||
resource (DeployResource): The resource information.
|
||||
provider (CreateParameter): The cloud service provider parameter
|
||||
|
||||
Raises:
|
||||
NotLoginException: To use this api, you need login first.
|
||||
NotSupportError: Not supported platform.
|
||||
RequestError: The server return error.
|
||||
|
||||
Returns:
|
||||
InstanceInfo: The instance information.
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise NotLoginException(
|
||||
'Token does not exist, please login first.')
|
||||
if provider.vendor != Vendor.EAS:
|
||||
raise NotSupportError(
|
||||
'Not support vendor: %s ,only support EAS current.' %
|
||||
(provider.vendor))
|
||||
create_params = DeployServiceParameters(
|
||||
instance_name=instance_name,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
resource=resource,
|
||||
provider=provider)
|
||||
path = f'{self.endpoint}/api/v1/deployer/endpoint'
|
||||
body = attrs.asdict(create_params)
|
||||
r = requests.post(
|
||||
path,
|
||||
json=body,
|
||||
cookies=cookies,
|
||||
)
|
||||
handle_http_response(r, logger, cookies, 'create_eas_instance')
|
||||
if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
def list_deployed_model_instances(self,
|
||||
provider: ServiceParameters,
|
||||
skip: int = 0,
|
||||
limit: int = 100):
|
||||
"""List deployed model instances.
|
||||
|
||||
Args:
|
||||
provider (ListServiceParameter): The cloud service provider parameter,
|
||||
for eas, need access_key_id and access_key_secret.
|
||||
skip: start of the list, current not support.
|
||||
limit: maximum number of instances return, current not support
|
||||
Raises:
|
||||
NotLoginException: To use this api, you need login first.
|
||||
RequestError: The request is failed from server.
|
||||
|
||||
Returns:
|
||||
List: List of instance information
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise NotLoginException(
|
||||
'Token does not exist, please login first.')
|
||||
params = ListServiceParameters(
|
||||
provider=provider, skip=skip, limit=limit)
|
||||
path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint,
|
||||
params.to_query_str())
|
||||
r = requests.get(path, cookies=cookies)
|
||||
handle_http_response(r, logger, cookies, 'list_deployed_model')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
def get_deployed_model_instance(self, instance_name: str,
|
||||
provider: ServiceParameters):
|
||||
"""Query the specified instance information.
|
||||
|
||||
Args:
|
||||
instance_name (str): The deployed instance name.
|
||||
provider (GetParameter): The cloud provider information, for eas
|
||||
need region(eg: ch-hangzhou), access_key_id and access_key_secret.
|
||||
|
||||
Raises:
|
||||
NotLoginException: To use this api, you need login first.
|
||||
RequestError: The request is failed from server.
|
||||
|
||||
Returns:
|
||||
Dict: The request instance information
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise NotLoginException(
|
||||
'Token does not exist, please login first.')
|
||||
params = GetServiceParameters(provider=provider)
|
||||
path = '%s/api/v1/deployer/endpoint/%s?%s' % (
|
||||
self.endpoint, instance_name, params.to_query_str())
|
||||
r = requests.get(path, cookies=cookies)
|
||||
handle_http_response(r, logger, cookies, 'get_deployed_model')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
def delete_deployed_model_instance(self, instance_name: str,
|
||||
provider: ServiceParameters):
|
||||
"""Delete deployed model, this api send delete command and return, it will take
|
||||
some to delete, please check through the cloud console.
|
||||
|
||||
Args:
|
||||
instance_name (str): The instance name you want to delete.
|
||||
provider (DeleteParameter): The cloud provider information, for eas
|
||||
need region(eg: ch-hangzhou), access_key_id and access_key_secret.
|
||||
|
||||
Raises:
|
||||
NotLoginException: To call this api, you need login first.
|
||||
RequestError: The request is failed.
|
||||
|
||||
Returns:
|
||||
Dict: The deleted instance information.
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise NotLoginException(
|
||||
'Token does not exist, please login first.')
|
||||
params = DeleteServiceParameters(provider=provider)
|
||||
path = '%s/api/v1/deployer/endpoint/%s?%s' % (
|
||||
self.endpoint, instance_name, params.to_query_str())
|
||||
r = requests.delete(path, cookies=cookies)
|
||||
handle_http_response(r, logger, cookies, 'delete_deployed_model')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
def _check_cookie(self,
|
||||
use_cookies: Union[bool,
|
||||
CookieJar] = False) -> CookieJar:
|
||||
|
||||
189
modelscope/hub/deploy.py
Normal file
189
modelscope/hub/deploy.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import urllib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
import json
|
||||
from attr import fields
|
||||
from attrs import asdict, define, field, validators
|
||||
|
||||
|
||||
class Accelerator(object):
|
||||
CPU = 'cpu'
|
||||
GPU = 'gpu'
|
||||
|
||||
|
||||
class Vendor(object):
|
||||
EAS = 'eas'
|
||||
|
||||
|
||||
class EASRegion(object):
|
||||
beijing = 'cn-beijing'
|
||||
hangzhou = 'cn-hangzhou'
|
||||
|
||||
|
||||
class EASCpuInstanceType(object):
|
||||
"""EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html)
|
||||
"""
|
||||
tiny = 'ecs.c6.2xlarge'
|
||||
small = 'ecs.c6.4xlarge'
|
||||
medium = 'ecs.c6.6xlarge'
|
||||
large = 'ecs.c6.8xlarge'
|
||||
|
||||
|
||||
class EASGpuInstanceType(object):
|
||||
"""EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html)
|
||||
"""
|
||||
tiny = 'ecs.gn5-c28g1.7xlarge'
|
||||
small = 'ecs.gn5-c8g1.4xlarge'
|
||||
medium = 'ecs.gn6i-c24g1.12xlarge'
|
||||
large = 'ecs.gn6e-c12g1.3xlarge'
|
||||
|
||||
|
||||
def min_smaller_than_max(instance, attribute, value):
|
||||
if value > instance.max_replica:
|
||||
raise ValueError(
|
||||
"'min_replica' value: %s has to be smaller than 'max_replica' value: %s!"
|
||||
% (value, instance.max_replica))
|
||||
|
||||
|
||||
@define
|
||||
class ServiceScalingConfig(object):
|
||||
"""Resource scaling config
|
||||
Currently we ignore max_replica
|
||||
Args:
|
||||
max_replica: maximum replica
|
||||
min_replica: minimum replica
|
||||
"""
|
||||
max_replica: int = field(default=1, validator=validators.ge(1))
|
||||
min_replica: int = field(
|
||||
default=1, validator=[validators.ge(1), min_smaller_than_max])
|
||||
|
||||
|
||||
@define
|
||||
class ServiceResourceConfig(object):
|
||||
"""Eas Resource request.
|
||||
|
||||
Args:
|
||||
accelerator: the accelerator(cpu|gpu)
|
||||
instance_type: the instance type.
|
||||
scaling: The instance scaling config.
|
||||
"""
|
||||
instance_type: str
|
||||
scaling: ServiceScalingConfig
|
||||
accelerator: str = field(
|
||||
default=Accelerator.CPU,
|
||||
validator=validators.in_([Accelerator.CPU, Accelerator.GPU]))
|
||||
|
||||
|
||||
@define
|
||||
class ServiceParameters(ABC):
|
||||
pass
|
||||
|
||||
|
||||
@define
|
||||
class EASDeployParameters(ServiceParameters):
|
||||
"""Parameters for EAS Deployment.
|
||||
|
||||
Args:
|
||||
resource_group: the resource group to deploy, current default.
|
||||
region: The eas instance region(eg: cn-hangzhou).
|
||||
access_key_id: The eas account access key id.
|
||||
access_key_secret: The eas account access key secret.
|
||||
vendor: must be 'eas'
|
||||
"""
|
||||
region: str
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
resource_group: Optional[str] = None
|
||||
vendor: str = field(
|
||||
default=Vendor.EAS, validator=validators.in_([Vendor.EAS]))
|
||||
"""
|
||||
def __init__(self,
|
||||
instance_name: str,
|
||||
access_key_id: str,
|
||||
access_key_secret: str,
|
||||
region = EASRegion.beijing,
|
||||
instance_type: str = EASCpuInstances.small,
|
||||
accelerator: str = Accelerator.CPU,
|
||||
resource_group: Optional[str] = None,
|
||||
scaling: Optional[str] = None):
|
||||
self.instance_name=instance_name
|
||||
self.access_key_id=self.access_key_id
|
||||
self.access_key_secret = access_key_secret
|
||||
self.region = region
|
||||
self.instance_type = instance_type
|
||||
self.accelerator = accelerator
|
||||
self.resource_group = resource_group
|
||||
self.scaling = scaling
|
||||
"""
|
||||
|
||||
|
||||
@define
|
||||
class EASListParameters(ServiceParameters):
|
||||
"""EAS instance list parameters.
|
||||
|
||||
Args:
|
||||
resource_group: the resource group to deploy, current default.
|
||||
region: The eas instance region(eg: cn-hangzhou).
|
||||
access_key_id: The eas account access key id.
|
||||
access_key_secret: The eas account access key secret.
|
||||
vendor: must be 'eas'
|
||||
"""
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region: str = None
|
||||
resource_group: str = None
|
||||
vendor: str = field(
|
||||
default=Vendor.EAS, validator=validators.in_([Vendor.EAS]))
|
||||
|
||||
|
||||
@define
|
||||
class DeployServiceParameters(object):
|
||||
"""Deploy service parameters
|
||||
|
||||
Args:
|
||||
instance_name: the name of the service.
|
||||
model_id: the modelscope model_id
|
||||
revision: the modelscope model revision
|
||||
resource: the resource requirement.
|
||||
provider: the cloud service provider.
|
||||
"""
|
||||
instance_name: str
|
||||
model_id: str
|
||||
revision: str
|
||||
resource: ServiceResourceConfig
|
||||
provider: ServiceParameters
|
||||
|
||||
|
||||
class AttrsToQueryString(ABC):
|
||||
"""Convert the attrs class to json string.
|
||||
|
||||
Args:
|
||||
"""
|
||||
|
||||
def to_query_str(self):
|
||||
self_dict = asdict(
|
||||
self.provider, filter=lambda attr, value: value is not None)
|
||||
json_str = json.dumps(self_dict)
|
||||
print(json_str)
|
||||
safe_str = urllib.parse.quote_plus(json_str)
|
||||
print(safe_str)
|
||||
query_param = 'provider=%s' % safe_str
|
||||
return query_param
|
||||
|
||||
|
||||
@define
|
||||
class ListServiceParameters(AttrsToQueryString):
|
||||
provider: ServiceParameters
|
||||
skip: int = 0
|
||||
limit: int = 100
|
||||
|
||||
|
||||
@define
|
||||
class GetServiceParameters(AttrsToQueryString):
|
||||
provider: ServiceParameters
|
||||
|
||||
|
||||
@define
|
||||
class DeleteServiceParameters(AttrsToQueryString):
|
||||
provider: ServiceParameters
|
||||
@@ -9,6 +9,10 @@ from modelscope.utils.logger import get_logger
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class NotSupportError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NotExistError(Exception):
|
||||
pass
|
||||
|
||||
@@ -66,6 +70,7 @@ def handle_http_response(response, logger, cookies, model_id):
|
||||
logger.error(
|
||||
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \
|
||||
private. Please login first.')
|
||||
logger.error('Response details: %s' % response.content)
|
||||
raise error
|
||||
|
||||
|
||||
|
||||
@@ -67,8 +67,9 @@ class Models(object):
|
||||
space_dst = 'space-dst'
|
||||
space_intent = 'space-intent'
|
||||
space_modeling = 'space-modeling'
|
||||
star = 'star'
|
||||
star3 = 'star3'
|
||||
space_T_en = 'space-T-en'
|
||||
space_T_cn = 'space-T-cn'
|
||||
|
||||
tcrf = 'transformer-crf'
|
||||
transformer_softmax = 'transformer-softmax'
|
||||
lcrf = 'lstm-crf'
|
||||
|
||||
@@ -16,6 +16,7 @@ from modelscope.models.builder import MODELS
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .utils import timestamp_format
|
||||
from .yolox.data.data_augment import ValTransform
|
||||
from .yolox.exp import get_exp_by_name
|
||||
from .yolox.utils import postprocess
|
||||
@@ -99,14 +100,17 @@ class RealtimeVideoDetector(TorchModel):
|
||||
def inference_video(self, v_path):
|
||||
outputs = []
|
||||
desc = 'Detecting video: {}'.format(v_path)
|
||||
for frame, result in tqdm(
|
||||
self.inference_video_iter(v_path), desc=desc):
|
||||
for frame_idx, (frame, result) in enumerate(
|
||||
tqdm(self.inference_video_iter(v_path), desc=desc)):
|
||||
result = result + (timestamp_format(seconds=frame_idx
|
||||
/ self.fps), )
|
||||
outputs.append(result)
|
||||
|
||||
return outputs
|
||||
|
||||
def inference_video_iter(self, v_path):
|
||||
capture = cv2.VideoCapture(v_path)
|
||||
self.fps = capture.get(cv2.CAP_PROP_FPS)
|
||||
while capture.isOpened():
|
||||
ret, frame = capture.read()
|
||||
if not ret:
|
||||
|
||||
9
modelscope/models/cv/realtime_object_detection/utils.py
Normal file
9
modelscope/models/cv/realtime_object_detection/utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import math
|
||||
|
||||
|
||||
def timestamp_format(seconds):
|
||||
m, s = divmod(seconds, 60)
|
||||
h, m = divmod(m, 60)
|
||||
time = '%02d:%02d:%06.3f' % (h, m, s)
|
||||
return time
|
||||
@@ -24,8 +24,8 @@ import json
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Star3Config(object):
|
||||
"""Configuration class to store the configuration of a `Star3Model`.
|
||||
class SpaceTCnConfig(object):
|
||||
"""Configuration class to store the configuration of a `SpaceTCnModel`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -40,10 +40,10 @@ class Star3Config(object):
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02):
|
||||
"""Constructs Star3Config.
|
||||
"""Constructs SpaceTCnConfig.
|
||||
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `Star3Model`.
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `SpaceTCnConfig`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
@@ -59,7 +59,7 @@ class Star3Config(object):
|
||||
max_position_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into `Star3Model`.
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into `SpaceTCnConfig`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
@@ -89,15 +89,15 @@ class Star3Config(object):
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `Star3Config` from a Python dictionary of parameters."""
|
||||
config = Star3Config(vocab_size_or_config_json_file=-1)
|
||||
"""Constructs a `SpaceTCnConfig` from a Python dictionary of parameters."""
|
||||
config = SpaceTCnConfig(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `Star3Config` from a json file of parameters."""
|
||||
"""Constructs a `SpaceTCnConfig` from a json file of parameters."""
|
||||
with open(json_file, 'r', encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
@@ -27,7 +27,8 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from modelscope.models.nlp.star3.configuration_star3 import Star3Config
|
||||
from modelscope.models.nlp.space_T_cn.configuration_space_T_cn import \
|
||||
SpaceTCnConfig
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -609,9 +610,9 @@ class PreTrainedBertModel(nn.Module):
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedBertModel, self).__init__()
|
||||
if not isinstance(config, Star3Config):
|
||||
if not isinstance(config, SpaceTCnConfig):
|
||||
raise ValueError(
|
||||
'Parameter config in `{}(config)` should be an instance of class `Star3Config`. '
|
||||
'Parameter config in `{}(config)` should be an instance of class `SpaceTCnConfig`. '
|
||||
'To create a model from a Google pretrained model use '
|
||||
'`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format(
|
||||
self.__class__.__name__, self.__class__.__name__))
|
||||
@@ -676,7 +677,7 @@ class PreTrainedBertModel(nn.Module):
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
config = Star3Config.from_json_file(config_file)
|
||||
config = SpaceTCnConfig.from_json_file(config_file)
|
||||
logger.info('Model config {}'.format(config))
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
@@ -742,11 +743,11 @@ class PreTrainedBertModel(nn.Module):
|
||||
return model
|
||||
|
||||
|
||||
class Star3Model(PreTrainedBertModel):
|
||||
"""Star3Model model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR3.0").
|
||||
class SpaceTCnModel(PreTrainedBertModel):
|
||||
"""SpaceTCnModel model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR-T-CN").
|
||||
|
||||
Params:
|
||||
config: a Star3Config class instance with the configuration to build a new model
|
||||
config: a SpaceTCnConfig class instance with the configuration to build a new model
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
@@ -780,16 +781,16 @@ class Star3Model(PreTrainedBertModel):
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
config = modeling.Star3Config(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
config = modeling.SpaceTCnConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = modeling.Star3Model(config=config)
|
||||
model = modeling.SpaceTCnModel(config=config)
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config, schema_link_module='none'):
|
||||
super(Star3Model, self).__init__(config)
|
||||
super(SpaceTCnModel, self).__init__(config)
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = BertEncoder(
|
||||
config, schema_link_module=schema_link_module)
|
||||
@@ -20,7 +20,7 @@ __all__ = ['StarForTextToSql']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.conversational_text_to_sql, module_name=Models.star)
|
||||
Tasks.table_question_answering, module_name=Models.space_T_en)
|
||||
class StarForTextToSql(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
|
||||
@@ -3,27 +3,25 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import json
|
||||
import numpy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model, Tensor
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.nlp.star3.configuration_star3 import Star3Config
|
||||
from modelscope.models.nlp.star3.modeling_star3 import Seq2SQL, Star3Model
|
||||
from modelscope.preprocessors.star3.fields.struct import Constant
|
||||
from modelscope.preprocessors.space_T_cn.fields.struct import Constant
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.device import verify_device
|
||||
from .space_T_cn.configuration_space_T_cn import SpaceTCnConfig
|
||||
from .space_T_cn.modeling_space_T_cn import Seq2SQL, SpaceTCnModel
|
||||
|
||||
__all__ = ['TableQuestionAnswering']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.table_question_answering, module_name=Models.star3)
|
||||
Tasks.table_question_answering, module_name=Models.space_T_cn)
|
||||
class TableQuestionAnswering(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
@@ -43,9 +41,9 @@ class TableQuestionAnswering(Model):
|
||||
os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
|
||||
map_location='cpu')
|
||||
|
||||
self.backbone_config = Star3Config.from_json_file(
|
||||
self.backbone_config = SpaceTCnConfig.from_json_file(
|
||||
os.path.join(self.model_dir, ModelFile.CONFIGURATION))
|
||||
self.backbone_model = Star3Model(
|
||||
self.backbone_model = SpaceTCnModel(
|
||||
config=self.backbone_config, schema_link_module='rat')
|
||||
self.backbone_model.load_state_dict(state_dict['backbone_model'])
|
||||
|
||||
|
||||
@@ -606,21 +606,12 @@ TASK_OUTPUTS = {
|
||||
# }
|
||||
Tasks.task_oriented_conversation: [OutputKeys.OUTPUT],
|
||||
|
||||
# conversational text-to-sql result for single sample
|
||||
# {
|
||||
# "text": "SELECT shop.Name FROM shop."
|
||||
# }
|
||||
Tasks.conversational_text_to_sql: [OutputKeys.TEXT],
|
||||
|
||||
# table-question-answering result for single sample
|
||||
# {
|
||||
# "sql": "SELECT shop.Name FROM shop."
|
||||
# "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]}
|
||||
# }
|
||||
Tasks.table_question_answering: [
|
||||
OutputKeys.SQL_STRING, OutputKeys.SQL_QUERY, OutputKeys.HISTORY,
|
||||
OutputKeys.QUERT_RESULT
|
||||
],
|
||||
Tasks.table_question_answering: [OutputKeys.OUTPUT],
|
||||
|
||||
# ============ audio tasks ===================
|
||||
# asr result for single sample
|
||||
|
||||
@@ -69,9 +69,6 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/nlp_space_dialog-modeling'),
|
||||
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking,
|
||||
'damo/nlp_space_dialog-state-tracking'),
|
||||
Tasks.conversational_text_to_sql:
|
||||
(Pipelines.conversational_text_to_sql,
|
||||
'damo/nlp_star_conversational-text-to-sql'),
|
||||
Tasks.table_question_answering:
|
||||
(Pipelines.table_question_answering_pipeline,
|
||||
'damo/nlp-convai-text2sql-pretrain-cn'),
|
||||
|
||||
@@ -113,9 +113,8 @@ class AnimalRecognitionPipeline(Pipeline):
|
||||
label_mapping = f.readlines()
|
||||
score = torch.max(inputs['outputs'])
|
||||
inputs = {
|
||||
OutputKeys.SCORES:
|
||||
score.item(),
|
||||
OutputKeys.SCORES: [score.item()],
|
||||
OutputKeys.LABELS:
|
||||
label_mapping[inputs['outputs'].argmax()].split('\t')[1]
|
||||
[label_mapping[inputs['outputs'].argmax()].split('\t')[1]]
|
||||
}
|
||||
return inputs
|
||||
|
||||
@@ -114,9 +114,8 @@ class GeneralRecognitionPipeline(Pipeline):
|
||||
label_mapping = f.readlines()
|
||||
score = torch.max(inputs['outputs'])
|
||||
inputs = {
|
||||
OutputKeys.SCORES:
|
||||
score.item(),
|
||||
OutputKeys.SCORES: [score.item()],
|
||||
OutputKeys.LABELS:
|
||||
label_mapping[inputs['outputs'].argmax()].split('\t')[1]
|
||||
[label_mapping[inputs['outputs'].argmax()].split('\t')[1]]
|
||||
}
|
||||
return inputs
|
||||
|
||||
@@ -45,15 +45,17 @@ class RealtimeVideoObjectDetectionPipeline(Pipeline):
|
||||
**kwargs) -> str:
|
||||
forward_output = input['forward_output']
|
||||
|
||||
scores, boxes, labels = [], [], []
|
||||
scores, boxes, labels, timestamps = [], [], [], []
|
||||
for result in forward_output:
|
||||
box, score, label = result
|
||||
box, score, label, timestamp = result
|
||||
scores.append(score)
|
||||
boxes.append(box)
|
||||
labels.append(label)
|
||||
timestamps.append(timestamp)
|
||||
|
||||
return {
|
||||
OutputKeys.BOXES: boxes,
|
||||
OutputKeys.SCORES: scores,
|
||||
OutputKeys.LABELS: labels,
|
||||
OutputKeys.TIMESTAMPS: timestamps,
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ __all__ = ['ConversationalTextToSqlPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.conversational_text_to_sql,
|
||||
Tasks.table_question_answering,
|
||||
module_name=Pipelines.conversational_text_to_sql)
|
||||
class ConversationalTextToSqlPipeline(Pipeline):
|
||||
|
||||
@@ -62,7 +62,7 @@ class ConversationalTextToSqlPipeline(Pipeline):
|
||||
Dict[str, str]: the prediction results
|
||||
"""
|
||||
sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db'])
|
||||
result = {OutputKeys.TEXT: sql}
|
||||
result = {OutputKeys.OUTPUT: {OutputKeys.TEXT: sql}}
|
||||
return result
|
||||
|
||||
def _collate_fn(self, data):
|
||||
|
||||
@@ -13,8 +13,9 @@ from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import TableQuestionAnsweringPreprocessor
|
||||
from modelscope.preprocessors.star3.fields.database import Database
|
||||
from modelscope.preprocessors.star3.fields.struct import Constant, SQLQuery
|
||||
from modelscope.preprocessors.space_T_cn.fields.database import Database
|
||||
from modelscope.preprocessors.space_T_cn.fields.struct import (Constant,
|
||||
SQLQuery)
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
__all__ = ['TableQuestionAnsweringPipeline']
|
||||
@@ -320,7 +321,7 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
||||
OutputKeys.QUERT_RESULT: tabledata,
|
||||
}
|
||||
|
||||
return output
|
||||
return {OutputKeys.OUTPUT: output}
|
||||
|
||||
def _collate_fn(self, data):
|
||||
return data
|
||||
|
||||
@@ -40,7 +40,7 @@ if TYPE_CHECKING:
|
||||
DialogStateTrackingPreprocessor)
|
||||
from .video import ReadVideoData, MovieSceneSegmentationPreprocessor
|
||||
from .star import ConversationalTextToSqlPreprocessor
|
||||
from .star3 import TableQuestionAnsweringPreprocessor
|
||||
from .space_T_cn import TableQuestionAnsweringPreprocessor
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -81,7 +81,7 @@ else:
|
||||
'DialogStateTrackingPreprocessor', 'InputFeatures'
|
||||
],
|
||||
'star': ['ConversationalTextToSqlPreprocessor'],
|
||||
'star3': ['TableQuestionAnsweringPreprocessor'],
|
||||
'space_T_cn': ['TableQuestionAnsweringPreprocessor'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -4,7 +4,7 @@ import sqlite3
|
||||
import json
|
||||
import tqdm
|
||||
|
||||
from modelscope.preprocessors.star3.fields.struct import Trie
|
||||
from modelscope.preprocessors.space_T_cn.fields.struct import Trie
|
||||
|
||||
|
||||
class Database:
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import re
|
||||
|
||||
from modelscope.preprocessors.star3.fields.struct import TypeInfo
|
||||
from modelscope.preprocessors.space_T_cn.fields.struct import TypeInfo
|
||||
|
||||
|
||||
class SchemaLinker:
|
||||
@@ -8,8 +8,8 @@ from transformers import BertTokenizer
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.preprocessors.base import Preprocessor
|
||||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||||
from modelscope.preprocessors.star3.fields.database import Database
|
||||
from modelscope.preprocessors.star3.fields.schema_link import SchemaLinker
|
||||
from modelscope.preprocessors.space_T_cn.fields.database import Database
|
||||
from modelscope.preprocessors.space_T_cn.fields.schema_link import SchemaLinker
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Fields, ModelFile
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
@@ -123,7 +123,6 @@ class NLPTasks(object):
|
||||
backbone = 'backbone'
|
||||
text_error_correction = 'text-error-correction'
|
||||
faq_question_answering = 'faq-question-answering'
|
||||
conversational_text_to_sql = 'conversational-text-to-sql'
|
||||
information_extraction = 'information-extraction'
|
||||
document_segmentation = 'document-segmentation'
|
||||
feature_extraction = 'feature-extraction'
|
||||
|
||||
@@ -20,7 +20,7 @@ def text2sql_tracking_and_print_results(
|
||||
results = p(case)
|
||||
print({'question': item})
|
||||
print(results)
|
||||
last_sql = results['text']
|
||||
last_sql = results[OutputKeys.OUTPUT][OutputKeys.TEXT]
|
||||
history.append(item)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
addict
|
||||
attrs
|
||||
datasets
|
||||
easydict
|
||||
einops
|
||||
|
||||
@@ -16,7 +16,7 @@ from modelscope.utils.test_utils import test_level
|
||||
class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.conversational_text_to_sql
|
||||
self.task = Tasks.table_question_answering
|
||||
self.model_id = 'damo/nlp_star_conversational-text-to-sql'
|
||||
|
||||
model_id = 'damo/nlp_star_conversational-text-to-sql'
|
||||
@@ -66,11 +66,6 @@ class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck):
|
||||
pipelines = [pipeline(task=self.task, model=self.model_id)]
|
||||
text2sql_tracking_and_print_results(self.test_case, pipelines)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_default_model(self):
|
||||
pipelines = [pipeline(task=self.task)]
|
||||
text2sql_tracking_and_print_results(self.test_case, pipelines)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_demo_compatibility(self):
|
||||
self.compatibility_check()
|
||||
|
||||
@@ -12,7 +12,7 @@ from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline
|
||||
from modelscope.preprocessors import TableQuestionAnsweringPreprocessor
|
||||
from modelscope.preprocessors.star3.fields.database import Database
|
||||
from modelscope.preprocessors.space_T_cn.fields.database import Database
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -38,7 +38,7 @@ def tableqa_tracking_and_print_results_with_history(
|
||||
output_dict = p({
|
||||
'question': question,
|
||||
'history_sql': historical_queries
|
||||
})
|
||||
})[OutputKeys.OUTPUT]
|
||||
print('question', question)
|
||||
print('sql text:', output_dict[OutputKeys.SQL_STRING])
|
||||
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
|
||||
@@ -61,7 +61,7 @@ def tableqa_tracking_and_print_results_without_history(
|
||||
}
|
||||
for p in pipelines:
|
||||
for question in test_case['utterance']:
|
||||
output_dict = p({'question': question})
|
||||
output_dict = p({'question': question})[OutputKeys.OUTPUT]
|
||||
print('question', question)
|
||||
print('sql text:', output_dict[OutputKeys.SQL_STRING])
|
||||
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
|
||||
@@ -92,7 +92,7 @@ def tableqa_tracking_and_print_results_with_tableid(
|
||||
'question': question,
|
||||
'table_id': table_id,
|
||||
'history_sql': historical_queries
|
||||
})
|
||||
})[OutputKeys.OUTPUT]
|
||||
print('question', question)
|
||||
print('sql text:', output_dict[OutputKeys.SQL_STRING])
|
||||
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
|
||||
@@ -147,11 +147,6 @@ class TableQuestionAnswering(unittest.TestCase):
|
||||
]
|
||||
tableqa_tracking_and_print_results_with_tableid(pipelines)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_from_task(self):
|
||||
pipelines = [pipeline(Tasks.table_question_answering, self.model_id)]
|
||||
tableqa_tracking_and_print_results_with_history(pipelines)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub_with_other_classes(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
|
||||
Reference in New Issue
Block a user