diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 5f8b88f9..3d4f8c7d 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -109,7 +109,8 @@ else: 'sentence_embedding': ['SentenceEmbedding'], 'T5': ['T5ForConditionalGeneration'], 'mglm': ['MGLMForTextSummarization'], - 'codegeex': ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], + 'codegeex': + ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], 'gpt_neo': ['GPTNeoModel'], 'bloom': ['BloomModel'], } diff --git a/modelscope/models/nlp/codegeex/codegeex_for_code_generation.py b/modelscope/models/nlp/codegeex/codegeex_for_code_generation.py index dbe6d4a4..ff191cba 100755 --- a/modelscope/models/nlp/codegeex/codegeex_for_code_generation.py +++ b/modelscope/models/nlp/codegeex/codegeex_for_code_generation.py @@ -65,7 +65,7 @@ class CodeGeeXForCodeGeneration(TorchModel): bad_ids = None lang = input['language'] prompt = input['prompt'] - prompt = f"# language: {lang}\n{prompt}" + prompt = f'# language: {lang}\n{prompt}' logger = get_logger() tokenizer = self.tokenizer model = self.model @@ -83,8 +83,7 @@ class CodeGeeXForCodeGeneration(TorchModel): topk=1, topp=0.9, temperature=0.9, - greedy=True - ) + greedy=True) is_finished = [False for _ in range(micro_batch_size)] for i, generated in enumerate(token_stream): generated_tokens = generated[0] diff --git a/modelscope/pipelines/nlp/codegeex_code_generation_pipeline.py b/modelscope/pipelines/nlp/codegeex_code_generation_pipeline.py index 2eaebca3..f23461b1 100755 --- a/modelscope/pipelines/nlp/codegeex_code_generation_pipeline.py +++ b/modelscope/pipelines/nlp/codegeex_code_generation_pipeline.py @@ -21,7 +21,7 @@ class CodeGeeXCodeGenerationPipeline(Pipeline): *args, **kwargs): model = CodeGeeXForCodeGeneration(model) if isinstance(model, - str) else model + str) else model self.model = model self.model.eval() self.model.half() @@ -38,8 +38,15 @@ class CodeGeeXCodeGenerationPipeline(Pipeline): for para in ['prompt', 'language']: if para not in inputs: raise Exception('Please check your input format.') - if inputs['language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa - raise Exception('Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa + if inputs['language'] not in [ + 'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++', + 'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript', + 'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin', + 'Pascal', 'R', 'Fortran', 'Lean' + ]: # noqa + raise Exception( + 'Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa + ) # noqa return self.model(inputs) diff --git a/modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py b/modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py index 61be5620..8bd5a6da 100755 --- a/modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py +++ b/modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py @@ -38,11 +38,25 @@ class CodeGeeXCodeTranslationPipeline(Pipeline): for para in ['prompt', 'source language', 'target language']: if para not in inputs: raise Exception('please check your input format.') - if inputs['source language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa - raise Exception('Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa + if inputs['source language'] not in [ + 'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++', + 'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript', + 'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin', + 'Pascal', 'R', 'Fortran', 'Lean' + ]: + raise Exception( + 'Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa + ) # noqa - if inputs['target language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa - raise Exception('Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa + if inputs['target language'] not in [ + 'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++', + 'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript', + 'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin', + 'Pascal', 'R', 'Fortran', 'Lean' + ]: + raise Exception( + 'Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa + ) # noqa return self.model(inputs)