This commit is contained in:
yzhao
2022-08-25 18:39:48 +08:00
parent 5e3446db4b
commit e1f13bcf7a
6 changed files with 25 additions and 18 deletions

View File

@@ -8,12 +8,18 @@ from modelscope.models.nlp.structbert import SbertTokenizer
from modelscope.models.nlp.utils.distributed import DistributedTorchModel
from . import DistributedPlug
from ...base import Tensor
from modelscope.utils.hub import read_config
__all__ = ['PlugForTextGeneration']
@MODELS.register_module(Tasks.text_generation, module_name=Models.plug)
class PlugForTextGeneration(DistributedTorchModel):
def __init__(self, model_dir, cls_token_id, **kwargs):
assert cls_token_id is not None
super().__init__(model_dir, **kwargs)
self.cls_token_id = cls_token_id
def _forward_one(self, input: Dict[str, Any]) -> Dict[str, Tensor]:
return self.model(**input)
@@ -25,10 +31,9 @@ class PlugForTextGeneration(DistributedTorchModel):
res = self.model_pool.map(DistributedPlug.forward, [input]*self.world_size)
return res[0]
def _instantiate_one(self, model_dir, rank):
tokenizer = SbertTokenizer.from_pretrained(model_dir)
self.cls_token_id = tokenizer.cls_token_id
self.model = DistributedPlug.instantiate(model_dir, rank)
def _instantiate_one(self, rank, model_dir):
cfg = read_config(model_dir)
self.model = DistributedPlug(model_dir, rank, **cfg.model)

View File

@@ -15,19 +15,24 @@ class DistributedTorchModel(Model):
super().__init__(model_dir, *args, **kwargs)
self.model_pool = None
self.world_size = None
def __getstate__(self):
self_dict = self.__dict__.copy()
del self_dict['model_pool']
return self_dict
@classmethod
def _instantiate(cls, model_dir):
model = DistributedTorchModel(model_dir=model_dir)
def _instantiate(cls, model_dir, **kwargs):
model = cls(model_dir=model_dir, **kwargs)
torch.multiprocessing.set_start_method("spawn")
cfg = read_config(model_dir)
model.world_size = cfg.model.word_size
model.world_size = cfg.model.world_size
ranks = list(range(model.world_size))
model.model_pool = Pool(model.world_size)
model.model_pool.map(partial(model._instantiate_one, model_dir=model_dir), ranks)
return model
def _instantiate_one(self, model_dir, rank):
def _instantiate_one(self, rank, model_dir):
pass
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:

View File

@@ -55,7 +55,7 @@ class TextGenerationPipeline(Pipeline):
first_sequence=first_sequence,
second_sequence=None,
sequence_length=kwargs.pop('sequence_length', 128))
model.eval()
# model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.tokenizer = preprocessor.tokenizer

View File

@@ -18,8 +18,7 @@ import torch.nn as nn
from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from sofa.utils import mpu
from modelscope.utils.nlp import mpu
class tofp16(nn.Module):
"""

View File

@@ -48,10 +48,6 @@ def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
def _init_dist_pytorch(backend: str, **kwargs) -> None:
# rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
#torch.cuda.set_device(local_rank)
dist.init_process_group(backend=backend, **kwargs)

View File

@@ -9,6 +9,7 @@ from modelscope.pipelines.nlp import TextGenerationPipeline
from modelscope.preprocessors import TextGenerationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
from transformers import BertTokenizer
class TextGenerationTest(unittest.TestCase):
@@ -38,11 +39,12 @@ class TextGenerationTest(unittest.TestCase):
self.plug_input = '段誉轻挥折扇,摇了摇头,说'
def test_plug(self):
cache_path = "/home/suluyan.sly/model/plug_model"
model = PlugForTextGeneration(cache_path)
cache_path = "/home/yuze.zyz/MaaS-lib/plug_model"
tokenizer = BertTokenizer.from_pretrained(cache_path)
model = PlugForTextGeneration._instantiate(cache_path, cls_token_id=tokenizer.cls_token_id)
preprocessor = TextGenerationPreprocessor(
cache_path,
model.tokenizer,
tokenizer,
first_sequence='sentence',
second_sequence=None)
pipeline1 = TextGenerationPipeline(model, preprocessor)