* fix(llm ppl): 1. cache position; 2. stream_gready_search; 3. swift_mapping

* fix punkt

---------

Co-authored-by: suluyan <suluyan.sly@alibaba-inc.com>
This commit is contained in:
suluyana
2024-12-23 09:55:12 +08:00
committed by GitHub
parent 9304d40539
commit 60780769b1
5 changed files with 40 additions and 21 deletions

View File

@@ -31,6 +31,7 @@ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
python -m spacy download en_core_web_sm
pip install faiss-gpu
pip install healpy
pip install huggingface-hub==0.25.2
# test with install
pip install .
else

View File

@@ -170,7 +170,7 @@ def pipeline(task: str = None,
pipeline_props['device'] = device
cfg = ConfigDict(pipeline_props)
clear_llm_info(kwargs)
clear_llm_info(kwargs, pipeline_name)
if kwargs:
cfg.update(kwargs)
@@ -223,7 +223,7 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model,
List[Model]],
revision: Optional[str],
kwargs: Dict[str, Any]) -> Optional[str]:
from .nlp.llm_pipeline import SWIFT_MODEL_ID_MAPPING, ModelTypeHelper, LLMAdapterRegistry
from .nlp.llm_pipeline import SWIFT_MODEL_ID_MAPPING, init_swift_model_mapping, ModelTypeHelper, LLMAdapterRegistry
from ..hub.check_model import get_model_id_from_cache
if isinstance(model, list):
model = model[0]
@@ -236,8 +236,9 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model,
model_id = get_model_id_from_cache(model)
else:
model_id = model
global SWIFT_MODEL_ID_MAPPING
if model_id in SWIFT_MODEL_ID_MAPPING:
init_swift_model_mapping()
if model_id.lower() in SWIFT_MODEL_ID_MAPPING:
return 'llm'
model_type = ModelTypeHelper.get(
model, revision, with_adapter=True, split='-', use_cache=True)
@@ -245,9 +246,10 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model,
return 'llm'
def clear_llm_info(kwargs: Dict):
def clear_llm_info(kwargs: Dict, pipeline_name: str):
from modelscope.utils.model_type_helper import ModelTypeHelper
kwargs.pop('external_engine_for_llm', None)
kwargs.pop('llm_framework', None)
if pipeline_name != 'llm':
kwargs.pop('llm_framework', None)
ModelTypeHelper.clear_cache()

View File

@@ -33,6 +33,17 @@ SWIFT_MODEL_ID_MAPPING = {}
SWIFT_FRAMEWORK = 'swift'
def init_swift_model_mapping():
from swift.llm.utils import MODEL_MAPPING
global SWIFT_MODEL_ID_MAPPING
if not SWIFT_MODEL_ID_MAPPING:
SWIFT_MODEL_ID_MAPPING = {
v['model_id_or_path'].lower(): k
for k, v in MODEL_MAPPING.items()
}
class LLMAdapterRegistry:
llm_format_map = {'qwen': [None, None, None]}
@@ -216,14 +227,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
def _init_swift(self, model_id, device) -> None:
from swift.llm import prepare_model_template
from swift.llm.utils import MODEL_MAPPING, InferArguments
global SWIFT_MODEL_ID_MAPPING
if not SWIFT_MODEL_ID_MAPPING:
SWIFT_MODEL_ID_MAPPING = {
v['model_id_or_path']: k
for k, v in MODEL_MAPPING.items()
}
from swift.llm.utils import InferArguments
def format_messages(messages: Dict[str, List[Dict[str, str]]],
tokenizer: PreTrainedTokenizer,
@@ -261,9 +265,12 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
else:
return dict(system=system, prompt=prompt, history=history)
assert model_id in SWIFT_MODEL_ID_MAPPING,\
init_swift_model_mapping()
assert model_id.lower() in SWIFT_MODEL_ID_MAPPING,\
f'Invalid model id {model_id} or Swift framework does not support this model.'
args = InferArguments(model_type=SWIFT_MODEL_ID_MAPPING[model_id])
args = InferArguments(
model_type=SWIFT_MODEL_ID_MAPPING[model_id.lower()])
model, template = prepare_model_template(
args, device_map=self.device_map)
self.model = add_stream_generate(model)

View File

@@ -213,11 +213,14 @@ class FillMaskPoNetPreprocessor(FillMaskPreprocessorBase):
osp.join(model_dir, ModelFile.CONFIGURATION))
self.language = self.cfg.model.get('language', 'en')
if self.language == 'en':
from nltk.tokenize import sent_tokenize
import nltk
nltk.download('punkt_tab')
# import_external_nltk_data(
# osp.join(model_dir, 'nltk_data'), 'tokenizers/punkt_tab')
from nltk.tokenize import sent_tokenize
from packaging import version
if version.parse(nltk.__version__) >= version.parse('3.8.2'):
nltk.download('punkt_tab')
else:
import_external_nltk_data(
osp.join(model_dir, 'nltk_data'), 'tokenizers/punkt_tab')
elif self.language in ['zh', 'cn']:
def sent_tokenize(para):

View File

@@ -175,7 +175,11 @@ class PretrainedModelStreamingOutputMixin(StreamingOutputMixin):
@contextmanager
def _replace_generate(self, model: PreTrainedModel) -> Generator:
if version.parse(transformers.__version__) >= version.parse('4.39.0'):
if version.parse(transformers.__version__) >= version.parse('4.43.0'):
greedy_search_name = 'stream_greedy_search'
sample_name = '_sample'
elif version.parse(
transformers.__version__) >= version.parse('4.39.0'):
greedy_search_name = '_greedy_search'
sample_name = '_sample'
else:
@@ -449,6 +453,8 @@ class PretrainedModelStreamingOutputMixin(StreamingOutputMixin):
break
# prepare model inputs
model_kwargs = self._get_initial_cache_position(
input_ids, model_kwargs)
model_inputs = self.prepare_inputs_for_generation(
input_ids, **model_kwargs)