diff --git a/modelscope/models/nlp/plug/distributed_plug.py b/modelscope/models/nlp/plug/distributed_plug.py index edfe5a4d..43501650 100644 --- a/modelscope/models/nlp/plug/distributed_plug.py +++ b/modelscope/models/nlp/plug/distributed_plug.py @@ -42,7 +42,7 @@ class DistributedPlug(TorchModel): mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()]))) - if self.config.deepspeed and self.args.fp16: + if self.config.deepspeed and self.config.fp16: model.half() # GPU allocation. diff --git a/modelscope/models/nlp/plug/plug_for_text_generation.py b/modelscope/models/nlp/plug/plug_for_text_generation.py index 5dc3bbe0..35565cc9 100644 --- a/modelscope/models/nlp/plug/plug_for_text_generation.py +++ b/modelscope/models/nlp/plug/plug_for_text_generation.py @@ -22,18 +22,18 @@ class PlugForTextGeneration(DistributedTorchModel): self.cls_token_id = cls_token_id def _forward_one(self, input: Dict[str, Any]) -> Dict[str, Tensor]: - return self.model(**input) + return self.__class__.model(input) def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: batch_size = input['input_ids'].shape[0] dec_input_ids = torch.full([batch_size, 1], self.cls_token_id, dtype=torch.long) input["dec_input_ids"] = dec_input_ids - res = self.model_pool.map(DistributedPlug.forward, [input]*self.world_size) - return res[0] + res = self.forward(input) + return res def _instantiate_one(self, rank, model_dir): cfg = read_config(model_dir) - self.model = DistributedPlug(model_dir, rank, **cfg.model) + self.__class__.model = DistributedPlug(model_dir, rank, **cfg.model) diff --git a/modelscope/models/nlp/utils/distributed.py b/modelscope/models/nlp/utils/distributed.py index d4302941..c934dddf 100644 --- a/modelscope/models/nlp/utils/distributed.py +++ b/modelscope/models/nlp/utils/distributed.py @@ -52,8 +52,8 @@ def initialize_distributed(rank, mpu, world_size, model_parallel_size): init_method = 'tcp://' master_ip = os.getenv('MASTER_ADDR', '127.0.0.1') master_port = os.getenv('MASTER_PORT', '29500') - if not _is_free_port(int(master_port)): - master_port = str(_find_free_port()) + # if not _is_free_port(int(master_port)): + # master_port = str(_find_free_port()) init_method += master_ip + ':' + master_port init_dist('pytorch', world_size=world_size, rank=rank, init_method=init_method) # Set the model-parallel communicators.