Add bloom text generation model (#492)

* add bloom text-generation

* pre-commit passed
This commit is contained in:
tastelikefeet
2023-08-22 16:56:15 +08:00
committed by GitHub
parent cccd502fa2
commit df53a6a89f
3 changed files with 20 additions and 3 deletions

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
BertConfig,
SiameseUieModel,
)
from .bloom import BloomModel
from .bloom import BloomModel, BloomForTextGeneration
from .codegeex import CodeGeeXForCodeTranslation, CodeGeeXForCodeGeneration
from .glm_130b import GLM130bForTextGeneration
from .csanmt import CsanmtForTranslation
@@ -79,7 +79,6 @@ if TYPE_CHECKING:
from .llama import LlamaForTextGeneration, LlamaConfig, LlamaModel, LlamaTokenizer, LlamaTokenizerFast
from .llama2 import Llama2ForTextGeneration, Llama2Config, Llama2Model, Llama2Tokenizer, Llama2TokenizerFast
from .qwen import QWenForTextGeneration, QWenConfig, QWenModel, QWenTokenizer
else:
_import_structure = {
'bart': ['BartForTextErrorCorrection'],
@@ -94,7 +93,7 @@ else:
'BertConfig',
'SiameseUieModel',
],
'bloom': ['BloomModel'],
'bloom': ['BloomModel', 'BloomForTextGeneration'],
'csanmt': ['CsanmtForTranslation'],
'canmt': ['CanmtForTranslation'],
'polylm': ['PolyLMForTextGeneration'],

View File

@@ -5,9 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .backbone import BloomModel
from .text_generation import BloomForTextGeneration
else:
_import_structure = {
'backbone': ['BloomModel'],
'text_generation': ['BloomForTextGeneration'],
}
import sys
sys.modules[__name__] = LazyImportModule(

View File

@@ -0,0 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from transformers import BloomConfig
from transformers import BloomForCausalLM as BloomForCausalLMTransform
from modelscope.metainfo import Models
from modelscope.models import MODELS
from modelscope.utils.constant import Tasks
@MODELS.register_module(
group_key=Tasks.text_generation, module_name=Models.bloom)
class BloomForTextGeneration(BloomForCausalLMTransform):
def __init__(self, **kwargs):
config = BloomConfig(**kwargs)
super().__init__(config)