mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
[to #42322933] fix for plug
1. update pipeline for new preprocessor 2. update trainer for dist_info (remove megatron-ddp wapper) Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10969594
This commit is contained in:
@@ -65,8 +65,7 @@ class DistributedPlugPipeline(DistributedPipeline):
|
||||
sequence_length=sequence_length,
|
||||
**kwargs)
|
||||
super().__init__(model, preprocessor=preprocessor, **kwargs)
|
||||
assert hasattr(preprocessor, 'tokenizer')
|
||||
self.cls_token_id = preprocessor.tokenizer.cls_token_id
|
||||
self.cls_token_id = preprocessor.nlp_tokenizer.tokenizer.cls_token_id
|
||||
|
||||
@classmethod
|
||||
def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -105,6 +104,6 @@ class DistributedPlugPipeline(DistributedPipeline):
|
||||
from modelscope.outputs import OutputKeys
|
||||
generate_context = inputs['generate_context']
|
||||
generate_context = ''.join(
|
||||
self.preprocessor.tokenizer.convert_ids_to_tokens(
|
||||
self.preprocessor.nlp_tokenizer.tokenizer.convert_ids_to_tokens(
|
||||
generate_context)).replace('[UNK]', '“').replace('##', '')
|
||||
return {OutputKeys.TEXT: generate_context}
|
||||
|
||||
@@ -66,9 +66,9 @@ class PlugTrainer(NlpEpochBasedTrainer):
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
model = self.model
|
||||
|
||||
embeddings = model.module.module.model.bert.embeddings
|
||||
layers = model.module.module.model.bert.encoder.layer
|
||||
dec_layers = model.module.module.model.decoder.decoder
|
||||
embeddings = model.module.model.bert.embeddings
|
||||
layers = model.module.model.bert.encoder.layer
|
||||
dec_layers = model.module.model.decoder.decoder
|
||||
param_groups = []
|
||||
param_groups += list(
|
||||
self._get_params_for_weight_decay_optimization(layers))
|
||||
@@ -160,7 +160,7 @@ class PlugTrainer(NlpEpochBasedTrainer):
|
||||
|
||||
def evaluation_step(self, data):
|
||||
# wapper 1: DeepspeedEngine, wapper 2: DDP
|
||||
model = self.model.module.module
|
||||
model = self.model.module
|
||||
model.eval()
|
||||
|
||||
# model: fp16 wapper; model.module : distributedPlug
|
||||
|
||||
Reference in New Issue
Block a user