This commit is contained in:
shuaigezhu
2022-11-24 11:20:25 +08:00
parent f171552ee3
commit 1ab8a1f764
5 changed files with 11 additions and 10 deletions

View File

@@ -98,7 +98,9 @@ class CodeGeeXForCodeTranslation(TorchModel):
generated_code = tokenizer.decode_code(
generated_tokens_[n_token_prompt:])
generated_code = ''.join(generated_code)
logger.info('================================= Generated code:')
logger.info(
'================================= Generated code:'
)
logger.info(generated_code)
if all(is_finished):
break

View File

@@ -1,8 +1,9 @@
# Copyright (c) 2022 Zhipu.AI
from typing import List
import torch
import torch.nn.functional as F
from typing import List
def get_ltor_masks_and_position_ids(
@@ -124,7 +125,7 @@ def pad_batch(batch, pad_id, seq_length):
tokens.extend([pad_id] * (seq_length - context_length))
context_lengths.append(context_length)
return batch, context_lengths
def get_token_stream(
model,

View File

@@ -1,8 +1,9 @@
# Copyright (c) 2022 Zhipu.AI
from typing import List, Union
import torch
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast
from typing import List, Union
def encode_whitespaces(text, start_extra_id: int, max_len: int):

View File

@@ -28,9 +28,9 @@ class CodeGeeXCodeTranslationPipeline(Pipeline):
self.model.cuda()
super().__init__(model=model, **kwargs)
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
return inputs
return inputs
# define the forward pass
def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]:

View File

@@ -17,10 +17,7 @@ class CodeGeeXCodeTranslationTest(unittest.TestCase, DemoCompatibilityCheck):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_CodeGeeX_with_name(self):
model = 'ZhipuAI/CodeGeeX-Code-Translation-13B'
pipe = pipeline(
task=Tasks.code_translation,
model=model
)
pipe = pipeline(task=Tasks.code_translation, model=model)
inputs = {
'prompt': 'for i in range(10):\n\tprint(i)\n',
'source language': 'Python',