mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
[to #42322933] plug finetune
plug finetune :已在du reader- robust数据集上回归至最佳结果
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10916382
This commit is contained in:
@@ -338,6 +338,7 @@ class Trainers(object):
|
||||
nlp_veco_trainer = 'nlp-veco-trainer'
|
||||
nlp_text_ranking_trainer = 'nlp-text-ranking-trainer'
|
||||
text_generation_trainer = 'text-generation-trainer'
|
||||
nlp_plug_trainer = 'nlp-plug-trainer'
|
||||
|
||||
# audio trainers
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
@@ -500,6 +501,9 @@ class Hooks(object):
|
||||
# CLIP logit_scale clamp
|
||||
ClipClampLogitScaleHook = 'ClipClampLogitScaleHook'
|
||||
|
||||
# train
|
||||
DeepspeedHook = 'DeepspeedHook'
|
||||
|
||||
|
||||
class LR_Schedulers(object):
|
||||
"""learning rate scheduler is defined here
|
||||
|
||||
88
modelscope/models/nlp/plug/AnnealingLR.py
Executable file
88
modelscope/models/nlp/plug/AnnealingLR.py
Executable file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch DataLoader for TFRecords"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class AnnealingLR(_LRScheduler):
|
||||
"""Anneals the learning rate from start to zero along a cosine curve."""
|
||||
|
||||
DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
start_lr,
|
||||
warmup_iter,
|
||||
num_iters,
|
||||
decay_style=None,
|
||||
last_iter=-1):
|
||||
self.optimizer = optimizer
|
||||
self.start_lr = start_lr
|
||||
self.warmup_iter = warmup_iter
|
||||
self._step_count = last_iter + 1
|
||||
self.end_iter = num_iters
|
||||
self.decay_style = decay_style.lower() if isinstance(decay_style,
|
||||
str) else None
|
||||
self.step(self._step_count)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('learning rate decaying', decay_style)
|
||||
|
||||
def get_lr(self):
|
||||
# https://openreview.net/pdf?id=BJYwwY9ll pg. 4
|
||||
if self.warmup_iter > 0 and self._step_count <= self.warmup_iter:
|
||||
return float(self.start_lr) * self._step_count / self.warmup_iter
|
||||
else:
|
||||
if self.decay_style == self.DECAY_STYLES[0]:
|
||||
return self.start_lr * ((
|
||||
self.end_iter - # noqa W504
|
||||
(self._step_count - self.warmup_iter)) / self.end_iter)
|
||||
elif self.decay_style == self.DECAY_STYLES[1]:
|
||||
return self.start_lr / 2.0 * (
|
||||
math.cos(math.pi * (self._step_count - self.warmup_iter)
|
||||
/ self.end_iter) + 1)
|
||||
elif self.decay_style == self.DECAY_STYLES[2]:
|
||||
# TODO: implement exponential decay
|
||||
return self.start_lr
|
||||
else:
|
||||
return self.start_lr
|
||||
|
||||
def step(self, step_num=None):
|
||||
if step_num is None:
|
||||
step_num = self._step_count + 1
|
||||
self._step_count = step_num
|
||||
new_lr = self.get_lr()
|
||||
for group in self.optimizer.param_groups:
|
||||
group['lr'] = new_lr
|
||||
|
||||
def state_dict(self):
|
||||
sd = {
|
||||
'start_lr': self.start_lr,
|
||||
'warmup_iter': self.warmup_iter,
|
||||
'_step_count': self._step_count,
|
||||
'decay_style': self.decay_style,
|
||||
'end_iter': self.end_iter
|
||||
}
|
||||
return sd
|
||||
|
||||
def load_state_dict(self, sd):
|
||||
self.start_lr = sd['start_lr']
|
||||
self.warmup_iter = sd['warmup_iter']
|
||||
self._step_count = sd['_step_count']
|
||||
self.end_iter = sd['end_iter']
|
||||
self.decay_style = sd['decay_style']
|
||||
self.step(self._step_count)
|
||||
@@ -1009,6 +1009,118 @@ class PlugModel(torch.nn.Module):
|
||||
sequence_output=sequence_output,
|
||||
parallel_output=parallel_output)
|
||||
|
||||
@staticmethod
|
||||
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||
# This function has been mostly taken from huggingface conversational ai code at
|
||||
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-
|
||||
# conversational-ai-with-transfer-learning-2d818ac26313
|
||||
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1,
|
||||
None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p > 0.0:
|
||||
# convert to 1D
|
||||
logits = logits.view(logits.size()[1]).contiguous()
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(
|
||||
F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
||||
..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
logits[indices_to_remove] = filter_value
|
||||
# going back to 2D
|
||||
logits = logits.view(1, -1).contiguous()
|
||||
return logits
|
||||
|
||||
def generate(self, input, out_length=128, model_cfg=None, *kwargs):
|
||||
device = torch.cuda.current_device()
|
||||
batch_size = input['input_ids'].shape[0]
|
||||
tokens = input['input_ids'].view(1, -1).contiguous().to(device)
|
||||
dec_input_ids = input['dec_input_ids'].to(device)
|
||||
attention_mask = input['attention_mask'].to(device)
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
# Only supports batch_size=1
|
||||
all_generate_tokens = []
|
||||
generate_tokens = []
|
||||
counter = 0
|
||||
sequence_output = None
|
||||
vocab_size = self.config.original_vocab_size
|
||||
sep_token_idx = 102 # index of [SEP] token in BertTokenizer
|
||||
while counter < out_length:
|
||||
if counter % 128 == 0 and counter != 0:
|
||||
# Sliding window
|
||||
generate_tokens.append(sep_token_idx)
|
||||
start = (tokens == sep_token_idx).nonzero(
|
||||
as_tuple=True)[-1]
|
||||
if start + len(generate_tokens) >= 512:
|
||||
tokens = torch.cat([
|
||||
tokens[:start],
|
||||
torch.cuda.LongTensor(generate_tokens)
|
||||
], -1)[-512:]
|
||||
else:
|
||||
tokens[0][start:start + len(generate_tokens
|
||||
)] = torch.cuda.LongTensor(
|
||||
generate_tokens)
|
||||
|
||||
attention_mask = (tokens != 0)
|
||||
dec_input_ids = input['dec_input_ids'].to(device)
|
||||
generate_tokens = []
|
||||
sequence_output = None
|
||||
|
||||
position_ids = torch.full([batch_size, 1],
|
||||
len(generate_tokens),
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
_, logits, sequence_output = self.model(
|
||||
tokens,
|
||||
None,
|
||||
attention_mask,
|
||||
dec_input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
is_infer=True,
|
||||
sequence_output=sequence_output,
|
||||
parallel_output=False)
|
||||
logits = logits[:, -1, :]
|
||||
logits = logits / model_cfg['temperature']
|
||||
logits = self.top_k_logits(
|
||||
logits, top_k=model_cfg['top_k'], top_p=model_cfg['top_p'])
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
prev = torch.argmax(log_probs, 1).unsqueeze(1)
|
||||
# prev = torch.multinomial(log_probs, num_samples=1)
|
||||
prev_token = prev[0].item()
|
||||
if prev_token >= vocab_size:
|
||||
prev_token = 100
|
||||
prev[0] = 100
|
||||
if prev_token == 102 and len(all_generate_tokens) > int(
|
||||
max(1, out_length) * 0.8):
|
||||
break
|
||||
if prev_token == 102:
|
||||
counter += 1
|
||||
continue
|
||||
dec_input_ids = torch.cat([dec_input_ids, prev], dim=1)
|
||||
generate_tokens.append(prev_token)
|
||||
all_generate_tokens.append(prev_token)
|
||||
counter += 1
|
||||
|
||||
generate_context = []
|
||||
for token in all_generate_tokens:
|
||||
if generate_context and generate_context[
|
||||
-1] == 100 and token == 100:
|
||||
continue
|
||||
else:
|
||||
generate_context.append(token)
|
||||
return {'generate_context': generate_context}
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
return self.model.state_dict(
|
||||
destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
@@ -225,7 +225,7 @@ class PlugNLGConfig(PlugNLUConfig):
|
||||
fp32_layernorm=True,
|
||||
fp32_embedding=False,
|
||||
fp32_tokentypes=False,
|
||||
layernorm_epsilon=1e-5,
|
||||
layernorm_epsilon=1e-12,
|
||||
attn_separate=False,
|
||||
**kwargs):
|
||||
super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs)
|
||||
|
||||
@@ -75,7 +75,7 @@ class DistributedPlug(TorchModel):
|
||||
seed = 42 if 'seed' not in kwargs else kwargs['seed']
|
||||
set_random_seed_mpu(seed)
|
||||
self.iteration = 0
|
||||
self.dist_model = self.initialize_model(path_load_tag='model')
|
||||
self.model = self.initialize_model(path_load_tag='model')
|
||||
|
||||
def initialize_model(self, path_load_tag='model'):
|
||||
"""Build the model."""
|
||||
@@ -120,115 +120,28 @@ class DistributedPlug(TorchModel):
|
||||
model.module.model.load_state_dict(load_model, strict=False)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||
# This function has been mostly taken from huggingface conversational ai code at
|
||||
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-
|
||||
# conversational-ai-with-transfer-learning-2d818ac26313
|
||||
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1,
|
||||
None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p > 0.0:
|
||||
# convert to 1D
|
||||
logits = logits.view(logits.size()[1]).contiguous()
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(
|
||||
F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
||||
..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
logits[indices_to_remove] = filter_value
|
||||
# going back to 2D
|
||||
logits = logits.view(1, -1).contiguous()
|
||||
return logits
|
||||
def forward(self,
|
||||
input_tokens,
|
||||
token_type_ids=None,
|
||||
attention_mask=None,
|
||||
target_tokens=None,
|
||||
position_ids=None,
|
||||
decode_attention_mask=None,
|
||||
checkpoint_activations=False,
|
||||
is_infer=False,
|
||||
sequence_output=None,
|
||||
parallel_output=True):
|
||||
return self.model(
|
||||
input_tokens,
|
||||
token_type_ids,
|
||||
attention_mask,
|
||||
target_tokens,
|
||||
position_ids,
|
||||
decode_attention_mask,
|
||||
checkpoint_activations=checkpoint_activations,
|
||||
is_infer=is_infer,
|
||||
sequence_output=sequence_output,
|
||||
parallel_output=parallel_output)
|
||||
|
||||
def generate(self, input: Dict[str, Tensor], out_length=128, *kwargs):
|
||||
device = torch.cuda.current_device()
|
||||
batch_size = input['input_ids'].shape[0]
|
||||
tokens = input['input_ids'].view(1, -1).contiguous().to(device)
|
||||
dec_input_ids = input['dec_input_ids'].to(device)
|
||||
attention_mask = input['attention_mask'].to(device)
|
||||
self.dist_model.eval()
|
||||
with torch.no_grad():
|
||||
# Only supports batch_size=1
|
||||
all_generate_tokens = []
|
||||
generate_tokens = []
|
||||
counter = 0
|
||||
sequence_output = None
|
||||
vocab_size = self.config.original_vocab_size
|
||||
sep_token_idx = 102 # index of [SEP] token in BertTokenizer
|
||||
while counter < out_length:
|
||||
if counter % 128 == 0 and counter != 0:
|
||||
# Sliding window
|
||||
generate_tokens.append(sep_token_idx)
|
||||
start = (tokens == sep_token_idx).nonzero(
|
||||
as_tuple=True)[-1]
|
||||
if start + len(generate_tokens) >= 512:
|
||||
tokens = torch.cat([
|
||||
tokens[:start],
|
||||
torch.cuda.LongTensor(generate_tokens)
|
||||
], -1)[-512:]
|
||||
else:
|
||||
tokens[0][start:start + len(generate_tokens
|
||||
)] = torch.cuda.LongTensor(
|
||||
generate_tokens)
|
||||
|
||||
attention_mask = (tokens != 0)
|
||||
dec_input_ids = input['dec_input_ids'].to(device)
|
||||
generate_tokens = []
|
||||
sequence_output = None
|
||||
|
||||
position_ids = torch.full([batch_size, 1],
|
||||
len(generate_tokens),
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
_, logits, sequence_output = self.dist_model(
|
||||
tokens,
|
||||
None,
|
||||
attention_mask,
|
||||
dec_input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
is_infer=True,
|
||||
sequence_output=sequence_output,
|
||||
parallel_output=False)
|
||||
logits = logits[:, -1, :]
|
||||
logits = logits / self.model_cfg['temperature']
|
||||
logits = self.top_k_logits(
|
||||
logits,
|
||||
top_k=self.model_cfg['top_k'],
|
||||
top_p=self.model_cfg['top_p'])
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
prev = torch.multinomial(log_probs, num_samples=1)
|
||||
prev_token = prev[0].item()
|
||||
if prev_token >= vocab_size:
|
||||
prev_token = 100
|
||||
prev[0] = 100
|
||||
if prev_token == 102 and len(all_generate_tokens) > int(
|
||||
max(1, out_length) * 0.8):
|
||||
break
|
||||
if prev_token == 102:
|
||||
counter += 1
|
||||
continue
|
||||
dec_input_ids = torch.cat([dec_input_ids, prev], dim=1)
|
||||
generate_tokens.append(prev_token)
|
||||
all_generate_tokens.append(prev_token)
|
||||
counter += 1
|
||||
|
||||
generate_context = []
|
||||
for token in all_generate_tokens:
|
||||
if generate_context and generate_context[
|
||||
-1] == 100 and token == 100:
|
||||
continue
|
||||
else:
|
||||
generate_context.append(token)
|
||||
return {'generate_context': generate_context}
|
||||
return self.model.generate(input, out_length, self.model_cfg, *kwargs)
|
||||
|
||||
225
modelscope/models/nlp/plug/generator.py
Normal file
225
modelscope/models/nlp/plug/generator.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
|
||||
|
||||
class TextGenerator(object):
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
vocab,
|
||||
symbols,
|
||||
global_scorer=None,
|
||||
logger=None,
|
||||
dump_beam=''):
|
||||
self.alpha = 0.6
|
||||
|
||||
self.logger = logger
|
||||
self.cuda = (torch.cuda.device_count() > 0)
|
||||
|
||||
self.model = model
|
||||
# TODO generator
|
||||
self.vocab = vocab
|
||||
self.symbols = symbols
|
||||
self.start_token = 101 # ['[PAD]']
|
||||
self.end_token = 102 # '[PAD]']
|
||||
|
||||
self.global_scorer = global_scorer
|
||||
self.beam_size = 5
|
||||
self.min_length = 5
|
||||
self.max_length = 384
|
||||
|
||||
self.dump_beam = dump_beam
|
||||
|
||||
# for debugging
|
||||
self.beam_trace = self.dump_beam != ''
|
||||
self.beam_accum = None
|
||||
|
||||
if self.beam_trace:
|
||||
self.beam_accum = {
|
||||
'predicted_ids': [],
|
||||
'beam_parent_ids': [],
|
||||
'scores': [],
|
||||
'log_probs': []
|
||||
}
|
||||
|
||||
def _build_target_tokens(self, pred):
|
||||
tokens = []
|
||||
for tok in pred:
|
||||
tok = int(tok)
|
||||
tokens.append(tok)
|
||||
if tokens[-1] == self.end_token:
|
||||
tokens = tokens[:-1]
|
||||
break
|
||||
tokens = [t for t in tokens if t < len(self.vocab)]
|
||||
tokens = self.vocab.DecodeIds(tokens).split(' ')
|
||||
return tokens
|
||||
|
||||
def tile(self, x, count, dim=0):
|
||||
"""
|
||||
Tiles x on dimension dim count times.
|
||||
"""
|
||||
perm = list(range(len(x.size())))
|
||||
if dim != 0:
|
||||
perm[0], perm[dim] = perm[dim], perm[0]
|
||||
x = x.permute(perm).contiguous()
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = x.view(batch, -1) \
|
||||
.transpose(0, 1) \
|
||||
.repeat(count, 1) \
|
||||
.transpose(0, 1) \
|
||||
.contiguous() \
|
||||
.view(*out_size)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
||||
|
||||
def translate_batch(self, encoder_inputs, fast=False):
|
||||
with torch.no_grad():
|
||||
return self._fast_translate_batch(
|
||||
encoder_inputs, self.max_length, min_length=self.min_length)
|
||||
|
||||
def _fast_translate_batch(self, encoder_inputs, max_length, min_length=0):
|
||||
|
||||
assert not self.dump_beam
|
||||
|
||||
beam_size = self.beam_size
|
||||
tokens, types, padding_mask = encoder_inputs
|
||||
batch_size = tokens.size(0)
|
||||
device = tokens.device
|
||||
tmp_alive_seq = torch.full([batch_size, 1],
|
||||
self.start_token,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
prediction_scores, dec_feat_seq, sequence_output = self.model(
|
||||
tokens,
|
||||
types,
|
||||
padding_mask,
|
||||
tmp_alive_seq,
|
||||
None,
|
||||
None,
|
||||
checkpoint_activations=False,
|
||||
is_infer=True,
|
||||
parallel_output=False,
|
||||
sequence_output=None)
|
||||
src_features = sequence_output
|
||||
|
||||
src_features = self.tile(src_features, beam_size, dim=0)
|
||||
attention_mask = self.tile(padding_mask, beam_size, dim=0)
|
||||
batch_offset = torch.arange(
|
||||
batch_size, dtype=torch.long, device=device)
|
||||
beam_offset = torch.arange(
|
||||
0,
|
||||
batch_size * beam_size,
|
||||
step=beam_size,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
alive_seq = torch.full([batch_size * beam_size, 1],
|
||||
self.start_token,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
# Give full probability to the first beam on the first step.
|
||||
topk_log_probs = (
|
||||
torch.tensor(
|
||||
[0.0] + [float('-inf')] * (beam_size - 1),
|
||||
device=device).repeat(batch_size))
|
||||
|
||||
# Structure that holds finished hypotheses.
|
||||
hypotheses = [[] for _ in range(batch_size)] # noqa: F812
|
||||
|
||||
results = {}
|
||||
results['predictions'] = [[] for _ in range(batch_size)] # noqa: F812
|
||||
results['scores'] = [[] for _ in range(batch_size)] # noqa: F812
|
||||
results['gold_score'] = [0] * batch_size
|
||||
results['batch'] = []
|
||||
dec_attn_mask = None
|
||||
dec_position_ids = None
|
||||
|
||||
for step in range(max_length):
|
||||
prediction_scores, dec_feat_seq, _ = self.model(
|
||||
tokens,
|
||||
types,
|
||||
attention_mask,
|
||||
alive_seq,
|
||||
dec_position_ids,
|
||||
dec_attn_mask,
|
||||
checkpoint_activations=False,
|
||||
is_infer=True,
|
||||
parallel_output=False,
|
||||
sequence_output=src_features)
|
||||
|
||||
dec_feat_seq = dec_feat_seq[:, -1, :]
|
||||
vocab_size = dec_feat_seq.size(-1)
|
||||
log_probs = torch.log(
|
||||
torch.softmax(dec_feat_seq.view(-1, vocab_size), dim=-1))
|
||||
|
||||
if step < min_length:
|
||||
log_probs[:, self.end_token] = -1e20
|
||||
log_probs += topk_log_probs.view(-1).unsqueeze(1)
|
||||
|
||||
alpha = self.alpha # global_scorer.alpha
|
||||
length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha
|
||||
curr_scores = log_probs / length_penalty
|
||||
|
||||
curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
|
||||
topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)
|
||||
topk_log_probs = topk_scores * length_penalty
|
||||
|
||||
# Resolve beam origin and true word ids.
|
||||
topk_beam_index = topk_ids.div(vocab_size, rounding_mode='trunc')
|
||||
topk_ids = topk_ids.fmod(vocab_size)
|
||||
|
||||
# Map beam_index to batch_index in the flat representation.
|
||||
batch_index = (
|
||||
topk_beam_index
|
||||
+ beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
|
||||
select_indices = batch_index.view(-1)
|
||||
|
||||
# Append last prediction.
|
||||
alive_seq = torch.cat([
|
||||
alive_seq.index_select(0, select_indices),
|
||||
topk_ids.view(-1, 1)
|
||||
], -1)
|
||||
|
||||
is_finished = topk_ids.eq(self.end_token)
|
||||
if step + 1 == max_length:
|
||||
is_finished.fill_(1) # self.end_token)
|
||||
# End condition is top beam is finished.
|
||||
end_condition = is_finished[:, 0].eq(1) # self.end_token)
|
||||
# Save finished hypotheses.
|
||||
if is_finished.any():
|
||||
predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
|
||||
for i in range(is_finished.size(0)):
|
||||
b = batch_offset[i]
|
||||
if end_condition[i]:
|
||||
is_finished[i].fill_(1) # self.end_token)
|
||||
finished_hyp = is_finished[i].nonzero().view(-1)
|
||||
# Store finished hypotheses for this batch.
|
||||
for j in finished_hyp:
|
||||
hypotheses[b].append(
|
||||
(topk_scores[i, j], predictions[i, j, 1:]))
|
||||
# If the batch reached the end, save the n_best hypotheses.
|
||||
if end_condition[i]:
|
||||
best_hyp = sorted(
|
||||
hypotheses[b], key=lambda x: x[0], reverse=True)
|
||||
score, pred = best_hyp[0]
|
||||
results['scores'][b].append(score)
|
||||
results['predictions'][b].append(pred)
|
||||
non_finished = end_condition.eq(0).nonzero().view(-1)
|
||||
# If all sentences are translated, no need to go further.
|
||||
if len(non_finished) == 0:
|
||||
break
|
||||
# Remove finished batches for the next step.
|
||||
topk_log_probs = topk_log_probs.index_select(0, non_finished)
|
||||
batch_index = batch_index.index_select(0, non_finished)
|
||||
batch_offset = batch_offset.index_select(0, non_finished)
|
||||
alive_seq = predictions.index_select(0, non_finished) \
|
||||
.view(-1, alive_seq.size(-1))
|
||||
# Reorder states.
|
||||
select_indices = batch_index.view(-1)
|
||||
src_features = src_features.index_select(0, select_indices)
|
||||
attention_mask = attention_mask.index_select(0, select_indices)
|
||||
|
||||
return results
|
||||
@@ -122,6 +122,8 @@ class TextGenerationTransformersPreprocessor(TextGenerationPreprocessorBase):
|
||||
kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids',
|
||||
False)
|
||||
kwargs['max_length'] = sequence_length
|
||||
self.src_length = kwargs['max_length']
|
||||
self.tgt_length = kwargs.pop('target_max_length', kwargs['max_length'])
|
||||
model_type = None
|
||||
if model_dir is not None:
|
||||
model_type = get_model_type(model_dir)
|
||||
@@ -154,10 +156,14 @@ class TextGenerationTransformersPreprocessor(TextGenerationPreprocessorBase):
|
||||
'return_tensors'] = 'pt' if self.mode == ModeKeys.INFERENCE else None
|
||||
|
||||
output = self.nlp_tokenizer(sequence1, **kwargs)
|
||||
|
||||
if self.mode != ModeKeys.INFERENCE:
|
||||
if sequence2 is not None:
|
||||
self.nlp_tokenizer.tokenize_kwargs[
|
||||
'max_length'] = self.tgt_length
|
||||
labels = self.nlp_tokenizer(sequence2)['input_ids']
|
||||
self.nlp_tokenizer.tokenize_kwargs[
|
||||
'max_length'] = self.src_length
|
||||
|
||||
src_input_ids = output['input_ids']
|
||||
src_attention_mask = output['attention_mask']
|
||||
else:
|
||||
|
||||
@@ -25,7 +25,7 @@ else:
|
||||
'hook': ['Hook'],
|
||||
'iter_timer_hook': ['IterTimerHook'],
|
||||
'logger': ['TensorboardHook', 'TextLoggerHook'],
|
||||
'lr_scheduler_hook': ['LrSchedulerHook'],
|
||||
'lr_scheduler_hook': ['LrSchedulerHook', 'NoneLrSchedulerHook'],
|
||||
'optimizer_hook': [
|
||||
'ApexAMPOptimizerHook', 'NoneOptimizerHook', 'OptimizerHook',
|
||||
'TorchAMPOptimizerHook'
|
||||
|
||||
@@ -104,7 +104,8 @@ class CheckpointHook(Hook):
|
||||
return
|
||||
|
||||
if self._should_save(trainer):
|
||||
if is_master():
|
||||
if is_master() or trainer.cfg.model.get('model_parallel_size',
|
||||
1) != 1:
|
||||
self.logger.info(
|
||||
f'Saving checkpoint at {trainer.epoch + 1} epoch')
|
||||
self._save_checkpoint(trainer)
|
||||
@@ -260,7 +261,8 @@ class CheckpointHook(Hook):
|
||||
return
|
||||
|
||||
if self._should_save(trainer):
|
||||
if is_master():
|
||||
if is_master() or trainer.cfg.model.get('model_parallel_size',
|
||||
1) != 1:
|
||||
self.logger.info(
|
||||
f'Saving checkpoint at {trainer.iter + 1} iterations')
|
||||
self._save_checkpoint(trainer)
|
||||
|
||||
116
modelscope/trainers/hooks/deepspeed_hook.py
Normal file
116
modelscope/trainers/hooks/deepspeed_hook.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from types import MethodType
|
||||
|
||||
import deepspeed
|
||||
from megatron import mpu
|
||||
|
||||
from modelscope.metainfo import Hooks
|
||||
from modelscope.trainers.hooks import (BestCkptSaverHook, CheckpointHook,
|
||||
LrSchedulerHook, NoneLrSchedulerHook,
|
||||
NoneOptimizerHook, OptimizerHook)
|
||||
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
|
||||
from modelscope.utils.constant import LogKeys, ModelFile
|
||||
from modelscope.utils.torch_utils import is_master
|
||||
from .builder import HOOKS
|
||||
from .hook import Hook
|
||||
from .priority import Priority
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name=Hooks.DeepspeedHook)
|
||||
class DeepspeedHook(Hook):
|
||||
PRIORITY = Priority.VERY_HIGH
|
||||
|
||||
def __init__(self,
|
||||
deepspeed_activation_checkpointing=True,
|
||||
save_zero_checkpoint=False,
|
||||
loss_key='loss'):
|
||||
self.save_zero_checkpoint = save_zero_checkpoint
|
||||
self.loss_key = loss_key
|
||||
self.deepspeed_activation_checkpointing = deepspeed_activation_checkpointing
|
||||
|
||||
def before_run(self, trainer):
|
||||
# deepspeed init
|
||||
args = trainer.cfg.train
|
||||
args.deepspeed_config = os.path.join(trainer.model_dir,
|
||||
args.deepspeed_config)
|
||||
|
||||
trainer.model, _, _, _ = deepspeed.initialize(
|
||||
model=trainer.model,
|
||||
optimizer=trainer.optimizer,
|
||||
args=args,
|
||||
lr_scheduler=trainer.lr_scheduler,
|
||||
mpu=mpu,
|
||||
dist_init_required=False)
|
||||
trainer.model.save_zero_checkpoint = self.save_zero_checkpoint
|
||||
|
||||
if self.deepspeed_activation_checkpointing:
|
||||
model = trainer.model
|
||||
while hasattr(model, 'module'):
|
||||
model = model.module
|
||||
deepspeed.checkpointing.configure(
|
||||
mpu,
|
||||
deepspeed_config=args.deepspeed_config,
|
||||
num_checkpoints=model.config.num_hidden_layers)
|
||||
|
||||
mpu.checkpoint = deepspeed.checkpointing.checkpoint
|
||||
mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
|
||||
mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed
|
||||
|
||||
# modify hooks
|
||||
for i, hook in enumerate(trainer._hooks):
|
||||
# backward & step
|
||||
if isinstance(hook, OptimizerHook):
|
||||
trainer._hooks[i] = NoneOptimizerHook()
|
||||
if isinstance(hook, LrSchedulerHook):
|
||||
trainer._hooks[i] = NoneLrSchedulerHook()
|
||||
|
||||
# save checkpoint
|
||||
if isinstance(hook, CheckpointHook):
|
||||
|
||||
def _save_checkpoint(self, trainer):
|
||||
if self.by_epoch:
|
||||
cur_save_dir = os.path.join(
|
||||
self.save_dir,
|
||||
f'{LogKeys.EPOCH}_{trainer.epoch + 1}')
|
||||
else:
|
||||
cur_save_dir = os.path.join(
|
||||
self.save_dir,
|
||||
f'{LogKeys.ITER}_{trainer.iter + 1}')
|
||||
if (self.is_last_epoch(trainer)
|
||||
and self.by_epoch) or (self.is_last_iter(trainer)
|
||||
and not self.by_epoch):
|
||||
cur_save_dir = os.path.join(self.save_dir,
|
||||
ModelFile.TRAIN_OUTPUT_DIR)
|
||||
trainer.model.save_checkpoint(cur_save_dir)
|
||||
|
||||
trainer._hooks[i]._save_checkpoint = MethodType(
|
||||
_save_checkpoint, trainer._hooks[i])
|
||||
|
||||
if isinstance(hook, BestCkptSaverHook):
|
||||
|
||||
def _save_checkpoint(self, trainer):
|
||||
if self.by_epoch:
|
||||
cur_save_dir = os.path.join(
|
||||
self.save_dir,
|
||||
f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}'
|
||||
)
|
||||
else:
|
||||
cur_save_dir = os.path.join(
|
||||
self.save_dir,
|
||||
f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth'
|
||||
)
|
||||
trainer.model.save_checkpoint(cur_save_dir)
|
||||
self._best_ckpt_file = cur_save_dir
|
||||
|
||||
trainer._hooks[i]._save_checkpoint = MethodType(
|
||||
_save_checkpoint, trainer._hooks[i])
|
||||
|
||||
def after_train_iter(self, trainer):
|
||||
# The `trainer.model` here is actually a deepspeed engine object.
|
||||
# backward step
|
||||
loss = trainer.train_outputs[self.loss_key]
|
||||
trainer.model.backward(loss)
|
||||
|
||||
# update parameters
|
||||
trainer.model.step()
|
||||
@@ -80,7 +80,8 @@ class TextLoggerHook(LoggerHook):
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
_, world_size = get_dist_info()
|
||||
if world_size > 1:
|
||||
if world_size > 1 and getattr(trainer.cfg.model, 'model_parallel_size',
|
||||
1) < world_size:
|
||||
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
|
||||
return mem_mb.item()
|
||||
|
||||
|
||||
195
modelscope/trainers/nlp/plug_trainer.py
Normal file
195
modelscope/trainers/nlp/plug_trainer.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from megatron import mpu
|
||||
from torch import nn
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.base import Model, TorchModel
|
||||
from modelscope.models.nlp.plug import DistributedPlug
|
||||
from modelscope.models.nlp.plug.backbone import BertLayerNorm
|
||||
from modelscope.models.nlp.plug.generator import TextGenerator
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from ..base import TRAINERS
|
||||
from ..nlp_trainer import NlpEpochBasedTrainer
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.nlp_plug_trainer)
|
||||
class PlugTrainer(NlpEpochBasedTrainer):
|
||||
|
||||
def build_model(self) -> Union[nn.Module, TorchModel]:
|
||||
rank = int(os.environ.get('LOCAL_RANK', -1))
|
||||
master_ip = os.environ.get('MASTER_ADDR', '127.0.0.1')
|
||||
master_port = os.environ.get('MASTER_PORT', '29500')
|
||||
model = DistributedPlug(
|
||||
self.model_dir,
|
||||
rank,
|
||||
master_ip=master_ip,
|
||||
master_port=master_port,
|
||||
**self.cfg.model)
|
||||
return model.model
|
||||
|
||||
def to_parallel(self, model) -> Union[nn.Module, TorchModel]:
|
||||
from modelscope.utils.nlp.distributed import DistributedDataParallel as DDP
|
||||
return DDP(model)
|
||||
|
||||
def _get_params_for_weight_decay_optimization(self, module):
|
||||
|
||||
weight_decay_params = {'params': []}
|
||||
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
|
||||
for module_ in module.modules():
|
||||
if isinstance(module_, (BertLayerNorm, torch.nn.LayerNorm)):
|
||||
no_weight_decay_params['params'].extend([
|
||||
p for p in list(module_._parameters.values())
|
||||
if p is not None
|
||||
])
|
||||
else:
|
||||
weight_decay_params['params'].extend([
|
||||
p for n, p in list(module_._parameters.items())
|
||||
if p is not None and 'mask_score' not in n
|
||||
and 'mask' not in n and n != 'bias'
|
||||
])
|
||||
no_weight_decay_params['params'].extend([
|
||||
p for n, p in list(module_._parameters.items())
|
||||
if p is not None and n == 'bias'
|
||||
])
|
||||
|
||||
return weight_decay_params, no_weight_decay_params
|
||||
|
||||
def create_optimizer_and_scheduler(self):
|
||||
optimizer, lr_scheduler = self.optimizers
|
||||
optimizer_cfg = self.cfg.train.get('optimizer', None)
|
||||
# optim_options = {}
|
||||
if optimizer_cfg is not None:
|
||||
optim_options = optimizer_cfg.pop('options', {})
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
model = self.model
|
||||
|
||||
embeddings = model.module.module.model.bert.embeddings
|
||||
layers = model.module.module.model.bert.encoder.layer
|
||||
dec_layers = model.module.module.model.decoder.decoder
|
||||
param_groups = []
|
||||
param_groups += list(
|
||||
self._get_params_for_weight_decay_optimization(layers))
|
||||
param_groups += list(
|
||||
self._get_params_for_weight_decay_optimization(embeddings))
|
||||
param_groups += list(
|
||||
self._get_params_for_weight_decay_optimization(dec_layers))
|
||||
|
||||
for param_group in param_groups:
|
||||
for param in param_group['params']:
|
||||
if not hasattr(param, 'model_parallel'):
|
||||
param.model_parallel = False
|
||||
optimizer = DeepSpeedCPUAdam(
|
||||
param_groups,
|
||||
lr=optimizer_cfg.lr,
|
||||
weight_decay=optimizer_cfg.weight_decay)
|
||||
|
||||
lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None)
|
||||
|
||||
if lr_scheduler_cfg is not None:
|
||||
assert optimizer is not None
|
||||
lr_options = lr_scheduler_cfg.pop('options', {})
|
||||
from modelscope.models.nlp.plug.AnnealingLR import AnnealingLR
|
||||
num_iters = self.max_iters
|
||||
lr_scheduler = AnnealingLR(
|
||||
optimizer,
|
||||
start_lr=optimizer_cfg.lr,
|
||||
warmup_iter=lr_scheduler_cfg.warmup * num_iters,
|
||||
num_iters=num_iters,
|
||||
decay_style=lr_scheduler_cfg.decay_style,
|
||||
last_iter=-1)
|
||||
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
return self.optimizer, self.lr_scheduler, optim_options, lr_options
|
||||
|
||||
def _get_masks_and_position_ids(self, data, eod_token):
|
||||
# Extract batch size and sequence length.
|
||||
batch_size, seq_length = data.size()
|
||||
|
||||
# Attention mask (lower triangular).
|
||||
att_mask_batch = 1
|
||||
attention_mask = torch.tril(
|
||||
torch.ones((att_mask_batch, seq_length, seq_length),
|
||||
device=data.device)).view(att_mask_batch, 1, seq_length,
|
||||
seq_length)
|
||||
|
||||
# Loss mask.
|
||||
loss_mask = torch.ones(
|
||||
data.size(), dtype=torch.float, device=data.device)
|
||||
loss_mask[data == eod_token] = 0.0
|
||||
|
||||
# Position ids.
|
||||
position_ids = torch.arange(
|
||||
seq_length, dtype=torch.long, device=data.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(data)
|
||||
return attention_mask, loss_mask, position_ids
|
||||
|
||||
def train_step(self, model, inputs):
|
||||
self._mode = ModeKeys.TRAIN
|
||||
# format inputs
|
||||
checkpoint_activations = getattr(self.cfg.train,
|
||||
'checkpoint_activations', True)
|
||||
tgt_tokens = inputs['labels'][:, :-1].contiguous()
|
||||
tgt_labels = inputs['labels'][:, 1:].contiguous()
|
||||
tgt_attention_mask, dec_loss_mask, position_ids = self._get_masks_and_position_ids(
|
||||
tgt_tokens, 0)
|
||||
if getattr(self.cfg.train, 'fp16', None):
|
||||
tgt_attention_mask = tgt_attention_mask.half()
|
||||
|
||||
# forward step
|
||||
_, output = model(
|
||||
inputs['input_ids'],
|
||||
None,
|
||||
inputs['attention_mask'],
|
||||
tgt_tokens,
|
||||
position_ids,
|
||||
tgt_attention_mask,
|
||||
checkpoint_activations=checkpoint_activations)
|
||||
|
||||
losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
|
||||
tgt_labels)
|
||||
dec_loss_mask = dec_loss_mask.view(-1)
|
||||
loss = torch.sum(losses.view(-1) * dec_loss_mask) / dec_loss_mask.sum()
|
||||
|
||||
# add model output info to log
|
||||
self.train_outputs = {'loss': loss}
|
||||
self.log_buffer.update(self.train_outputs)
|
||||
|
||||
def evaluation_step(self, data):
|
||||
# wapper 1: DeepspeedEngine, wapper 2: DDP
|
||||
model = self.model.module.module
|
||||
model.eval()
|
||||
|
||||
# model: fp16 wapper; model.module : distributedPlug
|
||||
vocab_size = model.module.config.original_vocab_size
|
||||
batch_size = data['input_ids'].shape[0]
|
||||
beam_generator = TextGenerator(model,
|
||||
self.eval_preprocessor.nlp_tokenizer,
|
||||
None)
|
||||
|
||||
with torch.no_grad():
|
||||
tokens = data['input_ids'].long()
|
||||
padding_mask = data['attention_mask'].byte()
|
||||
target_ids = data['labels'].long()
|
||||
target_labels = target_ids[:, 1:].contiguous()
|
||||
encoder_inputs = [tokens, None, padding_mask]
|
||||
result = beam_generator.translate_batch(encoder_inputs)
|
||||
pred_list = result['predictions']
|
||||
target_list = target_labels.cpu().numpy().tolist()
|
||||
result['preds'] = []
|
||||
data['tgts'] = []
|
||||
for i in range(batch_size):
|
||||
pred_ids = pred_list[i][0]
|
||||
pred_ids[pred_ids > vocab_size - 1] = 100
|
||||
pred_ids = pred_ids.cpu().numpy().tolist()
|
||||
|
||||
gold_string = self.eval_preprocessor.decode(
|
||||
target_list[i], skip_special_tokens=True)
|
||||
pred_string = self.eval_preprocessor.decode(
|
||||
pred_ids, skip_special_tokens=True)
|
||||
result['preds'].append(pred_string)
|
||||
data['tgts'].append(gold_string)
|
||||
return result
|
||||
@@ -845,7 +845,10 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
batch_size = batch_size_per_gpu
|
||||
num_workers = workers_per_gpu
|
||||
|
||||
if dist and not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
if dist and not isinstance(
|
||||
dataset,
|
||||
torch.utils.data.IterableDataset) and self.cfg.model.get(
|
||||
'model_parallel_size', 1) == 1:
|
||||
sampler = DistributedSampler(
|
||||
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle)
|
||||
else:
|
||||
@@ -935,7 +938,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
""" Evaluation loop used by `EpochBasedTrainer.evaluate()`.
|
||||
|
||||
"""
|
||||
if self._dist:
|
||||
if self._dist and self.cfg.model.get('model_parallel_size', 1) == 1:
|
||||
from modelscope.trainers.utils.inference import multi_gpu_test
|
||||
metric_values = multi_gpu_test(
|
||||
self,
|
||||
|
||||
53
tests/trainers/test_plug_finetune_text_generation.py
Normal file
53
tests/trainers/test_plug_finetune_text_generation.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
def test_trainer_with_model_and_args():
|
||||
|
||||
def concat_answer_context(dataset):
|
||||
dataset['src_txt'] = dataset['answers']['text'][0] + '[SEP]' + dataset[
|
||||
'context']
|
||||
return dataset
|
||||
|
||||
from datasets import load_dataset
|
||||
dataset_dict = load_dataset('luozhouyang/dureader', 'robust')
|
||||
|
||||
train_dataset = dataset_dict['train'].map(concat_answer_context) \
|
||||
.rename_columns({'question': 'tgt_txt'}).remove_columns('context') \
|
||||
.remove_columns('id').remove_columns('answers')
|
||||
eval_dataset = dataset_dict['validation'].map(concat_answer_context) \
|
||||
.rename_columns({'question': 'tgt_txt'}).remove_columns('context') \
|
||||
.remove_columns('id').remove_columns('answers')
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.makedirs(tmp_dir)
|
||||
|
||||
model_id = 'damo/nlp_plug_text-generation_27B'
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
work_dir=tmp_dir)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.nlp_plug_trainer, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--local_rank')
|
||||
test_trainer_with_model_and_args()
|
||||
Reference in New Issue
Block a user