mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
fix bugs
This commit is contained in:
@@ -8,12 +8,18 @@ from modelscope.models.nlp.structbert import SbertTokenizer
|
||||
from modelscope.models.nlp.utils.distributed import DistributedTorchModel
|
||||
from . import DistributedPlug
|
||||
from ...base import Tensor
|
||||
from modelscope.utils.hub import read_config
|
||||
|
||||
__all__ = ['PlugForTextGeneration']
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.text_generation, module_name=Models.plug)
|
||||
class PlugForTextGeneration(DistributedTorchModel):
|
||||
|
||||
def __init__(self, model_dir, cls_token_id, **kwargs):
|
||||
assert cls_token_id is not None
|
||||
super().__init__(model_dir, **kwargs)
|
||||
self.cls_token_id = cls_token_id
|
||||
|
||||
def _forward_one(self, input: Dict[str, Any]) -> Dict[str, Tensor]:
|
||||
return self.model(**input)
|
||||
@@ -25,10 +31,9 @@ class PlugForTextGeneration(DistributedTorchModel):
|
||||
res = self.model_pool.map(DistributedPlug.forward, [input]*self.world_size)
|
||||
return res[0]
|
||||
|
||||
def _instantiate_one(self, model_dir, rank):
|
||||
tokenizer = SbertTokenizer.from_pretrained(model_dir)
|
||||
self.cls_token_id = tokenizer.cls_token_id
|
||||
self.model = DistributedPlug.instantiate(model_dir, rank)
|
||||
def _instantiate_one(self, rank, model_dir):
|
||||
cfg = read_config(model_dir)
|
||||
self.model = DistributedPlug(model_dir, rank, **cfg.model)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -15,19 +15,24 @@ class DistributedTorchModel(Model):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.model_pool = None
|
||||
self.world_size = None
|
||||
|
||||
def __getstate__(self):
|
||||
self_dict = self.__dict__.copy()
|
||||
del self_dict['model_pool']
|
||||
return self_dict
|
||||
|
||||
@classmethod
|
||||
def _instantiate(cls, model_dir):
|
||||
model = DistributedTorchModel(model_dir=model_dir)
|
||||
def _instantiate(cls, model_dir, **kwargs):
|
||||
model = cls(model_dir=model_dir, **kwargs)
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
cfg = read_config(model_dir)
|
||||
model.world_size = cfg.model.word_size
|
||||
model.world_size = cfg.model.world_size
|
||||
ranks = list(range(model.world_size))
|
||||
model.model_pool = Pool(model.world_size)
|
||||
model.model_pool.map(partial(model._instantiate_one, model_dir=model_dir), ranks)
|
||||
return model
|
||||
|
||||
def _instantiate_one(self, model_dir, rank):
|
||||
def _instantiate_one(self, rank, model_dir):
|
||||
pass
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@@ -55,7 +55,7 @@ class TextGenerationPipeline(Pipeline):
|
||||
first_sequence=first_sequence,
|
||||
second_sequence=None,
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
model.eval()
|
||||
# model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.tokenizer = preprocessor.tokenizer
|
||||
|
||||
|
||||
@@ -18,8 +18,7 @@ import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from sofa.utils import mpu
|
||||
|
||||
from modelscope.utils.nlp import mpu
|
||||
|
||||
class tofp16(nn.Module):
|
||||
"""
|
||||
|
||||
@@ -48,10 +48,6 @@ def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
|
||||
|
||||
|
||||
def _init_dist_pytorch(backend: str, **kwargs) -> None:
|
||||
# rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
|
||||
#torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from modelscope.pipelines.nlp import TextGenerationPipeline
|
||||
from modelscope.preprocessors import TextGenerationPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
class TextGenerationTest(unittest.TestCase):
|
||||
@@ -38,11 +39,12 @@ class TextGenerationTest(unittest.TestCase):
|
||||
self.plug_input = '段誉轻挥折扇,摇了摇头,说'
|
||||
|
||||
def test_plug(self):
|
||||
cache_path = "/home/suluyan.sly/model/plug_model"
|
||||
model = PlugForTextGeneration(cache_path)
|
||||
cache_path = "/home/yuze.zyz/MaaS-lib/plug_model"
|
||||
tokenizer = BertTokenizer.from_pretrained(cache_path)
|
||||
model = PlugForTextGeneration._instantiate(cache_path, cls_token_id=tokenizer.cls_token_id)
|
||||
preprocessor = TextGenerationPreprocessor(
|
||||
cache_path,
|
||||
model.tokenizer,
|
||||
tokenizer,
|
||||
first_sequence='sentence',
|
||||
second_sequence=None)
|
||||
pipeline1 = TextGenerationPipeline(model, preprocessor)
|
||||
|
||||
Reference in New Issue
Block a user