add code_generation files

This commit is contained in:
shuaigezhu
2022-11-25 16:41:44 +08:00
parent c9064caa58
commit 028551cd62
4 changed files with 32 additions and 11 deletions

View File

@@ -109,7 +109,8 @@ else:
'sentence_embedding': ['SentenceEmbedding'],
'T5': ['T5ForConditionalGeneration'],
'mglm': ['MGLMForTextSummarization'],
'codegeex': ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'],
'codegeex':
['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'],
'gpt_neo': ['GPTNeoModel'],
'bloom': ['BloomModel'],
}

View File

@@ -65,7 +65,7 @@ class CodeGeeXForCodeGeneration(TorchModel):
bad_ids = None
lang = input['language']
prompt = input['prompt']
prompt = f"# language: {lang}\n{prompt}"
prompt = f'# language: {lang}\n{prompt}'
logger = get_logger()
tokenizer = self.tokenizer
model = self.model
@@ -83,8 +83,7 @@ class CodeGeeXForCodeGeneration(TorchModel):
topk=1,
topp=0.9,
temperature=0.9,
greedy=True
)
greedy=True)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]

View File

@@ -21,7 +21,7 @@ class CodeGeeXCodeGenerationPipeline(Pipeline):
*args,
**kwargs):
model = CodeGeeXForCodeGeneration(model) if isinstance(model,
str) else model
str) else model
self.model = model
self.model.eval()
self.model.half()
@@ -38,8 +38,15 @@ class CodeGeeXCodeGenerationPipeline(Pipeline):
for para in ['prompt', 'language']:
if para not in inputs:
raise Exception('Please check your input format.')
if inputs['language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
raise Exception('Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
if inputs['language'] not in [
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
'Pascal', 'R', 'Fortran', 'Lean'
]: # noqa
raise Exception(
'Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
) # noqa
return self.model(inputs)

View File

@@ -38,11 +38,25 @@ class CodeGeeXCodeTranslationPipeline(Pipeline):
for para in ['prompt', 'source language', 'target language']:
if para not in inputs:
raise Exception('please check your input format.')
if inputs['source language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
raise Exception('Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
if inputs['source language'] not in [
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
'Pascal', 'R', 'Fortran', 'Lean'
]:
raise Exception(
'Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
) # noqa
if inputs['target language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa
raise Exception('Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa
if inputs['target language'] not in [
'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++',
'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript',
'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin',
'Pascal', 'R', 'Fortran', 'Lean'
]:
raise Exception(
'Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa
) # noqa
return self.model(inputs)