mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 20:49:37 +01:00
fix tokenizer params
This commit is contained in:
@@ -33,10 +33,12 @@ from .backbone import MsModelMixin
|
||||
def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]],
|
||||
max_length: int, tokenizer):
|
||||
system_prompt = f'<s>[INST] <<SYS>>\n{system}\n<</SYS>>\n\n'
|
||||
system_ids = tokenizer(system_prompt, return_tensors='pt').input_ids
|
||||
system_ids = tokenizer(
|
||||
system_prompt, add_special_tokens=False, return_tensors='pt').input_ids
|
||||
|
||||
text_prompt = f'{text.strip()} [/INST]'
|
||||
text_ids = tokenizer(text_prompt, return_tensors='pt').input_ids
|
||||
text_ids = tokenizer(
|
||||
text_prompt, add_special_tokens=False, return_tensors='pt').input_ids
|
||||
|
||||
prompt_length = system_ids.shape[-1] + text_ids.shape[-1]
|
||||
if prompt_length > max_length:
|
||||
@@ -51,7 +53,9 @@ def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]],
|
||||
assert isinstance(user, str)
|
||||
assert isinstance(bot, str)
|
||||
round_prompt = f'{user.strip()} [/INST] {bot.strip()} </s><s>[INST] '
|
||||
round_ids = tokenizer(round_prompt, return_tensors='pt').input_ids
|
||||
round_ids = tokenizer(
|
||||
round_prompt, add_special_tokens=False,
|
||||
return_tensors='pt').input_ids
|
||||
if prompt_length + round_ids.shape[-1] > max_length:
|
||||
# excess history should not be appended to the prompt
|
||||
break
|
||||
|
||||
@@ -27,7 +27,9 @@ if TYPE_CHECKING:
|
||||
from .translation_quality_estimation_pipeline import TranslationQualityEstimationPipeline
|
||||
from .text_error_correction_pipeline import TextErrorCorrectionPipeline
|
||||
from .word_alignment_pipeline import WordAlignmentPipeline
|
||||
from .text_generation_pipeline import TextGenerationPipeline, TextGenerationT5Pipeline, SeqGPTPipeline
|
||||
from .text_generation_pipeline import TextGenerationPipeline, TextGenerationT5Pipeline, \
|
||||
SeqGPTPipeline, ChatGLM6bTextGenerationPipeline, ChatGLM6bV2TextGenerationPipeline, \
|
||||
QWenChatPipeline, QWenTextGenerationPipeline, Llama2TaskPipeline
|
||||
from .fid_dialogue_pipeline import FidDialoguePipeline
|
||||
from .token_classification_pipeline import TokenClassificationPipeline
|
||||
from .translation_pipeline import TranslationPipeline
|
||||
@@ -80,7 +82,10 @@ else:
|
||||
'word_alignment_pipeline': ['WordAlignmentPipeline'],
|
||||
'text_generation_pipeline': [
|
||||
'TextGenerationPipeline', 'TextGenerationT5Pipeline',
|
||||
'SeqGPTPipeline'
|
||||
'ChatGLM6bTextGenerationPipeline',
|
||||
'ChatGLM6bV2TextGenerationPipeline', 'QWenChatPipeline',
|
||||
'QWenTextGenerationPipeline', 'SeqGPTPipeline',
|
||||
'Llama2TaskPipeline'
|
||||
],
|
||||
'fid_dialogue_pipeline': ['FidDialoguePipeline'],
|
||||
'token_classification_pipeline': ['TokenClassificationPipeline'],
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Copyright (c) 2022 Zhipu.AI
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope import Model, snapshot_download
|
||||
from modelscope.metainfo import Pipelines, Preprocessors
|
||||
from modelscope.models.nlp.llama2 import Llama2Tokenizer
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.pipelines.nlp.text_generation_pipeline import \
|
||||
TextGenerationPipeline
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.constant import Fields, Tasks
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.text_generation,
|
||||
module_name=Pipelines.llama2_text_generation_pipeline)
|
||||
class Llama2TaskPipeline(TextGenerationPipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str],
|
||||
preprocessor: Preprocessor = None,
|
||||
config_file: str = None,
|
||||
device: str = 'gpu',
|
||||
auto_collate=True,
|
||||
**kwargs):
|
||||
"""Use `model` and `preprocessor` to create a generation pipeline for prediction.
|
||||
|
||||
Args:
|
||||
model (str or Model): Supply either a local model dir which supported the text generation task,
|
||||
or a model id from the model hub, or a torch model instance.
|
||||
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
|
||||
the model if supplied.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the preprocessor's constructor.
|
||||
Examples:
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
>>> import torch
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope import snapshot_download, Model
|
||||
>>> model_dir = snapshot_download("modelscope/Llama-2-13b-chat-ms",
|
||||
>>> ignore_file_pattern = [r'\\w+\\.safetensors'])
|
||||
>>> pipe = pipeline(task=Tasks.text_generation, model=model_dir, device_map='auto',
|
||||
>>> torch_dtype=torch.float16)
|
||||
>>> inputs="咖啡的作用是什么?"
|
||||
>>> result = pipe(inputs,max_length=200, do_sample=True, top_p=0.85,
|
||||
>>> temperature=1.0, repetition_penalty=1., eos_token_id=2, bos_token_id=1, pad_token_id=0)
|
||||
>>> print(result['text'])
|
||||
|
||||
To view other examples plese check tests/pipelines/test_llama2_text_generation_pipeline.py.
|
||||
"""
|
||||
self.model = Model.from_pretrained(
|
||||
model, device_map='auto', torch_dtype=torch.float16)
|
||||
self.tokenizer = Llama2Tokenizer.from_pretrained(model)
|
||||
super().__init__(model=self.model, **kwargs)
|
||||
|
||||
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
return {}, pipeline_parameters, {}
|
||||
|
||||
def forward(self,
|
||||
inputs,
|
||||
max_length=2048,
|
||||
do_sample=True,
|
||||
top_p=0.85,
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.,
|
||||
eos_token_id=2,
|
||||
bos_token_id=1,
|
||||
pad_token_id=0,
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
output = {}
|
||||
inputs = self.tokenizer(inputs, return_tensors='pt')
|
||||
generate_ids = self.model.generate(
|
||||
inputs.input_ids.to('cuda'),
|
||||
max_length=max_length,
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
eos_token_id=eos_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
**forward_params)
|
||||
out = self.tokenizer.batch_decode(
|
||||
generate_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)[0]
|
||||
output['text'] = out
|
||||
return output
|
||||
|
||||
# format the outputs from pipeline
|
||||
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
|
||||
return input
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
# Copyright (c) 2022 Zhipu.AI
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -25,7 +24,8 @@ from modelscope.utils.torch_utils import is_on_same_device
|
||||
__all__ = [
|
||||
'TextGenerationPipeline', 'TextGenerationT5Pipeline',
|
||||
'ChatGLM6bTextGenerationPipeline', 'ChatGLM6bV2TextGenerationPipeline',
|
||||
'QWenChatPipeline', 'QWenTextGenerationPipeline', 'SeqGPTPipeline'
|
||||
'QWenChatPipeline', 'QWenTextGenerationPipeline', 'SeqGPTPipeline',
|
||||
'Llama2TaskPipeline'
|
||||
]
|
||||
|
||||
|
||||
@@ -199,7 +199,7 @@ class ChatGLM6bTextGenerationPipeline(Pipeline):
|
||||
use_bf16=False,
|
||||
**kwargs):
|
||||
from modelscope.models.nlp.chatglm.text_generation import (
|
||||
ChatGLMConfig, ChatGLMForConditionalGeneration)
|
||||
ChatGLMForConditionalGeneration)
|
||||
if isinstance(model, str):
|
||||
model_dir = snapshot_download(
|
||||
model) if not os.path.exists(model) else model
|
||||
@@ -427,7 +427,6 @@ class QWenTextGenerationPipeline(Pipeline):
|
||||
class SeqGPTPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: Union[Model, str], **kwargs):
|
||||
from modelscope.models.nlp import BloomForTextGeneration
|
||||
from modelscope.utils.hf_util import AutoTokenizer
|
||||
|
||||
if isinstance(model, str):
|
||||
@@ -468,3 +467,89 @@ class SeqGPTPipeline(Pipeline):
|
||||
# format the outputs from pipeline
|
||||
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
|
||||
return input
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.text_generation,
|
||||
module_name=Pipelines.llama2_text_generation_pipeline)
|
||||
class Llama2TaskPipeline(TextGenerationPipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str],
|
||||
preprocessor: Preprocessor = None,
|
||||
config_file: str = None,
|
||||
device: str = 'gpu',
|
||||
auto_collate=True,
|
||||
**kwargs):
|
||||
"""Use `model` and `preprocessor` to create a generation pipeline for prediction.
|
||||
|
||||
Args:
|
||||
model (str or Model): Supply either a local model dir which supported the text generation task,
|
||||
or a model id from the model hub, or a torch model instance.
|
||||
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
|
||||
the model if supplied.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the preprocessor's constructor.
|
||||
Examples:
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
>>> import torch
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope import snapshot_download, Model
|
||||
>>> model_dir = snapshot_download("modelscope/Llama-2-13b-chat-ms",
|
||||
>>> ignore_file_pattern = [r'\\w+\\.safetensors'])
|
||||
>>> pipe = pipeline(task=Tasks.text_generation, model=model_dir, device_map='auto',
|
||||
>>> torch_dtype=torch.float16)
|
||||
>>> inputs="咖啡的作用是什么?"
|
||||
>>> result = pipe(inputs,max_length=200, do_sample=True, top_p=0.85,
|
||||
>>> temperature=1.0, repetition_penalty=1., eos_token_id=2, bos_token_id=1, pad_token_id=0)
|
||||
>>> print(result['text'])
|
||||
|
||||
To view other examples plese check tests/pipelines/test_llama2_text_generation_pipeline.py.
|
||||
"""
|
||||
self.model = Model.from_pretrained(
|
||||
model, device_map='auto', torch_dtype=torch.float16)
|
||||
from modelscope.models.nlp.llama2 import Llama2Tokenizer
|
||||
self.tokenizer = Llama2Tokenizer.from_pretrained(model)
|
||||
super().__init__(model=self.model, **kwargs)
|
||||
|
||||
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
return {}, pipeline_parameters, {}
|
||||
|
||||
def forward(self,
|
||||
inputs,
|
||||
max_length=2048,
|
||||
do_sample=True,
|
||||
top_p=0.85,
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.,
|
||||
eos_token_id=2,
|
||||
bos_token_id=1,
|
||||
pad_token_id=0,
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
output = {}
|
||||
inputs = self.tokenizer(
|
||||
inputs, add_special_tokens=False, return_tensors='pt')
|
||||
generate_ids = self.model.generate(
|
||||
inputs.input_ids.to('cuda'),
|
||||
max_length=max_length,
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
eos_token_id=eos_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
**forward_params)
|
||||
out = self.tokenizer.batch_decode(
|
||||
generate_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)[0]
|
||||
output['text'] = out
|
||||
return output
|
||||
|
||||
# format the outputs from pipeline
|
||||
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
|
||||
return input
|
||||
|
||||
Reference in New Issue
Block a user