mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
fix(llm ppl): 1. cache position; 2. stream_gready_search; 3. swift_mapping
This commit is contained in:
@@ -170,7 +170,7 @@ def pipeline(task: str = None,
|
||||
pipeline_props['device'] = device
|
||||
cfg = ConfigDict(pipeline_props)
|
||||
|
||||
clear_llm_info(kwargs)
|
||||
clear_llm_info(kwargs, pipeline_name)
|
||||
if kwargs:
|
||||
cfg.update(kwargs)
|
||||
|
||||
@@ -223,7 +223,7 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model,
|
||||
List[Model]],
|
||||
revision: Optional[str],
|
||||
kwargs: Dict[str, Any]) -> Optional[str]:
|
||||
from .nlp.llm_pipeline import SWIFT_MODEL_ID_MAPPING, ModelTypeHelper, LLMAdapterRegistry
|
||||
from .nlp.llm_pipeline import SWIFT_MODEL_ID_MAPPING, init_swift_model_mapping, ModelTypeHelper, LLMAdapterRegistry
|
||||
from ..hub.check_model import get_model_id_from_cache
|
||||
if isinstance(model, list):
|
||||
model = model[0]
|
||||
@@ -236,8 +236,9 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model,
|
||||
model_id = get_model_id_from_cache(model)
|
||||
else:
|
||||
model_id = model
|
||||
global SWIFT_MODEL_ID_MAPPING
|
||||
if model_id in SWIFT_MODEL_ID_MAPPING:
|
||||
|
||||
init_swift_model_mapping()
|
||||
if model_id.lower() in SWIFT_MODEL_ID_MAPPING:
|
||||
return 'llm'
|
||||
model_type = ModelTypeHelper.get(
|
||||
model, revision, with_adapter=True, split='-', use_cache=True)
|
||||
@@ -245,9 +246,10 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model,
|
||||
return 'llm'
|
||||
|
||||
|
||||
def clear_llm_info(kwargs: Dict):
|
||||
def clear_llm_info(kwargs: Dict, pipeline_name: str):
|
||||
from modelscope.utils.model_type_helper import ModelTypeHelper
|
||||
|
||||
kwargs.pop('external_engine_for_llm', None)
|
||||
kwargs.pop('llm_framework', None)
|
||||
if pipeline_name != 'llm':
|
||||
kwargs.pop('llm_framework', None)
|
||||
ModelTypeHelper.clear_cache()
|
||||
|
||||
@@ -33,6 +33,17 @@ SWIFT_MODEL_ID_MAPPING = {}
|
||||
SWIFT_FRAMEWORK = 'swift'
|
||||
|
||||
|
||||
def init_swift_model_mapping():
|
||||
from swift.llm.utils import MODEL_MAPPING
|
||||
|
||||
global SWIFT_MODEL_ID_MAPPING
|
||||
if not SWIFT_MODEL_ID_MAPPING:
|
||||
SWIFT_MODEL_ID_MAPPING = {
|
||||
v['model_id_or_path'].lower(): k
|
||||
for k, v in MODEL_MAPPING.items()
|
||||
}
|
||||
|
||||
|
||||
class LLMAdapterRegistry:
|
||||
|
||||
llm_format_map = {'qwen': [None, None, None]}
|
||||
@@ -216,14 +227,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
|
||||
def _init_swift(self, model_id, device) -> None:
|
||||
from swift.llm import prepare_model_template
|
||||
from swift.llm.utils import MODEL_MAPPING, InferArguments
|
||||
|
||||
global SWIFT_MODEL_ID_MAPPING
|
||||
if not SWIFT_MODEL_ID_MAPPING:
|
||||
SWIFT_MODEL_ID_MAPPING = {
|
||||
v['model_id_or_path']: k
|
||||
for k, v in MODEL_MAPPING.items()
|
||||
}
|
||||
from swift.llm.utils import InferArguments
|
||||
|
||||
def format_messages(messages: Dict[str, List[Dict[str, str]]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
@@ -261,9 +265,11 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
else:
|
||||
return dict(system=system, prompt=prompt, history=history)
|
||||
|
||||
assert model_id in SWIFT_MODEL_ID_MAPPING,\
|
||||
init_swift_model_mapping()
|
||||
|
||||
assert model_id.lower() in SWIFT_MODEL_ID_MAPPING,\
|
||||
f'Invalid model id {model_id} or Swift framework does not support this model.'
|
||||
args = InferArguments(model_type=SWIFT_MODEL_ID_MAPPING[model_id])
|
||||
args = InferArguments(model_type=SWIFT_MODEL_ID_MAPPING[model_id.lower()])
|
||||
model, template = prepare_model_template(
|
||||
args, device_map=self.device_map)
|
||||
self.model = add_stream_generate(model)
|
||||
|
||||
@@ -175,7 +175,10 @@ class PretrainedModelStreamingOutputMixin(StreamingOutputMixin):
|
||||
|
||||
@contextmanager
|
||||
def _replace_generate(self, model: PreTrainedModel) -> Generator:
|
||||
if version.parse(transformers.__version__) >= version.parse('4.39.0'):
|
||||
if version.parse(transformers.__version__) >= version.parse('4.43.0'):
|
||||
greedy_search_name = 'stream_greedy_search'
|
||||
sample_name = '_sample'
|
||||
elif version.parse(transformers.__version__) >= version.parse('4.39.0'):
|
||||
greedy_search_name = '_greedy_search'
|
||||
sample_name = '_sample'
|
||||
else:
|
||||
@@ -449,6 +452,7 @@ class PretrainedModelStreamingOutputMixin(StreamingOutputMixin):
|
||||
break
|
||||
|
||||
# prepare model inputs
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user