diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index a8b93cc3..ea56efb5 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -514,6 +514,7 @@ class Pipelines(object): document_grounded_dialog_generate = 'document-grounded-dialog-generate' language_identification = 'language_identification' machine_reading_comprehension_for_ner = 'machine-reading-comprehension-for-ner' + llm = 'llm' # audio tasks sambert_hifigan_tts = 'sambert-hifigan-tts' diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 9f225383..a3b65812 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -117,6 +117,7 @@ class Model(ABC): else: invoked_by = Invoke.PRETRAINED + ignore_file_pattern = kwargs.pop('ignore_file_pattern', None) if osp.exists(model_name_or_path): local_model_dir = model_name_or_path else: @@ -126,7 +127,6 @@ class Model(ABC): ) invoked_by = '%s/%s' % (Invoke.KEY, invoked_by) - ignore_file_pattern = kwargs.pop('ignore_file_pattern', None) local_model_dir = snapshot_download( model_name_or_path, revision, diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index bffbebbd..d97a95f9 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -325,13 +325,9 @@ TASK_INPUTS = { }, # ============ nlp tasks =================== - Tasks.chat: [ - InputType.TEXT, - { - 'text': InputType.TEXT, - 'history': InputType.LIST, - } - ], + Tasks.chat: { + 'messages': InputType.LIST + }, Tasks.text_classification: [ InputType.TEXT, (InputType.TEXT, InputType.TEXT), diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index d6dff693..525bc92c 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -1,14 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import os.path as osp from typing import List, Optional, Union +from modelscope.hub.file_download import model_file_download from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE, Pipelines from modelscope.models.base import Model -from modelscope.utils.config import ConfigDict, check_config +from modelscope.utils.config import Config, ConfigDict, check_config from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, - ThirdParty) + ModelFile, ThirdParty) from modelscope.utils.hub import read_config from modelscope.utils.plugins import (register_modelhub_repo, register_plugins_repo) @@ -117,6 +119,8 @@ def pipeline(task: str = None, model_revision, third_party=third_party, ignore_file_pattern=ignore_file_pattern) + if pipeline_name is None and kwargs.get('llm_first'): + pipeline_name = llm_first_checker(model, model_revision) pipeline_props = {'type': pipeline_name} if pipeline_name is None: # get default pipeline for this task @@ -196,3 +200,39 @@ def get_default_pipeline_info(task): else: pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] return pipeline_name, default_model + + +def llm_first_checker(model: Union[str, List[str], Model, List[Model]], + revision: Optional[str]) -> Optional[str]: + from modelscope.pipelines.nlp.llm_pipeline import LLM_FORMAT_MAP + + def get_file_name(model: str, cfg_name: str, + revision: Optional[str]) -> Optional[str]: + if osp.exists(model): + return osp.join(model, cfg_name) + try: + return model_file_download(model, cfg_name, revision=revision) + except Exception: + return None + + def parse_model_type(file: Optional[str], pattern: str) -> Optional[str]: + if file is None or not osp.exists(file): + return None + return Config.from_file(file).safe_get(pattern) + + def get_model_type(model: str, revision: Optional[str]) -> Optional[str]: + cfg_file = get_file_name(model, ModelFile.CONFIGURATION, revision) + hf_cfg_file = get_file_name(model, ModelFile.CONFIG, revision) + cfg_model_type = parse_model_type(cfg_file, 'model.type') + hf_cfg_model_type = parse_model_type(hf_cfg_file, 'model_type') + return cfg_model_type or hf_cfg_model_type + + if isinstance(model, list): + model = model[0] + if not isinstance(model, str): + model = model.model_dir + model_type = get_model_type(model, revision) + if model_type is not None: + model_type = model_type.lower().split('-')[0] + if model_type in LLM_FORMAT_MAP: + return 'llm' diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 23473007..df7e2068 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -48,6 +48,7 @@ if TYPE_CHECKING: from .document_grounded_dialog_rerank_pipeline import DocumentGroundedDialogRerankPipeline from .language_identification_pipline import LanguageIdentificationPipeline from .machine_reading_comprehension_pipeline import MachineReadingComprehensionForNERPipeline + from .llm_pipeline import LLMPipeline else: _import_structure = { @@ -119,6 +120,7 @@ else: 'machine_reading_comprehension_pipeline': [ 'MachineReadingComprehensionForNERPipeline' ], + 'llm_pipeline': ['LLMPipeline'], } import sys diff --git a/modelscope/pipelines/nlp/llm_pipeline.py b/modelscope/pipelines/nlp/llm_pipeline.py index 63fc55ea..e2979ccb 100644 --- a/modelscope/pipelines/nlp/llm_pipeline.py +++ b/modelscope/pipelines/nlp/llm_pipeline.py @@ -11,6 +11,7 @@ from modelscope import (AutoModelForCausalLM, AutoTokenizer, Pipeline, snapshot_download) from modelscope.models.base import Model from modelscope.models.nlp import ChatGLM2Tokenizer, Llama2Tokenizer +from modelscope.outputs import OutputKeys from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.util import is_model, is_official_hub_path from modelscope.utils.constant import Invoke, ModelFile, Tasks @@ -19,7 +20,8 @@ from modelscope.utils.logger import get_logger logger = get_logger() -@PIPELINES.register_module(Tasks.chat, module_name='llm-pipeline') +@PIPELINES.register_module(Tasks.chat, module_name='llm') +@PIPELINES.register_module(Tasks.text_generation, module_name='llm') class LLMPipeline(Pipeline): def initiate_single_model(self, model): @@ -55,6 +57,9 @@ class LLMPipeline(Pipeline): *args, **kwargs): self.device_map = kwargs.pop('device_map', None) + # TODO: qwen-int4 need 'cuda'/'auto' device_map. + if not self.device_map and 'qwen' in kwargs['model'].lower(): + self.device_map = 'cuda' self.torch_dtype = kwargs.pop('torch_dtype', None) self.ignore_file_pattern = kwargs.pop('ignore_file_pattern', None) with self._temp_configuration_file(kwargs): @@ -138,6 +143,8 @@ class LLMPipeline(Pipeline): outputs, skip_special_tokens=True, **kwargs) if is_messages: response = self.format_output(response, **kwargs) + else: + response = {OutputKeys.TEXT: response} return response @@ -260,7 +267,7 @@ def chatglm2_format_output(response, **kwargs): response = response.replace('[[训练时间]]', '2023年') messages = {'role': 'assistant', 'content': response} outputs = { - 'messages': messages, + 'message': messages, } return outputs diff --git a/modelscope/utils/input_output.py b/modelscope/utils/input_output.py index d8e32cce..679069c1 100644 --- a/modelscope/utils/input_output.py +++ b/modelscope/utils/input_output.py @@ -773,6 +773,7 @@ def pipeline_output_to_service_base64_output(task_name, pipeline_output): pipeline_output = pipeline_output[0] for key, value in pipeline_output.items(): if key not in task_outputs: + json_serializable_output[key] = value continue # skip the output not defined. if key in [ OutputKeys.OUTPUT_IMG, OutputKeys.OUTPUT_IMGS, diff --git a/modelscope/utils/pipeline_inputs.json b/modelscope/utils/pipeline_inputs.json index 2ba31bcc..0cb9c1b1 100644 --- a/modelscope/utils/pipeline_inputs.json +++ b/modelscope/utils/pipeline_inputs.json @@ -16,13 +16,20 @@ }, "chat":{ "input":{ - "text":"你有什么推荐吗?", - "history":[ - [ - "今天天气真好,", - "今天天气真好,出去走走怎么样?" - ] - ] + "messages": [{ + "role": "user", + "content": "Hello! 你是谁?" + }, { + "role": "assistant", + "content": "我是你的助手。" + }, { + "role": "user", + "content": "你叫什么名字?" + }] + }, + "parameters": { + "do_sample": true, + "max_length": 512 } }, "domain-specific-object-detection":{ diff --git a/tests/json_call_test.py b/tests/json_call_test.py index 658c947f..7073a90d 100644 --- a/tests/json_call_test.py +++ b/tests/json_call_test.py @@ -37,12 +37,15 @@ class ModelJsonTest: # init pipeline ppl = pipeline( - task=task, model=model_id, model_revision=model_revision) + task=task, + model=model_id, + model_revision=model_revision, + llm_first=True) pipeline_info = get_pipeline_information_by_pipeline(ppl) # call pipeline data = get_task_input_examples(task) - print(task, data) + infer_result = call_pipeline_with_json(pipeline_info, ppl, data) result = pipeline_output_to_service_base64_output(task, infer_result) return result @@ -50,27 +53,20 @@ class ModelJsonTest: if __name__ == '__main__': model_list = [ - 'damo/nlp_structbert_nli_chinese-base', - 'damo/nlp_structbert_word-segmentation_chinese-base', - 'damo/nlp_structbert_zero-shot-classification_chinese-base', - 'damo/cv_unet_person-image-cartoon_compound-models', - 'damo/nlp_structbert_sentiment-classification_chinese-tiny', - 'damo/nlp_csanmt_translation_zh2en', - 'damo/nlp_rom_passage-ranking_chinese-base', - 'damo/ofa_image-caption_muge_base_zh', - 'damo/nlp_raner_named-entity-recognition_chinese-base-ecom-50cls', - 'damo/nlp_structbert_sentiment-classification_chinese-ecommerce-base', - 'damo/text-to-video-synthesis', - 'qwen/Qwen-7B', - 'qwen/Qwen-7B-Chat', - 'ZhipuAI/ChatGLM-6B', + 'qwen/Qwen-7B-Chat-Int4', + 'qwen/Qwen-14B-Chat-Int4', + 'baichuan-inc/Baichuan2-7B-Chat-4bits', + 'baichuan-inc/Baichuan2-13B-Chat-4bits', + 'ZhipuAI/chatglm2-6b-int4', ] tester = ModelJsonTest() for model in model_list: try: res = tester.test_single(model) - print(f'\nmodel_id {model} call_pipeline_with_json run ok.\n') + print( + f'\nmodel_id {model} call_pipeline_with_json run ok. {res}\n\n\n\n' + ) except BaseException as e: print( - f'\nmodel_id {model} call_pipeline_with_json run failed: {e}.\n' + f'\nmodel_id {model} call_pipeline_with_json run failed: {e}.\n\n\n\n' ) diff --git a/tests/pipelines/test_llm_pipeline.py b/tests/pipelines/test_llm_pipeline.py index 1b6d211a..9b7e832f 100644 --- a/tests/pipelines/test_llm_pipeline.py +++ b/tests/pipelines/test_llm_pipeline.py @@ -3,6 +3,7 @@ import unittest import torch +from modelscope import pipeline from modelscope.pipelines.nlp.llm_pipeline import LLMPipeline from modelscope.utils.test_utils import test_level @@ -132,143 +133,172 @@ class LLMPipelineTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_chatglm2(self): - pipe = LLMPipeline(model='ZhipuAI/chatglm2-6b', device_map='auto') + pipe = pipeline( + task='chat', model='ZhipuAI/chatglm2-6b', llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_chatglm2int4(self): - pipe = LLMPipeline(model='ZhipuAI/chatglm2-6b-int4') + pipe = pipeline( + task='chat', model='ZhipuAI/chatglm2-6b-int4', llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_chatglm232k(self): - pipe = LLMPipeline(model='ZhipuAI/chatglm2-6b-32k', device_map='auto') + pipe = pipeline( + task='chat', model='ZhipuAI/chatglm2-6b-32k', llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_llama2(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='modelscope/Llama-2-7b-ms', torch_dtype=torch.float16, device_map='auto', - ignore_file_pattern=[r'.+\.bin$']) + ignore_file_pattern=[r'.+\.bin$'], + llm_first=True) print('messages: ', pipe(self.messages_en, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_en, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_llama2chat(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='modelscope/Llama-2-7b-chat-ms', revision='v1.0.2', torch_dtype=torch.float16, device_map='auto', - ignore_file_pattern=[r'.+\.bin$']) + ignore_file_pattern=[r'.+\.bin$'], + llm_first=True) print('messages: ', pipe(self.messages_en, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_en, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_codellama(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='AI-ModelScope/CodeLlama-7b-Instruct-hf', torch_dtype=torch.float16, device_map='auto', - ignore_file_pattern=[r'.+\.bin$']) + ignore_file_pattern=[r'.+\.bin$'], + llm_first=True) print('messages: ', pipe(self.messages_code, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_code, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_baichuan_7b(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/baichuan-7B', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_baichuan_13b(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/Baichuan-13B-Base', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_baichuan_13bchat(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/Baichuan-13B-Chat', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_baichuan2_7b(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/Baichuan2-7B-Base', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_baichuan2_7bchat(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/Baichuan2-7B-Chat', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skip('Need bitsandbytes') def test_baichuan2_7bchat_int4(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/Baichuan2-7B-Chat-4bits', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skip('Need bitsandbytes') def test_baichuan2_13bchat_int4(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/Baichuan2-13B-Chat-4bits', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_wizardlm_13b(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='AI-ModelScope/WizardLM-13B-V1.2', device_map='auto', torch_dtype=torch.float16, - format_messages='wizardlm') + format_messages='wizardlm', + llm_first=True) print('messages: ', pipe(self.messages_en, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_en, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_wizardmath(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='AI-ModelScope/WizardMath-7B-V1.0', device_map='auto', torch_dtype=torch.float16, - format_messages='wizardcode') + format_messages='wizardcode', + llm_first=True) print('messages: ', pipe(self.message_wizard_math, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_wizard_math, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_wizardcode_13b(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='AI-ModelScope/WizardCoder-Python-13B-V1.0', device_map='auto', torch_dtype=torch.float16, - format_messages='wizardcode') + format_messages='wizardcode', + llm_first=True) print('messages: ', pipe(self.message_wizard_code, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_wizard_code, **self.gen_cfg)) @@ -284,19 +314,20 @@ class LLMPipelineTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_qwen(self): - pipe = LLMPipeline(model='qwen/Qwen-7B-Chat', device_map='auto') + pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat', llm_first=True) print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skip('Need optimum and auto-gptq') def test_qwen_int4(self): - pipe = LLMPipeline(model='qwen/Qwen-7B-Chat-Int4', device_map='auto') + pipe = pipeline( + task='chat', model='qwen/Qwen-7B-Chat-Int4', llm_first=True) print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_qwen_vl(self): - pipe = LLMPipeline(model='qwen/Qwen-VL-Chat', device_map='auto') + pipe = pipeline(task='chat', model='qwen/Qwen-VL-Chat', llm_first=True) print('messages: ', pipe(self.messages_mm, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))