diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index a2ecc210..1dd6c6d5 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -170,7 +170,7 @@ def pipeline(task: str = None, pipeline_props['device'] = device cfg = ConfigDict(pipeline_props) - clear_llm_info(kwargs) + clear_llm_info(kwargs, pipeline_name) if kwargs: cfg.update(kwargs) @@ -223,7 +223,7 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model, List[Model]], revision: Optional[str], kwargs: Dict[str, Any]) -> Optional[str]: - from .nlp.llm_pipeline import SWIFT_MODEL_ID_MAPPING, ModelTypeHelper, LLMAdapterRegistry + from .nlp.llm_pipeline import SWIFT_MODEL_ID_MAPPING, init_swift_model_mapping, ModelTypeHelper, LLMAdapterRegistry from ..hub.check_model import get_model_id_from_cache if isinstance(model, list): model = model[0] @@ -236,8 +236,9 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model, model_id = get_model_id_from_cache(model) else: model_id = model - global SWIFT_MODEL_ID_MAPPING - if model_id in SWIFT_MODEL_ID_MAPPING: + + init_swift_model_mapping() + if model_id.lower() in SWIFT_MODEL_ID_MAPPING: return 'llm' model_type = ModelTypeHelper.get( model, revision, with_adapter=True, split='-', use_cache=True) @@ -245,9 +246,10 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model, return 'llm' -def clear_llm_info(kwargs: Dict): +def clear_llm_info(kwargs: Dict, pipeline_name: str): from modelscope.utils.model_type_helper import ModelTypeHelper kwargs.pop('external_engine_for_llm', None) - kwargs.pop('llm_framework', None) + if pipeline_name != 'llm': + kwargs.pop('llm_framework', None) ModelTypeHelper.clear_cache() diff --git a/modelscope/pipelines/nlp/llm_pipeline.py b/modelscope/pipelines/nlp/llm_pipeline.py index c46bb46a..2c08c498 100644 --- a/modelscope/pipelines/nlp/llm_pipeline.py +++ b/modelscope/pipelines/nlp/llm_pipeline.py @@ -33,6 +33,17 @@ SWIFT_MODEL_ID_MAPPING = {} SWIFT_FRAMEWORK = 'swift' +def init_swift_model_mapping(): + from swift.llm.utils import MODEL_MAPPING + + global SWIFT_MODEL_ID_MAPPING + if not SWIFT_MODEL_ID_MAPPING: + SWIFT_MODEL_ID_MAPPING = { + v['model_id_or_path'].lower(): k + for k, v in MODEL_MAPPING.items() + } + + class LLMAdapterRegistry: llm_format_map = {'qwen': [None, None, None]} @@ -216,14 +227,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin): def _init_swift(self, model_id, device) -> None: from swift.llm import prepare_model_template - from swift.llm.utils import MODEL_MAPPING, InferArguments - - global SWIFT_MODEL_ID_MAPPING - if not SWIFT_MODEL_ID_MAPPING: - SWIFT_MODEL_ID_MAPPING = { - v['model_id_or_path']: k - for k, v in MODEL_MAPPING.items() - } + from swift.llm.utils import InferArguments def format_messages(messages: Dict[str, List[Dict[str, str]]], tokenizer: PreTrainedTokenizer, @@ -261,9 +265,11 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin): else: return dict(system=system, prompt=prompt, history=history) - assert model_id in SWIFT_MODEL_ID_MAPPING,\ + init_swift_model_mapping() + + assert model_id.lower() in SWIFT_MODEL_ID_MAPPING,\ f'Invalid model id {model_id} or Swift framework does not support this model.' - args = InferArguments(model_type=SWIFT_MODEL_ID_MAPPING[model_id]) + args = InferArguments(model_type=SWIFT_MODEL_ID_MAPPING[model_id.lower()]) model, template = prepare_model_template( args, device_map=self.device_map) self.model = add_stream_generate(model) diff --git a/modelscope/utils/streaming_output.py b/modelscope/utils/streaming_output.py index 96dad20f..8de808bd 100644 --- a/modelscope/utils/streaming_output.py +++ b/modelscope/utils/streaming_output.py @@ -175,7 +175,10 @@ class PretrainedModelStreamingOutputMixin(StreamingOutputMixin): @contextmanager def _replace_generate(self, model: PreTrainedModel) -> Generator: - if version.parse(transformers.__version__) >= version.parse('4.39.0'): + if version.parse(transformers.__version__) >= version.parse('4.43.0'): + greedy_search_name = 'stream_greedy_search' + sample_name = '_sample' + elif version.parse(transformers.__version__) >= version.parse('4.39.0'): greedy_search_name = '_greedy_search' sample_name = '_sample' else: @@ -449,6 +452,7 @@ class PretrainedModelStreamingOutputMixin(StreamingOutputMixin): break # prepare model inputs + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) model_inputs = self.prepare_inputs_for_generation( input_ids, **model_kwargs)