This commit is contained in:
shuaigezhu
2022-11-24 10:50:38 +08:00
parent db0f25a594
commit f171552ee3
9 changed files with 29 additions and 113 deletions

View File

@@ -1,6 +1,6 @@
# Modified by Zhipu.AI
# Original Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
from modelscope.utils.import_utils import LazyImportModule

View File

@@ -1,8 +1,8 @@
# Copyright (c) 2022 Zhipu.AI
import math
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
def fast_gelu(x):

View File

@@ -1,20 +1,15 @@
# Copyright (c) 2022 Zhipu.AI
import copy
import os
import random
import time
from typing import Dict
from typing import Any, Dict
import numpy as np
import torch
from IPython import embed
from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from .codegeex import CodeGeeXModel
from .inference import get_token_stream
from .tokenizer import CodeGeeXTokenizer
@@ -45,18 +40,18 @@ class CodeGeeXForCodeTranslation(TorchModel):
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
logger = get_logger()
# loading tokenizer
print('Loading tokenizer ...')
logger.info('Loading tokenizer ...')
self.tokenizer = CodeGeeXTokenizer(
tokenizer_path=model_dir + '/tokenizer', mode='codegeex-13b')
# loading model
state_dict_path = model_dir + '/ckpt_ms_translation_0817.pt'
print('Loading state dict ...')
logger.info('Loading state dict ...')
state_dict = torch.load(state_dict_path, map_location='cpu')
state_dict = state_dict['module']
print('Building CodeGeeX model ...')
logger.info('Building CodeGeeX model ...')
self.model = model_provider()
self.model.load_state_dict(state_dict)
self.model.eval()
@@ -68,21 +63,16 @@ class CodeGeeXForCodeTranslation(TorchModel):
seq_length = 2048
out_seq_length = 256
bad_ids = None
print('Generating ...')
src_lang = input['source language']
dst_lang = input['target language']
prompt = input['prompt']
prompt = f'code translation\n{src_lang}:\n{prompt}\n{dst_lang}:\n'
t0 = time.perf_counter()
logger = get_logger()
tokenizer = self.tokenizer
model = self.model
for prompt in [prompt]:
tokens = tokenizer.encode_code(prompt)
print(tokens)
print('Current prompt:')
print(prompt)
n_token_prompt = len(tokens)
print('N_token_prompt:', n_token_prompt)
token_stream = get_token_stream(
model,
tokenizer,
@@ -108,19 +98,10 @@ class CodeGeeXForCodeTranslation(TorchModel):
generated_code = tokenizer.decode_code(
generated_tokens_[n_token_prompt:])
generated_code = ''.join(generated_code)
t1 = time.perf_counter()
print('Total generation time:', t1 - t0, '# Tokens:',
len(generated_tokens_) - n_token_prompt)
print(
f'{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token'
)
print(
'================================= Generated code:'
)
print(generated_code)
t0 = time.perf_counter()
logger.info('================================= Generated code:')
logger.info(generated_code)
if all(is_finished):
break
print('Generation finished.')
logger.info('Generation finished.')
return {OutputKeys.TEXT: generated_code}

View File

@@ -1,12 +1,8 @@
import copy
import os
import time
import typing
from dataclasses import dataclass
# Copyright (c) 2022 Zhipu.AI
import json
import torch
import torch.nn.functional as F
from typing import List
def get_ltor_masks_and_position_ids(
@@ -128,38 +124,7 @@ def pad_batch(batch, pad_id, seq_length):
tokens.extend([pad_id] * (seq_length - context_length))
context_lengths.append(context_length)
return batch, context_lengths
def forward_step(
model,
tokens,
seq_length,
position_ids,
attention_mask,
layer_past=None,
get_key_value=None,
prompt_length=None,
context_length=None,
):
# Forward pass through the model.
output_tensor = model(
tokens,
position_ids,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
prompt_length=prompt_length,
context_length=context_length,
)
if get_key_value:
output_tensor, layer_past = output_tensor
if get_key_value:
return output_tensor, layer_past
return output_tensor
def get_token_stream(
model,

View File

@@ -1,8 +1,8 @@
import typing
# Copyright (c) 2022 Zhipu.AI
import torch
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast
from typing import List, Union
def encode_whitespaces(text, start_extra_id: int, max_len: int):

View File

@@ -1,13 +1,12 @@
# Copyright (c) 2022 Zhipu.AI
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Union
from modelscope.metainfo import Pipelines
from modelscope.models.base import Model
from modelscope.models.nlp import CodeGeeXForCodeTranslation
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import CodeGeeXPreprocessor, Preprocessor
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import Tasks
@@ -27,16 +26,18 @@ class CodeGeeXCodeTranslationPipeline(Pipeline):
self.model.eval()
self.model.half()
self.model.cuda()
if preprocessor is None:
preprocessor = CodeGeeXPreprocessor()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
super().__init__(model=model, **kwargs)
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
return inputs
# define the forward pass
def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]:
# check input format
for para in ['prompt', 'source language', 'target language']:
if para not in inputs:
return ('please check your input format.')
raise Exception('please check your input format.')
return self.model(inputs)
# format the outputs from pipeline

View File

@@ -30,7 +30,6 @@ if TYPE_CHECKING:
from .space_T_en import ConversationalTextToSqlPreprocessor
from .space_T_cn import TableQuestionAnsweringPreprocessor
from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor
from .codegeex_preprocessor import CodeGeeXPreprocessor
else:
_import_structure = {
'nlp_base': [
@@ -65,7 +64,6 @@ else:
'TextErrorCorrectionPreprocessor',
],
'mglm_summarization_preprocessor': ['MGLMSummarizationPreprocessor'],
'codegeex_preprocessor': ['CodeGeeXPreprocessor'],
'token_classification_thai_preprocessor': [
'NERPreprocessorThai',
'WordSegmentationPreprocessorThai',

View File

@@ -1,25 +0,0 @@
# Copyright (c) 2022 Zhipu.AI
import re
from typing import Any, Dict, Iterable, Optional, Tuple, Union
from modelscope.metainfo import Models, Preprocessors
from modelscope.preprocessors.base import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile
from modelscope.utils.type_assert import type_assert
@PREPROCESSORS.register_module(Fields.nlp, module_name=Preprocessors.codegeex)
class CodeGeeXPreprocessor(Preprocessor):
def __init__(self, *args, **kwargs):
"""preprocess the data
Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)
@type_assert(object, (str, tuple, Dict))
def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]:
return data

View File

@@ -2,9 +2,7 @@
import os
import unittest
from modelscope.models import Model
from modelscope.pipelines import pipeline
from modelscope.preprocessors import CodeGeeXPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
@@ -19,11 +17,9 @@ class CodeGeeXCodeTranslationTest(unittest.TestCase, DemoCompatibilityCheck):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_CodeGeeX_with_name(self):
model = 'ZhipuAI/CodeGeeX-Code-Translation-13B'
preprocessor = CodeGeeXPreprocessor()
pipe = pipeline(
task=Tasks.code_translation,
model=model,
preprocessor=preprocessor,
model=model
)
inputs = {
'prompt': 'for i in range(10):\n\tprint(i)\n',