From cb0a44eda5d5b7728406ac7bb13271f5b7e8b014 Mon Sep 17 00:00:00 2001 From: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> Date: Fri, 27 Oct 2023 18:20:53 +0800 Subject: [PATCH] Support VLLM in LLMPipeline (#604) --- modelscope/pipelines/accelerate/__init__.py | 0 modelscope/pipelines/accelerate/base.py | 78 +++++++++++++++++++++ modelscope/pipelines/accelerate/vllm.py | 74 +++++++++++++++++++ modelscope/pipelines/nlp/llm_pipeline.py | 46 +++++++++--- modelscope/utils/import_utils.py | 8 +++ 5 files changed, 196 insertions(+), 10 deletions(-) create mode 100644 modelscope/pipelines/accelerate/__init__.py create mode 100644 modelscope/pipelines/accelerate/base.py create mode 100644 modelscope/pipelines/accelerate/vllm.py diff --git a/modelscope/pipelines/accelerate/__init__.py b/modelscope/pipelines/accelerate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/accelerate/base.py b/modelscope/pipelines/accelerate/base.py new file mode 100644 index 00000000..05c1428b --- /dev/null +++ b/modelscope/pipelines/accelerate/base.py @@ -0,0 +1,78 @@ +import os.path +from abc import abstractmethod +from typing import List, Union + +import torch.cuda + +from modelscope import read_config, snapshot_download +from modelscope.utils.config import Config + + +class InferFramework: + + def __init__(self, model_id_or_dir: str, **kwargs): + """ + Args: + model_id_or_dir(`str`): The model id of the modelhub or a local dir containing model files. + """ + if os.path.exists(model_id_or_dir): + self.model_dir = model_id_or_dir + else: + self.model_dir = snapshot_download(model_id_or_dir) + + model_supported = self.model_type_supported(model_id_or_dir) + config: Config = read_config(self.model_dir) + model_type = config.safe_get('model.type') + if model_type is not None: + model_supported = model_supported or self.model_type_supported( + model_type) + config_file = os.path.join(self.model_dir, 'config.json') + if os.path.isfile(config_file): + config = Config.from_file(config_file) + model_type = config.safe_get('model_type') + if model_type is not None: + model_supported = model_supported or self.model_type_supported( + model_type) + + if not model_supported: + raise ValueError( + f'Model accelerating not supported: {model_id_or_dir}') + + @abstractmethod + def __call__(self, prompts: Union[List[str], List[List[int]]], + **kwargs) -> List[str]: + """ + Args: + prompts(`Union[List[str], List[List[int]]]`): + The string batch or the token list batch to input to the model. + Returns: + The answers in list according to the input prompt batch. + """ + pass + + def model_type_supported(self, model_type: str): + return False + + @staticmethod + def check_gpu_compatibility(major_version: int): + """Check the GPU compatibility. + """ + major, _ = torch.cuda.get_device_capability() + return major >= major_version + + @classmethod + def from_pretrained(cls, model_id_or_dir, framework='vllm', **kwargs): + """Instantiate the model wrapped by an accelerate framework. + Args: + model_id_or_dir(`str`): The model id of the modelhub or a local dir containing model files. + framework(`str`): The framework to use. + Returns: + The wrapped model. + """ + if framework == 'vllm': + from .vllm import Vllm + vllm = Vllm(model_id_or_dir, **kwargs) + vllm.llm_framework = framework + return vllm + else: + raise ValueError(f'Framework not supported: {framework}') diff --git a/modelscope/pipelines/accelerate/vllm.py b/modelscope/pipelines/accelerate/vllm.py new file mode 100644 index 00000000..5c11c29b --- /dev/null +++ b/modelscope/pipelines/accelerate/vllm.py @@ -0,0 +1,74 @@ +from typing import List, Union + +from modelscope.pipelines.accelerate.base import InferFramework +from modelscope.utils.import_utils import is_vllm_available + + +class Vllm(InferFramework): + + def __init__(self, + model_id_or_dir: str, + dtype: str = 'auto', + quantization: str = None, + tensor_parallel_size: int = 1): + """ + Args: + dtype: The dtype to use, support `auto`, `float16`, `bfloat16`, `float32` + quantization: The quantization bit, default None means do not do any quantization. + tensor_parallel_size: The tensor parallel size. + """ + super().__init__(model_id_or_dir) + if not is_vllm_available(): + raise ImportError( + 'Install vllm by `pip install vllm` before using vllm to accelerate inference' + ) + + from vllm import LLM + if not Vllm.check_gpu_compatibility(8) and (dtype + in ('bfloat16', 'auto')): + dtype = 'float16' + self.model = LLM( + self.model_dir, + dtype=dtype, + quantization=quantization, + trust_remote_code=True, + tensor_parallel_size=tensor_parallel_size) + + def __call__(self, prompts: Union[List[str], List[List[int]]], + **kwargs) -> List[str]: + """Generate tokens. + Args: + prompts(`Union[List[str], List[List[int]]]`): + The string batch or the token list batch to input to the model. + kwargs: Sampling parameters. + """ + from vllm import SamplingParams + sampling_params = SamplingParams(**kwargs) + if isinstance(prompts[0], str): + return [ + output.outputs[0].text for output in self.model.generate( + prompts, sampling_params=sampling_params) + ] + else: + return [ + output.outputs[0].text for output in self.model.generate( + prompt_token_ids=prompts, sampling_params=sampling_params) + ] + + def model_type_supported(self, model_type: str): + return any([ + model in model_type.lower() for model in [ + 'llama', + 'baichuan', + 'internlm', + 'mistral', + 'aquila', + 'bloom', + 'falcon', + 'gpt', + 'mpt', + 'opt', + 'qwen', + 'aquila', + ] + ]) diff --git a/modelscope/pipelines/nlp/llm_pipeline.py b/modelscope/pipelines/nlp/llm_pipeline.py index e2f669d8..5976c4b5 100644 --- a/modelscope/pipelines/nlp/llm_pipeline.py +++ b/modelscope/pipelines/nlp/llm_pipeline.py @@ -30,6 +30,10 @@ class LLMPipeline(Pipeline): if isinstance(model, str): logger.info(f'initiate model from {model}') if self._is_swift_model(model): + if self.llm_framework is not None: + logger.warning( + f'Cannot using swift with llm_framework, ignoring {self.llm_framework}.' + ) from swift import Swift base_model = self.cfg.safe_get('adapter_cfg.model_id_or_path') @@ -45,9 +49,15 @@ class LLMPipeline(Pipeline): trust_remote_code=True) swift_model = Swift.from_pretrained(base_model, model_id=model) return swift_model + if isinstance(model, str) and is_official_hub_path(model): logger.info(f'initiate model from location {model}.') - if is_model(model): + if self.llm_framework is not None: + model_dir = model if os.path.exists( + model) else snapshot_download(model) + return self._wrap_infer_framework(model_dir, + self.llm_framework) + elif is_model(model): return Model.from_pretrained( model, invoked_by=Invoke.PIPELINE, @@ -82,13 +92,19 @@ class LLMPipeline(Pipeline): self.cfg = Config.from_file(cfg_file) return self.cfg.safe_get('adapter_cfg.tuner_backend') == 'swift' + def _wrap_infer_framework(self, model_dir, framework='vllm'): + from modelscope.pipelines.accelerate.base import InferFramework + return InferFramework.from_pretrained(model_dir, framework) + def __init__(self, format_messages: Union[Callable, str] = None, format_output: Callable = None, tokenizer: PreTrainedTokenizer = None, + llm_framework: str = None, *args, **kwargs): self.device_map = kwargs.pop('device_map', None) + self.llm_framework = llm_framework # TODO: qwen-int4 need 'cuda'/'auto' device_map. if not self.device_map and 'qwen' in kwargs['model'].lower(): self.device_map = 'cuda' @@ -139,15 +155,20 @@ class LLMPipeline(Pipeline): is_messages = isinstance(inputs, dict) and 'messages' in inputs tokens = self.preprocess(inputs, is_messages, **preprocess_params) - if hasattr(self.model, 'generate'): - outputs = self.model.generate(**tokens, **forward_params) - elif hasattr(self.model, 'model') and hasattr(self.model.model, - 'generate'): - outputs = self.model.model.generate(**tokens, **forward_params) + if self.llm_framework is None: + if hasattr(self.model, 'generate'): + outputs = self.model.generate(**tokens, **forward_params) + elif hasattr(self.model, 'model') and hasattr( + self.model.model, 'generate'): + outputs = self.model.model.generate(**tokens, **forward_params) + else: + raise ValueError('model does not support `generate`!') else: - raise ValueError('model does not support `generate`!') + tokens = [list(tokens['inputs'].flatten().numpy())] + outputs = self.model(tokens, **forward_params)[0] - outputs = outputs.tolist()[0][len(tokens['inputs'][0]):] + if not isinstance(outputs, str): + outputs = outputs.tolist()[0][len(tokens['inputs'][0]):] response = self.postprocess(outputs, is_messages, **postprocess_params) return response @@ -165,14 +186,19 @@ class LLMPipeline(Pipeline): elif hasattr(self.model, 'model') and hasattr(self.model.model, 'device'): device = self.model.model.device + elif hasattr(self.model, 'llm_framework'): + device = 'cpu' else: raise ValueError('model does not have `device` attribute!') return {k: v.to(device) for k, v in tokens.items()} def postprocess(self, outputs, is_messages: bool, **kwargs): - response = self.tokenizer.decode( - outputs, skip_special_tokens=True, **kwargs) + if not isinstance(outputs, str): + response = self.tokenizer.decode( + outputs, skip_special_tokens=True, **kwargs) + else: + response = outputs if is_messages: response = self.format_output(response, **kwargs) else: diff --git a/modelscope/utils/import_utils.py b/modelscope/utils/import_utils.py index 1910039a..2de8770f 100644 --- a/modelscope/utils/import_utils.py +++ b/modelscope/utils/import_utils.py @@ -275,6 +275,14 @@ def is_espnet_available(pkg_name): and importlib.util.find_spec('espnet') +def is_vllm_available(): + return importlib.util.find_spec('vllm') is not None + + +def is_tensorrt_llm_available(): + return importlib.util.find_spec('tensorrt_llm') is not None + + REQUIREMENTS_MAAPING = OrderedDict([ ('protobuf', (is_protobuf_available, PROTOBUF_IMPORT_ERROR)), ('sentencepiece', (is_sentencepiece_available,