mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
support ms-swift 3.0.0 (#1166)
This commit is contained in:
@@ -32,6 +32,7 @@ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
|
||||
pip install faiss-gpu
|
||||
pip install healpy
|
||||
pip install huggingface-hub==0.25.2
|
||||
pip install ms-swift>=3.0.1
|
||||
# test with install
|
||||
pip install .
|
||||
else
|
||||
|
||||
@@ -223,8 +223,9 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model,
|
||||
List[Model]],
|
||||
revision: Optional[str],
|
||||
kwargs: Dict[str, Any]) -> Optional[str]:
|
||||
from .nlp.llm_pipeline import SWIFT_MODEL_ID_MAPPING, init_swift_model_mapping, ModelTypeHelper, LLMAdapterRegistry
|
||||
from .nlp.llm_pipeline import ModelTypeHelper, LLMAdapterRegistry
|
||||
from ..hub.check_model import get_model_id_from_cache
|
||||
from swift.llm import get_model_info_meta
|
||||
if isinstance(model, list):
|
||||
model = model[0]
|
||||
if not isinstance(model, str):
|
||||
@@ -237,9 +238,17 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model,
|
||||
else:
|
||||
model_id = model
|
||||
|
||||
init_swift_model_mapping()
|
||||
if model_id.lower() in SWIFT_MODEL_ID_MAPPING:
|
||||
try:
|
||||
info = get_model_info_meta(model_id)
|
||||
model_type = info[0].model_type
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Cannot using llm_framework with {model_id}, '
|
||||
f'ignoring llm_framework={self.llm_framework} : {e}')
|
||||
model_type = None
|
||||
if model_type:
|
||||
return 'llm'
|
||||
|
||||
model_type = ModelTypeHelper.get(
|
||||
model, revision, with_adapter=True, split='-', use_cache=True)
|
||||
if LLMAdapterRegistry.contains(model_type):
|
||||
|
||||
@@ -29,21 +29,9 @@ from modelscope.utils.streaming_output import (PipelineStreamingOutputMixin,
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
SWIFT_MODEL_ID_MAPPING = {}
|
||||
SWIFT_FRAMEWORK = 'swift'
|
||||
|
||||
|
||||
def init_swift_model_mapping():
|
||||
from swift.llm.utils import MODEL_MAPPING
|
||||
|
||||
global SWIFT_MODEL_ID_MAPPING
|
||||
if not SWIFT_MODEL_ID_MAPPING:
|
||||
SWIFT_MODEL_ID_MAPPING = {
|
||||
v['model_id_or_path'].lower(): k
|
||||
for k, v in MODEL_MAPPING.items()
|
||||
}
|
||||
|
||||
|
||||
class LLMAdapterRegistry:
|
||||
|
||||
llm_format_map = {'qwen': [None, None, None]}
|
||||
@@ -227,12 +215,12 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
|
||||
def _init_swift(self, model_id, device) -> None:
|
||||
from swift.llm import prepare_model_template
|
||||
from swift.llm.utils import InferArguments
|
||||
from swift.llm import InferArguments, get_model_info_meta
|
||||
|
||||
def format_messages(messages: Dict[str, List[Dict[str, str]]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
inputs, _ = self.template.encode(get_example(messages))
|
||||
inputs = self.template.encode(messages)
|
||||
inputs.pop('labels', None)
|
||||
if 'input_ids' in inputs:
|
||||
input_ids = torch.tensor(inputs['input_ids'])[None]
|
||||
@@ -265,12 +253,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
else:
|
||||
return dict(system=system, prompt=prompt, history=history)
|
||||
|
||||
init_swift_model_mapping()
|
||||
|
||||
assert model_id.lower() in SWIFT_MODEL_ID_MAPPING,\
|
||||
f'Invalid model id {model_id} or Swift framework does not support this model.'
|
||||
args = InferArguments(
|
||||
model_type=SWIFT_MODEL_ID_MAPPING[model_id.lower()])
|
||||
args = InferArguments(model=model_id)
|
||||
model, template = prepare_model_template(
|
||||
args, device_map=self.device_map)
|
||||
self.model = add_stream_generate(model)
|
||||
|
||||
Reference in New Issue
Block a user