# Copyright (c) Alibaba, Inc. and its affiliates. import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model from modelscope.models.nlp import SbertForFaqRanking, SbertForFaqRetrieval from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import FaqPipeline from modelscope.preprocessors import FaqPreprocessor from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level class FaqTest(unittest.TestCase): model_id = '/Users/tanfan/Desktop/Workdir/Gitlab/maas/MaaS-lib/.faq_test_model' param = { 'query_set': ['明天星期几', '今天星期六', '今天星期六'], 'support_set': [{ 'text': '今天星期六', 'label': 'label0' }, { 'text': '明天星期几', 'label': 'label1' }] } # @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') # def test_run_with_direct_file_download(self): # cache_path = self.model_id # snapshot_download(self.model_id) # preprocessor = FaqPreprocessor(cache_path) # model = SbertForFaq(cache_path) # pipeline_ins = FaqPipeline(model, preprocessor=preprocessor) # # result = pipeline_ins(self.param) # print(result) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) preprocessor = FaqPreprocessor(model.model_dir) pipeline_ins = pipeline( task=Tasks.faq, model=model, preprocessor=preprocessor) result = pipeline_ins(self.param) print(result) # @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') # def test_run_with_model_name(self): # pipeline_ins = pipeline(task=Tasks.faq, model=self.model_id) # result = pipeline_ins(self.param) # print(result) # @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') # def test_run_with_default_model(self): # pipeline_ins = pipeline(task=Tasks.faq) # print(pipeline_ins(self.param)) if __name__ == '__main__': unittest.main()