mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Support VLLM in LLMPipeline (#604)
This commit is contained in:
0
modelscope/pipelines/accelerate/__init__.py
Normal file
0
modelscope/pipelines/accelerate/__init__.py
Normal file
78
modelscope/pipelines/accelerate/base.py
Normal file
78
modelscope/pipelines/accelerate/base.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import os.path
|
||||
from abc import abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
import torch.cuda
|
||||
|
||||
from modelscope import read_config, snapshot_download
|
||||
from modelscope.utils.config import Config
|
||||
|
||||
|
||||
class InferFramework:
|
||||
|
||||
def __init__(self, model_id_or_dir: str, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
model_id_or_dir(`str`): The model id of the modelhub or a local dir containing model files.
|
||||
"""
|
||||
if os.path.exists(model_id_or_dir):
|
||||
self.model_dir = model_id_or_dir
|
||||
else:
|
||||
self.model_dir = snapshot_download(model_id_or_dir)
|
||||
|
||||
model_supported = self.model_type_supported(model_id_or_dir)
|
||||
config: Config = read_config(self.model_dir)
|
||||
model_type = config.safe_get('model.type')
|
||||
if model_type is not None:
|
||||
model_supported = model_supported or self.model_type_supported(
|
||||
model_type)
|
||||
config_file = os.path.join(self.model_dir, 'config.json')
|
||||
if os.path.isfile(config_file):
|
||||
config = Config.from_file(config_file)
|
||||
model_type = config.safe_get('model_type')
|
||||
if model_type is not None:
|
||||
model_supported = model_supported or self.model_type_supported(
|
||||
model_type)
|
||||
|
||||
if not model_supported:
|
||||
raise ValueError(
|
||||
f'Model accelerating not supported: {model_id_or_dir}')
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, prompts: Union[List[str], List[List[int]]],
|
||||
**kwargs) -> List[str]:
|
||||
"""
|
||||
Args:
|
||||
prompts(`Union[List[str], List[List[int]]]`):
|
||||
The string batch or the token list batch to input to the model.
|
||||
Returns:
|
||||
The answers in list according to the input prompt batch.
|
||||
"""
|
||||
pass
|
||||
|
||||
def model_type_supported(self, model_type: str):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def check_gpu_compatibility(major_version: int):
|
||||
"""Check the GPU compatibility.
|
||||
"""
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
return major >= major_version
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_id_or_dir, framework='vllm', **kwargs):
|
||||
"""Instantiate the model wrapped by an accelerate framework.
|
||||
Args:
|
||||
model_id_or_dir(`str`): The model id of the modelhub or a local dir containing model files.
|
||||
framework(`str`): The framework to use.
|
||||
Returns:
|
||||
The wrapped model.
|
||||
"""
|
||||
if framework == 'vllm':
|
||||
from .vllm import Vllm
|
||||
vllm = Vllm(model_id_or_dir, **kwargs)
|
||||
vllm.llm_framework = framework
|
||||
return vllm
|
||||
else:
|
||||
raise ValueError(f'Framework not supported: {framework}')
|
||||
74
modelscope/pipelines/accelerate/vllm.py
Normal file
74
modelscope/pipelines/accelerate/vllm.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import List, Union
|
||||
|
||||
from modelscope.pipelines.accelerate.base import InferFramework
|
||||
from modelscope.utils.import_utils import is_vllm_available
|
||||
|
||||
|
||||
class Vllm(InferFramework):
|
||||
|
||||
def __init__(self,
|
||||
model_id_or_dir: str,
|
||||
dtype: str = 'auto',
|
||||
quantization: str = None,
|
||||
tensor_parallel_size: int = 1):
|
||||
"""
|
||||
Args:
|
||||
dtype: The dtype to use, support `auto`, `float16`, `bfloat16`, `float32`
|
||||
quantization: The quantization bit, default None means do not do any quantization.
|
||||
tensor_parallel_size: The tensor parallel size.
|
||||
"""
|
||||
super().__init__(model_id_or_dir)
|
||||
if not is_vllm_available():
|
||||
raise ImportError(
|
||||
'Install vllm by `pip install vllm` before using vllm to accelerate inference'
|
||||
)
|
||||
|
||||
from vllm import LLM
|
||||
if not Vllm.check_gpu_compatibility(8) and (dtype
|
||||
in ('bfloat16', 'auto')):
|
||||
dtype = 'float16'
|
||||
self.model = LLM(
|
||||
self.model_dir,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=tensor_parallel_size)
|
||||
|
||||
def __call__(self, prompts: Union[List[str], List[List[int]]],
|
||||
**kwargs) -> List[str]:
|
||||
"""Generate tokens.
|
||||
Args:
|
||||
prompts(`Union[List[str], List[List[int]]]`):
|
||||
The string batch or the token list batch to input to the model.
|
||||
kwargs: Sampling parameters.
|
||||
"""
|
||||
from vllm import SamplingParams
|
||||
sampling_params = SamplingParams(**kwargs)
|
||||
if isinstance(prompts[0], str):
|
||||
return [
|
||||
output.outputs[0].text for output in self.model.generate(
|
||||
prompts, sampling_params=sampling_params)
|
||||
]
|
||||
else:
|
||||
return [
|
||||
output.outputs[0].text for output in self.model.generate(
|
||||
prompt_token_ids=prompts, sampling_params=sampling_params)
|
||||
]
|
||||
|
||||
def model_type_supported(self, model_type: str):
|
||||
return any([
|
||||
model in model_type.lower() for model in [
|
||||
'llama',
|
||||
'baichuan',
|
||||
'internlm',
|
||||
'mistral',
|
||||
'aquila',
|
||||
'bloom',
|
||||
'falcon',
|
||||
'gpt',
|
||||
'mpt',
|
||||
'opt',
|
||||
'qwen',
|
||||
'aquila',
|
||||
]
|
||||
])
|
||||
@@ -30,6 +30,10 @@ class LLMPipeline(Pipeline):
|
||||
if isinstance(model, str):
|
||||
logger.info(f'initiate model from {model}')
|
||||
if self._is_swift_model(model):
|
||||
if self.llm_framework is not None:
|
||||
logger.warning(
|
||||
f'Cannot using swift with llm_framework, ignoring {self.llm_framework}.'
|
||||
)
|
||||
from swift import Swift
|
||||
|
||||
base_model = self.cfg.safe_get('adapter_cfg.model_id_or_path')
|
||||
@@ -45,9 +49,15 @@ class LLMPipeline(Pipeline):
|
||||
trust_remote_code=True)
|
||||
swift_model = Swift.from_pretrained(base_model, model_id=model)
|
||||
return swift_model
|
||||
|
||||
if isinstance(model, str) and is_official_hub_path(model):
|
||||
logger.info(f'initiate model from location {model}.')
|
||||
if is_model(model):
|
||||
if self.llm_framework is not None:
|
||||
model_dir = model if os.path.exists(
|
||||
model) else snapshot_download(model)
|
||||
return self._wrap_infer_framework(model_dir,
|
||||
self.llm_framework)
|
||||
elif is_model(model):
|
||||
return Model.from_pretrained(
|
||||
model,
|
||||
invoked_by=Invoke.PIPELINE,
|
||||
@@ -82,13 +92,19 @@ class LLMPipeline(Pipeline):
|
||||
self.cfg = Config.from_file(cfg_file)
|
||||
return self.cfg.safe_get('adapter_cfg.tuner_backend') == 'swift'
|
||||
|
||||
def _wrap_infer_framework(self, model_dir, framework='vllm'):
|
||||
from modelscope.pipelines.accelerate.base import InferFramework
|
||||
return InferFramework.from_pretrained(model_dir, framework)
|
||||
|
||||
def __init__(self,
|
||||
format_messages: Union[Callable, str] = None,
|
||||
format_output: Callable = None,
|
||||
tokenizer: PreTrainedTokenizer = None,
|
||||
llm_framework: str = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
self.device_map = kwargs.pop('device_map', None)
|
||||
self.llm_framework = llm_framework
|
||||
# TODO: qwen-int4 need 'cuda'/'auto' device_map.
|
||||
if not self.device_map and 'qwen' in kwargs['model'].lower():
|
||||
self.device_map = 'cuda'
|
||||
@@ -139,15 +155,20 @@ class LLMPipeline(Pipeline):
|
||||
is_messages = isinstance(inputs, dict) and 'messages' in inputs
|
||||
tokens = self.preprocess(inputs, is_messages, **preprocess_params)
|
||||
|
||||
if hasattr(self.model, 'generate'):
|
||||
outputs = self.model.generate(**tokens, **forward_params)
|
||||
elif hasattr(self.model, 'model') and hasattr(self.model.model,
|
||||
'generate'):
|
||||
outputs = self.model.model.generate(**tokens, **forward_params)
|
||||
if self.llm_framework is None:
|
||||
if hasattr(self.model, 'generate'):
|
||||
outputs = self.model.generate(**tokens, **forward_params)
|
||||
elif hasattr(self.model, 'model') and hasattr(
|
||||
self.model.model, 'generate'):
|
||||
outputs = self.model.model.generate(**tokens, **forward_params)
|
||||
else:
|
||||
raise ValueError('model does not support `generate`!')
|
||||
else:
|
||||
raise ValueError('model does not support `generate`!')
|
||||
tokens = [list(tokens['inputs'].flatten().numpy())]
|
||||
outputs = self.model(tokens, **forward_params)[0]
|
||||
|
||||
outputs = outputs.tolist()[0][len(tokens['inputs'][0]):]
|
||||
if not isinstance(outputs, str):
|
||||
outputs = outputs.tolist()[0][len(tokens['inputs'][0]):]
|
||||
response = self.postprocess(outputs, is_messages, **postprocess_params)
|
||||
return response
|
||||
|
||||
@@ -165,14 +186,19 @@ class LLMPipeline(Pipeline):
|
||||
elif hasattr(self.model, 'model') and hasattr(self.model.model,
|
||||
'device'):
|
||||
device = self.model.model.device
|
||||
elif hasattr(self.model, 'llm_framework'):
|
||||
device = 'cpu'
|
||||
else:
|
||||
raise ValueError('model does not have `device` attribute!')
|
||||
return {k: v.to(device) for k, v in tokens.items()}
|
||||
|
||||
def postprocess(self, outputs, is_messages: bool, **kwargs):
|
||||
|
||||
response = self.tokenizer.decode(
|
||||
outputs, skip_special_tokens=True, **kwargs)
|
||||
if not isinstance(outputs, str):
|
||||
response = self.tokenizer.decode(
|
||||
outputs, skip_special_tokens=True, **kwargs)
|
||||
else:
|
||||
response = outputs
|
||||
if is_messages:
|
||||
response = self.format_output(response, **kwargs)
|
||||
else:
|
||||
|
||||
@@ -275,6 +275,14 @@ def is_espnet_available(pkg_name):
|
||||
and importlib.util.find_spec('espnet')
|
||||
|
||||
|
||||
def is_vllm_available():
|
||||
return importlib.util.find_spec('vllm') is not None
|
||||
|
||||
|
||||
def is_tensorrt_llm_available():
|
||||
return importlib.util.find_spec('tensorrt_llm') is not None
|
||||
|
||||
|
||||
REQUIREMENTS_MAAPING = OrderedDict([
|
||||
('protobuf', (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
|
||||
('sentencepiece', (is_sentencepiece_available,
|
||||
|
||||
Reference in New Issue
Block a user