mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
updated
This commit is contained in:
@@ -98,7 +98,9 @@ class CodeGeeXForCodeTranslation(TorchModel):
|
||||
generated_code = tokenizer.decode_code(
|
||||
generated_tokens_[n_token_prompt:])
|
||||
generated_code = ''.join(generated_code)
|
||||
logger.info('================================= Generated code:')
|
||||
logger.info(
|
||||
'================================= Generated code:'
|
||||
)
|
||||
logger.info(generated_code)
|
||||
if all(is_finished):
|
||||
break
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# Copyright (c) 2022 Zhipu.AI
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_ltor_masks_and_position_ids(
|
||||
@@ -124,7 +125,7 @@ def pad_batch(batch, pad_id, seq_length):
|
||||
tokens.extend([pad_id] * (seq_length - context_length))
|
||||
context_lengths.append(context_length)
|
||||
return batch, context_lengths
|
||||
|
||||
|
||||
|
||||
def get_token_stream(
|
||||
model,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# Copyright (c) 2022 Zhipu.AI
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
def encode_whitespaces(text, start_extra_id: int, max_len: int):
|
||||
|
||||
@@ -28,9 +28,9 @@ class CodeGeeXCodeTranslationPipeline(Pipeline):
|
||||
self.model.cuda()
|
||||
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
|
||||
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
|
||||
return inputs
|
||||
return inputs
|
||||
|
||||
# define the forward pass
|
||||
def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]:
|
||||
|
||||
@@ -17,10 +17,7 @@ class CodeGeeXCodeTranslationTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_CodeGeeX_with_name(self):
|
||||
model = 'ZhipuAI/CodeGeeX-Code-Translation-13B'
|
||||
pipe = pipeline(
|
||||
task=Tasks.code_translation,
|
||||
model=model
|
||||
)
|
||||
pipe = pipeline(task=Tasks.code_translation, model=model)
|
||||
inputs = {
|
||||
'prompt': 'for i in range(10):\n\tprint(i)\n',
|
||||
'source language': 'Python',
|
||||
|
||||
Reference in New Issue
Block a user