This commit is contained in:
shuaigezhu
2023-03-17 14:05:58 +08:00
parent 8abfffc7e5
commit ffa7975ba2
6 changed files with 113 additions and 0 deletions

View File

@@ -147,6 +147,7 @@ class Models(object):
T5 = 'T5'
mglm = 'mglm'
codegeex = 'codegeex'
chatglm6b = 'chatglm6b'
bloom = 'bloom'
unite = 'unite'
megatron_bert = 'megatron-bert'
@@ -424,6 +425,7 @@ class Pipelines(object):
feature_extraction = 'feature-extraction'
mglm_text_summarization = 'mglm-text-summarization'
codegeex_code_translation = 'codegeex-code-translation'
chatglm6b_text_generation = 'chatglm6b-text-generation'
codegeex_code_generation = 'codegeex-code-generation'
translation_en_to_de = 'translation_en_to_de' # keep it underscore
translation_en_to_ro = 'translation_en_to_ro' # keep it underscore

View File

@@ -18,6 +18,7 @@ if TYPE_CHECKING:
)
from .bloom import BloomModel
from .codegeex import CodeGeeXForCodeTranslation, CodeGeeXForCodeGeneration
from .chatglm_6b import ChatGLM6bForTextGeneration
from .csanmt import CsanmtForTranslation
from .deberta_v2 import DebertaV2ForMaskedLM, DebertaV2Model
from .gpt_neo import GPTNeoModel
@@ -89,6 +90,7 @@ else:
'csanmt': ['CsanmtForTranslation'],
'codegeex':
['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'],
'chatglm_6b': ['ChatGLM6bForTextGeneration'],
'deberta_v2': ['DebertaV2ForMaskedLM', 'DebertaV2Model'],
'heads': ['TextClassificationHead'],
'hf_transformers': ['TransformersModel'],

View File

@@ -0,0 +1,22 @@
# Modified by Zhipu.AI
# Original Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .chatglm6b_for_text_generation import ChatGLM6bForTextGeneration
else:
_import_structure = {
'chatglm6b_for_text_generation': ['ChatGLM6bForTextGeneration']
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,46 @@
# Copyright (c) 2022 Zhipu.AI
import copy
from typing import Any, Dict
import torch
from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from transformers import AutoTokenizer, AutoModel
@MODELS.register_module(Tasks.text_generation, module_name=Models.chatglm6b)
class ChatGLM6bForTextGeneration(TorchModel):
def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the chatglm6b from the `model_dir` path.
Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
self.logger = get_logger()
# loading tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
# loading model
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).half().cuda()
def forward(self, input: Dict) -> Dict:
return {OutputKeys.TEXT: self.chat(input)}
def chat(self, input: Dict) -> Dict:
text = input['text']
history = input['history']
response, history = self.model.chat(self.tokenizer, text, history)
self.logger.info('Generation finished.')
res = {'response': response, 'history': history}
return {OutputKeys.TEXT: res}
def quantize(self, bits: int):
self.model = self.model.quantize(bits)
return self

View File

@@ -35,6 +35,7 @@ if TYPE_CHECKING:
from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline
from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline
from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline
from .chatglm6b_text_generation_pipeline import ChatGLM6bTextGenerationPipeline
from .translation_evaluation_pipeline import TranslationEvaluationPipeline
from .user_satisfaction_estimation_pipeline import UserSatisfactionEstimationPipeline
from .siamese_uie_pipeline import SiameseUiePipeline
@@ -89,6 +90,7 @@ else:
['CodeGeeXCodeTranslationPipeline'],
'codegeex_code_generation_pipeline':
['CodeGeeXCodeGenerationPipeline'],
'chatglm6b_text_generation_pipeline': ['ChatGLM6bTextGenerationPipeline'],
'translation_evaluation_pipeline': ['TranslationEvaluationPipeline'],
'user_satisfaction_estimation_pipeline':
['UserSatisfactionEstimationPipeline'],

View File

@@ -0,0 +1,39 @@
# Copyright (c) 2022 Zhipu.AI
from typing import Any, Dict, Union
from modelscope.metainfo import Pipelines
from modelscope.models.nlp import ChatGLM6bForTextGeneration
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import Tasks
@PIPELINES.register_module(
group_key=Tasks.text_generation,
module_name=Pipelines.chatglm6b_text_generation)
class ChatGLM6bTextGenerationPipeline(Pipeline):
def __init__(self,
model: Union[ChatGLM6bForTextGeneration, str],
preprocessor: [Preprocessor] = None,
*args,
**kwargs):
model = ChatGLM6bForTextGeneration(model) if isinstance(model,
str) else model
self.model = model
self.model.eval()
super().__init__(model=model, **kwargs)
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
return inputs
# define the forward pass
def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]:
return self.model(inputs)
# format the outputs from pipeline
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
return input