mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-21 10:39:24 +01:00
add test cases
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
@@ -89,8 +90,22 @@ class DialogModelingTest(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
|
||||
def generate_and_print_dialog_response(
|
||||
self, pipelines: List[DialogModelingPipeline]):
|
||||
|
||||
result = {}
|
||||
for step, item in enumerate(self.test_case['sng0073']['log']):
|
||||
user = item['user']
|
||||
print('user: {}'.format(user))
|
||||
|
||||
result = pipelines[step % 2]({
|
||||
'user_input': user,
|
||||
'history': result
|
||||
})
|
||||
print('response : {}'.format(result['response']))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run(self):
|
||||
def test_run_by_direct_model_download(self):
|
||||
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
|
||||
@@ -106,17 +121,7 @@ class DialogModelingTest(unittest.TestCase):
|
||||
model=model,
|
||||
preprocessor=preprocessor)
|
||||
]
|
||||
|
||||
result = {}
|
||||
for step, item in enumerate(self.test_case['sng0073']['log']):
|
||||
user = item['user']
|
||||
print('user: {}'.format(user))
|
||||
|
||||
result = pipelines[step % 2]({
|
||||
'user_input': user,
|
||||
'history': result
|
||||
})
|
||||
print('response : {}'.format(result['response']))
|
||||
self.generate_and_print_dialog_response(pipelines)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
@@ -131,16 +136,23 @@ class DialogModelingTest(unittest.TestCase):
|
||||
preprocessor=preprocessor)
|
||||
]
|
||||
|
||||
result = {}
|
||||
for step, item in enumerate(self.test_case['sng0073']['log']):
|
||||
user = item['user']
|
||||
print('user: {}'.format(user))
|
||||
self.generate_and_print_dialog_response(pipelines)
|
||||
|
||||
result = pipelines[step % 2]({
|
||||
'user_input': user,
|
||||
'history': result
|
||||
})
|
||||
print('response : {}'.format(result['response']))
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_name(self):
|
||||
pipelines = [
|
||||
pipeline(task=Tasks.dialog_modeling, model=self.model_id),
|
||||
pipeline(task=Tasks.dialog_modeling, model=self.model_id)
|
||||
]
|
||||
self.generate_and_print_dialog_response(pipelines)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_default_model(self):
|
||||
pipelines = [
|
||||
pipeline(task=Tasks.dialog_modeling),
|
||||
pipeline(task=Tasks.dialog_modeling)
|
||||
]
|
||||
self.generate_and_print_dialog_response(pipelines)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user