mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 12:09:22 +01:00
model forward ready
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from maas_lib.trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer
|
||||
from maas_lib.utils.constant import Tasks
|
||||
from ...base import Model, Tensor
|
||||
from ...builder import MODELS
|
||||
@@ -32,6 +33,22 @@ class DialogGenerationModel(Model):
|
||||
reader=self.text_field,
|
||||
generator=self.generator)
|
||||
|
||||
def to_tensor(array):
|
||||
"""
|
||||
numpy array -> tensor
|
||||
"""
|
||||
import torch
|
||||
array = torch.tensor(array)
|
||||
return array.cuda() if self.config.use_gpu else array
|
||||
|
||||
self.trainer = MultiWOZTrainer(
|
||||
model=self.model,
|
||||
to_tensor=to_tensor,
|
||||
config=self.config,
|
||||
reader=self.text_field,
|
||||
evaluator=None)
|
||||
self.trainer.load()
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""return the result by the model
|
||||
|
||||
@@ -48,10 +65,38 @@ class DialogGenerationModel(Model):
|
||||
}
|
||||
"""
|
||||
from numpy import array, float32
|
||||
import torch
|
||||
|
||||
return {
|
||||
'predictions': array([1]), # lable 0-negative 1-positive
|
||||
'probabilities': array([[0.11491239, 0.8850876]], dtype=float32),
|
||||
'logits': array([[-0.53860897, 1.5029076]],
|
||||
dtype=float32) # true value
|
||||
turn_1 = {
|
||||
'user': [
|
||||
13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005,
|
||||
1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7
|
||||
]
|
||||
}
|
||||
old_pv_turn_1 = {}
|
||||
|
||||
turn_2 = {
|
||||
'user':
|
||||
[13, 1045, 2215, 2000, 2681, 2044, 2459, 1024, 2321, 1012, 7]
|
||||
}
|
||||
old_pv_turn_2 = {
|
||||
'labels': [[
|
||||
13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005,
|
||||
1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7
|
||||
]],
|
||||
'resp': [
|
||||
14, 1045, 2052, 2022, 3407, 2000, 2393, 2007, 2115, 5227, 1010,
|
||||
2079, 2017, 2031, 1037, 2051, 2017, 2052, 2066, 2000, 2681,
|
||||
2030, 7180, 2011, 1029, 8
|
||||
],
|
||||
'bspn': [
|
||||
15, 43, 7688, 10733, 12570, 21713, 4487, 15474, 6712, 3002,
|
||||
2198, 1005, 1055, 2267, 9
|
||||
],
|
||||
'db': [19, 24, 21, 20],
|
||||
'aspn': [16, 43, 48, 2681, 7180, 10]
|
||||
}
|
||||
|
||||
pv_turn = self.trainer.forward(turn=turn_2, old_pv_turn=old_pv_turn_2)
|
||||
|
||||
return pv_turn
|
||||
|
||||
0
maas_lib/trainers/nlp/space/__init__.py
Normal file
0
maas_lib/trainers/nlp/space/__init__.py
Normal file
0
maas_lib/trainers/nlp/space/metrics/__init__.py
Normal file
0
maas_lib/trainers/nlp/space/metrics/__init__.py
Normal file
73
maas_lib/trainers/nlp/space/metrics/metrics_tracker.py
Normal file
73
maas_lib/trainers/nlp/space/metrics/metrics_tracker.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
MetricsTracker class
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class MetricsTracker(object):
|
||||
""" Tracking metrics. """
|
||||
|
||||
def __init__(self):
|
||||
self.metrics_val = defaultdict(float) # 记录最新一个batch返回的指标
|
||||
self.metrics_avg = defaultdict(float) # 维护一个epoch内已训练batches的平均指标
|
||||
self.num_samples = 0
|
||||
|
||||
def update(self, metrics, num_samples):
|
||||
for key, val in metrics.items():
|
||||
if val is not None:
|
||||
val = float(val) # [val] -> val
|
||||
self.metrics_val[key] = val
|
||||
avg_val = (self.metrics_avg.get(key, 0) * self.num_samples +
|
||||
val * num_samples) / (
|
||||
self.num_samples + num_samples)
|
||||
self.metrics_avg[key] = avg_val
|
||||
self.num_samples += num_samples
|
||||
|
||||
def clear(self):
|
||||
self.metrics_val = defaultdict(float)
|
||||
self.metrics_avg = defaultdict(float)
|
||||
self.num_samples = 0
|
||||
|
||||
def items(self):
|
||||
return self.metrics_avg.items()
|
||||
|
||||
def get(self, name):
|
||||
if self.num_samples == 0:
|
||||
raise ValueError('There is no data in Metrics.')
|
||||
return self.metrics_avg.get(name)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
'metrics_val': self.metrics_val,
|
||||
'metrics_avg': self.metrics_avg,
|
||||
'num_samples': self.num_samples,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.metrics_val = state_dict['metrics_val']
|
||||
self.metrics_avg = state_dict['metrics_avg']
|
||||
self.num_samples = state_dict['num_samples']
|
||||
|
||||
def value(self):
|
||||
metric_strs = []
|
||||
for key, val in self.metrics_val.items():
|
||||
metric_str = f'{key.upper()}-{val:.3f}'
|
||||
metric_strs.append(metric_str)
|
||||
if 'token_nll' in self.metrics_val:
|
||||
metric_str = f"TOKEN_PPL-{math.exp(self.metrics_val['token_nll']):.3f}"
|
||||
metric_strs.append(metric_str)
|
||||
metric_strs = ' '.join(metric_strs)
|
||||
return metric_strs
|
||||
|
||||
def summary(self):
|
||||
metric_strs = []
|
||||
for key, val in self.metrics_avg.items():
|
||||
metric_str = f'{key.upper()}-{val:.3f}'
|
||||
metric_strs.append(metric_str)
|
||||
if 'token_nll' in self.metrics_avg:
|
||||
metric_str = f"TOKEN_PPL-{math.exp(self.metrics_avg['token_nll']):.3f}"
|
||||
metric_strs.append(metric_str)
|
||||
metric_strs = ' '.join(metric_strs)
|
||||
return metric_strs
|
||||
0
maas_lib/trainers/nlp/space/trainers/__init__.py
Normal file
0
maas_lib/trainers/nlp/space/trainers/__init__.py
Normal file
725
maas_lib/trainers/nlp/space/trainers/gen_trainer.py
Normal file
725
maas_lib/trainers/nlp/space/trainers/gen_trainer.py
Normal file
@@ -0,0 +1,725 @@
|
||||
"""
|
||||
Trainer class.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
from ..metrics.metrics_tracker import MetricsTracker
|
||||
|
||||
|
||||
def get_logger(log_path, name='default'):
|
||||
logger = logging.getLogger(name)
|
||||
logger.propagate = False
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter('%(message)s')
|
||||
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
sh.setFormatter(formatter)
|
||||
logger.addHandler(sh)
|
||||
|
||||
fh = logging.FileHandler(log_path, mode='w')
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
to_tensor,
|
||||
config,
|
||||
logger=None,
|
||||
lr_scheduler=None,
|
||||
optimizer=None,
|
||||
reader=None,
|
||||
evaluator=None):
|
||||
self.to_tensor = to_tensor
|
||||
|
||||
self.do_train = config.do_train
|
||||
self.do_infer = config.do_infer
|
||||
self.is_decreased_valid_metric = config.Trainer.valid_metric_name[
|
||||
0] == '-'
|
||||
self.valid_metric_name = config.Trainer.valid_metric_name[1:]
|
||||
self.num_epochs = config.Trainer.num_epochs
|
||||
# self.save_dir = config.Trainer.save_dir
|
||||
self.log_steps = config.Trainer.log_steps
|
||||
self.valid_steps = config.Trainer.valid_steps
|
||||
self.save_checkpoint = config.Trainer.save_checkpoint
|
||||
self.save_summary = config.Trainer.save_summary
|
||||
self.lr = config.Model.lr
|
||||
self.weight_decay = config.Model.weight_decay
|
||||
self.batch_size = config.Trainer.batch_size
|
||||
self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps
|
||||
self.warmup_steps = config.Model.warmup_steps
|
||||
self.gpu = config.Trainer.gpu
|
||||
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.model = model
|
||||
self.func_model = self.model.module if self.gpu > 1 else self.model
|
||||
self.reader = reader
|
||||
self.evaluator = evaluator
|
||||
self.tokenizer = reader.tokenizer
|
||||
|
||||
# if not os.path.exists(self.save_dir):
|
||||
# os.makedirs(self.save_dir)
|
||||
|
||||
# self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer")
|
||||
self.logger = logger or get_logger('trainer.log', 'trainer')
|
||||
|
||||
self.batch_metrics_tracker = MetricsTracker()
|
||||
self.token_metrics_tracker = MetricsTracker()
|
||||
|
||||
self.best_valid_metric = float(
|
||||
'inf' if self.is_decreased_valid_metric else '-inf')
|
||||
self.epoch = 0
|
||||
|
||||
def decode_generated_bspn_resp(self, generated):
|
||||
"""
|
||||
decode generated
|
||||
return decoded ('bspn', 'resp')
|
||||
"""
|
||||
decoded = {}
|
||||
eos_r_id = self.reader.eos_r_id
|
||||
eos_b_id = self.reader.eos_b_id
|
||||
|
||||
# eos_r may not exists if gpt2 generated repetitive words.
|
||||
if eos_r_id in generated:
|
||||
eos_r_idx = generated.index(eos_r_id)
|
||||
else:
|
||||
eos_r_idx = len(generated) - 1
|
||||
# self.logger.info('eos_r not in generated: ' + self.tokenizer.decode(generated))
|
||||
|
||||
# predicted bspn, resp
|
||||
eos_b_idx = generated.index(eos_b_id)
|
||||
decoded['bspn'] = generated[:eos_b_idx + 1]
|
||||
decoded['resp'] = generated[eos_b_idx + 1:eos_r_idx + 1]
|
||||
return decoded
|
||||
|
||||
def decode_generated_act_resp(self, generated):
|
||||
"""
|
||||
decode generated
|
||||
return decoded['resp'] ('bspn', 'aspn')
|
||||
"""
|
||||
decoded = {}
|
||||
eos_a_id = self.reader.eos_a_id
|
||||
eos_r_id = self.reader.eos_r_id
|
||||
eos_b_id = self.reader.eos_b_id
|
||||
|
||||
# eos_r may not exists if gpt2 generated repetitive words.
|
||||
if eos_r_id in generated:
|
||||
eos_r_idx = generated.index(eos_r_id)
|
||||
else:
|
||||
eos_r_idx = len(generated) - 1
|
||||
self.logger.info('eos_r not in generated: ' +
|
||||
self.tokenizer.decode(generated))
|
||||
|
||||
if self.reader.use_true_curr_aspn: # only predict resp
|
||||
decoded['resp'] = generated[:eos_r_idx + 1]
|
||||
else: # predicted aspn, resp
|
||||
eos_a_idx = generated.index(eos_a_id)
|
||||
decoded['aspn'] = generated[:eos_a_idx + 1]
|
||||
decoded['resp'] = generated[eos_a_idx + 1:eos_r_idx + 1]
|
||||
return decoded
|
||||
|
||||
def decode_generated_bspn(self, generated):
|
||||
eos_b_id = self.reader.eos_b_id
|
||||
if eos_b_id in generated:
|
||||
eos_b_idx = generated.index(eos_b_id)
|
||||
else:
|
||||
eos_b_idx = len(generated) - 1
|
||||
return generated[:eos_b_idx + 1]
|
||||
|
||||
def set_optimizers(self):
|
||||
"""
|
||||
Setup the optimizer and the learning rate scheduler.
|
||||
|
||||
from transformers.Trainer
|
||||
|
||||
parameters from cfg: lr (1e-3); warmup_steps
|
||||
"""
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'norm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
'params': [
|
||||
p for n, p in self.model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
'weight_decay':
|
||||
self.weight_decay,
|
||||
},
|
||||
{
|
||||
'params': [
|
||||
p for n, p in self.model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
'weight_decay':
|
||||
0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr)
|
||||
|
||||
num_training_steps = self.reader.set_stats['train']['num_training_steps_per_epoch'] * \
|
||||
self.num_epochs // self.gradient_accumulation_steps
|
||||
num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int(
|
||||
num_training_steps * 0.1)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps)
|
||||
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
def train(self, train_data, dev_data):
|
||||
# log info
|
||||
set_stats = self.reader.set_stats['train']
|
||||
self.logger.info('***** Running training *****')
|
||||
self.logger.info(
|
||||
' Num Training steps(one turn in a batch of dialogs) per epoch = %d',
|
||||
set_stats['num_training_steps_per_epoch'])
|
||||
self.logger.info(' Num Turns = %d', set_stats['num_turns'])
|
||||
self.logger.info(' Num Dialogs = %d', set_stats['num_dials'])
|
||||
self.logger.info(' Num Epochs = %d', self.num_epochs)
|
||||
self.logger.info(' Batch size = %d', self.batch_size)
|
||||
self.logger.info(' Gradient Accumulation steps = %d',
|
||||
self.gradient_accumulation_steps)
|
||||
self.logger.info(
|
||||
' Total optimization steps = %d',
|
||||
set_stats['num_training_steps_per_epoch'] * self.num_epochs //
|
||||
self.gradient_accumulation_steps)
|
||||
|
||||
# begin training
|
||||
num_epochs = self.num_epochs - self.epoch
|
||||
for epoch in range(num_epochs):
|
||||
self.train_epoch(train_data=train_data, dev_data=dev_data)
|
||||
|
||||
def train_epoch(self, train_data, dev_data):
|
||||
"""
|
||||
Train an epoch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def infer(self, data_type):
|
||||
"""
|
||||
Inference interface.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, turn, old_pv_turn):
|
||||
"""
|
||||
one turn inference
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, is_best=False):
|
||||
""" save """
|
||||
train_state = {
|
||||
'epoch': self.epoch,
|
||||
'best_valid_metric': self.best_valid_metric,
|
||||
'optimizer': self.optimizer.state_dict()
|
||||
}
|
||||
if self.lr_scheduler is not None:
|
||||
train_state['lr_scheduler'] = self.lr_scheduler.state_dict()
|
||||
|
||||
# Save checkpoint
|
||||
if self.save_checkpoint:
|
||||
model_file = os.path.join(self.save_dir,
|
||||
f'state_epoch_{self.epoch}.model')
|
||||
torch.save(self.model.state_dict(), model_file)
|
||||
self.logger.info(f"Saved model state to '{model_file}'")
|
||||
|
||||
train_file = os.path.join(self.save_dir,
|
||||
f'state_epoch_{self.epoch}.train')
|
||||
torch.save(train_state, train_file)
|
||||
self.logger.info(f"Saved train state to '{train_file}'")
|
||||
|
||||
# Save current best model
|
||||
if is_best:
|
||||
best_model_file = os.path.join(self.save_dir, 'best.model')
|
||||
torch.save(self.model.state_dict(), best_model_file)
|
||||
best_train_file = os.path.join(self.save_dir, 'best.train')
|
||||
torch.save(train_state, best_train_file)
|
||||
self.logger.info(
|
||||
f"Saved best model state to '{best_model_file}' with new best valid metric "
|
||||
f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}'
|
||||
)
|
||||
|
||||
def load(self):
|
||||
""" load """
|
||||
|
||||
def _load_model_state():
|
||||
model_state_dict = torch.load(
|
||||
f'{self.func_model.init_checkpoint}',
|
||||
map_location=lambda storage, loc: storage)
|
||||
|
||||
if 'module.' in list(model_state_dict.keys())[0]:
|
||||
new_model_state_dict = OrderedDict()
|
||||
for k, v in model_state_dict.items():
|
||||
assert k[:7] == 'module.'
|
||||
new_model_state_dict[k[7:]] = v
|
||||
model_state_dict = new_model_state_dict
|
||||
|
||||
new_model_state_dict = OrderedDict()
|
||||
parameters = {
|
||||
name: param
|
||||
for name, param in self.func_model.named_parameters()
|
||||
}
|
||||
for name, param in model_state_dict.items():
|
||||
if name in parameters:
|
||||
if param.shape != parameters[name].shape:
|
||||
assert hasattr(param, 'numpy')
|
||||
arr = param.numpy()
|
||||
z = np.random.normal(
|
||||
scale=self.func_model.initializer_range,
|
||||
size=parameters[name].shape).astype('float32')
|
||||
if name == 'embedder.token_embedding.weight':
|
||||
z[-param.shape[0]:] = arr
|
||||
print(
|
||||
f'part of parameter({name}) random normlize initialize'
|
||||
)
|
||||
else:
|
||||
if z.shape[0] < param.shape[0]:
|
||||
z = arr[:z.shape[0]]
|
||||
print(f'part of parameter({name}) are dropped')
|
||||
else:
|
||||
z[:param.shape[0]] = arr
|
||||
print(
|
||||
f'part of parameter({name}) random normlize initialize'
|
||||
)
|
||||
dtype, device = param.dtype, param.device
|
||||
z = torch.tensor(z, dtype=dtype, device=device)
|
||||
new_model_state_dict[name] = z
|
||||
else:
|
||||
new_model_state_dict[name] = param
|
||||
else:
|
||||
print(f'parameter({name}) are dropped')
|
||||
model_state_dict = new_model_state_dict
|
||||
|
||||
for name in parameters:
|
||||
if name not in model_state_dict:
|
||||
if parameters[name].requires_grad:
|
||||
print(f'parameter({name}) random normlize initialize')
|
||||
z = np.random.normal(
|
||||
scale=self.func_model.initializer_range,
|
||||
size=parameters[name].shape).astype('float32')
|
||||
dtype, device = parameters[name].dtype, parameters[
|
||||
name].device
|
||||
model_state_dict[name] = torch.tensor(
|
||||
z, dtype=dtype, device=device)
|
||||
else:
|
||||
model_state_dict[name] = parameters[name]
|
||||
|
||||
self.func_model.load_state_dict(model_state_dict)
|
||||
self.logger.info(
|
||||
f"Loaded model state from '{self.func_model.init_checkpoint}.model'"
|
||||
)
|
||||
|
||||
def _load_train_state():
|
||||
train_file = f'{self.func_model.init_checkpoint}.train'
|
||||
if os.path.exists(train_file):
|
||||
train_state_dict = torch.load(
|
||||
train_file, map_location=lambda storage, loc: storage)
|
||||
self.epoch = train_state_dict['epoch']
|
||||
self.best_valid_metric = train_state_dict['best_valid_metric']
|
||||
if self.optimizer is not None and 'optimizer' in train_state_dict:
|
||||
self.optimizer.load_state_dict(
|
||||
train_state_dict['optimizer'])
|
||||
if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict:
|
||||
self.lr_scheduler.load_state_dict(
|
||||
train_state_dict['lr_scheduler'])
|
||||
self.logger.info(
|
||||
f"Loaded train state from '{train_file}' with (epoch-{self.epoch} "
|
||||
f'best_valid_metric={self.best_valid_metric:.3f})')
|
||||
else:
|
||||
self.logger.info(f'Loaded no train state')
|
||||
|
||||
if self.func_model.init_checkpoint is None:
|
||||
self.logger.info(f'Loaded no model !!!')
|
||||
return
|
||||
|
||||
if self.do_train:
|
||||
_load_model_state()
|
||||
return
|
||||
|
||||
if self.do_infer:
|
||||
_load_model_state()
|
||||
_load_train_state()
|
||||
|
||||
|
||||
class MultiWOZTrainer(Trainer):
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
to_tensor,
|
||||
config,
|
||||
logger=None,
|
||||
lr_scheduler=None,
|
||||
optimizer=None,
|
||||
reader=None,
|
||||
evaluator=None):
|
||||
super(MultiWOZTrainer,
|
||||
self).__init__(model, to_tensor, config, logger, lr_scheduler,
|
||||
optimizer, reader, evaluator)
|
||||
|
||||
def train_epoch(self, train_data, dev_data):
|
||||
"""
|
||||
Train an epoch.
|
||||
"""
|
||||
times = []
|
||||
epoch_step = 0
|
||||
global_step = 0
|
||||
tr_batch_loss = 0.0
|
||||
tr_token_loss = 0.0
|
||||
self.epoch += 1
|
||||
self.batch_metrics_tracker.clear()
|
||||
self.token_metrics_tracker.clear()
|
||||
num_training_steps = self.reader.set_stats['train']['num_training_steps_per_epoch'] // \
|
||||
self.gradient_accumulation_steps # similar to the original num_batches
|
||||
|
||||
self.model.zero_grad()
|
||||
data_iterator = self.reader.get_data_iterator(all_batches=train_data)
|
||||
|
||||
for batch_idx, dial_batch in enumerate(data_iterator):
|
||||
pv_batch = []
|
||||
for turn_num, turn_batch in enumerate(dial_batch):
|
||||
first_turn = (turn_num == 0)
|
||||
samples, pv_batch = self.reader.convert_batch_turn(
|
||||
turn_batch, pv_batch, first_turn)
|
||||
batch, batch_size = self.reader.collate_fn_multi_turn(
|
||||
samples=samples)
|
||||
batch = type(batch)(
|
||||
map(lambda kv: (kv[0], self.to_tensor(kv[1])),
|
||||
batch.items()))
|
||||
|
||||
# Do a training iteration
|
||||
start_time = time.time()
|
||||
metrics = self.model(batch, is_training=True)
|
||||
if self.gpu > 1:
|
||||
for metric in metrics:
|
||||
if metric is not None:
|
||||
assert len(metric) == self.gpu
|
||||
nll, token_nll, token_num = metrics
|
||||
metrics = {}
|
||||
|
||||
token_num = torch.sum(token_num)
|
||||
token_nll = torch.sum(nll) * (batch_size /
|
||||
self.gpu) / token_num
|
||||
nll = torch.mean(nll)
|
||||
metrics['token_num'] = token_num
|
||||
metrics['token_nll'] = token_nll
|
||||
metrics['nll'] = nll
|
||||
loss = token_nll if self.func_model.token_loss else nll
|
||||
|
||||
metrics['loss'] = loss
|
||||
else:
|
||||
loss = metrics['loss']
|
||||
self.func_model._optimize(
|
||||
loss, do_update=False, optimizer=self.optimizer)
|
||||
metrics = {
|
||||
k: v.cpu().detach().numpy()
|
||||
if isinstance(v, torch.Tensor) else v
|
||||
for k, v in metrics.items()
|
||||
}
|
||||
token_num = metrics.pop('token_num', None)
|
||||
# bow_num = metrics.pop("bow_num", None)
|
||||
elapsed = time.time() - start_time
|
||||
times.append(elapsed)
|
||||
epoch_step += 1
|
||||
|
||||
tr_batch_loss += metrics['nll']
|
||||
tr_token_loss += metrics['token_nll']
|
||||
batch_metrics = {
|
||||
k: v
|
||||
for k, v in metrics.items() if 'token' not in k
|
||||
}
|
||||
token_metrics = {
|
||||
k: v
|
||||
for k, v in metrics.items() if 'token' in k
|
||||
}
|
||||
self.batch_metrics_tracker.update(batch_metrics, batch_size)
|
||||
self.token_metrics_tracker.update(token_metrics, token_num)
|
||||
|
||||
if (epoch_step % self.gradient_accumulation_steps == 0) or \
|
||||
(epoch_step == self.reader.set_stats['train']['num_training_steps_per_epoch']):
|
||||
self.optimizer.step()
|
||||
self.lr_scheduler.step()
|
||||
self.optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if self.log_steps > 0 and global_step % self.log_steps == 0:
|
||||
batch_metrics_message = self.batch_metrics_tracker.value(
|
||||
)
|
||||
token_metrics_message = self.token_metrics_tracker.value(
|
||||
)
|
||||
message_prefix = f'[Train][{self.epoch}][{global_step}/{num_training_steps}]'
|
||||
avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}'
|
||||
message = ' '.join([
|
||||
message_prefix, batch_metrics_message,
|
||||
token_metrics_message, avg_time
|
||||
])
|
||||
self.logger.info(message)
|
||||
|
||||
self.logger.info('-' * 150)
|
||||
avg_batch_loss = tr_batch_loss / epoch_step
|
||||
avg_token_loss = tr_token_loss / epoch_step
|
||||
batch_metrics_message = self.batch_metrics_tracker.summary()
|
||||
token_metrics_message = self.token_metrics_tracker.summary()
|
||||
message_prefix = f'[Valid][{self.epoch}]'
|
||||
message = ' '.join([
|
||||
message_prefix, batch_metrics_message, token_metrics_message,
|
||||
str(avg_batch_loss),
|
||||
str(avg_token_loss)
|
||||
])
|
||||
self.logger.info(message)
|
||||
|
||||
cur_valid_metric = self.batch_metrics_tracker.get(
|
||||
self.valid_metric_name)
|
||||
if self.is_decreased_valid_metric:
|
||||
is_best = cur_valid_metric < self.best_valid_metric
|
||||
else:
|
||||
is_best = cur_valid_metric > self.best_valid_metric
|
||||
if is_best:
|
||||
self.best_valid_metric = cur_valid_metric
|
||||
self.save(is_best)
|
||||
self.logger.info('-' * 150)
|
||||
|
||||
return
|
||||
|
||||
def infer(self, data_type='test'):
|
||||
"""
|
||||
Inference interface.
|
||||
"""
|
||||
self.logger.info('Generation starts ...')
|
||||
infer_save_file = os.path.join(self.save_dir,
|
||||
f'infer_{self.epoch}.result.json')
|
||||
infer_samples_save_file = os.path.join(
|
||||
self.save_dir, f'infer_samples_{self.epoch}.result.json')
|
||||
|
||||
# Inference
|
||||
result_collection = {}
|
||||
begin_time = time.time()
|
||||
|
||||
eval_data = self.reader.get_eval_data(data_type)
|
||||
set_stats = self.reader.set_stats[data_type]
|
||||
self.logger.info('***** Running Evaluation *****')
|
||||
self.logger.info(' Num Turns = %d', set_stats['num_turns'])
|
||||
|
||||
with torch.no_grad():
|
||||
pbar = tqdm(eval_data)
|
||||
for dial_idx, dialog in enumerate(pbar):
|
||||
pv_turn = {}
|
||||
for turn_idx, turn in enumerate(dialog):
|
||||
first_turn = (turn_idx == 0)
|
||||
inputs, prompt_id = self.reader.convert_turn_eval(
|
||||
turn, pv_turn, first_turn)
|
||||
batch, batch_size = self.reader.collate_fn_multi_turn(
|
||||
samples=[inputs])
|
||||
batch = type(batch)(
|
||||
map(lambda kv: (kv[0], self.to_tensor(kv[1])),
|
||||
batch.items()))
|
||||
if self.reader.use_true_curr_bspn: # generate act, response
|
||||
max_len = 60
|
||||
if not self.reader.use_true_curr_aspn:
|
||||
max_len = 80
|
||||
outputs = self.func_model.infer(
|
||||
inputs=batch,
|
||||
start_id=prompt_id,
|
||||
eos_id=self.reader.eos_r_id,
|
||||
max_gen_len=max_len)
|
||||
# resp_gen, need to trim previous context
|
||||
generated = outputs[0].cpu().numpy().tolist()
|
||||
try:
|
||||
decoded = self.decode_generated_act_resp(generated)
|
||||
except ValueError as exception:
|
||||
self.logger.info(str(exception))
|
||||
self.logger.info(self.tokenizer.decode(generated))
|
||||
decoded = {'resp': [], 'bspn': [], 'aspn': []}
|
||||
else: # predict bspn, access db, then generate act and resp
|
||||
outputs = self.func_model.infer(
|
||||
inputs=batch,
|
||||
start_id=prompt_id,
|
||||
eos_id=self.reader.eos_b_id,
|
||||
max_gen_len=60)
|
||||
generated_bs = outputs[0].cpu().numpy().tolist()
|
||||
bspn_gen = self.decode_generated_bspn(generated_bs)
|
||||
# check DB result
|
||||
if self.reader.use_true_db_pointer: # 控制当前轮的db是否为ground truth
|
||||
db = turn['db']
|
||||
else:
|
||||
db_result = self.reader.bspan_to_DBpointer(
|
||||
self.tokenizer.decode(bspn_gen),
|
||||
turn['turn_domain'])
|
||||
assert len(turn['db']) == 4
|
||||
book_result = turn['db'][2]
|
||||
assert isinstance(db_result, str)
|
||||
db = [self.reader.sos_db_id] + \
|
||||
self.tokenizer.convert_tokens_to_ids([db_result]) + \
|
||||
[book_result] + \
|
||||
[self.reader.eos_db_id]
|
||||
prompt_id = self.reader.sos_a_id
|
||||
|
||||
prev_input = torch.tensor(bspn_gen + db)
|
||||
if self.func_model.use_gpu:
|
||||
prev_input = prev_input.cuda()
|
||||
outputs_db = self.func_model.infer(
|
||||
inputs=batch,
|
||||
start_id=prompt_id,
|
||||
eos_id=self.reader.eos_r_id,
|
||||
max_gen_len=80,
|
||||
prev_input=prev_input)
|
||||
generated_ar = outputs_db[0].cpu().numpy().tolist()
|
||||
try:
|
||||
decoded = self.decode_generated_act_resp(
|
||||
generated_ar)
|
||||
decoded['bspn'] = bspn_gen
|
||||
except ValueError as exception:
|
||||
self.logger.info(str(exception))
|
||||
self.logger.info(
|
||||
self.tokenizer.decode(generated_ar))
|
||||
decoded = {'resp': [], 'bspn': [], 'aspn': []}
|
||||
|
||||
turn['resp_gen'] = decoded['resp']
|
||||
turn['bspn_gen'] = turn[
|
||||
'bspn'] if self.reader.use_true_curr_bspn else decoded[
|
||||
'bspn']
|
||||
turn['aspn_gen'] = turn[
|
||||
'aspn'] if self.reader.use_true_curr_aspn else decoded[
|
||||
'aspn']
|
||||
turn['dspn_gen'] = turn['dspn']
|
||||
|
||||
pv_turn['labels'] = inputs[
|
||||
'labels'] # all true previous context
|
||||
pv_turn['resp'] = turn[
|
||||
'resp'] if self.reader.use_true_prev_resp else decoded[
|
||||
'resp']
|
||||
if not self.reader.use_true_curr_bspn:
|
||||
pv_turn['bspn'] = turn[
|
||||
'bspn'] if self.reader.use_true_prev_bspn else decoded[
|
||||
'bspn']
|
||||
pv_turn['db'] = turn[
|
||||
'db'] if self.reader.use_true_prev_bspn else db
|
||||
pv_turn['aspn'] = turn[
|
||||
'aspn'] if self.reader.use_true_prev_aspn else decoded[
|
||||
'aspn']
|
||||
|
||||
tmp_dialog_result = self.reader.inverse_transpose_turn(dialog)
|
||||
result_collection.update(tmp_dialog_result)
|
||||
|
||||
# compute tmp scores
|
||||
results, _ = self.reader.wrap_result_lm(tmp_dialog_result)
|
||||
bleu, success, match = self.evaluator.validation_metric(
|
||||
results)
|
||||
score = 0.5 * (success + match) + bleu
|
||||
pbar.set_description(
|
||||
'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %
|
||||
(match, success, bleu, score))
|
||||
|
||||
# compute scores
|
||||
results, _ = self.reader.wrap_result_lm(result_collection)
|
||||
bleu, success, match = self.evaluator.validation_metric(results)
|
||||
score = 0.5 * (success + match) + bleu
|
||||
|
||||
# log results
|
||||
metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %\
|
||||
(match, success, bleu, score)
|
||||
message_prefix = f'[Infer][{self.epoch}]'
|
||||
time_cost = f'TIME-{time.time() - begin_time:.3f}'
|
||||
message = ' '.join([message_prefix, metrics_message, time_cost])
|
||||
self.logger.info(message)
|
||||
|
||||
# save results
|
||||
eval_results = {
|
||||
'bleu': bleu,
|
||||
'success': success,
|
||||
'match': match,
|
||||
'score': score,
|
||||
'result': message
|
||||
}
|
||||
with open(infer_save_file, 'w') as fp:
|
||||
json.dump(eval_results, fp, indent=2)
|
||||
self.logger.info(f'Saved inference results to {infer_save_file}')
|
||||
with open(infer_samples_save_file, 'w') as fp:
|
||||
for sample in results:
|
||||
line = json.dumps(sample)
|
||||
fp.write(line)
|
||||
fp.write('\n')
|
||||
self.logger.info(
|
||||
f'Saved inference samples to {infer_samples_save_file}')
|
||||
|
||||
return
|
||||
|
||||
def forward(self, turn, old_pv_turn):
|
||||
with torch.no_grad():
|
||||
first_turn = True if len(old_pv_turn) == 0 else False
|
||||
inputs, prompt_id = self.reader.convert_turn_eval(
|
||||
turn, old_pv_turn, first_turn)
|
||||
batch, batch_size = self.reader.collate_fn_multi_turn(
|
||||
samples=[inputs])
|
||||
batch = type(batch)(
|
||||
map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items()))
|
||||
pv_turn = {}
|
||||
print(batch)
|
||||
|
||||
outputs = self.func_model.infer(
|
||||
inputs=batch,
|
||||
start_id=prompt_id,
|
||||
eos_id=self.reader.eos_b_id,
|
||||
max_gen_len=60)
|
||||
generated_bs = outputs[0].cpu().numpy().tolist()
|
||||
bspn_gen = self.decode_generated_bspn(generated_bs)
|
||||
bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen)
|
||||
print(bspn_gen)
|
||||
print(bspn_token)
|
||||
turn_domain = []
|
||||
for item in bspn_token:
|
||||
if item.startswith('[') and item.endswith(']'):
|
||||
turn_domain.append(item)
|
||||
print(turn_domain)
|
||||
db_result = self.reader.bspan_to_DBpointer(
|
||||
self.tokenizer.decode(bspn_gen), ['[taxi]'])
|
||||
print(db_result)
|
||||
book_result = 21
|
||||
db = [self.reader.sos_db_id] + \
|
||||
self.tokenizer.convert_tokens_to_ids([db_result]) + \
|
||||
[book_result] + \
|
||||
[self.reader.eos_db_id]
|
||||
prompt_id = self.reader.sos_a_id
|
||||
|
||||
prev_input = torch.tensor(bspn_gen + db)
|
||||
if self.func_model.use_gpu:
|
||||
prev_input = prev_input.cuda()
|
||||
|
||||
outputs_db = self.func_model.infer(
|
||||
inputs=batch,
|
||||
start_id=prompt_id,
|
||||
eos_id=self.reader.eos_r_id,
|
||||
max_gen_len=80,
|
||||
prev_input=prev_input)
|
||||
generated_ar = outputs_db[0].cpu().numpy().tolist()
|
||||
decoded = self.decode_generated_act_resp(generated_ar)
|
||||
decoded['bspn'] = bspn_gen
|
||||
print(decoded)
|
||||
print(self.tokenizer.convert_ids_to_tokens(decoded['resp']))
|
||||
|
||||
pv_turn['labels'] = None
|
||||
pv_turn['resp'] = decoded['resp']
|
||||
pv_turn['bspn'] = decoded['bspn']
|
||||
pv_turn['db'] = None
|
||||
pv_turn['aspn'] = None
|
||||
|
||||
return pv_turn
|
||||
@@ -26,18 +26,19 @@ class DialogGenerationTest(unittest.TestCase):
|
||||
model_dir=modeldir,
|
||||
text_field=preprocessor.text_field,
|
||||
config=preprocessor.config)
|
||||
# pipeline = DialogGenerationPipeline(model, preprocessor)
|
||||
|
||||
history_dialog = {}
|
||||
for step, item in enumerate(test_case['sng0073']['log']):
|
||||
user_question = item['user']
|
||||
print('user: {}'.format(user_question))
|
||||
|
||||
# history_dialog_info = merge(history_dialog_info,
|
||||
# result) if step > 0 else {}
|
||||
# result = pipeline(user_question, history=history_dialog_info)
|
||||
#
|
||||
# print('sys : {}'.format(result['pred_answer']))
|
||||
print(model.forward(None))
|
||||
# pipeline = DialogGenerationPipeline(model=model, preprocessor=preprocessor)
|
||||
#
|
||||
# history_dialog_info = {}
|
||||
# for step, item in enumerate(test_case['sng0073']['log']):
|
||||
# user_question = item['user']
|
||||
# print('user: {}'.format(user_question))
|
||||
#
|
||||
# # history_dialog_info = merge(history_dialog_info,
|
||||
# # result) if step > 0 else {}
|
||||
# result = pipeline(user_question, history=history_dialog_info)
|
||||
# #
|
||||
# # print('sys : {}'.format(result['pred_answer']))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user