mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
Merge pull request #33 from modelscope/codegeex_code_translation
CodeGeex code translation and generation ut failed due to a known run.py environment setup issue that is being fixed. nothing to do with the change itself.
This commit is contained in:
@@ -84,6 +84,7 @@ class Models(object):
|
||||
ponet = 'ponet'
|
||||
T5 = 'T5'
|
||||
mglm = 'mglm'
|
||||
codegeex = 'codegeex'
|
||||
bloom = 'bloom'
|
||||
|
||||
# audio models
|
||||
@@ -256,6 +257,8 @@ class Pipelines(object):
|
||||
document_segmentation = 'document-segmentation'
|
||||
feature_extraction = 'feature-extraction'
|
||||
mglm_text_summarization = 'mglm-text-summarization'
|
||||
codegeex_code_translation = 'codegeex-code-translation'
|
||||
codegeex_code_generation = 'codegeex-code-generation'
|
||||
translation_en_to_de = 'translation_en_to_de' # keep it underscore
|
||||
translation_en_to_ro = 'translation_en_to_ro' # keep it underscore
|
||||
translation_en_to_fr = 'translation_en_to_fr' # keep it underscore
|
||||
|
||||
@@ -36,6 +36,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .T5 import T5ForConditionalGeneration
|
||||
from .mglm import MGLMForTextSummarization
|
||||
from .codegeex import CodeGeeXForCodeTranslation, CodeGeeXForCodeGeneration
|
||||
from .task_models import (
|
||||
FeatureExtractionModel,
|
||||
InformationExtractionModel,
|
||||
@@ -108,6 +109,8 @@ else:
|
||||
'sentence_embedding': ['SentenceEmbedding'],
|
||||
'T5': ['T5ForConditionalGeneration'],
|
||||
'mglm': ['MGLMForTextSummarization'],
|
||||
'codegeex':
|
||||
['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'],
|
||||
'gpt_neo': ['GPTNeoModel'],
|
||||
'bloom': ['BloomModel'],
|
||||
}
|
||||
|
||||
24
modelscope/models/nlp/codegeex/__init__.py
Executable file
24
modelscope/models/nlp/codegeex/__init__.py
Executable file
@@ -0,0 +1,24 @@
|
||||
# Modified by Zhipu.AI
|
||||
# Original Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .codegeex_for_code_translation import CodeGeeXForCodeTranslation
|
||||
from .codegeex_for_code_generation import CodeGeeXForCodeGeneration
|
||||
else:
|
||||
_import_structure = {
|
||||
'codegeex_for_code_translation': ['CodeGeeXForCodeTranslation'],
|
||||
'codegeex_for_code_generation': ['CodeGeeXForCodeGeneration'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
1030
modelscope/models/nlp/codegeex/codegeex.py
Executable file
1030
modelscope/models/nlp/codegeex/codegeex.py
Executable file
File diff suppressed because it is too large
Load Diff
110
modelscope/models/nlp/codegeex/codegeex_for_code_generation.py
Executable file
110
modelscope/models/nlp/codegeex/codegeex_for_code_generation.py
Executable file
@@ -0,0 +1,110 @@
|
||||
# Copyright (c) 2022 Zhipu.AI
|
||||
import copy
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .codegeex import CodeGeeXModel
|
||||
from .inference import get_token_stream
|
||||
from .tokenizer import CodeGeeXTokenizer
|
||||
|
||||
|
||||
def model_provider():
|
||||
"""Build the model."""
|
||||
|
||||
hidden_size = 5120
|
||||
num_attention_heads = 40
|
||||
num_layers = 39
|
||||
padded_vocab_size = 52224
|
||||
max_position_embeddings = 2048
|
||||
|
||||
model = CodeGeeXModel(hidden_size, num_layers, num_attention_heads,
|
||||
padded_vocab_size, max_position_embeddings)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.code_generation, module_name=Models.codegeex)
|
||||
class CodeGeeXForCodeGeneration(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the fast poem model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
logger = get_logger()
|
||||
# loading tokenizer
|
||||
logger.info('Loading tokenizer ...')
|
||||
self.tokenizer = CodeGeeXTokenizer(
|
||||
tokenizer_path=model_dir + '/tokenizer', mode='codegeex-13b')
|
||||
# loading model
|
||||
state_dict_path = model_dir + '/ckpt_ms_213000_fp32_52224.pt'
|
||||
logger.info('Loading state dict ...')
|
||||
state_dict = torch.load(state_dict_path, map_location='cpu')
|
||||
state_dict = state_dict['module']
|
||||
|
||||
logger.info('Building CodeGeeX model ...')
|
||||
self.model = model_provider()
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.model.eval()
|
||||
self.model.half()
|
||||
self.model.cuda()
|
||||
|
||||
def forward(self, input: Dict[str, str]) -> Dict[str, str]:
|
||||
micro_batch_size = 1
|
||||
seq_length = 2048
|
||||
out_seq_length = 256
|
||||
bad_ids = None
|
||||
lang = input['language']
|
||||
prompt = input['prompt']
|
||||
prompt = f'# language: {lang}\n{prompt}'
|
||||
logger = get_logger()
|
||||
tokenizer = self.tokenizer
|
||||
model = self.model
|
||||
for prompt in [prompt]:
|
||||
tokens = tokenizer.encode_code(prompt)
|
||||
n_token_prompt = len(tokens)
|
||||
token_stream = get_token_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
seq_length,
|
||||
out_seq_length,
|
||||
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
|
||||
micro_batch_size=micro_batch_size,
|
||||
bad_ids=bad_ids,
|
||||
topk=1,
|
||||
topp=0.9,
|
||||
temperature=0.9,
|
||||
greedy=True)
|
||||
is_finished = [False for _ in range(micro_batch_size)]
|
||||
for i, generated in enumerate(token_stream):
|
||||
generated_tokens = generated[0]
|
||||
for j in range(micro_batch_size):
|
||||
if is_finished[j]:
|
||||
continue
|
||||
if generated_tokens[j].cpu().numpy(
|
||||
)[-1] == tokenizer.eos_token_id or len(
|
||||
generated_tokens[j]) >= out_seq_length:
|
||||
is_finished[j] = True
|
||||
generated_tokens_ = generated_tokens[j].cpu().numpy(
|
||||
).tolist()
|
||||
generated_code = tokenizer.decode_code(
|
||||
generated_tokens_[n_token_prompt:])
|
||||
generated_code = ''.join(generated_code)
|
||||
logger.info(
|
||||
'================================= Generated code:'
|
||||
)
|
||||
logger.info(generated_code)
|
||||
if all(is_finished):
|
||||
break
|
||||
|
||||
logger.info('Generation finished.')
|
||||
return {OutputKeys.TEXT: generated_code}
|
||||
109
modelscope/models/nlp/codegeex/codegeex_for_code_translation.py
Executable file
109
modelscope/models/nlp/codegeex/codegeex_for_code_translation.py
Executable file
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) 2022 Zhipu.AI
|
||||
import copy
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .codegeex import CodeGeeXModel
|
||||
from .inference import get_token_stream
|
||||
from .tokenizer import CodeGeeXTokenizer
|
||||
|
||||
|
||||
def model_provider():
|
||||
"""Build the model."""
|
||||
|
||||
hidden_size = 5120
|
||||
num_attention_heads = 40
|
||||
num_layers = 39
|
||||
padded_vocab_size = 52224
|
||||
max_position_embeddings = 2048
|
||||
|
||||
model = CodeGeeXModel(hidden_size, num_layers, num_attention_heads,
|
||||
padded_vocab_size, max_position_embeddings)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.code_translation, module_name=Models.codegeex)
|
||||
class CodeGeeXForCodeTranslation(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the fast poem model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
logger = get_logger()
|
||||
# loading tokenizer
|
||||
logger.info('Loading tokenizer ...')
|
||||
self.tokenizer = CodeGeeXTokenizer(
|
||||
tokenizer_path=model_dir + '/tokenizer', mode='codegeex-13b')
|
||||
# loading model
|
||||
state_dict_path = model_dir + '/ckpt_ms_translation_0817.pt'
|
||||
logger.info('Loading state dict ...')
|
||||
state_dict = torch.load(state_dict_path, map_location='cpu')
|
||||
state_dict = state_dict['module']
|
||||
|
||||
logger.info('Building CodeGeeX model ...')
|
||||
self.model = model_provider()
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.model.eval()
|
||||
self.model.half()
|
||||
self.model.cuda()
|
||||
|
||||
def forward(self, input: Dict[str, str]) -> Dict[str, str]:
|
||||
micro_batch_size = 1
|
||||
seq_length = 2048
|
||||
out_seq_length = 256
|
||||
bad_ids = None
|
||||
src_lang = input['source language']
|
||||
dst_lang = input['target language']
|
||||
prompt = input['prompt']
|
||||
prompt = f'code translation\n{src_lang}:\n{prompt}\n{dst_lang}:\n'
|
||||
logger = get_logger()
|
||||
tokenizer = self.tokenizer
|
||||
model = self.model
|
||||
for prompt in [prompt]:
|
||||
tokens = tokenizer.encode_code(prompt)
|
||||
n_token_prompt = len(tokens)
|
||||
token_stream = get_token_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
seq_length,
|
||||
out_seq_length,
|
||||
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
|
||||
micro_batch_size=micro_batch_size,
|
||||
bad_ids=bad_ids,
|
||||
greedy=True,
|
||||
)
|
||||
is_finished = [False for _ in range(micro_batch_size)]
|
||||
for i, generated in enumerate(token_stream):
|
||||
generated_tokens = generated[0]
|
||||
for j in range(micro_batch_size):
|
||||
if is_finished[j]:
|
||||
continue
|
||||
if generated_tokens[j].cpu().numpy(
|
||||
)[-1] == tokenizer.eos_token_id or len(
|
||||
generated_tokens[j]) >= out_seq_length:
|
||||
is_finished[j] = True
|
||||
generated_tokens_ = generated_tokens[j].cpu().numpy(
|
||||
).tolist()
|
||||
generated_code = tokenizer.decode_code(
|
||||
generated_tokens_[n_token_prompt:])
|
||||
generated_code = ''.join(generated_code)
|
||||
logger.info(
|
||||
'================================= Generated code:'
|
||||
)
|
||||
logger.info(generated_code)
|
||||
if all(is_finished):
|
||||
break
|
||||
|
||||
logger.info('Generation finished.')
|
||||
return {OutputKeys.TEXT: generated_code}
|
||||
301
modelscope/models/nlp/codegeex/inference.py
Executable file
301
modelscope/models/nlp/codegeex/inference.py
Executable file
@@ -0,0 +1,301 @@
|
||||
# Copyright (c) 2022 Zhipu.AI
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_ltor_masks_and_position_ids(
|
||||
data,
|
||||
eod_token,
|
||||
reset_position_ids,
|
||||
reset_attention_mask,
|
||||
):
|
||||
"""Build masks and position id for left to right model."""
|
||||
|
||||
# Extract batch size and sequence length.
|
||||
micro_batch_size, seq_length = data.size()
|
||||
|
||||
# Attention mask (lower triangular).
|
||||
if reset_attention_mask:
|
||||
att_mask_batch = micro_batch_size
|
||||
else:
|
||||
att_mask_batch = 1
|
||||
attention_mask = torch.tril(
|
||||
torch.ones((att_mask_batch, seq_length, seq_length),
|
||||
device=data.device)).view(att_mask_batch, 1, seq_length,
|
||||
seq_length)
|
||||
|
||||
# Position ids.
|
||||
position_ids = torch.arange(
|
||||
seq_length, dtype=torch.long, device=data.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(data)
|
||||
# We need to clone as the ids will be modifed based on batch index.
|
||||
if reset_position_ids:
|
||||
position_ids = position_ids.clone()
|
||||
|
||||
if reset_position_ids or reset_attention_mask:
|
||||
# Loop through the batches:
|
||||
for b in range(micro_batch_size):
|
||||
|
||||
# Find indecies where EOD token is.
|
||||
eod_index = position_ids[b, data[b] == eod_token]
|
||||
# Detach indecies from positions if going to modify positions.
|
||||
if reset_position_ids:
|
||||
eod_index = eod_index.clone()
|
||||
|
||||
# Loop through EOD indecies:
|
||||
prev_index = 0
|
||||
for j in range(eod_index.size()[0]):
|
||||
i = eod_index[j]
|
||||
# Mask attention loss.
|
||||
if reset_attention_mask:
|
||||
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
|
||||
# Reset positions.
|
||||
if reset_position_ids:
|
||||
position_ids[b, (i + 1):] -= i + 1 - prev_index
|
||||
prev_index = i + 1
|
||||
|
||||
# Convert attention mask to binary:
|
||||
attention_mask = attention_mask < 0.5
|
||||
|
||||
return attention_mask, position_ids
|
||||
|
||||
|
||||
def get_batch(
|
||||
context_tokens,
|
||||
micro_batch_size,
|
||||
eod_token,
|
||||
reset_position_ids=False,
|
||||
reset_attention_mask=False,
|
||||
):
|
||||
"""Generate batch from context tokens."""
|
||||
tokens = context_tokens.view(micro_batch_size, -1).contiguous().cuda()
|
||||
# Get the attention mask and postition ids.
|
||||
attention_mask, position_ids = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
eod_token,
|
||||
reset_position_ids,
|
||||
reset_attention_mask,
|
||||
)
|
||||
|
||||
return tokens, attention_mask, position_ids
|
||||
|
||||
|
||||
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||
"""This function has been mostly taken from huggingface conversational
|
||||
ai code at
|
||||
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
|
||||
conversational-ai-with-transfer-learning-2d818ac26313"""
|
||||
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the
|
||||
# last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1,
|
||||
None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p > 0.0:
|
||||
# Cconvert to 1D
|
||||
sorted_logits, sorted_indices = torch.sort(
|
||||
logits, descending=True, dim=-1)
|
||||
cumulative_probs = torch.cumsum(
|
||||
F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token
|
||||
# above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
||||
..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
for i in range(sorted_indices.size(0)):
|
||||
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
|
||||
logits[i][indices_to_remove] = filter_value
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def pad_batch(batch, pad_id, seq_length):
|
||||
context_lengths = []
|
||||
for tokens in batch:
|
||||
context_length = len(tokens)
|
||||
if context_length < seq_length:
|
||||
tokens.extend([pad_id] * (seq_length - context_length))
|
||||
context_lengths.append(context_length)
|
||||
return batch, context_lengths
|
||||
|
||||
|
||||
def get_token_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
seq_length,
|
||||
out_seq_length,
|
||||
context_tokens,
|
||||
return_scores: bool = False,
|
||||
prompt_length: int = None,
|
||||
micro_batch_size: int = None,
|
||||
bad_ids: List = None,
|
||||
temperature: float = 1.0,
|
||||
topp: float = 1.0,
|
||||
topk: int = 0.0,
|
||||
greedy: bool = False,
|
||||
):
|
||||
context_tokens, context_lengths = pad_batch(context_tokens,
|
||||
tokenizer.eos_token_id,
|
||||
seq_length)
|
||||
|
||||
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
|
||||
context_length_tensor = torch.cuda.LongTensor(context_lengths)
|
||||
context_length = context_length_tensor.min().item()
|
||||
tokens, attention_mask, position_ids = get_batch(
|
||||
context_tokens_tensor,
|
||||
micro_batch_size,
|
||||
tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
batch_token_iterator = sample_sequence_batch(
|
||||
model,
|
||||
tokenizer,
|
||||
context_tokens_tensor,
|
||||
context_length_tensor,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
seq_length=seq_length,
|
||||
out_seq_length=out_seq_length,
|
||||
return_scores=return_scores,
|
||||
prompt_length=prompt_length,
|
||||
bad_ids=bad_ids,
|
||||
temperature=temperature,
|
||||
topp=topp,
|
||||
topk=topk,
|
||||
greedy=greedy,
|
||||
)
|
||||
|
||||
for tokens, lengths in batch_token_iterator:
|
||||
context_length += 1
|
||||
if tokens is not None:
|
||||
yield tokens[:, :context_length], lengths
|
||||
else:
|
||||
yield None, None
|
||||
|
||||
|
||||
def switch(val1, val2, boolean):
|
||||
boolean = boolean.type_as(val1)
|
||||
return (1 - boolean) * val1 + boolean * val2
|
||||
|
||||
|
||||
def sample_sequence_batch(
|
||||
model,
|
||||
tokenizer,
|
||||
context_tokens,
|
||||
context_lengths,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
seq_length,
|
||||
out_seq_length,
|
||||
maxlen=None,
|
||||
return_scores: bool = False,
|
||||
prompt_length: int = None,
|
||||
bad_ids: List = None,
|
||||
temperature: float = 1.0,
|
||||
topp: float = 1.0,
|
||||
topk: int = 0.0,
|
||||
recompute: bool = False,
|
||||
greedy: bool = False,
|
||||
):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
context_length = context_lengths.min().item()
|
||||
eos_id = tokenizer.eos_token_id
|
||||
|
||||
counter = 0
|
||||
org_context_length = context_length
|
||||
|
||||
layer_past = None
|
||||
batch_size = context_tokens.size(0)
|
||||
is_done = torch.zeros([batch_size]).byte().cuda()
|
||||
tokens = context_tokens
|
||||
if maxlen is None:
|
||||
maxlen = seq_length - 1
|
||||
if maxlen > (org_context_length + out_seq_length):
|
||||
maxlen = org_context_length + out_seq_length
|
||||
|
||||
lengths = torch.ones([batch_size]).long().cuda() * maxlen
|
||||
if return_scores:
|
||||
scores = torch.zeros([batch_size]).float().cuda()
|
||||
|
||||
while context_length <= (maxlen):
|
||||
|
||||
if recompute:
|
||||
logits = model(
|
||||
tokens,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length,
|
||||
)
|
||||
logits = logits[:, context_length - 1, :]
|
||||
else:
|
||||
if counter == 0:
|
||||
tokens2use = tokens[:, :context_length]
|
||||
positions2use = position_ids[:, :context_length]
|
||||
else:
|
||||
tokens2use = tokens[:, context_length - 1].view(
|
||||
batch_size, -1)
|
||||
positions2use = position_ids[:, context_length - 1].view(
|
||||
batch_size, -1)
|
||||
logits, layer_past = model(
|
||||
tokens2use,
|
||||
positions2use,
|
||||
attention_mask,
|
||||
layer_past=layer_past,
|
||||
get_key_value=True,
|
||||
prompt_length=prompt_length,
|
||||
context_length=context_length,
|
||||
)
|
||||
logits = logits[:, -1].view(batch_size, -1).contiguous()
|
||||
|
||||
if bad_ids is not None:
|
||||
for bad_id in bad_ids:
|
||||
logits[:, bad_id] = -10000
|
||||
if greedy:
|
||||
prev = torch.argmax(logits, dim=-1).view(-1)
|
||||
else:
|
||||
logits = logits.float()
|
||||
if return_scores:
|
||||
orig_log_probs = torch.log_softmax(logits, dim=-1)
|
||||
logits /= temperature
|
||||
logits = top_k_logits(logits, top_k=topk, top_p=topp)
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
|
||||
|
||||
started = context_lengths <= context_length
|
||||
|
||||
new_tokens = switch(tokens[:, context_length].view(-1), prev,
|
||||
started)
|
||||
|
||||
if not greedy and return_scores:
|
||||
indices = prev.view(-1, 1)
|
||||
new_scores = orig_log_probs.gather(1, indices).view(-1)
|
||||
new_scores = new_scores * started
|
||||
new_scores = new_scores * is_done.bool().logical_not()
|
||||
scores += new_scores
|
||||
|
||||
tokens[:, context_length] = new_tokens
|
||||
done_token = (prev == eos_id).byte() & started.byte()
|
||||
just_finished = (done_token & ~is_done).bool()
|
||||
lengths[just_finished.view(-1)] = context_length
|
||||
is_done = is_done | done_token
|
||||
done = torch.all(is_done)
|
||||
|
||||
if return_scores:
|
||||
yield tokens, (lengths, scores)
|
||||
else:
|
||||
yield tokens, lengths
|
||||
|
||||
context_length += 1
|
||||
counter += 1
|
||||
if done:
|
||||
break
|
||||
187
modelscope/models/nlp/codegeex/tokenizer.py
Executable file
187
modelscope/models/nlp/codegeex/tokenizer.py
Executable file
@@ -0,0 +1,187 @@
|
||||
# Copyright (c) 2022 Zhipu.AI
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||
|
||||
|
||||
def encode_whitespaces(text, start_extra_id: int, max_len: int):
|
||||
""" Encode whitespaces to extra tokens in GPT-J.
|
||||
|
||||
>>> encode_whitespaces('a\\n b\\n c', 10, 10)
|
||||
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
|
||||
"""
|
||||
|
||||
def push_acc_space(acc_len: int, text: str):
|
||||
if acc_len == 0:
|
||||
return text
|
||||
if acc_len == 1:
|
||||
return text + ' '
|
||||
assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}'
|
||||
extra_id = start_extra_id - 2 + acc_len
|
||||
extra_token = f'<|extratoken_{extra_id}|>'
|
||||
return text + extra_token
|
||||
|
||||
acc_len = 0
|
||||
res = ''
|
||||
for ch in text:
|
||||
if ch == ' ':
|
||||
acc_len += 1
|
||||
if acc_len == max_len:
|
||||
res = push_acc_space(acc_len, res)
|
||||
acc_len = 0
|
||||
else:
|
||||
res = push_acc_space(acc_len, res)
|
||||
acc_len = 0
|
||||
res = res + ch
|
||||
|
||||
res = push_acc_space(acc_len, res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def decode_whitespaces(text: str, start_extra_id: int, max_len: int):
|
||||
""" Decode the whitespace-encoded strings produced by encode_whitespace.
|
||||
|
||||
>>> text = 'a\\n b\\n c'
|
||||
>>> s, l = 10, 10
|
||||
>>> text == decode_whitespaces(encode_whitespaces(text, s, l), s, l)
|
||||
True
|
||||
"""
|
||||
for l in range(2, max_len + 1): # noqa
|
||||
token_id = start_extra_id - 2 + l
|
||||
token = f'<|extratoken_{token_id}|>'
|
||||
text = text.replace(token, ' ' * l)
|
||||
return text
|
||||
|
||||
|
||||
class Code13BDictionary(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dict_file: str,
|
||||
extra_token_ids: List[str] = None,
|
||||
pad_to_vocab_size: int = -1,
|
||||
):
|
||||
self._idx = dict()
|
||||
self._count = dict()
|
||||
self._num_symbols = 0
|
||||
self._symbols = []
|
||||
|
||||
self._add_symbol('<s>', 0)
|
||||
self._add_symbol('<pad>', 0)
|
||||
self._add_symbol('</s>', 0)
|
||||
self._add_symbol('<unk>', 0)
|
||||
self._load_dict(dict_file)
|
||||
|
||||
if extra_token_ids is None:
|
||||
extra_token_ids = [str(x) for x in range(50257, 50400)
|
||||
] # follows GPT-J settings
|
||||
|
||||
for token_id in extra_token_ids:
|
||||
self._add_symbol(token_id, 0)
|
||||
|
||||
if pad_to_vocab_size > 0:
|
||||
self._pad_to_vocab_size(pad_to_vocab_size)
|
||||
|
||||
def _pad_to_vocab_size(self, vocab_size: int):
|
||||
num_pad = vocab_size - len(self)
|
||||
if num_pad <= 0:
|
||||
return
|
||||
for i in range(1, num_pad + 1):
|
||||
self._add_symbol('vocab_pad_token{}'.format(i), 0)
|
||||
|
||||
def _load_dict(self, dict_file: str):
|
||||
with open(dict_file, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line == '' or line.startswith('#'):
|
||||
continue
|
||||
sym, count = line.split()
|
||||
self._add_symbol(sym, int(count))
|
||||
|
||||
def _add_symbol(self, sym: str, count: int):
|
||||
self._idx[sym] = self._num_symbols
|
||||
self._count[sym] = count
|
||||
self._symbols.append(sym)
|
||||
self._num_symbols += 1
|
||||
|
||||
def __len__(self):
|
||||
return self._num_symbols
|
||||
|
||||
def index(self, sym: str):
|
||||
return self._idx[sym]
|
||||
|
||||
def string(self, idx: int):
|
||||
return self._symbols[idx]
|
||||
|
||||
def map_token(self, token: Union[int, str]):
|
||||
if isinstance(token, int):
|
||||
token = str(token)
|
||||
return self.index(token)
|
||||
|
||||
def map_tokens(self, tokens):
|
||||
return [self.map_token(token) for token in tokens]
|
||||
|
||||
def decode_tokens(self, tokens):
|
||||
decoded = [
|
||||
'50256' if token == 50256 else self.string(token)
|
||||
for token in tokens
|
||||
]
|
||||
return [int(x) for x in decoded if not x.startswith('vocab_pad_token')]
|
||||
|
||||
|
||||
class CodeGeeXTokenizer(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: GPT2TokenizerFast = None,
|
||||
tokenizer_path: str = 'EleutherAI/gpt-j-6B',
|
||||
start_extra_id: int = 10,
|
||||
max_len: int = 10,
|
||||
mode='codegeex-13b',
|
||||
dict_file: str = None,
|
||||
):
|
||||
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
|
||||
tokenizer_path)
|
||||
if mode not in ['codegeex-13b', 'codegeex-python-13b']:
|
||||
raise ValueError(
|
||||
f"Invalid mode {mode}, choose from ['codegeex-13b', 'codegeex-python-13b']"
|
||||
)
|
||||
self.start_extra_id = start_extra_id
|
||||
self.max_len = max_len
|
||||
self.mode = mode
|
||||
if dict_file is not None:
|
||||
self.code_dict = Code13BDictionary(
|
||||
dict_file, pad_to_vocab_size=51200
|
||||
) if self.mode == 'codegeex-python-13b' else None
|
||||
else:
|
||||
self.code_dict = None
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
def encode_code(self, code: str):
|
||||
if self.mode == 'codegeex-13b':
|
||||
code = encode_whitespaces(code, self.start_extra_id, self.max_len)
|
||||
input_ids = self.tokenizer(
|
||||
code, is_split_into_words=False).input_ids
|
||||
|
||||
elif self.mode == 'codegeex-python-13b':
|
||||
code = encode_whitespaces(code, self.start_extra_id, self.max_len)
|
||||
input_ids = self.code_dict.map_tokens(self.tokenizer.encode(code))
|
||||
input_ids = torch.LongTensor(input_ids).reshape(1, -1)
|
||||
|
||||
return input_ids
|
||||
|
||||
def decode_code(self, input_ids):
|
||||
if self.mode == 'codegeex-13b':
|
||||
text = self.tokenizer.decode(input_ids, skip_special_tokens=False)
|
||||
output_code = decode_whitespaces(text, self.start_extra_id,
|
||||
self.max_len)
|
||||
elif self.mode == 'codegeex-python-13b':
|
||||
input_ids = [self.code_dict.decode_tokens(input_ids.tolist()[0])]
|
||||
text = self.tokenizer.decode(input_ids, skip_special_tokens=False)
|
||||
output_code = decode_whitespaces(text, self.start_extra_id,
|
||||
self.max_len)
|
||||
|
||||
return output_code
|
||||
@@ -32,6 +32,8 @@ if TYPE_CHECKING:
|
||||
from .word_segmentation_pipeline import WordSegmentationPipeline
|
||||
from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline
|
||||
from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline
|
||||
from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline
|
||||
from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline
|
||||
from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \
|
||||
WordSegmentationThaiPipeline
|
||||
|
||||
@@ -73,6 +75,10 @@ else:
|
||||
'zero_shot_classification_pipeline':
|
||||
['ZeroShotClassificationPipeline'],
|
||||
'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'],
|
||||
'codegeex_code_translation_pipeline':
|
||||
['CodeGeeXCodeTranslationPipeline'],
|
||||
'codegeex_code_generation_pipeline':
|
||||
['CodeGeeXCodeGenerationPipeline'],
|
||||
'multilingual_word_segmentation_pipeline': [
|
||||
'MultilingualWordSegmentationPipeline',
|
||||
'WordSegmentationThaiPipeline'
|
||||
|
||||
55
modelscope/pipelines/nlp/codegeex_code_generation_pipeline.py
Executable file
55
modelscope/pipelines/nlp/codegeex_code_generation_pipeline.py
Executable file
@@ -0,0 +1,55 @@
|
||||
# Copyright (c) 2022 Zhipu.AI
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.nlp import CodeGeeXForCodeGeneration
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
group_key=Tasks.code_generation,
|
||||
module_name=Pipelines.codegeex_code_generation)
|
||||
class CodeGeeXCodeGenerationPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[CodeGeeXForCodeGeneration, str],
|
||||
preprocessor: [Preprocessor] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
model = CodeGeeXForCodeGeneration(model) if isinstance(model,
|
||||
str) else model
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.model.half()
|
||||
self.model.cuda()
|
||||
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
# define the forward pass
|
||||
def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]:
|
||||
# check input format
|
||||
for para in ['prompt', 'language']:
|
||||
if para not in inputs:
|
||||
raise Exception('Please check your input format.')
|
||||
if inputs['language'] not in [
|
||||
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
|
||||
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
|
||||
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
|
||||
'Pascal', 'R', 'Fortran', 'Lean'
|
||||
]: # noqa
|
||||
raise Exception(
|
||||
'Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
|
||||
) # noqa
|
||||
|
||||
return self.model(inputs)
|
||||
|
||||
# format the outputs from pipeline
|
||||
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
|
||||
return input
|
||||
65
modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py
Executable file
65
modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py
Executable file
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2022 Zhipu.AI
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.nlp import CodeGeeXForCodeTranslation
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
group_key=Tasks.code_translation,
|
||||
module_name=Pipelines.codegeex_code_translation)
|
||||
class CodeGeeXCodeTranslationPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[CodeGeeXForCodeTranslation, str],
|
||||
preprocessor: [Preprocessor] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
model = CodeGeeXForCodeTranslation(model) if isinstance(model,
|
||||
str) else model
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.model.half()
|
||||
self.model.cuda()
|
||||
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
# define the forward pass
|
||||
def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]:
|
||||
# check input format
|
||||
for para in ['prompt', 'source language', 'target language']:
|
||||
if para not in inputs:
|
||||
raise Exception('please check your input format.')
|
||||
if inputs['source language'] not in [
|
||||
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
|
||||
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
|
||||
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
|
||||
'Pascal', 'R', 'Fortran', 'Lean'
|
||||
]:
|
||||
raise Exception(
|
||||
'Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
|
||||
) # noqa
|
||||
|
||||
if inputs['target language'] not in [
|
||||
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
|
||||
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
|
||||
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
|
||||
'Pascal', 'R', 'Fortran', 'Lean'
|
||||
]:
|
||||
raise Exception(
|
||||
'Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
|
||||
) # noqa
|
||||
|
||||
return self.model(inputs)
|
||||
|
||||
# format the outputs from pipeline
|
||||
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
|
||||
return input
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
SentenceEmbeddingPreprocessor, SequenceClassificationPreprocessor,
|
||||
TokenClassificationPreprocessor, TextErrorCorrectionPreprocessor,
|
||||
TextGenerationPreprocessor, Text2TextGenerationPreprocessor, Tokenize,
|
||||
WordSegmentationBlankSetToLabelPreprocessor,
|
||||
WordSegmentationBlankSetToLabelPreprocessor, CodeGeeXPreprocessor,
|
||||
MGLMSummarizationPreprocessor, ZeroShotClassificationPreprocessor,
|
||||
TextGenerationJiebaPreprocessor, SentencePiecePreprocessor,
|
||||
DialogIntentPredictionPreprocessor, DialogModelingPreprocessor,
|
||||
@@ -57,7 +57,7 @@ else:
|
||||
'TextErrorCorrectionPreprocessor', 'TextGenerationPreprocessor',
|
||||
'Tokenize', 'Text2TextGenerationPreprocessor',
|
||||
'WordSegmentationBlankSetToLabelPreprocessor',
|
||||
'MGLMSummarizationPreprocessor',
|
||||
'MGLMSummarizationPreprocessor', 'CodeGeeXPreprocessor',
|
||||
'ZeroShotClassificationPreprocessor',
|
||||
'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor',
|
||||
'NERPreprocessorViet', 'NERPreprocessorThai',
|
||||
|
||||
@@ -120,6 +120,8 @@ class NLPTasks(object):
|
||||
fill_mask = 'fill-mask'
|
||||
text_summarization = 'text-summarization'
|
||||
question_answering = 'question-answering'
|
||||
code_translation = 'code-translation'
|
||||
code_generation = 'code-generation'
|
||||
zero_shot_classification = 'zero-shot-classification'
|
||||
backbone = 'backbone'
|
||||
text_error_correction = 'text-error-correction'
|
||||
|
||||
Reference in New Issue
Block a user