add code_generation

This commit is contained in:
shuaigezhu
2022-11-25 16:35:19 +08:00
parent 65adde14d8
commit c9064caa58
8 changed files with 174 additions and 3 deletions

View File

@@ -257,6 +257,7 @@ class Pipelines(object):
feature_extraction = 'feature-extraction'
mglm_text_summarization = 'mglm-text-summarization'
codegeex_code_translation = 'codegeex-code-translation'
codegeex_code_generation = 'codegeex-code-generation'
translation_en_to_de = 'translation_en_to_de' # keep it underscore
translation_en_to_ro = 'translation_en_to_ro' # keep it underscore
translation_en_to_fr = 'translation_en_to_fr' # keep it underscore
@@ -384,7 +385,6 @@ class Preprocessors(object):
document_segmentation = 'document-segmentation'
feature_extraction = 'feature-extraction'
mglm_summarization = 'mglm-summarization'
codegeex = 'codegeex'
sentence_piece = 'sentence-piece'
# audio preprocessor

View File

@@ -36,7 +36,7 @@ if TYPE_CHECKING:
)
from .T5 import T5ForConditionalGeneration
from .mglm import MGLMForTextSummarization
from .codegeex import CodeGeeXForCodeTranslation
from .codegeex import CodeGeeXForCodeTranslation, CodeGeeXForCodeGeneration
from .task_models import (
FeatureExtractionModel,
InformationExtractionModel,
@@ -109,7 +109,7 @@ else:
'sentence_embedding': ['SentenceEmbedding'],
'T5': ['T5ForConditionalGeneration'],
'mglm': ['MGLMForTextSummarization'],
'codegeex': ['CodeGeeXForCodeTranslation'],
'codegeex': ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'],
'gpt_neo': ['GPTNeoModel'],
'bloom': ['BloomModel'],
}

View File

@@ -6,9 +6,11 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .codegeex_for_code_translation import CodeGeeXForCodeTranslation
from .codegeex_for_code_generation import CodeGeeXForCodeGeneration
else:
_import_structure = {
'codegeex_for_code_translation': ['CodeGeeXForCodeTranslation'],
'codegeex_for_code_generation': ['CodeGeeXForCodeGeneration'],
}
import sys

View File

@@ -0,0 +1,111 @@
# Copyright (c) 2022 Zhipu.AI
import copy
from typing import Any, Dict
import torch
from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from .codegeex import CodeGeeXModel
from .inference import get_token_stream
from .tokenizer import CodeGeeXTokenizer
def model_provider():
"""Build the model."""
hidden_size = 5120
num_attention_heads = 40
num_layers = 39
padded_vocab_size = 52224
max_position_embeddings = 2048
model = CodeGeeXModel(hidden_size, num_layers, num_attention_heads,
padded_vocab_size, max_position_embeddings)
return model
@MODELS.register_module(Tasks.code_generation, module_name=Models.codegeex)
class CodeGeeXForCodeGeneration(TorchModel):
def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the fast poem model from the `model_dir` path.
Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
logger = get_logger()
# loading tokenizer
logger.info('Loading tokenizer ...')
self.tokenizer = CodeGeeXTokenizer(
tokenizer_path=model_dir + '/tokenizer', mode='codegeex-13b')
# loading model
state_dict_path = model_dir + '/ckpt_ms_213000_fp32_52224.pt'
logger.info('Loading state dict ...')
state_dict = torch.load(state_dict_path, map_location='cpu')
state_dict = state_dict['module']
logger.info('Building CodeGeeX model ...')
self.model = model_provider()
self.model.load_state_dict(state_dict)
self.model.eval()
self.model.half()
self.model.cuda()
def forward(self, input: Dict[str, str]) -> Dict[str, str]:
micro_batch_size = 1
seq_length = 2048
out_seq_length = 256
bad_ids = None
lang = input['language']
prompt = input['prompt']
prompt = f"# language: {lang}\n{prompt}"
logger = get_logger()
tokenizer = self.tokenizer
model = self.model
for prompt in [prompt]:
tokens = tokenizer.encode_code(prompt)
n_token_prompt = len(tokens)
token_stream = get_token_stream(
model,
tokenizer,
seq_length,
out_seq_length,
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
micro_batch_size=micro_batch_size,
bad_ids=bad_ids,
topk=1,
topp=0.9,
temperature=0.9,
greedy=True
)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(micro_batch_size):
if is_finished[j]:
continue
if generated_tokens[j].cpu().numpy(
)[-1] == tokenizer.eos_token_id or len(
generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy(
).tolist()
generated_code = tokenizer.decode_code(
generated_tokens_[n_token_prompt:])
generated_code = ''.join(generated_code)
logger.info(
'================================= Generated code:'
)
logger.info(generated_code)
if all(is_finished):
break
logger.info('Generation finished.')
return {OutputKeys.TEXT: generated_code}

View File

@@ -33,6 +33,7 @@ if TYPE_CHECKING:
from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline
from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline
from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline
from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline
from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \
WordSegmentationThaiPipeline
@@ -76,6 +77,8 @@ else:
'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'],
'codegeex_code_translation_pipeline':
['CodeGeeXCodeTranslationPipeline'],
'codegeex_code_generation_pipeline':
['CodeGeeXCodeGenerationPipeline'],
'multilingual_word_segmentation_pipeline': [
'MultilingualWordSegmentationPipeline',
'WordSegmentationThaiPipeline'

View File

@@ -0,0 +1,48 @@
# Copyright (c) 2022 Zhipu.AI
from typing import Any, Dict, Union
from modelscope.metainfo import Pipelines
from modelscope.models.nlp import CodeGeeXForCodeGeneration
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import Tasks
@PIPELINES.register_module(
group_key=Tasks.code_generation,
module_name=Pipelines.codegeex_code_generation)
class CodeGeeXCodeGenerationPipeline(Pipeline):
def __init__(self,
model: Union[CodeGeeXForCodeGeneration, str],
preprocessor: [Preprocessor] = None,
*args,
**kwargs):
model = CodeGeeXForCodeGeneration(model) if isinstance(model,
str) else model
self.model = model
self.model.eval()
self.model.half()
self.model.cuda()
super().__init__(model=model, **kwargs)
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
return inputs
# define the forward pass
def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]:
# check input format
for para in ['prompt', 'language']:
if para not in inputs:
raise Exception('Please check your input format.')
if inputs['language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
raise Exception('Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
return self.model(inputs)
# format the outputs from pipeline
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
return input

View File

@@ -38,6 +38,12 @@ class CodeGeeXCodeTranslationPipeline(Pipeline):
for para in ['prompt', 'source language', 'target language']:
if para not in inputs:
raise Exception('please check your input format.')
if inputs['source language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
raise Exception('Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
if inputs['target language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
raise Exception('Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
return self.model(inputs)
# format the outputs from pipeline

View File

@@ -121,6 +121,7 @@ class NLPTasks(object):
text_summarization = 'text-summarization'
question_answering = 'question-answering'
code_translation = 'code-translation'
code_generation = 'code-generation'
zero_shot_classification = 'zero-shot-classification'
backbone = 'backbone'
text_error_correction = 'text-error-correction'