mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-21 18:49:23 +01:00
add test cases
This commit is contained in:
@@ -63,15 +63,16 @@ class SpaceForDialogIntent(Model):
|
|||||||
"""return the result by the model
|
"""return the result by the model
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Dict[str, Any]): the preprocessed data
|
input (Dict[str, Tensor]): the preprocessed data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, np.ndarray]: results
|
Dict[str, Tensor]: results
|
||||||
Example:
|
Example:
|
||||||
{
|
{
|
||||||
'predictions': array([1]), # lable 0-negative 1-positive
|
'pred': array([2.62349960e-03 4.12110658e-03 4.12748595e-05 3.77560973e-05
|
||||||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
|
1.08599677e-04 1.72710388e-05 2.95618793e-05 1.93638436e-04
|
||||||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
|
6.45841064e-05 1.15997791e-04 5.11605394e-05 9.87020373e-01
|
||||||
|
2.66957268e-05 4.72324500e-05 9.74208378e-05], dtype=float32)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@@ -62,15 +62,17 @@ class SpaceForDialogModeling(Model):
|
|||||||
"""return the result by the model
|
"""return the result by the model
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Dict[str, Any]): the preprocessed data
|
input (Dict[str, Tensor]): the preprocessed data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, np.ndarray]: results
|
Dict[str, Tensor]: results
|
||||||
Example:
|
Example:
|
||||||
{
|
{
|
||||||
'predictions': array([1]), # lable 0-negative 1-positive
|
'labels': array([1,192,321,12]), # lable
|
||||||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
|
'resp': array([293,1023,123,1123]), #vocab label for response
|
||||||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
|
'bspn': array([123,321,2,24,1 ]),
|
||||||
|
'aspn': array([47,8345,32,29,1983]),
|
||||||
|
'db': array([19, 24, 20]),
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import os
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
from ....metainfo import Models
|
||||||
from ....utils.nlp.space.utils_dst import batch_to_device
|
from ....utils.nlp.space.utils_dst import batch_to_device
|
||||||
from ...base import Model, Tensor
|
from ...base import Model, Tensor
|
||||||
from ...builder import MODELS
|
from ...builder import MODELS
|
||||||
@@ -9,7 +10,7 @@ from ...builder import MODELS
|
|||||||
__all__ = ['SpaceForDialogStateTracking']
|
__all__ = ['SpaceForDialogStateTracking']
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module(Tasks.dialog_state_tracking, module_name=r'space')
|
@MODELS.register_module(Tasks.dialog_state_tracking, module_name=Models.space)
|
||||||
class SpaceForDialogStateTracking(Model):
|
class SpaceForDialogStateTracking(Model):
|
||||||
|
|
||||||
def __init__(self, model_dir: str, *args, **kwargs):
|
def __init__(self, model_dir: str, *args, **kwargs):
|
||||||
@@ -17,8 +18,6 @@ class SpaceForDialogStateTracking(Model):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_dir (str): the model path.
|
model_dir (str): the model path.
|
||||||
model_cls (Optional[Any], optional): model loader, if None, use the
|
|
||||||
default loader to load model weights, by default None.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__(model_dir, *args, **kwargs)
|
super().__init__(model_dir, *args, **kwargs)
|
||||||
@@ -27,7 +26,6 @@ class SpaceForDialogStateTracking(Model):
|
|||||||
self.model_dir = model_dir
|
self.model_dir = model_dir
|
||||||
|
|
||||||
self.config = SpaceConfig.from_pretrained(self.model_dir)
|
self.config = SpaceConfig.from_pretrained(self.model_dir)
|
||||||
# self.model = SpaceForDST(self.config)
|
|
||||||
self.model = SpaceForDST.from_pretrained(self.model_dir)
|
self.model = SpaceForDST.from_pretrained(self.model_dir)
|
||||||
self.model.to(self.config.device)
|
self.model.to(self.config.device)
|
||||||
|
|
||||||
@@ -35,15 +33,20 @@ class SpaceForDialogStateTracking(Model):
|
|||||||
"""return the result by the model
|
"""return the result by the model
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Dict[str, Any]): the preprocessed data
|
input (Dict[str, Tensor]): the preprocessed data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, np.ndarray]: results
|
Dict[str, Tensor]: results
|
||||||
Example:
|
Example:
|
||||||
{
|
{
|
||||||
'predictions': array([1]), # lable 0-negative 1-positive
|
'inputs': dict(input_ids, input_masks,start_pos), # tracking states
|
||||||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
|
'outputs': dict(slots_logits),
|
||||||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
|
'unique_ids': str(test-example.json-0), # default value
|
||||||
|
'input_ids_unmasked': array([101, 7632, 1010,0,0,0])
|
||||||
|
'values': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]),
|
||||||
|
'inform': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]),
|
||||||
|
'prefix': str('final'), #default value
|
||||||
|
'ds': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}])
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -88,8 +91,6 @@ class SpaceForDialogStateTracking(Model):
|
|||||||
if u != 0:
|
if u != 0:
|
||||||
diag_state[slot][i] = u
|
diag_state[slot][i] = u
|
||||||
|
|
||||||
# print(outputs)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'inputs': inputs,
|
'inputs': inputs,
|
||||||
'outputs': outputs,
|
'outputs': outputs,
|
||||||
|
|||||||
@@ -41,6 +41,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
|||||||
'damo/nlp_space_dialog-intent-prediction'),
|
'damo/nlp_space_dialog-intent-prediction'),
|
||||||
Tasks.dialog_modeling: (Pipelines.dialog_modeling,
|
Tasks.dialog_modeling: (Pipelines.dialog_modeling,
|
||||||
'damo/nlp_space_dialog-modeling'),
|
'damo/nlp_space_dialog-modeling'),
|
||||||
|
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking,
|
||||||
|
'damo/nlp_space_dialog-state-tracking'),
|
||||||
Tasks.image_captioning: (Pipelines.image_caption,
|
Tasks.image_captioning: (Pipelines.image_caption,
|
||||||
'damo/ofa_image-caption_coco_large_en'),
|
'damo/ofa_image-caption_coco_large_en'),
|
||||||
Tasks.image_generation:
|
Tasks.image_generation:
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
from ...metainfo import Pipelines
|
from ...metainfo import Pipelines
|
||||||
|
from ...models import Model
|
||||||
from ...models.nlp import SpaceForDialogIntent
|
from ...models.nlp import SpaceForDialogIntent
|
||||||
from ...preprocessors import DialogIntentPredictionPreprocessor
|
from ...preprocessors import DialogIntentPredictionPreprocessor
|
||||||
from ...utils.constant import Tasks
|
from ...utils.constant import Tasks
|
||||||
@@ -18,17 +19,22 @@ __all__ = ['DialogIntentPredictionPipeline']
|
|||||||
module_name=Pipelines.dialog_intent_prediction)
|
module_name=Pipelines.dialog_intent_prediction)
|
||||||
class DialogIntentPredictionPipeline(Pipeline):
|
class DialogIntentPredictionPipeline(Pipeline):
|
||||||
|
|
||||||
def __init__(self, model: SpaceForDialogIntent,
|
def __init__(self,
|
||||||
preprocessor: DialogIntentPredictionPreprocessor, **kwargs):
|
model: Union[SpaceForDialogIntent, str],
|
||||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
|
preprocessor: DialogIntentPredictionPreprocessor = None,
|
||||||
|
**kwargs):
|
||||||
|
"""use `model` and `preprocessor` to create a dialog intent prediction pipeline
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (SequenceClassificationModel): a model instance
|
model (SpaceForDialogIntent): a model instance
|
||||||
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
|
preprocessor (DialogIntentPredictionPreprocessor): a preprocessor instance
|
||||||
"""
|
"""
|
||||||
|
model = model if isinstance(
|
||||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
model, SpaceForDialogIntent) else Model.from_pretrained(model)
|
||||||
|
if preprocessor is None:
|
||||||
|
preprocessor = DialogIntentPredictionPreprocessor(model.model_dir)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||||
self.categories = preprocessor.categories
|
self.categories = preprocessor.categories
|
||||||
|
|
||||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
from ...metainfo import Pipelines
|
from ...metainfo import Pipelines
|
||||||
|
from ...models import Model
|
||||||
from ...models.nlp import SpaceForDialogModeling
|
from ...models.nlp import SpaceForDialogModeling
|
||||||
from ...preprocessors import DialogModelingPreprocessor
|
from ...preprocessors import DialogModelingPreprocessor
|
||||||
from ...utils.constant import Tasks
|
from ...utils.constant import Tasks
|
||||||
@@ -17,17 +18,22 @@ __all__ = ['DialogModelingPipeline']
|
|||||||
Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling)
|
Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling)
|
||||||
class DialogModelingPipeline(Pipeline):
|
class DialogModelingPipeline(Pipeline):
|
||||||
|
|
||||||
def __init__(self, model: SpaceForDialogModeling,
|
def __init__(self,
|
||||||
preprocessor: DialogModelingPreprocessor, **kwargs):
|
model: Union[SpaceForDialogModeling, str],
|
||||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
|
preprocessor: DialogModelingPreprocessor = None,
|
||||||
|
**kwargs):
|
||||||
|
"""use `model` and `preprocessor` to create a dialog modleing pipeline for dialog response generation
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (SequenceClassificationModel): a model instance
|
model (SpaceForDialogModeling): a model instance
|
||||||
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
|
preprocessor (DialogModelingPreprocessor): a preprocessor instance
|
||||||
"""
|
"""
|
||||||
|
model = model if isinstance(
|
||||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
model, SpaceForDialogModeling) else Model.from_pretrained(model)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
if preprocessor is None:
|
||||||
|
preprocessor = DialogModelingPreprocessor(model.model_dir)
|
||||||
|
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
|
|
||||||
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]:
|
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
from ...metainfo import Pipelines
|
from ...metainfo import Pipelines
|
||||||
from ...models import SpaceForDialogStateTracking
|
from ...models import Model, SpaceForDialogStateTracking
|
||||||
from ...preprocessors import DialogStateTrackingPreprocessor
|
from ...preprocessors import DialogStateTrackingPreprocessor
|
||||||
from ...utils.constant import Tasks
|
from ...utils.constant import Tasks
|
||||||
from ..base import Pipeline
|
from ..base import Pipeline
|
||||||
@@ -15,17 +15,26 @@ __all__ = ['DialogStateTrackingPipeline']
|
|||||||
Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking)
|
Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking)
|
||||||
class DialogStateTrackingPipeline(Pipeline):
|
class DialogStateTrackingPipeline(Pipeline):
|
||||||
|
|
||||||
def __init__(self, model: SpaceForDialogStateTracking,
|
def __init__(self,
|
||||||
preprocessor: DialogStateTrackingPreprocessor, **kwargs):
|
model: Union[SpaceForDialogStateTracking, str],
|
||||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
|
preprocessor: DialogStateTrackingPreprocessor = None,
|
||||||
|
**kwargs):
|
||||||
|
"""use `model` and `preprocessor` to create a dialog state tracking pipeline for
|
||||||
|
observation of dialog states tracking after many turns of open domain dialogue
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (SequenceClassificationModel): a model instance
|
model (SpaceForDialogStateTracking): a model instance
|
||||||
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
|
preprocessor (DialogStateTrackingPreprocessor): a preprocessor instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
model = model if isinstance(
|
||||||
|
model,
|
||||||
|
SpaceForDialogStateTracking) else Model.from_pretrained(model)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
if preprocessor is None:
|
||||||
|
preprocessor = DialogStateTrackingPreprocessor(model.model_dir)
|
||||||
|
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||||
|
|
||||||
self.tokenizer = preprocessor.tokenizer
|
self.tokenizer = preprocessor.tokenizer
|
||||||
self.config = preprocessor.config
|
self.config = preprocessor.config
|
||||||
|
|
||||||
@@ -46,9 +55,7 @@ class DialogStateTrackingPipeline(Pipeline):
|
|||||||
values = inputs['values']
|
values = inputs['values']
|
||||||
inform = inputs['inform']
|
inform = inputs['inform']
|
||||||
prefix = inputs['prefix']
|
prefix = inputs['prefix']
|
||||||
# ds = {slot: 'none' for slot in self.config.dst_slot_list}
|
|
||||||
ds = inputs['ds']
|
ds = inputs['ds']
|
||||||
|
|
||||||
ds = predict_and_format(self.config, self.tokenizer, _inputs,
|
ds = predict_and_format(self.config, self.tokenizer, _inputs,
|
||||||
_outputs[2], _outputs[3], _outputs[4],
|
_outputs[2], _outputs[3], _outputs[4],
|
||||||
_outputs[5], unique_ids, input_ids_unmasked,
|
_outputs[5], unique_ids, input_ids_unmasked,
|
||||||
|
|||||||
@@ -138,13 +138,6 @@ TASK_OUTPUTS = {
|
|||||||
# }
|
# }
|
||||||
Tasks.sentiment_classification: [OutputKeys.SCORES, OutputKeys.LABELS],
|
Tasks.sentiment_classification: [OutputKeys.SCORES, OutputKeys.LABELS],
|
||||||
|
|
||||||
# sentiment classification result for single sample
|
|
||||||
# {
|
|
||||||
# "labels": ["happy", "sad", "calm", "angry"],
|
|
||||||
# "scores": [0.9, 0.1, 0.05, 0.05]
|
|
||||||
# }
|
|
||||||
Tasks.sentiment_classification: ['scores', 'labels'],
|
|
||||||
|
|
||||||
# zero-shot classification result for single sample
|
# zero-shot classification result for single sample
|
||||||
# {
|
# {
|
||||||
# "scores": [0.9, 0.1, 0.05, 0.05]
|
# "scores": [0.9, 0.1, 0.05, 0.05]
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class DialogIntentPredictionTest(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||||
def test_run(self):
|
def test_run_by_direct_model_download(self):
|
||||||
cache_path = snapshot_download(self.model_id)
|
cache_path = snapshot_download(self.model_id)
|
||||||
preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path)
|
preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path)
|
||||||
model = SpaceForDialogIntent(
|
model = SpaceForDialogIntent(
|
||||||
@@ -56,6 +56,20 @@ class DialogIntentPredictionTest(unittest.TestCase):
|
|||||||
for my_pipeline, item in list(zip(pipelines, self.test_case)):
|
for my_pipeline, item in list(zip(pipelines, self.test_case)):
|
||||||
print(my_pipeline(item))
|
print(my_pipeline(item))
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
|
def test_run_with_model_name(self):
|
||||||
|
pipelines = [
|
||||||
|
pipeline(task=Tasks.dialog_intent_prediction, model=self.model_id)
|
||||||
|
]
|
||||||
|
for my_pipeline, item in list(zip(pipelines, self.test_case)):
|
||||||
|
print(my_pipeline(item))
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||||
|
def test_run_with_default_model(self):
|
||||||
|
pipelines = [pipeline(task=Tasks.dialog_intent_prediction)]
|
||||||
|
for my_pipeline, item in list(zip(pipelines, self.test_case)):
|
||||||
|
print(my_pipeline(item))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from modelscope.hub.snapshot_download import snapshot_download
|
from modelscope.hub.snapshot_download import snapshot_download
|
||||||
from modelscope.models import Model
|
from modelscope.models import Model
|
||||||
@@ -89,8 +90,22 @@ class DialogModelingTest(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def generate_and_print_dialog_response(
|
||||||
|
self, pipelines: List[DialogModelingPipeline]):
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for step, item in enumerate(self.test_case['sng0073']['log']):
|
||||||
|
user = item['user']
|
||||||
|
print('user: {}'.format(user))
|
||||||
|
|
||||||
|
result = pipelines[step % 2]({
|
||||||
|
'user_input': user,
|
||||||
|
'history': result
|
||||||
|
})
|
||||||
|
print('response : {}'.format(result['response']))
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||||
def test_run(self):
|
def test_run_by_direct_model_download(self):
|
||||||
|
|
||||||
cache_path = snapshot_download(self.model_id)
|
cache_path = snapshot_download(self.model_id)
|
||||||
|
|
||||||
@@ -106,17 +121,7 @@ class DialogModelingTest(unittest.TestCase):
|
|||||||
model=model,
|
model=model,
|
||||||
preprocessor=preprocessor)
|
preprocessor=preprocessor)
|
||||||
]
|
]
|
||||||
|
self.generate_and_print_dialog_response(pipelines)
|
||||||
result = {}
|
|
||||||
for step, item in enumerate(self.test_case['sng0073']['log']):
|
|
||||||
user = item['user']
|
|
||||||
print('user: {}'.format(user))
|
|
||||||
|
|
||||||
result = pipelines[step % 2]({
|
|
||||||
'user_input': user,
|
|
||||||
'history': result
|
|
||||||
})
|
|
||||||
print('response : {}'.format(result['response']))
|
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
def test_run_with_model_from_modelhub(self):
|
def test_run_with_model_from_modelhub(self):
|
||||||
@@ -131,16 +136,23 @@ class DialogModelingTest(unittest.TestCase):
|
|||||||
preprocessor=preprocessor)
|
preprocessor=preprocessor)
|
||||||
]
|
]
|
||||||
|
|
||||||
result = {}
|
self.generate_and_print_dialog_response(pipelines)
|
||||||
for step, item in enumerate(self.test_case['sng0073']['log']):
|
|
||||||
user = item['user']
|
|
||||||
print('user: {}'.format(user))
|
|
||||||
|
|
||||||
result = pipelines[step % 2]({
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
'user_input': user,
|
def test_run_with_model_name(self):
|
||||||
'history': result
|
pipelines = [
|
||||||
})
|
pipeline(task=Tasks.dialog_modeling, model=self.model_id),
|
||||||
print('response : {}'.format(result['response']))
|
pipeline(task=Tasks.dialog_modeling, model=self.model_id)
|
||||||
|
]
|
||||||
|
self.generate_and_print_dialog_response(pipelines)
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||||
|
def test_run_with_default_model(self):
|
||||||
|
pipelines = [
|
||||||
|
pipeline(task=Tasks.dialog_modeling),
|
||||||
|
pipeline(task=Tasks.dialog_modeling)
|
||||||
|
]
|
||||||
|
self.generate_and_print_dialog_response(pipelines)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from modelscope.hub.snapshot_download import snapshot_download
|
from modelscope.hub.snapshot_download import snapshot_download
|
||||||
from modelscope.models import Model, SpaceForDialogStateTracking
|
from modelscope.models import Model, SpaceForDialogStateTracking
|
||||||
@@ -75,23 +76,10 @@ class DialogStateTrackingTest(unittest.TestCase):
|
|||||||
'User-8': 'Thank you, goodbye',
|
'User-8': 'Thank you, goodbye',
|
||||||
}]
|
}]
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
def tracking_and_print_dialog_states(
|
||||||
def test_run(self):
|
self, pipelines: List[DialogStateTrackingPipeline]):
|
||||||
cache_path = snapshot_download(self.model_id)
|
|
||||||
|
|
||||||
model = SpaceForDialogStateTracking(cache_path)
|
|
||||||
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
|
|
||||||
pipelines = [
|
|
||||||
DialogStateTrackingPipeline(
|
|
||||||
model=model, preprocessor=preprocessor),
|
|
||||||
pipeline(
|
|
||||||
task=Tasks.dialog_state_tracking,
|
|
||||||
model=model,
|
|
||||||
preprocessor=preprocessor)
|
|
||||||
]
|
|
||||||
|
|
||||||
pipelines_len = len(pipelines)
|
|
||||||
import json
|
import json
|
||||||
|
pipelines_len = len(pipelines)
|
||||||
history_states = [{}]
|
history_states = [{}]
|
||||||
utter = {}
|
utter = {}
|
||||||
for step, item in enumerate(self.test_case):
|
for step, item in enumerate(self.test_case):
|
||||||
@@ -106,6 +94,22 @@ class DialogStateTrackingTest(unittest.TestCase):
|
|||||||
|
|
||||||
history_states.extend([result['dialog_states'], {}])
|
history_states.extend([result['dialog_states'], {}])
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||||
|
def test_run_by_direct_model_download(self):
|
||||||
|
cache_path = snapshot_download(self.model_id)
|
||||||
|
|
||||||
|
model = SpaceForDialogStateTracking(cache_path)
|
||||||
|
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
|
||||||
|
pipelines = [
|
||||||
|
DialogStateTrackingPipeline(
|
||||||
|
model=model, preprocessor=preprocessor),
|
||||||
|
pipeline(
|
||||||
|
task=Tasks.dialog_state_tracking,
|
||||||
|
model=model,
|
||||||
|
preprocessor=preprocessor)
|
||||||
|
]
|
||||||
|
self.tracking_and_print_dialog_states(pipelines)
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
def test_run_with_model_from_modelhub(self):
|
def test_run_with_model_from_modelhub(self):
|
||||||
model = Model.from_pretrained(self.model_id)
|
model = Model.from_pretrained(self.model_id)
|
||||||
@@ -120,21 +124,19 @@ class DialogStateTrackingTest(unittest.TestCase):
|
|||||||
preprocessor=preprocessor)
|
preprocessor=preprocessor)
|
||||||
]
|
]
|
||||||
|
|
||||||
pipelines_len = len(pipelines)
|
self.tracking_and_print_dialog_states(pipelines)
|
||||||
import json
|
|
||||||
history_states = [{}]
|
|
||||||
utter = {}
|
|
||||||
for step, item in enumerate(self.test_case):
|
|
||||||
utter.update(item)
|
|
||||||
result = pipelines[step % pipelines_len]({
|
|
||||||
'utter':
|
|
||||||
utter,
|
|
||||||
'history_states':
|
|
||||||
history_states
|
|
||||||
})
|
|
||||||
print(json.dumps(result))
|
|
||||||
|
|
||||||
history_states.extend([result['dialog_states'], {}])
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
|
def test_run_with_model_name(self):
|
||||||
|
pipelines = [
|
||||||
|
pipeline(task=Tasks.dialog_state_tracking, model=self.model_id)
|
||||||
|
]
|
||||||
|
self.tracking_and_print_dialog_states(pipelines)
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||||
|
def test_run_with_default_model(self):
|
||||||
|
pipelines = [pipeline(task=Tasks.dialog_state_tracking)]
|
||||||
|
self.tracking_and_print_dialog_states(pipelines)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user