mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
fix bug and run
This commit is contained in:
@@ -42,7 +42,7 @@ class DistributedPlug(TorchModel):
|
||||
mpu.get_model_parallel_rank(),
|
||||
sum([p.nelement() for p in model.parameters()])))
|
||||
|
||||
if self.config.deepspeed and self.args.fp16:
|
||||
if self.config.deepspeed and self.config.fp16:
|
||||
model.half()
|
||||
|
||||
# GPU allocation.
|
||||
|
||||
@@ -22,18 +22,18 @@ class PlugForTextGeneration(DistributedTorchModel):
|
||||
self.cls_token_id = cls_token_id
|
||||
|
||||
def _forward_one(self, input: Dict[str, Any]) -> Dict[str, Tensor]:
|
||||
return self.model(**input)
|
||||
return self.__class__.model(input)
|
||||
|
||||
def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
batch_size = input['input_ids'].shape[0]
|
||||
dec_input_ids = torch.full([batch_size, 1], self.cls_token_id, dtype=torch.long)
|
||||
input["dec_input_ids"] = dec_input_ids
|
||||
res = self.model_pool.map(DistributedPlug.forward, [input]*self.world_size)
|
||||
return res[0]
|
||||
res = self.forward(input)
|
||||
return res
|
||||
|
||||
def _instantiate_one(self, rank, model_dir):
|
||||
cfg = read_config(model_dir)
|
||||
self.model = DistributedPlug(model_dir, rank, **cfg.model)
|
||||
self.__class__.model = DistributedPlug(model_dir, rank, **cfg.model)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -52,8 +52,8 @@ def initialize_distributed(rank, mpu, world_size, model_parallel_size):
|
||||
init_method = 'tcp://'
|
||||
master_ip = os.getenv('MASTER_ADDR', '127.0.0.1')
|
||||
master_port = os.getenv('MASTER_PORT', '29500')
|
||||
if not _is_free_port(int(master_port)):
|
||||
master_port = str(_find_free_port())
|
||||
# if not _is_free_port(int(master_port)):
|
||||
# master_port = str(_find_free_port())
|
||||
init_method += master_ip + ':' + master_port
|
||||
init_dist('pytorch', world_size=world_size, rank=rank, init_method=init_method)
|
||||
# Set the model-parallel communicators.
|
||||
|
||||
Reference in New Issue
Block a user