mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
add code_generation files
This commit is contained in:
@@ -109,7 +109,8 @@ else:
|
||||
'sentence_embedding': ['SentenceEmbedding'],
|
||||
'T5': ['T5ForConditionalGeneration'],
|
||||
'mglm': ['MGLMForTextSummarization'],
|
||||
'codegeex': ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'],
|
||||
'codegeex':
|
||||
['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'],
|
||||
'gpt_neo': ['GPTNeoModel'],
|
||||
'bloom': ['BloomModel'],
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ class CodeGeeXForCodeGeneration(TorchModel):
|
||||
bad_ids = None
|
||||
lang = input['language']
|
||||
prompt = input['prompt']
|
||||
prompt = f"# language: {lang}\n{prompt}"
|
||||
prompt = f'# language: {lang}\n{prompt}'
|
||||
logger = get_logger()
|
||||
tokenizer = self.tokenizer
|
||||
model = self.model
|
||||
@@ -83,8 +83,7 @@ class CodeGeeXForCodeGeneration(TorchModel):
|
||||
topk=1,
|
||||
topp=0.9,
|
||||
temperature=0.9,
|
||||
greedy=True
|
||||
)
|
||||
greedy=True)
|
||||
is_finished = [False for _ in range(micro_batch_size)]
|
||||
for i, generated in enumerate(token_stream):
|
||||
generated_tokens = generated[0]
|
||||
|
||||
@@ -21,7 +21,7 @@ class CodeGeeXCodeGenerationPipeline(Pipeline):
|
||||
*args,
|
||||
**kwargs):
|
||||
model = CodeGeeXForCodeGeneration(model) if isinstance(model,
|
||||
str) else model
|
||||
str) else model
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.model.half()
|
||||
@@ -38,8 +38,15 @@ class CodeGeeXCodeGenerationPipeline(Pipeline):
|
||||
for para in ['prompt', 'language']:
|
||||
if para not in inputs:
|
||||
raise Exception('Please check your input format.')
|
||||
if inputs['language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
|
||||
raise Exception('Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
|
||||
if inputs['language'] not in [
|
||||
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
|
||||
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
|
||||
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
|
||||
'Pascal', 'R', 'Fortran', 'Lean'
|
||||
]: # noqa
|
||||
raise Exception(
|
||||
'Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
|
||||
) # noqa
|
||||
|
||||
return self.model(inputs)
|
||||
|
||||
|
||||
@@ -38,11 +38,25 @@ class CodeGeeXCodeTranslationPipeline(Pipeline):
|
||||
for para in ['prompt', 'source language', 'target language']:
|
||||
if para not in inputs:
|
||||
raise Exception('please check your input format.')
|
||||
if inputs['source language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
|
||||
raise Exception('Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
|
||||
if inputs['source language'] not in [
|
||||
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
|
||||
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
|
||||
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
|
||||
'Pascal', 'R', 'Fortran', 'Lean'
|
||||
]:
|
||||
raise Exception(
|
||||
'Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
|
||||
) # noqa
|
||||
|
||||
if inputs['target language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
|
||||
raise Exception('Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
|
||||
if inputs['target language'] not in [
|
||||
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
|
||||
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
|
||||
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
|
||||
'Pascal', 'R', 'Fortran', 'Lean'
|
||||
]:
|
||||
raise Exception(
|
||||
'Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
|
||||
) # noqa
|
||||
|
||||
return self.model(inputs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user