mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
add code_generation
This commit is contained in:
@@ -257,6 +257,7 @@ class Pipelines(object):
|
||||
feature_extraction = 'feature-extraction'
|
||||
mglm_text_summarization = 'mglm-text-summarization'
|
||||
codegeex_code_translation = 'codegeex-code-translation'
|
||||
codegeex_code_generation = 'codegeex-code-generation'
|
||||
translation_en_to_de = 'translation_en_to_de' # keep it underscore
|
||||
translation_en_to_ro = 'translation_en_to_ro' # keep it underscore
|
||||
translation_en_to_fr = 'translation_en_to_fr' # keep it underscore
|
||||
@@ -384,7 +385,6 @@ class Preprocessors(object):
|
||||
document_segmentation = 'document-segmentation'
|
||||
feature_extraction = 'feature-extraction'
|
||||
mglm_summarization = 'mglm-summarization'
|
||||
codegeex = 'codegeex'
|
||||
sentence_piece = 'sentence-piece'
|
||||
|
||||
# audio preprocessor
|
||||
|
||||
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .T5 import T5ForConditionalGeneration
|
||||
from .mglm import MGLMForTextSummarization
|
||||
from .codegeex import CodeGeeXForCodeTranslation
|
||||
from .codegeex import CodeGeeXForCodeTranslation, CodeGeeXForCodeGeneration
|
||||
from .task_models import (
|
||||
FeatureExtractionModel,
|
||||
InformationExtractionModel,
|
||||
@@ -109,7 +109,7 @@ else:
|
||||
'sentence_embedding': ['SentenceEmbedding'],
|
||||
'T5': ['T5ForConditionalGeneration'],
|
||||
'mglm': ['MGLMForTextSummarization'],
|
||||
'codegeex': ['CodeGeeXForCodeTranslation'],
|
||||
'codegeex': ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'],
|
||||
'gpt_neo': ['GPTNeoModel'],
|
||||
'bloom': ['BloomModel'],
|
||||
}
|
||||
|
||||
@@ -6,9 +6,11 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .codegeex_for_code_translation import CodeGeeXForCodeTranslation
|
||||
from .codegeex_for_code_generation import CodeGeeXForCodeGeneration
|
||||
else:
|
||||
_import_structure = {
|
||||
'codegeex_for_code_translation': ['CodeGeeXForCodeTranslation'],
|
||||
'codegeex_for_code_generation': ['CodeGeeXForCodeGeneration'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
111
modelscope/models/nlp/codegeex/codegeex_for_code_generation.py
Executable file
111
modelscope/models/nlp/codegeex/codegeex_for_code_generation.py
Executable file
@@ -0,0 +1,111 @@
|
||||
# Copyright (c) 2022 Zhipu.AI
|
||||
import copy
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .codegeex import CodeGeeXModel
|
||||
from .inference import get_token_stream
|
||||
from .tokenizer import CodeGeeXTokenizer
|
||||
|
||||
|
||||
def model_provider():
|
||||
"""Build the model."""
|
||||
|
||||
hidden_size = 5120
|
||||
num_attention_heads = 40
|
||||
num_layers = 39
|
||||
padded_vocab_size = 52224
|
||||
max_position_embeddings = 2048
|
||||
|
||||
model = CodeGeeXModel(hidden_size, num_layers, num_attention_heads,
|
||||
padded_vocab_size, max_position_embeddings)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.code_generation, module_name=Models.codegeex)
|
||||
class CodeGeeXForCodeGeneration(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the fast poem model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
logger = get_logger()
|
||||
# loading tokenizer
|
||||
logger.info('Loading tokenizer ...')
|
||||
self.tokenizer = CodeGeeXTokenizer(
|
||||
tokenizer_path=model_dir + '/tokenizer', mode='codegeex-13b')
|
||||
# loading model
|
||||
state_dict_path = model_dir + '/ckpt_ms_213000_fp32_52224.pt'
|
||||
logger.info('Loading state dict ...')
|
||||
state_dict = torch.load(state_dict_path, map_location='cpu')
|
||||
state_dict = state_dict['module']
|
||||
|
||||
logger.info('Building CodeGeeX model ...')
|
||||
self.model = model_provider()
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.model.eval()
|
||||
self.model.half()
|
||||
self.model.cuda()
|
||||
|
||||
def forward(self, input: Dict[str, str]) -> Dict[str, str]:
|
||||
micro_batch_size = 1
|
||||
seq_length = 2048
|
||||
out_seq_length = 256
|
||||
bad_ids = None
|
||||
lang = input['language']
|
||||
prompt = input['prompt']
|
||||
prompt = f"# language: {lang}\n{prompt}"
|
||||
logger = get_logger()
|
||||
tokenizer = self.tokenizer
|
||||
model = self.model
|
||||
for prompt in [prompt]:
|
||||
tokens = tokenizer.encode_code(prompt)
|
||||
n_token_prompt = len(tokens)
|
||||
token_stream = get_token_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
seq_length,
|
||||
out_seq_length,
|
||||
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
|
||||
micro_batch_size=micro_batch_size,
|
||||
bad_ids=bad_ids,
|
||||
topk=1,
|
||||
topp=0.9,
|
||||
temperature=0.9,
|
||||
greedy=True
|
||||
)
|
||||
is_finished = [False for _ in range(micro_batch_size)]
|
||||
for i, generated in enumerate(token_stream):
|
||||
generated_tokens = generated[0]
|
||||
for j in range(micro_batch_size):
|
||||
if is_finished[j]:
|
||||
continue
|
||||
if generated_tokens[j].cpu().numpy(
|
||||
)[-1] == tokenizer.eos_token_id or len(
|
||||
generated_tokens[j]) >= out_seq_length:
|
||||
is_finished[j] = True
|
||||
generated_tokens_ = generated_tokens[j].cpu().numpy(
|
||||
).tolist()
|
||||
generated_code = tokenizer.decode_code(
|
||||
generated_tokens_[n_token_prompt:])
|
||||
generated_code = ''.join(generated_code)
|
||||
logger.info(
|
||||
'================================= Generated code:'
|
||||
)
|
||||
logger.info(generated_code)
|
||||
if all(is_finished):
|
||||
break
|
||||
|
||||
logger.info('Generation finished.')
|
||||
return {OutputKeys.TEXT: generated_code}
|
||||
@@ -33,6 +33,7 @@ if TYPE_CHECKING:
|
||||
from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline
|
||||
from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline
|
||||
from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline
|
||||
from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline
|
||||
from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \
|
||||
WordSegmentationThaiPipeline
|
||||
|
||||
@@ -76,6 +77,8 @@ else:
|
||||
'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'],
|
||||
'codegeex_code_translation_pipeline':
|
||||
['CodeGeeXCodeTranslationPipeline'],
|
||||
'codegeex_code_generation_pipeline':
|
||||
['CodeGeeXCodeGenerationPipeline'],
|
||||
'multilingual_word_segmentation_pipeline': [
|
||||
'MultilingualWordSegmentationPipeline',
|
||||
'WordSegmentationThaiPipeline'
|
||||
|
||||
48
modelscope/pipelines/nlp/codegeex_code_generation_pipeline.py
Executable file
48
modelscope/pipelines/nlp/codegeex_code_generation_pipeline.py
Executable file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2022 Zhipu.AI
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.nlp import CodeGeeXForCodeGeneration
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
group_key=Tasks.code_generation,
|
||||
module_name=Pipelines.codegeex_code_generation)
|
||||
class CodeGeeXCodeGenerationPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[CodeGeeXForCodeGeneration, str],
|
||||
preprocessor: [Preprocessor] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
model = CodeGeeXForCodeGeneration(model) if isinstance(model,
|
||||
str) else model
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.model.half()
|
||||
self.model.cuda()
|
||||
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
# define the forward pass
|
||||
def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]:
|
||||
# check input format
|
||||
for para in ['prompt', 'language']:
|
||||
if para not in inputs:
|
||||
raise Exception('Please check your input format.')
|
||||
if inputs['language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
|
||||
raise Exception('Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
|
||||
|
||||
return self.model(inputs)
|
||||
|
||||
# format the outputs from pipeline
|
||||
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
|
||||
return input
|
||||
@@ -38,6 +38,12 @@ class CodeGeeXCodeTranslationPipeline(Pipeline):
|
||||
for para in ['prompt', 'source language', 'target language']:
|
||||
if para not in inputs:
|
||||
raise Exception('please check your input format.')
|
||||
if inputs['source language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
|
||||
raise Exception('Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
|
||||
|
||||
if inputs['target language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
|
||||
raise Exception('Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
|
||||
|
||||
return self.model(inputs)
|
||||
|
||||
# format the outputs from pipeline
|
||||
|
||||
@@ -121,6 +121,7 @@ class NLPTasks(object):
|
||||
text_summarization = 'text-summarization'
|
||||
question_answering = 'question-answering'
|
||||
code_translation = 'code-translation'
|
||||
code_generation = 'code-generation'
|
||||
zero_shot_classification = 'zero-shot-classification'
|
||||
backbone = 'backbone'
|
||||
text_error_correction = 'text-error-correction'
|
||||
|
||||
Reference in New Issue
Block a user