fix bug and run

This commit is contained in:
yzhao
2022-08-25 20:04:21 +08:00
parent e1f13bcf7a
commit fc6f292fef
3 changed files with 7 additions and 7 deletions

View File

@@ -42,7 +42,7 @@ class DistributedPlug(TorchModel):
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])))
if self.config.deepspeed and self.args.fp16:
if self.config.deepspeed and self.config.fp16:
model.half()
# GPU allocation.

View File

@@ -22,18 +22,18 @@ class PlugForTextGeneration(DistributedTorchModel):
self.cls_token_id = cls_token_id
def _forward_one(self, input: Dict[str, Any]) -> Dict[str, Tensor]:
return self.model(**input)
return self.__class__.model(input)
def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
batch_size = input['input_ids'].shape[0]
dec_input_ids = torch.full([batch_size, 1], self.cls_token_id, dtype=torch.long)
input["dec_input_ids"] = dec_input_ids
res = self.model_pool.map(DistributedPlug.forward, [input]*self.world_size)
return res[0]
res = self.forward(input)
return res
def _instantiate_one(self, rank, model_dir):
cfg = read_config(model_dir)
self.model = DistributedPlug(model_dir, rank, **cfg.model)
self.__class__.model = DistributedPlug(model_dir, rank, **cfg.model)

View File

@@ -52,8 +52,8 @@ def initialize_distributed(rank, mpu, world_size, model_parallel_size):
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', '127.0.0.1')
master_port = os.getenv('MASTER_PORT', '29500')
if not _is_free_port(int(master_port)):
master_port = str(_find_free_port())
# if not _is_free_port(int(master_port)):
# master_port = str(_find_free_port())
init_method += master_ip + ':' + master_port
init_dist('pytorch', world_size=world_size, rank=rank, init_method=init_method)
# Set the model-parallel communicators.