mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
merge with space
This commit is contained in:
@@ -5,4 +5,6 @@ from .palm_for_text_generation import * # noqa F403
|
||||
from .sbert_for_sentence_similarity import * # noqa F403
|
||||
from .sbert_for_token_classification import * # noqa F403
|
||||
from .sentiment_classification_model import * # noqa F403
|
||||
from .space.dialog_intent_prediction_model import * # noqa F403
|
||||
from .space.dialog_modeling_model import * # noqa F403
|
||||
from .zero_shot_classification_model import * # noqa F403
|
||||
|
||||
0
modelscope/models/nlp/space/__init__.py
Normal file
0
modelscope/models/nlp/space/__init__.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.preprocessors.space.fields.intent_field import \
|
||||
IntentBPETextField
|
||||
from modelscope.trainers.nlp.space.trainers.intent_trainer import IntentTrainer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ...base import Model, Tensor
|
||||
from ...builder import MODELS
|
||||
from .model.generator import Generator
|
||||
from .model.model_base import ModelBase
|
||||
|
||||
__all__ = ['DialogIntentModel']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.dialog_intent_prediction, module_name=r'space-intent')
|
||||
class DialogIntentModel(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the test generation model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
model_cls (Optional[Any], optional): model loader, if None, use the
|
||||
default loader to load model weights, by default None.
|
||||
"""
|
||||
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.model_dir = model_dir
|
||||
self.config = kwargs.pop(
|
||||
'config',
|
||||
Config.from_file(
|
||||
os.path.join(self.model_dir, 'configuration.json')))
|
||||
self.text_field = kwargs.pop(
|
||||
'text_field',
|
||||
IntentBPETextField(self.model_dir, config=self.config))
|
||||
|
||||
self.generator = Generator.create(self.config, reader=self.text_field)
|
||||
self.model = ModelBase.create(
|
||||
model_dir=model_dir,
|
||||
config=self.config,
|
||||
reader=self.text_field,
|
||||
generator=self.generator)
|
||||
|
||||
def to_tensor(array):
|
||||
"""
|
||||
numpy array -> tensor
|
||||
"""
|
||||
import torch
|
||||
array = torch.tensor(array)
|
||||
return array.cuda() if self.config.use_gpu else array
|
||||
|
||||
self.trainer = IntentTrainer(
|
||||
model=self.model,
|
||||
to_tensor=to_tensor,
|
||||
config=self.config,
|
||||
reader=self.text_field)
|
||||
self.trainer.load()
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""return the result by the model
|
||||
|
||||
Args:
|
||||
input (Dict[str, Any]): the preprocessed data
|
||||
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: results
|
||||
Example:
|
||||
{
|
||||
'predictions': array([1]), # lable 0-negative 1-positive
|
||||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
|
||||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
|
||||
}
|
||||
"""
|
||||
import numpy as np
|
||||
pred = self.trainer.forward(input)
|
||||
pred = np.squeeze(pred[0], 0)
|
||||
|
||||
return {'pred': pred}
|
||||
82
modelscope/models/nlp/space/dialog_modeling_model.py
Normal file
82
modelscope/models/nlp/space/dialog_modeling_model.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from modelscope.preprocessors.space.fields.gen_field import \
|
||||
MultiWOZBPETextField
|
||||
from modelscope.trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ...base import Model, Tensor
|
||||
from ...builder import MODELS
|
||||
from .model.generator import Generator
|
||||
from .model.model_base import ModelBase
|
||||
|
||||
__all__ = ['DialogModelingModel']
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.dialog_modeling, module_name=r'space-modeling')
|
||||
class DialogModelingModel(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the test generation model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
model_cls (Optional[Any], optional): model loader, if None, use the
|
||||
default loader to load model weights, by default None.
|
||||
"""
|
||||
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.model_dir = model_dir
|
||||
self.config = kwargs.pop(
|
||||
'config',
|
||||
Config.from_file(
|
||||
os.path.join(self.model_dir, 'configuration.json')))
|
||||
self.text_field = kwargs.pop(
|
||||
'text_field',
|
||||
MultiWOZBPETextField(self.model_dir, config=self.config))
|
||||
self.generator = Generator.create(self.config, reader=self.text_field)
|
||||
self.model = ModelBase.create(
|
||||
model_dir=model_dir,
|
||||
config=self.config,
|
||||
reader=self.text_field,
|
||||
generator=self.generator)
|
||||
|
||||
def to_tensor(array):
|
||||
"""
|
||||
numpy array -> tensor
|
||||
"""
|
||||
import torch
|
||||
array = torch.tensor(array)
|
||||
return array.cuda() if self.config.use_gpu else array
|
||||
|
||||
self.trainer = MultiWOZTrainer(
|
||||
model=self.model,
|
||||
to_tensor=to_tensor,
|
||||
config=self.config,
|
||||
reader=self.text_field,
|
||||
evaluator=None)
|
||||
self.trainer.load()
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""return the result by the model
|
||||
|
||||
Args:
|
||||
input (Dict[str, Any]): the preprocessed data
|
||||
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: results
|
||||
Example:
|
||||
{
|
||||
'predictions': array([1]), # lable 0-negative 1-positive
|
||||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
|
||||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
|
||||
}
|
||||
"""
|
||||
|
||||
turn = {'user': input['user']}
|
||||
old_pv_turn = input['history']
|
||||
|
||||
pv_turn = self.trainer.forward(turn=turn, old_pv_turn=old_pv_turn)
|
||||
|
||||
return pv_turn
|
||||
3
modelscope/models/nlp/space/model/__init__.py
Normal file
3
modelscope/models/nlp/space/model/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .gen_unified_transformer import GenUnifiedTransformer
|
||||
from .intent_unified_transformer import IntentUnifiedTransformer
|
||||
from .unified_transformer import UnifiedTransformer
|
||||
285
modelscope/models/nlp/space/model/gen_unified_transformer.py
Normal file
285
modelscope/models/nlp/space/model/gen_unified_transformer.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
IntentUnifiedTransformer
|
||||
"""
|
||||
import torch
|
||||
|
||||
from modelscope.models.nlp.space.model.unified_transformer import \
|
||||
UnifiedTransformer
|
||||
|
||||
|
||||
class GenUnifiedTransformer(UnifiedTransformer):
|
||||
"""
|
||||
Implement generation unified transformer.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, config, reader, generator):
|
||||
super(GenUnifiedTransformer, self).__init__(model_dir, config, reader,
|
||||
generator)
|
||||
self.understand = config.BPETextField.understand
|
||||
|
||||
if self.use_gpu:
|
||||
self.cuda()
|
||||
return
|
||||
|
||||
def _forward(self, inputs, is_training, with_label):
|
||||
""" Real forward process of model in different mode(train/test). """
|
||||
|
||||
def cat(x, y, dim=1):
|
||||
return torch.cat([x, y], dim=dim)
|
||||
|
||||
outputs = {}
|
||||
|
||||
if self.understand or self.policy:
|
||||
if self.understand:
|
||||
prompt_token = inputs['understand_token']
|
||||
prompt_mask = inputs['understand_mask']
|
||||
if self.policy:
|
||||
prompt_token = cat(prompt_token, inputs['policy_token'])
|
||||
prompt_mask = cat(prompt_mask, inputs['policy_mask'])
|
||||
else:
|
||||
prompt_token = inputs['policy_token']
|
||||
prompt_mask = inputs['policy_mask']
|
||||
|
||||
enc_embed, dec_embed, prompt_embed = self._encoder_prompt_decoder_network(
|
||||
src_token=inputs['src_token'],
|
||||
src_mask=inputs['src_mask'],
|
||||
tgt_token=inputs['tgt_token'][:, :-1],
|
||||
tgt_mask=inputs['tgt_mask'][:, :-1],
|
||||
prompt_token=prompt_token,
|
||||
prompt_mask=prompt_mask,
|
||||
src_pos=inputs['src_pos'],
|
||||
src_type=inputs['src_type'],
|
||||
src_turn=inputs['src_turn'],
|
||||
tgt_pos=inputs['tgt_pos'][:, :-1],
|
||||
tgt_type=inputs['tgt_type'][:, :-1],
|
||||
tgt_turn=inputs['tgt_turn'][:, :-1])
|
||||
else:
|
||||
enc_embed, dec_embed = self._encoder_decoder_network(
|
||||
src_token=inputs['src_token'],
|
||||
src_mask=inputs['src_mask'],
|
||||
tgt_token=inputs['tgt_token'][:, :-1],
|
||||
tgt_mask=inputs['tgt_mask'][:, :-1],
|
||||
src_pos=inputs['src_pos'],
|
||||
src_type=inputs['src_type'],
|
||||
src_turn=inputs['src_turn'],
|
||||
tgt_pos=inputs['tgt_pos'][:, :-1],
|
||||
tgt_type=inputs['tgt_type'][:, :-1],
|
||||
tgt_turn=inputs['tgt_turn'][:, :-1])
|
||||
|
||||
outputs['dec_probs'] = self._dec_head(dec_embed=dec_embed)
|
||||
return outputs
|
||||
|
||||
def _collect_metrics(self, inputs, outputs, with_label, data_file):
|
||||
|
||||
metrics = {}
|
||||
loss = 0.
|
||||
|
||||
label = inputs['tgt_token'][:, 1:]
|
||||
token_num = torch.sum(torch.sum(inputs['tgt_mask'], dim=1) - 1)
|
||||
nll = self.nll_loss(
|
||||
torch.log(outputs['dec_probs'] + 1e-12).permute(0, 2, 1), label)
|
||||
nll = torch.sum(nll, dim=1)
|
||||
token_nll = torch.sum(nll) / token_num
|
||||
nll = torch.mean(nll)
|
||||
metrics['nll'] = nll
|
||||
metrics['token_nll'] = token_nll
|
||||
metrics['token_num'] = token_num
|
||||
loss = loss + (token_nll if self.token_loss else nll)
|
||||
|
||||
metrics['loss'] = loss
|
||||
if self.gpu > 1:
|
||||
return nll, token_nll, token_num
|
||||
else:
|
||||
return metrics
|
||||
|
||||
def _optimize(self, loss, do_update=False, optimizer=None):
|
||||
""" Optimize loss function and update model. """
|
||||
assert optimizer is not None
|
||||
|
||||
if self.gradient_accumulation_steps > 1:
|
||||
loss = loss / self.gradient_accumulation_steps
|
||||
|
||||
loss.backward()
|
||||
|
||||
if self.grad_clip is not None and self.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
parameters=self.parameters(), max_norm=self.grad_clip)
|
||||
|
||||
if do_update:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
return
|
||||
|
||||
def _init_state(self,
|
||||
src_token,
|
||||
src_mask,
|
||||
src_pos=None,
|
||||
src_type=None,
|
||||
src_turn=None):
|
||||
""" Initialize decode state. """
|
||||
state = {}
|
||||
batch_size = src_token.shape[0]
|
||||
|
||||
src_embed = self.embedder(src_token, src_pos, src_type, src_turn)
|
||||
src_embed = self.embed_layer_norm(src_embed)
|
||||
|
||||
mask = self._create_mask(src_mask, append_head=False)
|
||||
|
||||
enc_out = src_embed
|
||||
|
||||
cache = {}
|
||||
for _l, layer in enumerate(self.layers):
|
||||
cache[f'layer_{_l}'] = {}
|
||||
enc_out = layer(enc_out, mask, cache[f'layer_{_l}'])
|
||||
|
||||
state['cache'] = cache
|
||||
state['mask'] = mask[:, :1]
|
||||
state['batch_size'] = batch_size
|
||||
shape = [batch_size, 1, 1]
|
||||
state['pred_mask'] = torch.ones(shape, dtype=torch.float32)
|
||||
state['pred_pos'] = torch.zeros(shape, dtype=torch.int64)
|
||||
state['pred_type'] = torch.zeros(shape, dtype=torch.int64)
|
||||
state['pred_turn'] = torch.zeros(shape, dtype=torch.int64)
|
||||
if self.use_gpu:
|
||||
state['pred_mask'] = state['pred_mask'].cuda()
|
||||
state['pred_pos'] = state['pred_pos'].cuda()
|
||||
state['pred_type'] = state['pred_type'].cuda()
|
||||
state['pred_turn'] = state['pred_turn'].cuda()
|
||||
|
||||
return state
|
||||
|
||||
def _init_prompt_state(self,
|
||||
src_token,
|
||||
src_mask,
|
||||
prompt_token,
|
||||
prompt_mask,
|
||||
src_pos=None,
|
||||
src_type=None,
|
||||
src_turn=None,
|
||||
prompt_pos=None,
|
||||
prompt_type=None,
|
||||
prompt_turn=None):
|
||||
""" Initialize decode state. """
|
||||
state = {}
|
||||
batch_size = src_token.shape[0]
|
||||
|
||||
src_embed = self.embedder(src_token, src_pos, src_type, src_turn)
|
||||
prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type,
|
||||
prompt_turn)
|
||||
embed = torch.cat([src_embed, prompt_embed], dim=1)
|
||||
embed = self.embed_layer_norm(embed)
|
||||
enc_out = embed
|
||||
|
||||
enc_mask = self._create_mask(src_mask, auto_regressive=False)
|
||||
dec_mask = self._create_mask(prompt_mask, auto_regressive=True)
|
||||
mask = self._join_mask(enc_mask, dec_mask)
|
||||
|
||||
cache = {}
|
||||
for _l, layer in enumerate(self.layers):
|
||||
cache[f'layer_{_l}'] = {}
|
||||
enc_out = layer(enc_out, mask, cache[f'layer_{_l}'])
|
||||
|
||||
state['cache'] = cache
|
||||
state['mask'] = mask[:, -1:] # state["mask"] = mask[:, :1]
|
||||
state['batch_size'] = batch_size
|
||||
shape = [batch_size, 1, 1]
|
||||
state['pred_mask'] = torch.ones(shape, dtype=torch.float32)
|
||||
state['pred_pos'] = torch.zeros(shape, dtype=torch.int64)
|
||||
state['pred_type'] = torch.zeros(shape, dtype=torch.int64)
|
||||
state['pred_turn'] = torch.zeros(shape, dtype=torch.int64)
|
||||
if self.use_gpu:
|
||||
state['pred_mask'] = state['pred_mask'].cuda()
|
||||
state['pred_pos'] = state['pred_pos'].cuda()
|
||||
state['pred_type'] = state['pred_type'].cuda()
|
||||
state['pred_turn'] = state['pred_turn'].cuda()
|
||||
|
||||
return state
|
||||
|
||||
def _decode(self, state):
|
||||
""" Decoding one time stamp. """
|
||||
|
||||
# shape: [batch_size, 1, seq_len]
|
||||
mask = state['mask']
|
||||
|
||||
# shape: [batch_size, 1, 1]
|
||||
pred_token = state['pred_token']
|
||||
pred_mask = state['pred_mask']
|
||||
pred_pos = state['pred_pos']
|
||||
pred_type = state['pred_type']
|
||||
pred_turn = state['pred_turn']
|
||||
|
||||
# list of shape(len: num_layers): [batch_size, seq_len, hidden_dim]
|
||||
cache = state['cache']
|
||||
|
||||
pred_embed = self.embedder(pred_token, pred_pos, pred_type,
|
||||
pred_turn).squeeze(-2)
|
||||
pred_embed = self.embed_layer_norm(pred_embed)
|
||||
|
||||
# shape: [batch_size, 1, seq_len + 1]
|
||||
mask = torch.cat([mask, 1 - pred_mask], dim=2)
|
||||
|
||||
# shape: [batch_size, 1, hidden_dim]
|
||||
for _l, layer in enumerate(self.layers):
|
||||
pred_embed = layer(pred_embed, mask, cache[f'layer_{_l}'])
|
||||
|
||||
# shape: [batch_size, vocab_size]
|
||||
pred_probs = self._dec_head(dec_embed=pred_embed[:, 0])
|
||||
pred_logits = torch.log(pred_probs)
|
||||
|
||||
state['mask'] = mask
|
||||
return pred_logits, state
|
||||
|
||||
def _infer(self,
|
||||
inputs,
|
||||
start_id=None,
|
||||
eos_id=None,
|
||||
max_gen_len=None,
|
||||
prev_input=None):
|
||||
""" Real inference process of model. """
|
||||
|
||||
def cat(x, y, dim=1):
|
||||
return torch.cat([x, y], dim=dim)
|
||||
|
||||
# Initial decode state.
|
||||
if self.understand or self.policy:
|
||||
if self.understand:
|
||||
prompt_token = inputs['understand_token']
|
||||
prompt_mask = inputs['understand_mask']
|
||||
if self.policy:
|
||||
prompt_token = cat(prompt_token, inputs['policy_token'])
|
||||
prompt_mask = cat(prompt_mask, inputs['policy_mask'])
|
||||
else:
|
||||
prompt_token = inputs['policy_token']
|
||||
prompt_mask = inputs['policy_mask']
|
||||
|
||||
state = self._init_prompt_state(
|
||||
src_token=inputs['src_token'],
|
||||
src_mask=inputs['src_mask'],
|
||||
prompt_token=prompt_token,
|
||||
prompt_mask=prompt_mask,
|
||||
src_pos=inputs['src_pos'],
|
||||
src_type=inputs['src_type'],
|
||||
src_turn=inputs['src_turn'])
|
||||
else:
|
||||
state = self._init_state(
|
||||
src_token=inputs['src_token'],
|
||||
src_mask=inputs['src_mask'],
|
||||
src_pos=inputs['src_pos'],
|
||||
src_type=inputs['src_type'],
|
||||
src_turn=inputs['src_turn'])
|
||||
|
||||
# Generation process.
|
||||
gen_results = self.generator(
|
||||
step_fn=self._decode,
|
||||
state=state,
|
||||
start_id=start_id,
|
||||
eos_id=eos_id,
|
||||
max_gen_len=max_gen_len,
|
||||
prev_input=prev_input)
|
||||
|
||||
outputs = gen_results['preds']
|
||||
return outputs
|
||||
|
||||
|
||||
GenUnifiedTransformer.register('GenUnifiedTransformer')
|
||||
290
modelscope/models/nlp/space/model/generator.py
Normal file
290
modelscope/models/nlp/space/model/generator.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
Generator class.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def repeat(var, times):
|
||||
if isinstance(var, list):
|
||||
return [repeat(x, times) for x in var]
|
||||
elif isinstance(var, dict):
|
||||
return {k: repeat(v, times) for k, v in var.items()}
|
||||
elif isinstance(var, torch.Tensor):
|
||||
var = var.unsqueeze(1)
|
||||
expand_times = [1] * len(var.shape)
|
||||
expand_times[1] = times
|
||||
dtype = var.dtype
|
||||
var = var.float()
|
||||
var = var.repeat(*expand_times)
|
||||
shape = [var.shape[0] * var.shape[1]] + list(var.shape[2:])
|
||||
var = var.reshape(*shape)
|
||||
var = torch.tensor(var, dtype=dtype)
|
||||
return var
|
||||
else:
|
||||
return var
|
||||
|
||||
|
||||
def gather(var, idx):
|
||||
if isinstance(var, list):
|
||||
return [gather(x, idx) for x in var]
|
||||
elif isinstance(var, dict):
|
||||
return {k: gather(v, idx) for k, v in var.items()}
|
||||
elif isinstance(var, torch.Tensor):
|
||||
out = var.index_select(dim=0, index=idx)
|
||||
return out
|
||||
else:
|
||||
return var
|
||||
|
||||
|
||||
class Generator(object):
|
||||
""" Genrator class. """
|
||||
|
||||
_registry = dict()
|
||||
|
||||
@classmethod
|
||||
def register(cls, name):
|
||||
Generator._registry[name] = cls
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def by_name(name):
|
||||
return Generator._registry[name]
|
||||
|
||||
@staticmethod
|
||||
def create(config, *args, **kwargs):
|
||||
""" Create generator. """
|
||||
generator_cls = Generator.by_name(config.Generator.generator)
|
||||
return generator_cls(config, *args, **kwargs)
|
||||
|
||||
def __init__(self, config, reader):
|
||||
self.vocab_size = reader.vocab_size
|
||||
self.bos_id = reader.bos_id
|
||||
self.eos_id = reader.eos_id
|
||||
self.unk_id = reader.unk_id
|
||||
self.pad_id = reader.pad_id
|
||||
self.min_gen_len = config.Generator.min_gen_len
|
||||
self.max_gen_len = config.Generator.max_gen_len
|
||||
self.use_gpu = config.use_gpu
|
||||
assert 1 <= self.min_gen_len <= self.max_gen_len
|
||||
return
|
||||
|
||||
def __call__(self, step_fn, state):
|
||||
"""
|
||||
Running generation.
|
||||
|
||||
@param : step_fn : decoding one step
|
||||
@type : function
|
||||
|
||||
@param : state : initial state
|
||||
@type : dict
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BeamSearch(Generator):
|
||||
""" BeamSearch generator. """
|
||||
|
||||
def __init__(self, config, reader):
|
||||
super().__init__(config, reader)
|
||||
self.beam_size = config.Generator.beam_size
|
||||
self.length_average = config.Generator.length_average
|
||||
self.length_penalty = config.Generator.length_penalty
|
||||
self.ignore_unk = config.Generator.ignore_unk
|
||||
return
|
||||
|
||||
def __call__(self,
|
||||
step_fn,
|
||||
state,
|
||||
start_id=None,
|
||||
eos_id=None,
|
||||
max_gen_len=None,
|
||||
prev_input=None):
|
||||
"""
|
||||
Running beam search.
|
||||
|
||||
@param : step_fn : decoding one step
|
||||
@type : function
|
||||
|
||||
@param : state : initial state
|
||||
@type : dict
|
||||
"""
|
||||
if prev_input is not None:
|
||||
|
||||
if isinstance(prev_input, list):
|
||||
length = max(list(map(lambda x: len(x), prev_input)))
|
||||
prev_input_numpy = np.full((len(prev_input), length),
|
||||
self.pad_id)
|
||||
for i, x in enumerate(prev_input):
|
||||
prev_input_numpy[i, :len(x)] = x
|
||||
prev_input_tensor = torch.from_numpy(prev_input_numpy)
|
||||
if self.use_gpu:
|
||||
prev_input_tensor = prev_input_tensor.cuda()
|
||||
|
||||
for i in range(length):
|
||||
state['pred_token'] = prev_input_tensor[:, i].unsqueeze(
|
||||
-1).unsqueeze(-1)
|
||||
if i != 0:
|
||||
state['pred_mask'] = torch.not_equal(
|
||||
state['pred_token'], self.pad_id).float()
|
||||
state['pred_pos'] = state['pred_pos'] + state[
|
||||
'pred_mask'].int()
|
||||
_, state = step_fn(state)
|
||||
else:
|
||||
assert isinstance(prev_input, torch.Tensor)
|
||||
for i, input in enumerate(prev_input):
|
||||
state['pred_token'] = input.expand(1, 1, 1)
|
||||
if i != 0:
|
||||
state['pred_mask'] = torch.not_equal(
|
||||
state['pred_token'], self.pad_id).float()
|
||||
state['pred_pos'] = state['pred_pos'] + 1
|
||||
_, state = step_fn(state)
|
||||
|
||||
batch_size = state['batch_size']
|
||||
beam_size = self.beam_size
|
||||
|
||||
# shape: [batch_size, 1]
|
||||
pos_index = torch.arange(
|
||||
0, batch_size, 1, dtype=torch.int64) * beam_size
|
||||
pos_index = pos_index.unsqueeze(1)
|
||||
|
||||
# shape: [batch_size, beam_size, 1]
|
||||
if start_id is None:
|
||||
start_id = self.bos_id
|
||||
if eos_id is None:
|
||||
eos_id = self.eos_id
|
||||
predictions = torch.ones([batch_size, beam_size, 1],
|
||||
dtype=torch.int64) * start_id
|
||||
|
||||
if self.use_gpu:
|
||||
pos_index = pos_index.cuda()
|
||||
predictions = predictions.cuda()
|
||||
|
||||
# initial input (start_id)
|
||||
state['pred_token'] = predictions[:, :1]
|
||||
if prev_input is not None:
|
||||
state['pred_mask'] = torch.not_equal(state['pred_token'],
|
||||
self.pad_id).float()
|
||||
state['pred_pos'] = state['pred_pos'] + 1
|
||||
|
||||
# shape: [batch_size, vocab_size]
|
||||
scores, state = step_fn(state)
|
||||
|
||||
unk_penalty = np.zeros(self.vocab_size, dtype='float32')
|
||||
unk_penalty[self.unk_id] = -1e10
|
||||
unk_penalty = torch.from_numpy(unk_penalty)
|
||||
|
||||
eos_penalty = np.zeros(self.vocab_size, dtype='float32')
|
||||
eos_penalty[eos_id] = -1e10
|
||||
eos_penalty = torch.from_numpy(eos_penalty)
|
||||
|
||||
scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32')
|
||||
scores_after_end[
|
||||
self.pad_id] = 0 # 希望<eos>之后只生成<pad>,故使词表中log(p(<pad>))最高(0)
|
||||
scores_after_end = torch.from_numpy(scores_after_end)
|
||||
|
||||
if self.use_gpu:
|
||||
unk_penalty = unk_penalty.cuda()
|
||||
eos_penalty = eos_penalty.cuda()
|
||||
scores_after_end = scores_after_end.cuda()
|
||||
|
||||
if self.ignore_unk:
|
||||
scores = scores + unk_penalty
|
||||
scores = scores + eos_penalty
|
||||
|
||||
# shape: [batch_size, beam_size]
|
||||
sequence_scores, preds = torch.topk(scores, self.beam_size)
|
||||
|
||||
predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2)
|
||||
state = repeat(state, beam_size)
|
||||
|
||||
if max_gen_len is None:
|
||||
max_gen_len = self.max_gen_len
|
||||
for step in range(2, max_gen_len + 1):
|
||||
pre_ids = predictions[:, :, -1:]
|
||||
state['pred_token'] = pre_ids.reshape(batch_size * beam_size, 1, 1)
|
||||
state['pred_mask'] = torch.not_equal(state['pred_token'],
|
||||
self.pad_id).float()
|
||||
state['pred_pos'] = state['pred_pos'] + 1
|
||||
scores, state = step_fn(state)
|
||||
|
||||
# Generate next
|
||||
# scores shape: [batch_size * beam_size, vocab_size]
|
||||
if self.ignore_unk:
|
||||
scores = scores + unk_penalty
|
||||
|
||||
if step <= self.min_gen_len:
|
||||
scores = scores + eos_penalty
|
||||
|
||||
# scores shape: [batch_size, beam_size, vocab_size]
|
||||
scores = scores.reshape(batch_size, beam_size, self.vocab_size)
|
||||
|
||||
# previous token is [PAD] or [EOS]
|
||||
pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \
|
||||
(1 - torch.not_equal(pre_ids, self.pad_id).float())
|
||||
|
||||
scores = scores * (1 - pre_eos_mask) + pre_eos_mask.repeat(
|
||||
1, 1, self.vocab_size) * scores_after_end
|
||||
if self.length_average:
|
||||
scaled_value = \
|
||||
pre_eos_mask + (1 - pre_eos_mask) * (1 - 1 / step)
|
||||
sequence_scores = sequence_scores.unsqueeze(2) * scaled_value
|
||||
scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step)
|
||||
scores = scores * scaled_value
|
||||
elif self.length_penalty >= 0.0:
|
||||
scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \
|
||||
(math.pow((4 + step) / (5 + step), self.length_penalty))
|
||||
sequence_scores = scaled_value * sequence_scores
|
||||
scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \
|
||||
(math.pow(1 / (5 + step), self.length_penalty))
|
||||
scores = scores * scaled_value
|
||||
scores = scores + sequence_scores.unsqueeze(-1)
|
||||
scores = scores.reshape(batch_size, beam_size * self.vocab_size)
|
||||
|
||||
topk_scores, topk_indices = torch.topk(scores, beam_size)
|
||||
# topk_indices: [batch_size, beam_size * self.vocab_size] (已reshape)
|
||||
# 判断当前时间步产生词的前一个词在哪个beam中,对vocab_size取商
|
||||
parent_idx = topk_indices.floor_divide(self.vocab_size)
|
||||
# 对vocab_size取余
|
||||
preds = topk_indices % self.vocab_size
|
||||
|
||||
# Gather state / sequence_scores
|
||||
parent_idx = parent_idx + pos_index
|
||||
parent_idx = parent_idx.reshape(batch_size * beam_size)
|
||||
state = gather(state, parent_idx)
|
||||
sequence_scores = topk_scores
|
||||
|
||||
predictions = predictions.reshape(batch_size * beam_size, step)
|
||||
predictions = gather(predictions, parent_idx)
|
||||
predictions = predictions.reshape(batch_size, beam_size, step)
|
||||
predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2)
|
||||
|
||||
# 希望生成的整个句子已完结,所以要求最后一个token为<eos>或者<pad>(跟在<eos>之后),否则惩罚
|
||||
pre_ids = predictions[:, :, -1]
|
||||
pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \
|
||||
(1 - torch.not_equal(pre_ids, self.pad_id).float())
|
||||
sequence_scores = sequence_scores * pre_eos_mask + (
|
||||
1 - pre_eos_mask) * (-1e10)
|
||||
|
||||
# 先获得ascending排序的index,便于之后对predictions和sequence_scores排序(针对beam size轴)
|
||||
indices = torch.argsort(sequence_scores, dim=1)
|
||||
indices = indices + pos_index
|
||||
indices = indices.reshape(-1)
|
||||
sequence_scores = sequence_scores.reshape(batch_size * beam_size)
|
||||
predictions = predictions.reshape(batch_size * beam_size, -1)
|
||||
sequence_scores = gather(sequence_scores, indices)
|
||||
predictions = gather(predictions, indices)
|
||||
sequence_scores = sequence_scores.reshape(batch_size, beam_size)
|
||||
predictions = predictions.reshape(batch_size, beam_size, -1)
|
||||
|
||||
results = {
|
||||
'preds': predictions[:, -1],
|
||||
'scores': sequence_scores[:, -1]
|
||||
}
|
||||
return results
|
||||
|
||||
|
||||
BeamSearch.register('BeamSearch')
|
||||
198
modelscope/models/nlp/space/model/intent_unified_transformer.py
Normal file
198
modelscope/models/nlp/space/model/intent_unified_transformer.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
IntentUnifiedTransformer
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.utils.nlp.space.criterions import compute_kl_loss
|
||||
from .unified_transformer import UnifiedTransformer
|
||||
|
||||
|
||||
class IntentUnifiedTransformer(UnifiedTransformer):
|
||||
"""
|
||||
Implement intent unified transformer.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, config, reader, generator):
|
||||
super(IntentUnifiedTransformer, self).__init__(model_dir, config,
|
||||
reader, generator)
|
||||
self.example = config.Model.example
|
||||
self.num_intent = config.Model.num_intent
|
||||
self.with_rdrop = config.Model.with_rdrop
|
||||
self.kl_ratio = config.Model.kl_ratio
|
||||
self.loss_fct = nn.CrossEntropyLoss()
|
||||
if self.example:
|
||||
self.loss_fct = nn.NLLLoss()
|
||||
else:
|
||||
self.intent_classifier = nn.Linear(self.hidden_dim,
|
||||
self.num_intent)
|
||||
self.loss_fct = nn.CrossEntropyLoss()
|
||||
|
||||
if self.use_gpu:
|
||||
self.cuda()
|
||||
return
|
||||
|
||||
def _forward(self, inputs, is_training, with_label):
|
||||
""" Real forward process of model in different mode(train/test). """
|
||||
|
||||
def aug(v):
|
||||
assert isinstance(v, torch.Tensor)
|
||||
return torch.cat([v, v], dim=0)
|
||||
|
||||
outputs = {}
|
||||
|
||||
if self.with_mlm:
|
||||
mlm_embed = self._encoder_network(
|
||||
input_token=inputs['mlm_token'],
|
||||
input_mask=inputs['src_mask'],
|
||||
input_pos=inputs['src_pos'],
|
||||
input_type=inputs['src_type'],
|
||||
input_turn=inputs['src_turn'])
|
||||
outputs['mlm_probs'] = self._mlm_head(mlm_embed=mlm_embed)
|
||||
|
||||
if self.with_rdrop or self.with_contrastive:
|
||||
enc_embed, dec_embed = self._encoder_decoder_network(
|
||||
src_token=aug(inputs['src_token']),
|
||||
src_mask=aug(inputs['src_mask']),
|
||||
tgt_token=aug(inputs['tgt_token']),
|
||||
tgt_mask=aug(inputs['tgt_mask']),
|
||||
src_pos=aug(inputs['src_pos']),
|
||||
src_type=aug(inputs['src_type']),
|
||||
src_turn=aug(inputs['src_turn']))
|
||||
else:
|
||||
enc_embed, dec_embed = self._encoder_decoder_network(
|
||||
src_token=inputs['src_token'],
|
||||
src_mask=inputs['src_mask'],
|
||||
tgt_token=inputs['tgt_token'],
|
||||
tgt_mask=inputs['tgt_mask'],
|
||||
src_pos=inputs['src_pos'],
|
||||
src_type=inputs['src_type'],
|
||||
src_turn=inputs['src_turn'])
|
||||
features = dec_embed[:, -1]
|
||||
features = self.pooler(features) if self.with_pool else features
|
||||
|
||||
if self.example:
|
||||
assert not self.with_rdrop
|
||||
ex_enc_embed, ex_dec_embed = self._encoder_decoder_network(
|
||||
src_token=inputs['example_src_token'],
|
||||
src_mask=inputs['example_src_mask'],
|
||||
tgt_token=inputs['example_tgt_token'],
|
||||
tgt_mask=inputs['example_tgt_mask'],
|
||||
src_pos=inputs['example_src_pos'],
|
||||
src_type=inputs['example_src_type'],
|
||||
src_turn=inputs['example_src_turn'])
|
||||
ex_features = ex_dec_embed[:, -1]
|
||||
ex_features = self.pooler(
|
||||
ex_features) if self.with_pool else ex_features
|
||||
|
||||
probs = self.softmax(features.mm(ex_features.t()))
|
||||
example_intent = inputs['example_intent'].unsqueeze(0)
|
||||
intent_probs = torch.zeros(probs.size(0), self.num_intent)
|
||||
intent_probs = intent_probs.cuda(
|
||||
) if self.use_gpu else intent_probs
|
||||
intent_probs = intent_probs.scatter_add(
|
||||
-1, example_intent.repeat(probs.size(0), 1), probs)
|
||||
outputs['intent_probs'] = intent_probs
|
||||
else:
|
||||
intent_logits = self.intent_classifier(features)
|
||||
outputs['intent_logits'] = intent_logits
|
||||
|
||||
if self.with_contrastive:
|
||||
features = features if self.with_pool else self.pooler(features)
|
||||
batch_size = features.size(0) // 2
|
||||
features = \
|
||||
torch.cat(
|
||||
[features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)],
|
||||
dim=1
|
||||
)
|
||||
features = F.normalize(features, dim=-1, p=2)
|
||||
outputs['features'] = features
|
||||
|
||||
return outputs
|
||||
|
||||
def _collect_metrics(self, inputs, outputs, with_label, data_file):
|
||||
|
||||
metrics = {}
|
||||
batch_size = inputs['src_token'].size(0)
|
||||
|
||||
intent_label = torch.cat([inputs['intent_label'], inputs['intent_label']], dim=0) \
|
||||
if self.with_rdrop or self.with_contrastive else inputs['intent_label']
|
||||
|
||||
if self.example:
|
||||
intent_loss = self.loss_fct(
|
||||
torch.log(outputs['intent_probs'] + 1e-12).view(
|
||||
-1, self.num_intent), intent_label.type(torch.long))
|
||||
else:
|
||||
intent_loss = self.loss_fct(
|
||||
outputs['intent_logits'].view(-1, self.num_intent),
|
||||
intent_label.type(torch.long))
|
||||
metrics['intent_loss'] = intent_loss
|
||||
loss = intent_loss
|
||||
|
||||
if self.with_mlm:
|
||||
mlm_num = torch.sum(torch.sum(inputs['mlm_mask'], dim=1))
|
||||
mlm = self.nll_loss(
|
||||
torch.log(outputs['mlm_probs'] + 1e-12).permute(0, 2, 1),
|
||||
inputs['mlm_label'])
|
||||
mlm = torch.sum(mlm, dim=1)
|
||||
token_mlm = torch.sum(mlm) / mlm_num
|
||||
mlm = torch.mean(mlm)
|
||||
metrics['mlm'] = mlm
|
||||
metrics['token_mlm'] = token_mlm
|
||||
metrics['mlm_num'] = mlm_num
|
||||
loss = loss + (token_mlm
|
||||
if self.token_loss else mlm) * self.mlm_ratio
|
||||
else:
|
||||
mlm, token_mlm, mlm_num = None, None, None
|
||||
|
||||
if self.with_rdrop:
|
||||
kl = compute_kl_loss(
|
||||
p=outputs['intent_logits'][:batch_size],
|
||||
q=outputs['intent_logits'][batch_size:])
|
||||
metrics['kl'] = kl
|
||||
loss = loss + kl * self.kl_ratio
|
||||
else:
|
||||
kl = None
|
||||
|
||||
if self.with_contrastive:
|
||||
pass
|
||||
con = None
|
||||
else:
|
||||
con = None
|
||||
|
||||
metrics['loss'] = loss
|
||||
|
||||
if self.gpu > 1:
|
||||
return intent_loss, mlm, token_mlm, mlm_num, kl, con
|
||||
else:
|
||||
return metrics
|
||||
|
||||
def _infer(self,
|
||||
inputs,
|
||||
start_id=None,
|
||||
eos_id=None,
|
||||
max_gen_len=None,
|
||||
prev_input=None):
|
||||
""" Real inference process of model. """
|
||||
results = {}
|
||||
enc_embed, dec_embed = self._encoder_decoder_network(
|
||||
src_token=inputs['src_token'],
|
||||
src_mask=inputs['src_mask'],
|
||||
tgt_token=inputs['tgt_token'],
|
||||
tgt_mask=inputs['tgt_mask'],
|
||||
src_pos=inputs['src_pos'],
|
||||
src_type=inputs['src_type'],
|
||||
src_turn=inputs['src_turn'])
|
||||
features = dec_embed[:, -1]
|
||||
features = self.pooler(features) if self.with_pool else features
|
||||
if self.example:
|
||||
results['features'] = features
|
||||
else:
|
||||
intent_logits = self.intent_classifier(features)
|
||||
intent_probs = self.softmax(intent_logits)
|
||||
results['intent_probs'] = intent_probs
|
||||
return results
|
||||
|
||||
|
||||
IntentUnifiedTransformer.register('IntentUnifiedTransformer')
|
||||
99
modelscope/models/nlp/space/model/model_base.py
Normal file
99
modelscope/models/nlp/space/model/model_base.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Model base
|
||||
"""
|
||||
import os
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ModelBase(nn.Module):
|
||||
"""
|
||||
Basic model wrapper for static graph and dygrpah.
|
||||
"""
|
||||
_registry = dict()
|
||||
|
||||
@classmethod
|
||||
def register(cls, name):
|
||||
ModelBase._registry[name] = cls
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def by_name(name):
|
||||
return ModelBase._registry[name]
|
||||
|
||||
@staticmethod
|
||||
def create(model_dir, config, *args, **kwargs):
|
||||
model_cls = ModelBase.by_name(config.Model.model)
|
||||
return model_cls(model_dir, config, *args, **kwargs)
|
||||
|
||||
def __init__(self, model_dir, config):
|
||||
super(ModelBase, self).__init__()
|
||||
self.init_checkpoint = os.path.join(model_dir, 'pytorch_model.bin')
|
||||
self.abandon_label = config.Dataset.abandon_label
|
||||
self.use_gpu = config.use_gpu
|
||||
self.gpu = config.Trainer.gpu
|
||||
return
|
||||
|
||||
def _create_parameters(self):
|
||||
""" Create model's paramters. """
|
||||
raise NotImplementedError
|
||||
|
||||
def _forward(self, inputs, is_training, with_label):
|
||||
""" NO LABEL: Real forward process of model in different mode(train/test). """
|
||||
raise NotImplementedError
|
||||
|
||||
def _collect_metrics(self, inputs, outputs, with_label, data_file):
|
||||
""" NO LABEL: Calculate loss function by using inputs and outputs. """
|
||||
raise NotImplementedError
|
||||
|
||||
def _optimize(self, loss, optimizer, lr_scheduler):
|
||||
""" Optimize loss function and update model. """
|
||||
raise NotImplementedError
|
||||
|
||||
def _infer(self, inputs, start_id, eos_id, max_gen_len, prev_input):
|
||||
""" Real inference process of model. """
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self,
|
||||
inputs,
|
||||
is_training=False,
|
||||
with_label=False,
|
||||
data_file=None):
|
||||
"""
|
||||
Forward process, include real forward, collect metrices and optimize(optional)
|
||||
|
||||
@params : inputs : input data
|
||||
@type : dict of numpy.ndarray/int/float/...
|
||||
"""
|
||||
if is_training:
|
||||
self.train()
|
||||
else:
|
||||
self.eval()
|
||||
|
||||
with_label = False if self.abandon_label else with_label
|
||||
outputs = self._forward(inputs, is_training, with_label=with_label)
|
||||
metrics = self._collect_metrics(
|
||||
inputs, outputs, with_label=with_label, data_file=data_file)
|
||||
|
||||
return metrics
|
||||
|
||||
def infer(self,
|
||||
inputs,
|
||||
start_id=None,
|
||||
eos_id=None,
|
||||
max_gen_len=None,
|
||||
prev_input=None):
|
||||
"""
|
||||
Inference process.
|
||||
|
||||
@params : inputs : input data
|
||||
@type : dict of numpy.ndarray/int/float/...
|
||||
"""
|
||||
self.eval()
|
||||
results = self._infer(
|
||||
inputs,
|
||||
start_id=start_id,
|
||||
eos_id=eos_id,
|
||||
max_gen_len=max_gen_len,
|
||||
prev_input=prev_input)
|
||||
return results
|
||||
322
modelscope/models/nlp/space/model/unified_transformer.py
Normal file
322
modelscope/models/nlp/space/model/unified_transformer.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
UnifiedTransformer
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.nlp.space.model.model_base import ModelBase
|
||||
from modelscope.models.nlp.space.modules.embedder import Embedder
|
||||
from modelscope.models.nlp.space.modules.transformer_block import \
|
||||
TransformerBlock
|
||||
|
||||
|
||||
class UnifiedTransformer(ModelBase):
|
||||
"""
|
||||
Implement unified transformer.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, config, reader, generator, dtype='float32'):
|
||||
super(UnifiedTransformer, self).__init__(model_dir, config)
|
||||
self.reader = reader
|
||||
self.generator = generator
|
||||
self.policy = config.BPETextField.policy
|
||||
self.generation = config.BPETextField.generation
|
||||
self.num_token_embeddings = config.Model.num_token_embeddings
|
||||
self.num_pos_embeddings = config.Model.num_pos_embeddings
|
||||
self.num_type_embeddings = config.Model.num_type_embeddings
|
||||
self.num_turn_embeddings = config.Model.num_turn_embeddings
|
||||
self.temperature = config.Model.temperature
|
||||
self.hidden_dim = config.Model.hidden_dim
|
||||
self.num_heads = config.Model.num_heads
|
||||
self.num_layers = config.Model.num_layers
|
||||
self.padding_idx = config.Model.padding_idx
|
||||
self.dropout = config.Model.dropout
|
||||
self.embed_dropout = config.Model.embed_dropout
|
||||
self.attn_dropout = config.Model.attn_dropout
|
||||
self.ff_dropout = config.Model.ff_dropout
|
||||
self.mlm_ratio = config.Model.mlm_ratio
|
||||
self.mmd_ratio = config.Model.mmd_ratio
|
||||
self.pos_trainable = config.Model.pos_trainable
|
||||
self.label_smooth = config.Model.label_smooth
|
||||
self.initializer_range = config.Model.initializer_range
|
||||
self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps
|
||||
self.token_loss = config.Trainer.token_loss
|
||||
self.learning_method = config.Dataset.learning_method
|
||||
self.with_contrastive = config.Dataset.with_contrastive
|
||||
self.with_query_bow = config.BPETextField.with_query_bow
|
||||
self.with_resp_bow = config.BPETextField.with_resp_bow
|
||||
self.with_pool = config.Model.with_pool
|
||||
self.with_mlm = config.Dataset.with_mlm
|
||||
self._dtype = dtype
|
||||
|
||||
self.embedder = Embedder(
|
||||
self.hidden_dim,
|
||||
self.num_token_embeddings,
|
||||
self.num_pos_embeddings,
|
||||
self.num_type_embeddings,
|
||||
self.num_turn_embeddings,
|
||||
padding_idx=self.padding_idx,
|
||||
dropout=self.embed_dropout,
|
||||
pos_trainable=self.pos_trainable)
|
||||
self.embed_layer_norm = nn.LayerNorm(
|
||||
normalized_shape=self.hidden_dim,
|
||||
eps=1e-12,
|
||||
elementwise_affine=True)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerBlock(self.hidden_dim, self.num_heads, self.dropout,
|
||||
self.attn_dropout, self.ff_dropout)
|
||||
for _ in range(config.Model.num_layers)
|
||||
])
|
||||
|
||||
if self.with_mlm:
|
||||
self.mlm_transform = nn.Sequential(
|
||||
nn.Linear(self.hidden_dim, self.hidden_dim), nn.GELU(),
|
||||
nn.LayerNorm(
|
||||
normalized_shape=self.hidden_dim,
|
||||
eps=1e-12,
|
||||
elementwise_affine=True))
|
||||
self.mlm_bias = nn.Parameter(
|
||||
torch.zeros(self.num_token_embeddings))
|
||||
|
||||
self.pooler = nn.Sequential(
|
||||
nn.Linear(self.hidden_dim, self.hidden_dim), nn.Tanh())
|
||||
|
||||
if self.with_query_bow or self.with_resp_bow:
|
||||
self.bow_predictor = nn.Linear(
|
||||
self.hidden_dim, self.num_token_embeddings, bias=False)
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.bce_loss = nn.BCELoss(reduction='none')
|
||||
self.nll_loss = nn.NLLLoss(
|
||||
ignore_index=self.padding_idx, reduction='none')
|
||||
self._create_parameters()
|
||||
|
||||
self.max_grad_norm = config.Model.max_grad_norm
|
||||
if self.max_grad_norm is not None:
|
||||
self.grad_clip = self.max_grad_norm
|
||||
else:
|
||||
self.grad_clip = None
|
||||
self.weight_decay = config.Model.weight_decay
|
||||
|
||||
if self.use_gpu:
|
||||
self.cuda()
|
||||
|
||||
return
|
||||
|
||||
def _create_parameters(self):
|
||||
""" Create model's paramters. """
|
||||
sequence_mask = np.tri(
|
||||
self.num_pos_embeddings,
|
||||
self.num_pos_embeddings,
|
||||
dtype=self._dtype)
|
||||
self.sequence_mask = torch.tensor(sequence_mask)
|
||||
return
|
||||
|
||||
def _create_mask(self,
|
||||
input_mask,
|
||||
append_head=False,
|
||||
auto_regressive=False):
|
||||
"""
|
||||
Create attention mask.
|
||||
创建从序列形式到矩阵形式的mask:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len]
|
||||
mask除了要考虑attention mask(自回归),还需要考虑pad的mask(自回归和双向)
|
||||
注:
|
||||
1. 一个句子中的非<pad>词看整个句子,该句中只有<pad>词才被mask
|
||||
2. 一个句子中的<pad>词看整个句子,该句的所有词都应该被mask
|
||||
|
||||
@param : input_mask
|
||||
@type : Variable(shape: [batch_size, max_seq_len])
|
||||
|
||||
@param : auto_regressive
|
||||
@type : bool
|
||||
"""
|
||||
seq_len = input_mask.shape[1]
|
||||
|
||||
input_mask = input_mask.float()
|
||||
mask1 = input_mask.unsqueeze(-1).repeat(1, 1, seq_len)
|
||||
mask2 = mask1.permute(0, 2, 1)
|
||||
mask = mask1 * mask2
|
||||
|
||||
if append_head:
|
||||
# 拼接上句首位置([M]/z)的mask
|
||||
mask = torch.cat([mask[:, :1, :], mask], dim=1)
|
||||
mask = torch.cat([mask[:, :, :1], mask], dim=2)
|
||||
seq_len += 1
|
||||
|
||||
if auto_regressive:
|
||||
# 将tgt端的<pad> mask和自回归attention mask融合
|
||||
seq_mask = self.sequence_mask[:seq_len, :seq_len]
|
||||
seq_mask = seq_mask.to(mask.device)
|
||||
mask = mask * seq_mask
|
||||
|
||||
mask = 1 - mask
|
||||
return mask
|
||||
|
||||
def _join_mask(self, mask1, mask2):
|
||||
"""
|
||||
Merge source attention mask and target attention mask.
|
||||
合并后的整个mask矩阵可以分为四个部分:左上lu/右上ru/左下lb/右下rb
|
||||
|
||||
@param : mask1 : source attention mask
|
||||
@type : Variable(shape: [batch_size, max_src_len, max_src_len])
|
||||
|
||||
@param : mask1 : target attention mask
|
||||
@type : Variable(shape: [batch_size, max_tgt_len, max_tgt_len])
|
||||
"""
|
||||
batch_size = mask1.shape[0]
|
||||
seq_len1 = mask1.shape[1]
|
||||
seq_len2 = mask2.shape[1]
|
||||
# seq_len = seq_len1 + seq_len2
|
||||
|
||||
mask_lu = mask1
|
||||
mask_ru = torch.ones(batch_size, seq_len1, seq_len2)
|
||||
if self.use_gpu:
|
||||
mask_ru = mask_ru.cuda()
|
||||
mask3 = mask2[:, :, :1].repeat(1, 1, seq_len1)
|
||||
mask4 = mask1[:, :1].repeat(1, seq_len2, 1)
|
||||
mask_lb = mask3 + mask4 - mask3 * mask4
|
||||
mask_rb = mask2
|
||||
mask_u = torch.cat([mask_lu, mask_ru], dim=2)
|
||||
mask_b = torch.cat([mask_lb, mask_rb], dim=2)
|
||||
mask = torch.cat([mask_u, mask_b], dim=1)
|
||||
return mask
|
||||
|
||||
def _mlm_head(self, mlm_embed):
|
||||
mlm_embed = self.mlm_transform(mlm_embed)
|
||||
mlm_logits = torch.matmul(
|
||||
mlm_embed, self.embedder.token_embedding.weight.T) + self.mlm_bias
|
||||
mlm_probs = self.softmax(mlm_logits)
|
||||
return mlm_probs
|
||||
|
||||
def _dec_head(self, dec_embed):
|
||||
dec_logits = torch.matmul(dec_embed,
|
||||
self.embedder.token_embedding.weight.T)
|
||||
dec_probs = self.softmax(dec_logits)
|
||||
return dec_probs
|
||||
|
||||
def _refactor_feature(self, features):
|
||||
features = self.pooler(features) if self.with_pool else features
|
||||
batch_size = features.size(0) // 2
|
||||
features = \
|
||||
torch.cat(
|
||||
[features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)],
|
||||
dim=1
|
||||
)
|
||||
features = F.normalize(features, dim=-1, p=2)
|
||||
return features
|
||||
|
||||
def _encoder_network(self,
|
||||
input_token,
|
||||
input_mask,
|
||||
input_pos=None,
|
||||
input_type=None,
|
||||
input_turn=None):
|
||||
embed = self.embedder(input_token, input_pos, input_type, input_turn)
|
||||
embed = self.embed_layer_norm(embed)
|
||||
mask = self._create_mask(input_mask, auto_regressive=False)
|
||||
|
||||
for layer in self.layers:
|
||||
embed = layer(embed, mask, None)
|
||||
|
||||
return embed
|
||||
|
||||
def _encoder_decoder_network(self,
|
||||
src_token,
|
||||
src_mask,
|
||||
tgt_token,
|
||||
tgt_mask,
|
||||
src_pos=None,
|
||||
src_type=None,
|
||||
src_turn=None,
|
||||
tgt_pos=None,
|
||||
tgt_type=None,
|
||||
tgt_turn=None):
|
||||
src_embed = self.embedder(src_token, src_pos, src_type, src_turn)
|
||||
tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn)
|
||||
embed = torch.cat([src_embed, tgt_embed], dim=1)
|
||||
embed = self.embed_layer_norm(embed)
|
||||
|
||||
enc_mask = self._create_mask(src_mask, auto_regressive=False)
|
||||
dec_mask = self._create_mask(tgt_mask, auto_regressive=True)
|
||||
mask = self._join_mask(enc_mask, dec_mask)
|
||||
|
||||
for layer in self.layers:
|
||||
embed = layer(embed, mask, None)
|
||||
|
||||
tgt_len = tgt_token.shape[1]
|
||||
enc_embed = embed[:, :-tgt_len]
|
||||
dec_embed = embed[:, -tgt_len:]
|
||||
|
||||
return enc_embed, dec_embed
|
||||
|
||||
def _encoder_prompt_decoder_network(self,
|
||||
src_token,
|
||||
src_mask,
|
||||
tgt_token,
|
||||
tgt_mask,
|
||||
prompt_token,
|
||||
prompt_mask,
|
||||
src_pos=None,
|
||||
src_type=None,
|
||||
src_turn=None,
|
||||
tgt_pos=None,
|
||||
tgt_type=None,
|
||||
tgt_turn=None,
|
||||
prompt_pos=None,
|
||||
prompt_type=None,
|
||||
prompt_turn=None):
|
||||
src_embed = self.embedder(src_token, src_pos, src_type, src_turn)
|
||||
tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn)
|
||||
prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type,
|
||||
prompt_turn)
|
||||
|
||||
embed = torch.cat([src_embed, prompt_embed, tgt_embed], dim=1)
|
||||
embed = self.embed_layer_norm(embed)
|
||||
|
||||
enc_mask = self._create_mask(src_mask, auto_regressive=False)
|
||||
dec_mask = self._create_mask(
|
||||
torch.cat([prompt_mask, tgt_mask], dim=1), auto_regressive=True)
|
||||
mask = self._join_mask(enc_mask, dec_mask)
|
||||
|
||||
for layer in self.layers:
|
||||
embed = layer(embed, mask, None)
|
||||
|
||||
src_len = src_token.shape[1]
|
||||
tgt_len = tgt_token.shape[1]
|
||||
enc_embed = embed[:, :src_len]
|
||||
dec_embed = embed[:, -tgt_len:]
|
||||
prompt_embed = embed[:, src_len:-tgt_len]
|
||||
|
||||
return enc_embed, dec_embed, prompt_embed
|
||||
|
||||
def _optimize(self, loss, optimizer=None, lr_scheduler=None):
|
||||
""" Optimize loss function and update model. """
|
||||
assert optimizer is not None
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
if self.grad_clip is not None and self.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
parameters=self.parameters(), max_norm=self.grad_clip)
|
||||
optimizer.step()
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
return
|
||||
|
||||
def _infer(self,
|
||||
inputs,
|
||||
start_id=None,
|
||||
eos_id=None,
|
||||
max_gen_len=None,
|
||||
prev_input=None):
|
||||
""" Real inference process of model. """
|
||||
results = {}
|
||||
return results
|
||||
|
||||
|
||||
UnifiedTransformer.register('UnifiedTransformer')
|
||||
0
modelscope/models/nlp/space/modules/__init__.py
Normal file
0
modelscope/models/nlp/space/modules/__init__.py
Normal file
67
modelscope/models/nlp/space/modules/embedder.py
Normal file
67
modelscope/models/nlp/space/modules/embedder.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Embedder class.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Embedder(nn.Module):
|
||||
"""
|
||||
Composite embedding layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_dim,
|
||||
num_token_embeddings,
|
||||
num_pos_embeddings,
|
||||
num_type_embeddings,
|
||||
num_turn_embeddings,
|
||||
padding_idx=None,
|
||||
dropout=0.1,
|
||||
pos_trainable=False):
|
||||
super(Embedder, self).__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(num_token_embeddings, hidden_dim)
|
||||
self.pos_embedding = nn.Embedding(num_pos_embeddings, hidden_dim)
|
||||
self.pos_embedding.weight.requires_grad = pos_trainable
|
||||
self.type_embedding = nn.Embedding(num_type_embeddings, hidden_dim)
|
||||
self.turn_embedding = nn.Embedding(num_turn_embeddings, hidden_dim)
|
||||
self.dropout_layer = nn.Dropout(p=dropout)
|
||||
|
||||
# follow the default xavier_uniform initializer in paddle version
|
||||
# otherwise, there are bugs for dec_probs computation in weight typing setting
|
||||
# default norm initializer in nn.Embedding in pytorch, which samples larger values
|
||||
nn.init.xavier_uniform_(self.token_embedding.weight)
|
||||
nn.init.xavier_uniform_(self.pos_embedding.weight)
|
||||
nn.init.xavier_uniform_(self.type_embedding.weight)
|
||||
nn.init.xavier_uniform_(self.turn_embedding.weight)
|
||||
return
|
||||
|
||||
def forward(self, token_inp, pos_inp=None, type_inp=None, turn_inp=None):
|
||||
embed = self.token_embedding(token_inp)
|
||||
if pos_inp is not None:
|
||||
embed += self.pos_embedding(pos_inp)
|
||||
if type_inp is not None:
|
||||
embed += self.type_embedding(type_inp)
|
||||
if turn_inp is not None:
|
||||
embed += self.turn_embedding(turn_inp)
|
||||
embed = self.dropout_layer(embed)
|
||||
return embed
|
||||
|
||||
|
||||
def main():
|
||||
import numpy as np
|
||||
|
||||
model = Embedder(10, 20, 20, 20, 20)
|
||||
token_inp = torch.tensor(
|
||||
np.random.randint(0, 19, [10, 10]).astype('int64'))
|
||||
pos_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64'))
|
||||
type_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64'))
|
||||
turn_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64'))
|
||||
out = model(token_inp, pos_inp, type_inp, turn_inp)
|
||||
print(out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
43
modelscope/models/nlp/space/modules/feedforward.py
Normal file
43
modelscope/models/nlp/space/modules/feedforward.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
FeedForward class.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""
|
||||
Positional feed forward layer.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim, inner_dim, dropout):
|
||||
super(FeedForward, self).__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.inner_dim = inner_dim
|
||||
self.linear_hidden = nn.Sequential(
|
||||
nn.Linear(hidden_dim, inner_dim), nn.GELU())
|
||||
self.linear_out = nn.Linear(inner_dim, hidden_dim)
|
||||
self.dropout_layer = nn.Dropout(p=dropout)
|
||||
return
|
||||
|
||||
def forward(self, x):
|
||||
out = self.linear_hidden(x)
|
||||
out = self.dropout_layer(out)
|
||||
out = self.linear_out(out)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
import numpy as np
|
||||
|
||||
model = FeedForward(10, 20, 0.5)
|
||||
inp = np.random.rand(2, 3, 10).astype('float32')
|
||||
inp = torch.tensor(inp)
|
||||
out = model(inp)
|
||||
print(out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
64
modelscope/models/nlp/space/modules/functions.py
Normal file
64
modelscope/models/nlp/space/modules/functions.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Helpful functions.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def unsqueeze(input, dims):
|
||||
""" Implement multi-dimension unsqueeze function. """
|
||||
if isinstance(dims, (list, tuple)):
|
||||
dims = [
|
||||
dim if dim >= 0 else dim + len(input.shape) + 1 for dim in dims
|
||||
]
|
||||
dims = sorted(dims, reverse=True)
|
||||
shape = list(input.shape)
|
||||
for dim in dims:
|
||||
shape.insert(dim, 1)
|
||||
return torch.reshape(input, shape)
|
||||
elif isinstance(dims, int):
|
||||
return input.unsqueeze(dims)
|
||||
else:
|
||||
raise ValueError('Warning: type(dims) must in (list, tuple, int)!')
|
||||
|
||||
|
||||
def gumbel_softmax(input, tau=1, eps=1e-10):
|
||||
""" Basic implement of gumbel_softmax. """
|
||||
U = torch.tensor(np.random.rand(*input.shape))
|
||||
gumbel = 0.0 - torch.log(eps - torch.log(U + eps))
|
||||
y = input + gumbel
|
||||
return F.softmax(y / tau)
|
||||
|
||||
|
||||
def equal(x, y, dtype=None):
|
||||
""" Implement equal in dygraph mode. (paddle) """
|
||||
if dtype is None:
|
||||
dtype = 'float32'
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.numpy()
|
||||
if isinstance(y, torch.Tensor):
|
||||
y = y.numpy()
|
||||
out = np.equal(x, y).astype(dtype)
|
||||
return torch.tensor(out)
|
||||
|
||||
|
||||
def not_equal(x, y, dtype=None):
|
||||
""" Implement not_equal in dygraph mode. (paddle) """
|
||||
return 1 - equal(x, y, dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
a = torch.tensor([[1, 1], [3, 4]])
|
||||
b = torch.tensor([[1, 1], [3, 4]])
|
||||
c = torch.equal(a, a)
|
||||
c1 = equal(a, 3)
|
||||
d = 1 - torch.not_equal(a, 3).float()
|
||||
print(c)
|
||||
print(c1)
|
||||
print(d)
|
||||
e = F.gumbel_softmax(a)
|
||||
f = a.unsqueeze(a)
|
||||
g = unsqueeze(a, dims=[0, 0, 1])
|
||||
print(g, g.shape)
|
||||
109
modelscope/models/nlp/space/modules/multihead_attention.py
Normal file
109
modelscope/models/nlp/space/modules/multihead_attention.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
MultiheadAttention class.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""
|
||||
Multi head attention layer.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim, num_heads, dropout):
|
||||
assert hidden_dim % num_heads == 0
|
||||
super(MultiheadAttention, self).__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.linear_qkv = nn.Linear(hidden_dim, hidden_dim * 3)
|
||||
self.linear_out = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.dropout_layer = nn.Dropout(p=dropout)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
return
|
||||
|
||||
def _split_heads(self, x, is_key=False):
|
||||
x = x.reshape(x.size(0), x.size(1), self.num_heads, self.head_dim)
|
||||
x = x.permute(0, 2, 3, 1) if is_key else x.permute(0, 2, 1, 3)
|
||||
return x
|
||||
|
||||
def _merge_heads(self, x):
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
x = x.reshape(x.size(0), x.size(1), self.hidden_dim)
|
||||
return x
|
||||
|
||||
def _attn(self, query, key, value, mask):
|
||||
# shape: [batch_size, num_head, seq_len, seq_len]
|
||||
scores = torch.matmul(query, key)
|
||||
scores = scores * self.scale
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1)
|
||||
mask = mask.repeat(1, self.num_heads, 1, 1)
|
||||
scores.masked_fill_(
|
||||
mask.bool(),
|
||||
float('-inf')) # scores = (1 - mask) * scores + mask * (-1e10)
|
||||
|
||||
attn = self.softmax(scores)
|
||||
attn = self.dropout_layer(attn)
|
||||
|
||||
if mask is not None:
|
||||
'''
|
||||
mask: [batch size, num_heads, seq_len, seq_len]
|
||||
mask后两维(seq_len, seq_len)矩阵来看,其中有的行可能都是true(1),对应句子中<pad>位看的行
|
||||
导致softmax后该行的每个位置的attn prob都为1/n而非0,所以此处需重置为0
|
||||
|
||||
>>> F.softmax([-1e10, -100, -100])
|
||||
>>> [0.00, 0.50, 0.50]
|
||||
>>> F.softmax([-1e10, -1e10, -1e10])
|
||||
>>> [0.33, 0.33, 0.33]
|
||||
==> [0.00, 0.00, 0.00]
|
||||
'''
|
||||
attn.masked_fill_(mask.bool(), 0.) # attn = (1 - mask) * attn
|
||||
|
||||
out = torch.matmul(attn, value)
|
||||
return out
|
||||
|
||||
def forward(self, inp, mask=None, cache=None):
|
||||
""" Forward process of self attention. """
|
||||
# shape: [batch_size, seq_len, 3 * hidden_dim]
|
||||
qkv = self.linear_qkv(inp)
|
||||
query, key, value = torch.split(qkv, self.hidden_dim, dim=2)
|
||||
|
||||
# shape: [batch_size, num_head, seq_len, head_dim]
|
||||
query = self._split_heads(query)
|
||||
# shape: [batch_size, num_head, head_dim, seq_len]
|
||||
key = self._split_heads(key, is_key=True)
|
||||
# shape: [batch_size, num_head, seq_len, head_dim]
|
||||
value = self._split_heads(value)
|
||||
|
||||
if cache is not None:
|
||||
if 'key' in cache and 'value' in cache:
|
||||
key = torch.cat([cache['key'], key], dim=3)
|
||||
value = torch.cat([cache['value'], value], dim=2)
|
||||
cache['key'] = key
|
||||
cache['value'] = value
|
||||
|
||||
out = self._attn(query, key, value, mask)
|
||||
out = self._merge_heads(out)
|
||||
out = self.linear_out(out)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
import numpy as np
|
||||
|
||||
model = MultiheadAttention(10, 2, 0.5)
|
||||
inp = np.random.rand(2, 3, 10).astype('float32')
|
||||
inp = torch.tensor(inp)
|
||||
mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32')
|
||||
mask = torch.tensor(mask)
|
||||
out = model(inp, mask=mask, cache=None)
|
||||
print(out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
73
modelscope/models/nlp/space/modules/transformer_block.py
Normal file
73
modelscope/models/nlp/space/modules/transformer_block.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
TransformerBlock class.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.models.nlp.space.modules.feedforward import FeedForward
|
||||
from modelscope.models.nlp.space.modules.multihead_attention import \
|
||||
MultiheadAttention
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""
|
||||
Transformer block module.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim, num_heads, dropout, attn_dropout,
|
||||
ff_dropout):
|
||||
super(TransformerBlock, self).__init__()
|
||||
|
||||
self.attn = MultiheadAttention(
|
||||
hidden_dim=hidden_dim, num_heads=num_heads, dropout=attn_dropout)
|
||||
self.attn_norm = nn.LayerNorm(
|
||||
normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True)
|
||||
self.ff = FeedForward(
|
||||
hidden_dim=hidden_dim,
|
||||
inner_dim=4 * hidden_dim,
|
||||
dropout=ff_dropout)
|
||||
self.ff_norm = nn.LayerNorm(
|
||||
normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True)
|
||||
self.dropout_layer = nn.Dropout(p=dropout)
|
||||
return
|
||||
|
||||
def forward(self, inp, mask=None, cache=None):
|
||||
"""
|
||||
Forward process on one transformer layer.
|
||||
|
||||
@param : x
|
||||
@type : Variable(shape: [batch_size, seq_len, hidden_size])
|
||||
|
||||
@param : memory
|
||||
@type : Variable(shape: [batch_size, seq_len, hidden_size])
|
||||
|
||||
@param : mask
|
||||
|
||||
@param : cache
|
||||
"""
|
||||
attn_out = self.attn(inp, mask, cache)
|
||||
attn_out = self.dropout_layer(attn_out)
|
||||
attn_out = self.attn_norm(attn_out + inp)
|
||||
|
||||
ff_out = self.ff(attn_out)
|
||||
ff_out = self.dropout_layer(ff_out)
|
||||
ff_out = self.ff_norm(ff_out + attn_out)
|
||||
|
||||
return ff_out
|
||||
|
||||
|
||||
def main():
|
||||
import numpy as np
|
||||
|
||||
model = TransformerBlock(10, 2, 0.5, 0.5, 0.5)
|
||||
inp = np.random.rand(2, 3, 10).astype('float32')
|
||||
inp = torch.tensor(inp)
|
||||
mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32')
|
||||
mask = torch.tensor(mask)
|
||||
out = model(inp, mask=mask, cache=None)
|
||||
print(out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -4,3 +4,4 @@ from .builder import pipeline
|
||||
from .cv import * # noqa F403
|
||||
from .multi_modal import * # noqa F403
|
||||
from .nlp import * # noqa F403
|
||||
from .nlp.space import * # noqa F403
|
||||
|
||||
@@ -14,7 +14,7 @@ from .outputs import TASK_OUTPUTS
|
||||
from .util import is_model_name
|
||||
|
||||
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
||||
Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
|
||||
Input = Union[str, tuple, dict, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
|
||||
InputModel = Union[str, Model]
|
||||
|
||||
output_keys = [
|
||||
@@ -120,6 +120,7 @@ class Pipeline(ABC):
|
||||
out = self.preprocess(input, **preprocess_params)
|
||||
out = self.forward(out, **forward_params)
|
||||
out = self.postprocess(out, **postprocess_params)
|
||||
|
||||
self._check_output(out)
|
||||
return out
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ from .nli_pipeline import * # noqa F403
|
||||
from .sentence_similarity_pipeline import * # noqa F403
|
||||
from .sentiment_classification_pipeline import * # noqa F403
|
||||
from .sequence_classification_pipeline import * # noqa F403
|
||||
from .space.dialog_intent_prediction_pipeline import * # noqa F403
|
||||
from .space.dialog_modeling_pipeline import * # noqa F403
|
||||
from .text_generation_pipeline import * # noqa F403
|
||||
from .word_segmentation_pipeline import * # noqa F403
|
||||
from .zero_shot_classification_pipeline import * # noqa F403
|
||||
|
||||
0
modelscope/pipelines/nlp/space/__init__.py
Normal file
0
modelscope/pipelines/nlp/space/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from modelscope.models.nlp import DialogIntentModel
|
||||
from modelscope.preprocessors import DialogIntentPredictionPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ...base import Input, Pipeline
|
||||
from ...builder import PIPELINES
|
||||
|
||||
__all__ = ['DialogIntentPredictionPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.dialog_intent_prediction, module_name=r'space-intent')
|
||||
class DialogIntentPredictionPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: DialogIntentModel,
|
||||
preprocessor: DialogIntentPredictionPreprocessor, **kwargs):
|
||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
|
||||
|
||||
Args:
|
||||
model (SequenceClassificationModel): a model instance
|
||||
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
|
||||
"""
|
||||
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.model = model
|
||||
# self.tokenizer = preprocessor.tokenizer
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): _description_
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: the prediction results
|
||||
"""
|
||||
import numpy as np
|
||||
pred = inputs['pred']
|
||||
pos = np.where(pred == np.max(pred))
|
||||
|
||||
result = {'pred': pred, 'label': pos[0]}
|
||||
|
||||
return result
|
||||
46
modelscope/pipelines/nlp/space/dialog_modeling_pipeline.py
Normal file
46
modelscope/pipelines/nlp/space/dialog_modeling_pipeline.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from modelscope.models.nlp import DialogModelingModel
|
||||
from modelscope.preprocessors import DialogModelingPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ...base import Pipeline, Tensor
|
||||
from ...builder import PIPELINES
|
||||
|
||||
__all__ = ['DialogModelingPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.dialog_modeling, module_name=r'space-modeling')
|
||||
class DialogModelingPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: DialogModelingModel,
|
||||
preprocessor: DialogModelingPreprocessor, **kwargs):
|
||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
|
||||
|
||||
Args:
|
||||
model (SequenceClassificationModel): a model instance
|
||||
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
|
||||
"""
|
||||
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.model = model
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): _description_
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: the prediction results
|
||||
"""
|
||||
sys_rsp = self.preprocessor.text_field.tokenizer.convert_ids_to_tokens(
|
||||
inputs['resp'])
|
||||
assert len(sys_rsp) > 2
|
||||
sys_rsp = sys_rsp[1:len(sys_rsp) - 1]
|
||||
# sys_rsp = self.preprocessor.text_field.tokenizer.
|
||||
|
||||
inputs['sys'] = sys_rsp
|
||||
|
||||
return inputs
|
||||
@@ -7,4 +7,6 @@ from .common import Compose
|
||||
from .image import LoadImage, load_image
|
||||
from .multi_model import OfaImageCaptionPreprocessor
|
||||
from .nlp import * # noqa F403
|
||||
from .space.dialog_intent_prediction_preprocessor import * # noqa F403
|
||||
from .space.dialog_modeling_preprocessor import * # noqa F403
|
||||
from .text_to_speech import * # noqa F403
|
||||
|
||||
0
modelscope/preprocessors/space/__init__.py
Normal file
0
modelscope/preprocessors/space/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.preprocessors.space.fields.intent_field import \
|
||||
IntentBPETextField
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Fields
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
from ..base import Preprocessor
|
||||
from ..builder import PREPROCESSORS
|
||||
|
||||
__all__ = ['DialogIntentPredictionPreprocessor']
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-intent')
|
||||
class DialogIntentPredictionPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""preprocess the data via the vocab.txt from the `model_dir` path
|
||||
|
||||
Args:
|
||||
model_dir (str): model path
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.model_dir: str = model_dir
|
||||
self.config = Config.from_file(
|
||||
os.path.join(self.model_dir, 'configuration.json'))
|
||||
self.text_field = IntentBPETextField(
|
||||
self.model_dir, config=self.config)
|
||||
|
||||
@type_assert(object, str)
|
||||
def __call__(self, data: str) -> Dict[str, Any]:
|
||||
"""process the raw input data
|
||||
|
||||
Args:
|
||||
data (str): a sentence
|
||||
Example:
|
||||
'you are so handsome.'
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: the preprocessed data
|
||||
"""
|
||||
samples = self.text_field.preprocessor([data])
|
||||
samples, _ = self.text_field.collate_fn_multi_turn(samples)
|
||||
|
||||
return samples
|
||||
@@ -0,0 +1,51 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.preprocessors.space.fields.gen_field import \
|
||||
MultiWOZBPETextField
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Fields, InputFields
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
from ..base import Preprocessor
|
||||
from ..builder import PREPROCESSORS
|
||||
|
||||
__all__ = ['DialogModelingPreprocessor']
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-modeling')
|
||||
class DialogModelingPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""preprocess the data via the vocab.txt from the `model_dir` path
|
||||
|
||||
Args:
|
||||
model_dir (str): model path
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.model_dir: str = model_dir
|
||||
self.config = Config.from_file(
|
||||
os.path.join(self.model_dir, 'configuration.json'))
|
||||
self.text_field = MultiWOZBPETextField(
|
||||
self.model_dir, config=self.config)
|
||||
|
||||
@type_assert(object, Dict)
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""process the raw input data
|
||||
|
||||
Args:
|
||||
data (str): a sentence
|
||||
Example:
|
||||
'you are so handsome.'
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: the preprocessed data
|
||||
"""
|
||||
|
||||
user_ids = self.text_field.get_ids(data['user_input'])
|
||||
data['user'] = user_ids
|
||||
|
||||
return data
|
||||
0
modelscope/preprocessors/space/fields/__init__.py
Normal file
0
modelscope/preprocessors/space/fields/__init__.py
Normal file
1522
modelscope/preprocessors/space/fields/dst_processors.py
Normal file
1522
modelscope/preprocessors/space/fields/dst_processors.py
Normal file
File diff suppressed because it is too large
Load Diff
687
modelscope/preprocessors/space/fields/gen_field.py
Normal file
687
modelscope/preprocessors/space/fields/gen_field.py
Normal file
@@ -0,0 +1,687 @@
|
||||
"""
|
||||
Field class
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.preprocessors.space.tokenizer import Tokenizer
|
||||
from modelscope.utils.nlp.space import ontology, utils
|
||||
from modelscope.utils.nlp.space.db_ops import MultiWozDB
|
||||
from modelscope.utils.nlp.space.utils import list2np
|
||||
|
||||
|
||||
class BPETextField(object):
|
||||
|
||||
pad_token = '[PAD]'
|
||||
bos_token = '[BOS]'
|
||||
eos_token = '[EOS]'
|
||||
unk_token = '[UNK]'
|
||||
sos_u_token = '<sos_u>'
|
||||
eos_u_token = '<eos_u>'
|
||||
sos_b_token = '<sos_b>'
|
||||
eos_b_token = '<eos_b>'
|
||||
sos_d_token = '<sos_d>'
|
||||
eos_d_token = '<eos_d>'
|
||||
sos_a_token = '<sos_a>'
|
||||
eos_a_token = '<eos_a>'
|
||||
sos_db_token = '<sos_db>'
|
||||
eos_db_token = '<eos_db>'
|
||||
sos_r_token = '<sos_r>'
|
||||
eos_r_token = '<eos_r>'
|
||||
|
||||
@property
|
||||
def bot_id(self):
|
||||
"""
|
||||
用于区分user和bot两个角色
|
||||
1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings'
|
||||
"""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def user_id(self):
|
||||
"""
|
||||
用于区分user和bot两个角色
|
||||
1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings'
|
||||
"""
|
||||
return 1
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.vocab_size
|
||||
|
||||
@property
|
||||
def num_specials(self):
|
||||
return len(self.tokenizer.special_tokens)
|
||||
|
||||
@property
|
||||
def pad_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.pad_token])[0]
|
||||
|
||||
@property
|
||||
def bos_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.bos_token])[0]
|
||||
|
||||
@property
|
||||
def eos_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.eos_token])[0]
|
||||
|
||||
@property
|
||||
def unk_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.unk_token])[0]
|
||||
|
||||
@property
|
||||
def sos_u_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.sos_u_token])[0]
|
||||
|
||||
@property
|
||||
def eos_u_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.eos_u_token])[0]
|
||||
|
||||
@property
|
||||
def sos_b_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.sos_b_token])[0]
|
||||
|
||||
@property
|
||||
def eos_b_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.eos_b_token])[0]
|
||||
|
||||
@property
|
||||
def sos_db_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.sos_db_token])[0]
|
||||
|
||||
@property
|
||||
def eos_db_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.eos_db_token])[0]
|
||||
|
||||
@property
|
||||
def sos_a_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.sos_a_token])[0]
|
||||
|
||||
@property
|
||||
def eos_a_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.eos_a_token])[0]
|
||||
|
||||
@property
|
||||
def sos_r_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.sos_r_token])[0]
|
||||
|
||||
@property
|
||||
def eos_r_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.eos_r_token])[0]
|
||||
|
||||
@property
|
||||
def sos_d_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.sos_d_token])[0]
|
||||
|
||||
@property
|
||||
def eos_d_id(self):
|
||||
return self.tokenizer.convert_tokens_to_ids([self.eos_d_token])[0]
|
||||
|
||||
def __init__(self, config):
|
||||
self.gpu = 0
|
||||
self.tokenizer = None
|
||||
self.vocab = None
|
||||
self.db = None
|
||||
self.set_stats = {}
|
||||
|
||||
self.prompt_num_for_understand = config.BPETextField.prompt_num_for_understand
|
||||
self.prompt_num_for_policy = config.BPETextField.prompt_num_for_policy
|
||||
self.understand_tokens = ontology.get_understand_tokens(
|
||||
self.prompt_num_for_understand)
|
||||
self.policy_tokens = ontology.get_policy_tokens(
|
||||
self.prompt_num_for_policy)
|
||||
|
||||
self.with_query_bow = config.BPETextField.with_query_bow
|
||||
self.understand = config.BPETextField.understand
|
||||
self.policy = config.BPETextField.policy
|
||||
|
||||
self.batch_size = config.Trainer.batch_size
|
||||
self.filtered = config.BPETextField.filtered
|
||||
self.max_len = config.BPETextField.max_len
|
||||
self.min_utt_len = config.BPETextField.min_utt_len
|
||||
self.max_utt_len = config.BPETextField.max_utt_len
|
||||
self.min_ctx_turn = config.BPETextField.min_ctx_turn
|
||||
self.max_ctx_turn = config.BPETextField.max_ctx_turn - 1 # subtract reply turn
|
||||
|
||||
self.use_true_prev_bspn = config.Generator.use_true_prev_bspn
|
||||
self.use_true_prev_aspn = config.Generator.use_true_prev_aspn
|
||||
self.use_true_db_pointer = config.Generator.use_true_db_pointer
|
||||
self.use_true_prev_resp = config.Generator.use_true_prev_resp
|
||||
self.use_true_curr_bspn = config.Generator.use_true_curr_bspn
|
||||
self.use_true_curr_aspn = config.Generator.use_true_curr_aspn
|
||||
self.use_all_previous_context = config.Generator.use_all_previous_context
|
||||
self.use_true_bspn_for_ctr_eval = config.Generator.use_true_bspn_for_ctr_eval
|
||||
self.use_true_domain_for_ctr_eval = config.Generator.use_true_domain_for_ctr_eval
|
||||
|
||||
def collate_fn_multi_turn(self, samples):
|
||||
batch_size = len(samples)
|
||||
batch = {}
|
||||
|
||||
src = [sp['src'][-self.max_ctx_turn:] for sp in samples]
|
||||
query_token, src_token, src_pos, src_turn, src_role = [], [], [], [], []
|
||||
for utts in src:
|
||||
query_token.append(utts[-1])
|
||||
utt_lens = [len(utt) for utt in utts]
|
||||
|
||||
# Token ids
|
||||
src_token.append(list(chain(*utts))[-self.max_len:])
|
||||
|
||||
# Position ids
|
||||
pos = [list(range(utt_len)) for utt_len in utt_lens]
|
||||
src_pos.append(list(chain(*pos))[-self.max_len:])
|
||||
|
||||
# Turn ids
|
||||
turn = [[len(utts) - i] * l for i, l in enumerate(utt_lens)]
|
||||
src_turn.append(list(chain(*turn))[-self.max_len:])
|
||||
|
||||
# Role ids
|
||||
role = [
|
||||
[self.bot_id if (len(utts) - i) % 2 == 0 else self.user_id] * l
|
||||
for i, l in enumerate(utt_lens)
|
||||
]
|
||||
src_role.append(list(chain(*role))[-self.max_len:])
|
||||
|
||||
# src端序列和tgt端序列需要分开pad,以保证解码时第一个词对齐
|
||||
src_token = list2np(src_token, padding=self.pad_id)
|
||||
src_pos = list2np(src_pos, padding=self.pad_id)
|
||||
src_turn = list2np(src_turn, padding=self.pad_id)
|
||||
src_role = list2np(src_role, padding=self.pad_id)
|
||||
batch['src_token'] = src_token
|
||||
batch['src_pos'] = src_pos
|
||||
batch['src_type'] = src_role
|
||||
batch['src_turn'] = src_turn
|
||||
batch['src_mask'] = (src_token != self.pad_id).astype('int64')
|
||||
|
||||
if self.with_query_bow:
|
||||
query_token = list2np(query_token, padding=self.pad_id)
|
||||
batch['query_token'] = query_token
|
||||
batch['query_mask'] = (query_token != self.pad_id).astype('int64')
|
||||
|
||||
if self.understand_ids and self.understand:
|
||||
understand = [self.understand_ids for _ in samples]
|
||||
understand_token = np.array(understand).astype('int64')
|
||||
batch['understand_token'] = understand_token
|
||||
batch['understand_mask'] = \
|
||||
(understand_token != self.pad_id).astype('int64')
|
||||
|
||||
if self.policy_ids and self.policy:
|
||||
policy = [self.policy_ids for _ in samples]
|
||||
policy_token = np.array(policy).astype('int64')
|
||||
batch['policy_token'] = policy_token
|
||||
batch['policy_mask'] = \
|
||||
(policy_token != self.pad_id).astype('int64')
|
||||
|
||||
if 'tgt' in samples[0]:
|
||||
tgt = [sp['tgt'] for sp in samples]
|
||||
|
||||
# Token ids & Label ids
|
||||
tgt_token = list2np(tgt, padding=self.pad_id)
|
||||
|
||||
# Position ids
|
||||
tgt_pos = np.zeros_like(tgt_token)
|
||||
tgt_pos[:] = np.arange(tgt_token.shape[1], dtype=tgt_token.dtype)
|
||||
|
||||
# Turn ids
|
||||
tgt_turn = np.zeros_like(tgt_token)
|
||||
|
||||
# Role ids
|
||||
tgt_role = np.full_like(tgt_token, self.bot_id)
|
||||
|
||||
batch['tgt_token'] = tgt_token
|
||||
batch['tgt_pos'] = tgt_pos
|
||||
batch['tgt_type'] = tgt_role
|
||||
batch['tgt_turn'] = tgt_turn
|
||||
batch['tgt_mask'] = (tgt_token != self.pad_id).astype('int64')
|
||||
|
||||
return batch, batch_size
|
||||
|
||||
def _bucket_by_turn(self, encoded_data):
|
||||
turn_bucket = {}
|
||||
for dial in encoded_data:
|
||||
turn_len = len(dial)
|
||||
if turn_len not in turn_bucket:
|
||||
turn_bucket[turn_len] = []
|
||||
turn_bucket[turn_len].append(dial)
|
||||
return OrderedDict(sorted(turn_bucket.items(), key=lambda i: i[0]))
|
||||
|
||||
def _construct_mini_batch(self, data):
|
||||
all_batches = []
|
||||
batch = []
|
||||
for dial in data:
|
||||
batch.append(dial)
|
||||
if len(batch) == self.batch_size:
|
||||
# print('batch size: %d, batch num +1'%(len(batch)))
|
||||
all_batches.append(batch)
|
||||
batch = []
|
||||
# if remainder > 1/2 batch_size, just put them in the previous batch, otherwise form a new batch
|
||||
# print('last batch size: %d, batch num +1'%(len(batch)))
|
||||
# if (len(batch) % len(cfg.cuda_device)) != 0:
|
||||
# batch = batch[:-(len(batch) % len(cfg.cuda_device))]
|
||||
# TODO deal with deleted data
|
||||
if self.gpu <= 1:
|
||||
if len(batch) > 0.5 * self.batch_size:
|
||||
all_batches.append(batch)
|
||||
elif len(all_batches):
|
||||
all_batches[-1].extend(batch)
|
||||
else:
|
||||
all_batches.append(batch)
|
||||
|
||||
return all_batches
|
||||
|
||||
def transpose_batch(self, batch):
|
||||
dial_batch = []
|
||||
turn_num = len(batch[0])
|
||||
for turn in range(turn_num):
|
||||
turn_l = {}
|
||||
for dial in batch:
|
||||
this_turn = dial[turn]
|
||||
for k in this_turn:
|
||||
if k not in turn_l:
|
||||
turn_l[k] = []
|
||||
turn_l[k].append(this_turn[k])
|
||||
dial_batch.append(turn_l)
|
||||
return dial_batch
|
||||
|
||||
def get_eval_data(self, set_name='dev'):
|
||||
name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev}
|
||||
dial = name_to_set[set_name]
|
||||
|
||||
if set_name not in self.set_stats:
|
||||
self.set_stats[set_name] = {}
|
||||
num_turns = 0
|
||||
num_dials = len(dial)
|
||||
for d in dial:
|
||||
num_turns += len(d)
|
||||
|
||||
self.set_stats[set_name]['num_turns'] = num_turns
|
||||
self.set_stats[set_name]['num_dials'] = num_dials
|
||||
|
||||
return dial
|
||||
|
||||
def get_nontranspose_data_iterator(self, all_batches):
|
||||
for i, batch in enumerate(all_batches):
|
||||
yield batch
|
||||
|
||||
def get_data_iterator(self, all_batches):
|
||||
for i, batch in enumerate(all_batches):
|
||||
yield self.transpose_batch(batch)
|
||||
|
||||
|
||||
class MultiWOZBPETextField(BPETextField):
|
||||
|
||||
def __init__(self, model_dir, config):
|
||||
super(MultiWOZBPETextField, self).__init__(config)
|
||||
import spacy
|
||||
self.nlp = spacy.load('en_core_web_sm')
|
||||
|
||||
self.db = MultiWozDB(
|
||||
model_dir, {
|
||||
'attraction': 'db/attraction_db_processed.json',
|
||||
'hospital': 'db/hospital_db_processed.json',
|
||||
'hotel': 'db/hotel_db_processed.json',
|
||||
'police': 'db/police_db_processed.json',
|
||||
'restaurant': 'db/restaurant_db_processed.json',
|
||||
'taxi': 'db/taxi_db_processed.json',
|
||||
'train': 'db/train_db_processed.json',
|
||||
})
|
||||
self._build_vocab(model_dir)
|
||||
|
||||
special_tokens = [
|
||||
self.pad_token, self.bos_token, self.eos_token, self.unk_token
|
||||
]
|
||||
special_tokens.extend(self.add_sepcial_tokens())
|
||||
self.tokenizer = Tokenizer(
|
||||
vocab_path=os.path.join(model_dir, 'vocab.txt'),
|
||||
special_tokens=special_tokens,
|
||||
tokenizer_type=config.BPETextField.tokenizer_type)
|
||||
self.understand_ids = self.tokenizer.convert_tokens_to_ids(
|
||||
self.understand_tokens)
|
||||
self.policy_ids = self.tokenizer.convert_tokens_to_ids(
|
||||
self.policy_tokens)
|
||||
|
||||
return
|
||||
|
||||
def get_ids(self, data: str):
|
||||
result = [self.sos_u_id] + self.tokenizer.convert_tokens_to_ids(
|
||||
self.tokenizer.tokenize(
|
||||
self._get_convert_str(data))) + [self.eos_u_id]
|
||||
return result
|
||||
|
||||
def inverse_transpose_turn(self, turn_list):
|
||||
"""
|
||||
eval, one dialog at a time
|
||||
"""
|
||||
dialogs = {}
|
||||
turn_num = len(turn_list)
|
||||
dial_id = turn_list[0]['dial_id']
|
||||
dialogs[dial_id] = []
|
||||
for turn_idx in range(turn_num):
|
||||
dial_turn = {}
|
||||
turn = turn_list[turn_idx]
|
||||
for key, value in turn.items():
|
||||
if key == 'dial_id':
|
||||
continue
|
||||
if key == 'pointer' and self.db is not None:
|
||||
turn_domain = turn['turn_domain'][-1]
|
||||
value = self.db.pointerBack(value, turn_domain)
|
||||
dial_turn[key] = value
|
||||
dialogs[dial_id].append(dial_turn)
|
||||
return dialogs
|
||||
|
||||
def inverse_transpose_batch(self, turn_batch_list):
|
||||
"""
|
||||
:param turn_batch_list: list of transpose dial batch
|
||||
"""
|
||||
dialogs = {}
|
||||
total_turn_num = len(turn_batch_list)
|
||||
# initialize
|
||||
for idx_in_batch, dial_id in enumerate(turn_batch_list[0]['dial_id']):
|
||||
dialogs[dial_id] = []
|
||||
for turn_n in range(total_turn_num):
|
||||
dial_turn = {}
|
||||
turn_batch = turn_batch_list[turn_n]
|
||||
for key, v_list in turn_batch.items():
|
||||
if key == 'dial_id':
|
||||
continue
|
||||
value = v_list[idx_in_batch]
|
||||
if key == 'pointer' and self.db is not None:
|
||||
turn_domain = turn_batch['turn_domain'][idx_in_batch][
|
||||
-1]
|
||||
value = self.db.pointerBack(value, turn_domain)
|
||||
dial_turn[key] = value
|
||||
dialogs[dial_id].append(dial_turn)
|
||||
return dialogs
|
||||
|
||||
def get_batches(self, set_name):
|
||||
"""
|
||||
compute dataset stats.
|
||||
"""
|
||||
global dia_count
|
||||
log_str = ''
|
||||
name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev}
|
||||
dial = name_to_set[set_name]
|
||||
turn_bucket = self._bucket_by_turn(dial)
|
||||
# self._shuffle_turn_bucket(turn_bucket)
|
||||
all_batches = []
|
||||
|
||||
if set_name not in self.set_stats:
|
||||
self.set_stats[set_name] = {}
|
||||
num_training_steps = 0
|
||||
num_turns = 0
|
||||
num_dials = 0
|
||||
|
||||
for k in turn_bucket:
|
||||
if set_name != 'test' and k == 1 or k >= 17:
|
||||
continue
|
||||
batches = self._construct_mini_batch(turn_bucket[k])
|
||||
try:
|
||||
log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % (
|
||||
k, len(turn_bucket[k]), len(batches), len(batches[-1]))
|
||||
except Exception:
|
||||
log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % (
|
||||
k, len(turn_bucket[k]), len(batches), 0.0)
|
||||
# print("turn num:%d, dial num:v%d, batch num: %d, "%(k, len(turn_bucket[k]), len(batches)))
|
||||
num_training_steps += k * len(batches)
|
||||
num_turns += k * len(turn_bucket[k])
|
||||
num_dials += len(turn_bucket[k])
|
||||
all_batches += batches
|
||||
log_str += 'total batch num: %d\n' % len(all_batches)
|
||||
# print('total batch num: %d'%len(all_batches))
|
||||
# print('dialog count: %d'%dia_count)
|
||||
# return all_batches
|
||||
|
||||
# log stats
|
||||
# logging.info(log_str)
|
||||
# cfg.num_training_steps = num_training_steps * cfg.epoch_num
|
||||
self.set_stats[set_name][
|
||||
'num_training_steps_per_epoch'] = num_training_steps # turn-level的steps
|
||||
self.set_stats[set_name]['num_turns'] = num_turns
|
||||
self.set_stats[set_name]['num_dials'] = num_dials
|
||||
|
||||
if set_name == 'train':
|
||||
random.shuffle(all_batches)
|
||||
return all_batches
|
||||
|
||||
def add_sepcial_tokens(self):
|
||||
"""
|
||||
add special tokens to gpt tokenizer
|
||||
serves a similar role of Vocab.construt()
|
||||
make a dict of special tokens
|
||||
"""
|
||||
special_tokens = []
|
||||
prompt_tokens = self.understand_tokens + self.policy_tokens
|
||||
special_tokens.extend(
|
||||
ontology.get_special_tokens(other_tokens=prompt_tokens))
|
||||
|
||||
for word in ontology.all_domains + ['general']:
|
||||
word = '[' + word + ']'
|
||||
special_tokens.append(word)
|
||||
for word in ontology.all_acts:
|
||||
word = '[' + word + ']'
|
||||
special_tokens.append(word)
|
||||
for word in self.vocab._word2idx.keys():
|
||||
if word.startswith('[value_') and word.endswith(']'):
|
||||
special_tokens.append(word)
|
||||
|
||||
return special_tokens
|
||||
|
||||
def _build_vocab(self, model_dir: str):
|
||||
self.vocab = utils.MultiWOZVocab(3000)
|
||||
vp = os.path.join('{}/vocab'.format(model_dir))
|
||||
self.vocab.load_vocab(vp)
|
||||
return self.vocab.vocab_size
|
||||
|
||||
def _get_convert_str(self, sent):
|
||||
assert isinstance(sent, str)
|
||||
return ' '.join([
|
||||
self.tokenizer.spec_convert_dict.get(tok, tok)
|
||||
for tok in sent.split()
|
||||
])
|
||||
|
||||
def bspan_to_DBpointer(self, bspan, turn_domain):
|
||||
constraint_dict = self.bspan_to_constraint_dict(bspan)
|
||||
# print(constraint_dict)
|
||||
matnums = self.db.get_match_num(constraint_dict)
|
||||
match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1]
|
||||
match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom
|
||||
match = matnums[match_dom]
|
||||
# vector = self.db.addDBPointer(match_dom, match)
|
||||
vector = self.db.addDBIndicator(match_dom, match)
|
||||
return vector
|
||||
|
||||
def bspan_to_constraint_dict(self, bspan, bspn_mode='bspn'):
|
||||
"""
|
||||
['[hotel]', 'pricerange', 'cheap', 'type', 'hotel'] -> {'hotel': {'pricerange': 'cheap', 'type': 'hotel'}}
|
||||
"""
|
||||
bspan = bspan.split() if isinstance(bspan, str) else bspan
|
||||
constraint_dict = {}
|
||||
domain = None
|
||||
conslen = len(bspan)
|
||||
for idx, cons in enumerate(bspan):
|
||||
cons = self.vocab.decode(cons) if type(cons) is not str else cons
|
||||
if cons == '<eos_b>':
|
||||
break
|
||||
if '[' in cons:
|
||||
if cons[1:-1] not in ontology.all_domains:
|
||||
continue
|
||||
domain = cons[1:-1]
|
||||
elif cons in ontology.get_slot:
|
||||
if domain is None:
|
||||
continue
|
||||
if cons == 'people':
|
||||
# handle confusion of value name "people's portraits..." and slot people
|
||||
try:
|
||||
ns = bspan[idx + 1]
|
||||
ns = self.vocab.decode(ns) if type(
|
||||
ns) is not str else ns
|
||||
if ns == "'s":
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
if not constraint_dict.get(domain):
|
||||
constraint_dict[domain] = {}
|
||||
if bspn_mode == 'bsdx':
|
||||
constraint_dict[domain][cons] = 1
|
||||
continue
|
||||
vidx = idx + 1
|
||||
if vidx == conslen:
|
||||
break
|
||||
vt_collect = []
|
||||
vt = bspan[vidx]
|
||||
vt = self.vocab.decode(vt) if type(vt) is not str else vt
|
||||
while vidx < conslen and vt != '<eos_b>' and '[' not in vt and vt not in ontology.get_slot:
|
||||
vt_collect.append(vt)
|
||||
vidx += 1
|
||||
if vidx == conslen:
|
||||
break
|
||||
vt = bspan[vidx]
|
||||
vt = self.vocab.decode(vt) if type(vt) is not str else vt
|
||||
if vt_collect:
|
||||
constraint_dict[domain][cons] = ' '.join(vt_collect)
|
||||
|
||||
return constraint_dict
|
||||
|
||||
def convert_batch_turn(self, turn_batch, pv_batch, first_turn=False):
|
||||
"""
|
||||
URURU:这里的含义是指轮级别的训练(数据整理),区别于session级别的训练方式(convert_batch_session);
|
||||
但不同于eval时的含义,eval时二者都是逐轮依次生成的,那时URURU的含义请见相关的函数注释;
|
||||
|
||||
convert the current and the last turn
|
||||
concat [U_0,R_0,...,U_{t-1}, R_{t-1}, U_t, B_t, A_t, R_t]
|
||||
firts turn: [U_t, B_t, A_t, R_t]
|
||||
try: [user, bspn, db, aspn, resp]
|
||||
|
||||
"""
|
||||
inputs = []
|
||||
if first_turn:
|
||||
batch_zipped = zip(turn_batch['user'], turn_batch['bspn'],
|
||||
turn_batch['db'], turn_batch['aspn'],
|
||||
turn_batch['resp'])
|
||||
for u, b, db, a, r in batch_zipped:
|
||||
if self.use_true_curr_bspn:
|
||||
src = [u + b + db]
|
||||
tgt = a + r
|
||||
else:
|
||||
src = [u]
|
||||
tgt = b + db + a + r
|
||||
inputs.append({'src': src, 'tgt': tgt})
|
||||
pv = [src[-1], tgt]
|
||||
pv_batch.append(pv)
|
||||
else:
|
||||
batch_zipped = zip(pv_batch, turn_batch['user'],
|
||||
turn_batch['bspn'], turn_batch['db'],
|
||||
turn_batch['aspn'], turn_batch['resp'])
|
||||
for i, (pv, u, b, db, a, r) in enumerate(batch_zipped):
|
||||
if self.use_true_curr_bspn:
|
||||
src = pv + [u + b + db]
|
||||
tgt = a + r
|
||||
else:
|
||||
src = pv + [u]
|
||||
tgt = b + db + a + r
|
||||
inputs.append({'src': src, 'tgt': tgt})
|
||||
pv = [src[-1], tgt]
|
||||
pv_batch[i].extend(pv)
|
||||
|
||||
return inputs, pv_batch
|
||||
|
||||
def wrap_result_lm(self, result_dict, eos_syntax=None):
|
||||
results = []
|
||||
eos_syntax = ontology.eos_tokens if not eos_syntax else eos_syntax
|
||||
sos_syntax = ontology.sos_tokens
|
||||
# ground truth bs, as, ds.. generate response
|
||||
field = [
|
||||
'dial_id', 'turn_num', 'user', 'bspn_gen', 'bsdx', 'resp_gen',
|
||||
'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'bspn', 'pointer',
|
||||
'qspn_gen', 'qspn'
|
||||
]
|
||||
|
||||
for dial_id, turns in result_dict.items():
|
||||
entry = {'dial_id': dial_id, 'trun_num': len(turns)}
|
||||
for f in field[2:]:
|
||||
entry[f] = '' # TODO ???
|
||||
results.append(entry)
|
||||
for turn_idx, turn in enumerate(turns):
|
||||
entry = {'dial_id': dial_id}
|
||||
for key in field:
|
||||
if key in ['dial_id']:
|
||||
continue
|
||||
v = turn.get(key, '')
|
||||
if key == 'turn_domain':
|
||||
v = ' '.join(v)
|
||||
|
||||
if key in eos_syntax and v != '':
|
||||
# remove eos tokens
|
||||
v = self.tokenizer.decode(v)
|
||||
v = v.split()
|
||||
# remove eos/sos in span
|
||||
if eos_syntax[key] in v:
|
||||
v.remove(eos_syntax[key])
|
||||
if sos_syntax[key] in v:
|
||||
v.remove(sos_syntax[key])
|
||||
v = ' '.join(v)
|
||||
else:
|
||||
pass # v = v
|
||||
entry[key] = v
|
||||
|
||||
results.append(entry)
|
||||
|
||||
return results, field
|
||||
|
||||
def convert_turn_eval(self, turn, pv_turn, first_turn=False):
|
||||
"""
|
||||
input: [all previous ubar, U_t, B_t, A_t] predict R_t
|
||||
firts turn: [U_t, B_t, A_t] predict R_t
|
||||
|
||||
regarding the context, all previous ubar is too slow, try the previous ubar
|
||||
"""
|
||||
inputs = {}
|
||||
|
||||
context_list = []
|
||||
prompt_id = None
|
||||
if self.use_true_curr_bspn:
|
||||
if self.use_true_curr_aspn: # only predict resp
|
||||
context_list = ['user', 'bspn', 'db', 'aspn']
|
||||
prompt_id = self.sos_r_id
|
||||
else: # predicted aspn
|
||||
context_list = ['user', 'bspn', 'db']
|
||||
prompt_id = self.sos_a_id
|
||||
else: # predict bspn aspn resp. db are not predicted. this part tbd.
|
||||
context_list = ['user']
|
||||
prompt_id = self.sos_b_id
|
||||
|
||||
if first_turn:
|
||||
context = []
|
||||
for c in context_list:
|
||||
context += turn[c]
|
||||
|
||||
inputs['src'] = [context]
|
||||
inputs['labels'] = [context]
|
||||
else:
|
||||
context = []
|
||||
for c in context_list:
|
||||
context += turn[c]
|
||||
|
||||
if self.use_true_curr_bspn:
|
||||
pv_context = pv_turn['labels'] + [
|
||||
pv_turn['aspn'] + pv_turn['resp']
|
||||
]
|
||||
else:
|
||||
pv_info = pv_turn['bspn'] + pv_turn['db'] + pv_turn[
|
||||
'aspn'] + pv_turn['resp']
|
||||
pv_context = pv_turn['labels'] + [pv_info]
|
||||
|
||||
# prompt response, add sos_r
|
||||
inputs['src'] = pv_context + [context]
|
||||
|
||||
if self.use_all_previous_context:
|
||||
inputs['labels'] = pv_context + [
|
||||
context
|
||||
] # use all previous ubar history
|
||||
else:
|
||||
inputs['labels'] = [context] # use previous turn
|
||||
|
||||
return inputs, prompt_id
|
||||
1093
modelscope/preprocessors/space/fields/intent_field.py
Normal file
1093
modelscope/preprocessors/space/fields/intent_field.py
Normal file
File diff suppressed because it is too large
Load Diff
672
modelscope/preprocessors/space/tokenizer.py
Normal file
672
modelscope/preprocessors/space/tokenizer.py
Normal file
@@ -0,0 +1,672 @@
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import unicodedata
|
||||
|
||||
import json
|
||||
import regex as re
|
||||
|
||||
|
||||
def clean_string(string):
|
||||
replace_mp = {
|
||||
' - ': '-',
|
||||
" ' ": "'",
|
||||
" n't": "n't",
|
||||
" 'm": "'m",
|
||||
' do not': " don't",
|
||||
" 's": "'s",
|
||||
" 've": "'ve",
|
||||
" 're": "'re"
|
||||
}
|
||||
for k, v in replace_mp.items():
|
||||
string = string.replace(k, v)
|
||||
return string
|
||||
|
||||
|
||||
class Tokenizer(object):
|
||||
|
||||
def __init__(self, vocab_path, special_tokens=[], tokenizer_type='Bert'):
|
||||
self.tokenizer_type = tokenizer_type
|
||||
if tokenizer_type == 'Bert':
|
||||
self.spec_convert_dict = {
|
||||
'[BOS]': '[unused0]',
|
||||
'[EOS]': '[unused1]'
|
||||
}
|
||||
for token in special_tokens:
|
||||
if token not in self.spec_convert_dict and token not in [
|
||||
'[PAD]', '[UNK]'
|
||||
]:
|
||||
self.spec_convert_dict[
|
||||
token] = f'[unused{len(self.spec_convert_dict)}]'
|
||||
self.spec_revert_dict = {
|
||||
v: k
|
||||
for k, v in self.spec_convert_dict.items()
|
||||
}
|
||||
special_tokens = [
|
||||
self.spec_convert_dict.get(tok, tok) for tok in special_tokens
|
||||
]
|
||||
self.special_tokens = ('[UNK]', '[SEP]', '[PAD]', '[CLS]',
|
||||
'[MASK]')
|
||||
self.special_tokens += tuple(x for x in special_tokens
|
||||
if x not in self.special_tokens)
|
||||
|
||||
self._tokenizer = BertTokenizer(
|
||||
vocab_path, never_split=self.special_tokens)
|
||||
for tok in self.special_tokens:
|
||||
'''
|
||||
需要先保证special_tokens在词表中,这里设置special_tokens的目的是为了这些词能够完整占位,不再切分为子词;
|
||||
若不在词表中,可以使用词表中的[unused]符号进行转换:spec_convert_dict;
|
||||
'''
|
||||
assert tok in self._tokenizer.vocab, f"special token '{tok}' is not in the vocabulary"
|
||||
self.vocab_size = len(self._tokenizer.vocab)
|
||||
elif tokenizer_type == 'GPT2':
|
||||
self.spec_convert_dict = {'[UNK]': '<unk>'}
|
||||
self.spec_revert_dict = {
|
||||
v: k
|
||||
for k, v in self.spec_convert_dict.items()
|
||||
}
|
||||
special_tokens = [
|
||||
tok for tok in special_tokens
|
||||
if tok not in self.spec_convert_dict
|
||||
]
|
||||
vocab_file = os.path.join(vocab_path, 'vocab.json')
|
||||
merges_file = os.path.join(vocab_path, 'merges.txt')
|
||||
self._tokenizer = GPT2Tokenizer(
|
||||
vocab_file, merges_file, special_tokens=special_tokens)
|
||||
self.num_specials = len(special_tokens)
|
||||
self.vocab_size = len(self._tokenizer)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def tokenize(self, text):
|
||||
return self._tokenizer.tokenize(text)
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
if self.tokenizer_type == 'Bert':
|
||||
tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens]
|
||||
ids = self._tokenizer.convert_tokens_to_ids(tokens)
|
||||
return ids
|
||||
else:
|
||||
tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens]
|
||||
ids = self._tokenizer.convert_tokens_to_ids(tokens)
|
||||
ids = [(i + self.num_specials) % self.vocab_size for i in ids]
|
||||
return ids
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
if self.tokenizer_type == 'Bert':
|
||||
tokens = self._tokenizer.convert_ids_to_tokens(ids)
|
||||
tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens]
|
||||
return tokens
|
||||
else:
|
||||
ids = [(i - self.num_specials) % self.vocab_size for i in ids]
|
||||
tokens = self._tokenizer.convert_ids_to_tokens(ids)
|
||||
tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens]
|
||||
return tokens
|
||||
|
||||
def decode(self, ids, ignore_tokens=[]):
|
||||
tokens = self.convert_ids_to_tokens(ids)
|
||||
if len(ignore_tokens) > 0:
|
||||
ignore_tokens = set(ignore_tokens)
|
||||
tokens = [tok for tok in tokens if tok not in ignore_tokens]
|
||||
if self.tokenizer_type == 'Bert':
|
||||
string = ' '.join(tokens).replace(' ##', '')
|
||||
else:
|
||||
string = ''.join(tokens)
|
||||
string = bytearray([
|
||||
self._tokenizer.byte_decoder[c] for c in string
|
||||
]).decode('utf-8')
|
||||
string = clean_string(string)
|
||||
return string
|
||||
|
||||
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes."""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, 'r', encoding='utf-8') as reader:
|
||||
while True:
|
||||
token = reader.readline()
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class BertTokenizer(object):
|
||||
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_file,
|
||||
do_lower_case=True,
|
||||
max_len=None,
|
||||
do_basic_tokenize=True,
|
||||
never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')):
|
||||
"""Constructs a BertTokenizer.
|
||||
|
||||
Args:
|
||||
vocab_file: Path to a one-wordpiece-per-line vocabulary file
|
||||
do_lower_case: Whether to lower case the input
|
||||
Only has an effect when do_wordpiece_only=False
|
||||
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
|
||||
max_len: An artificial maximum length to truncate tokenized sequences to;
|
||||
Effective maximum length is always the minimum of this
|
||||
value (if specified) and the underlying BERT model's
|
||||
sequence length.
|
||||
never_split: List of tokens which will never be split during tokenization.
|
||||
Only has an effect when do_wordpiece_only=False
|
||||
"""
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||
'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`'
|
||||
.format(vocab_file))
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.ids_to_tokens = collections.OrderedDict([
|
||||
(ids, tok) for tok, ids in self.vocab.items()
|
||||
])
|
||||
self.do_basic_tokenize = do_basic_tokenize
|
||||
if do_basic_tokenize:
|
||||
self.basic_tokenizer = BasicTokenizer(
|
||||
do_lower_case=do_lower_case, never_split=never_split)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
else:
|
||||
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
"""Converts a sequence of tokens into ids using the vocab."""
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(self.vocab[token])
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning(
|
||||
'Token indices sequence length is longer than the specified maximum '
|
||||
' sequence length for this BERT model ({} > {}). Running this'
|
||||
' sequence through BERT will result in indexing errors'.format(
|
||||
len(ids), self.max_len))
|
||||
return ids
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
"""Converts a sequence of ids in wordpiece tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
tokens.append(self.ids_to_tokens[i])
|
||||
return tokens
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self,
|
||||
do_lower_case=True,
|
||||
never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')):
|
||||
"""Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
self.never_split = never_split
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = self._clean_text(text)
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case and token not in self.never_split:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(' '.join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize('NFD', text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == 'Mn':
|
||||
continue
|
||||
output.append(char)
|
||||
return ''.join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
if text in self.never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return [''.join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(' ')
|
||||
output.append(char)
|
||||
output.append(' ')
|
||||
else:
|
||||
output.append(char)
|
||||
return ''.join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
tmp = (cp >= 0x4E00 and cp <= 0x9FFF)
|
||||
tmp = tmp or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||
tmp = tmp or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||
tmp = tmp or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||
tmp = tmp or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||
tmp = tmp or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||
tmp = tmp or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||
tmp = tmp or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
||||
if tmp:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(' ')
|
||||
else:
|
||||
output.append(char)
|
||||
return ''.join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenization."""
|
||||
|
||||
def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer`.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = ''.join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = '##' + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == ' ' or char == '\t' or char == '\n' or char == '\r':
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == 'Zs':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == '\t' or char == '\n' or char == '\r':
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith('C'):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
tmp = (cp >= 33 and cp <= 47)
|
||||
tmp = tmp or (cp >= 58 and cp <= 64)
|
||||
tmp = tmp or (cp >= 91 and cp <= 96)
|
||||
tmp = tmp or (cp >= 123 and cp <= 126)
|
||||
if tmp:
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith('P'):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for OpenAI GPT."""
|
||||
|
||||
try:
|
||||
from functools import lru_cache
|
||||
except ImportError:
|
||||
# Just a dummy decorator to get the checks to run on python2
|
||||
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
|
||||
def lru_cache():
|
||||
return lambda func: func
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
_chr = unichr if sys.version_info[0] == 2 else chr
|
||||
bs = list(range(ord('!'),
|
||||
ord('~') + 1)) + list(range(
|
||||
ord('¡'),
|
||||
ord('¬') + 1)) + list(range(ord('®'),
|
||||
ord('ÿ') + 1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [_chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class GPT2Tokenizer(object):
|
||||
"""
|
||||
GPT-2 BPE tokenizer. Peculiarities:
|
||||
- Byte-level BPE
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
errors='replace',
|
||||
special_tokens=None,
|
||||
max_len=None):
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.encoder = json.load(open(vocab_file))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
self.pat = re.compile(
|
||||
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
)
|
||||
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
self.set_special_tokens(special_tokens)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
def set_special_tokens(self, special_tokens):
|
||||
""" Add a list of additional tokens to the encoder.
|
||||
The additional tokens are indexed starting from the last index of the
|
||||
current vocabulary in the order of the `special_tokens` list.
|
||||
"""
|
||||
if not special_tokens:
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
return
|
||||
self.special_tokens = dict((tok, len(self.encoder) + i)
|
||||
for i, tok in enumerate(special_tokens))
|
||||
self.special_tokens_decoder = {
|
||||
v: k
|
||||
for k, v in self.special_tokens.items()
|
||||
}
|
||||
logger.info('Special tokens {}'.format(self.special_tokens))
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(
|
||||
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except Exception:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[
|
||||
i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def tokenize(self, text):
|
||||
""" Tokenize a string. """
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[ord(b)] for b in token
|
||||
if ord(b) in self.byte_encoder)
|
||||
if token == '':
|
||||
continue
|
||||
bpe_tokens.extend(
|
||||
bpe_token for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
""" Converts a sequence of tokens into ids using the vocab. """
|
||||
ids = []
|
||||
python_version_3 = isinstance(tokens, str)
|
||||
python_version_2 = (
|
||||
sys.version_info[0] == 2 and isinstance(tokens, unicode))
|
||||
if python_version_3 or python_version_2:
|
||||
if tokens in self.special_tokens:
|
||||
return self.special_tokens[tokens]
|
||||
else:
|
||||
return self.encoder.get(tokens, 0)
|
||||
for token in tokens:
|
||||
if token in self.special_tokens:
|
||||
ids.append(self.special_tokens[token])
|
||||
else:
|
||||
ids.append(self.encoder.get(token, 0))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning(
|
||||
'Token indices sequence length is longer than the specified maximum '
|
||||
' sequence length for this OpenAI GPT model ({} > {}). Running this'
|
||||
' sequence through the model will result in indexing errors'.
|
||||
format(len(ids), self.max_len))
|
||||
return ids
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
if i in self.special_tokens_decoder:
|
||||
if not skip_special_tokens:
|
||||
tokens.append(self.special_tokens_decoder[i])
|
||||
else:
|
||||
tokens.append(self.decoder[i])
|
||||
return tokens
|
||||
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||
'utf-8', errors=self.errors)
|
||||
return text
|
||||
0
modelscope/trainers/nlp/space/__init__.py
Normal file
0
modelscope/trainers/nlp/space/__init__.py
Normal file
0
modelscope/trainers/nlp/space/metrics/__init__.py
Normal file
0
modelscope/trainers/nlp/space/metrics/__init__.py
Normal file
73
modelscope/trainers/nlp/space/metrics/metrics_tracker.py
Normal file
73
modelscope/trainers/nlp/space/metrics/metrics_tracker.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
MetricsTracker class
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class MetricsTracker(object):
|
||||
""" Tracking metrics. """
|
||||
|
||||
def __init__(self):
|
||||
self.metrics_val = defaultdict(float) # 记录最新一个batch返回的指标
|
||||
self.metrics_avg = defaultdict(float) # 维护一个epoch内已训练batches的平均指标
|
||||
self.num_samples = 0
|
||||
|
||||
def update(self, metrics, num_samples):
|
||||
for key, val in metrics.items():
|
||||
if val is not None:
|
||||
val = float(val) # [val] -> val
|
||||
self.metrics_val[key] = val
|
||||
avg_val = \
|
||||
(self.metrics_avg.get(key, 0) * self.num_samples + val * num_samples) / \
|
||||
(self.num_samples + num_samples)
|
||||
self.metrics_avg[key] = avg_val
|
||||
self.num_samples += num_samples
|
||||
|
||||
def clear(self):
|
||||
self.metrics_val = defaultdict(float)
|
||||
self.metrics_avg = defaultdict(float)
|
||||
self.num_samples = 0
|
||||
|
||||
def items(self):
|
||||
return self.metrics_avg.items()
|
||||
|
||||
def get(self, name):
|
||||
if self.num_samples == 0:
|
||||
raise ValueError('There is no data in Metrics.')
|
||||
return self.metrics_avg.get(name)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
'metrics_val': self.metrics_val,
|
||||
'metrics_avg': self.metrics_avg,
|
||||
'num_samples': self.num_samples,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.metrics_val = state_dict['metrics_val']
|
||||
self.metrics_avg = state_dict['metrics_avg']
|
||||
self.num_samples = state_dict['num_samples']
|
||||
|
||||
def value(self):
|
||||
metric_strs = []
|
||||
for key, val in self.metrics_val.items():
|
||||
metric_str = f'{key.upper()}-{val:.3f}'
|
||||
metric_strs.append(metric_str)
|
||||
if 'token_nll' in self.metrics_val:
|
||||
metric_str = f"TOKEN_PPL-{math.exp(self.metrics_val['token_nll']):.3f}"
|
||||
metric_strs.append(metric_str)
|
||||
metric_strs = ' '.join(metric_strs)
|
||||
return metric_strs
|
||||
|
||||
def summary(self):
|
||||
metric_strs = []
|
||||
for key, val in self.metrics_avg.items():
|
||||
metric_str = f'{key.upper()}-{val:.3f}'
|
||||
metric_strs.append(metric_str)
|
||||
if 'token_nll' in self.metrics_avg:
|
||||
metric_str = f"TOKEN_PPL-{math.exp(self.metrics_avg['token_nll']):.3f}"
|
||||
metric_strs.append(metric_str)
|
||||
metric_strs = ' '.join(metric_strs)
|
||||
return metric_strs
|
||||
0
modelscope/trainers/nlp/space/trainers/__init__.py
Normal file
0
modelscope/trainers/nlp/space/trainers/__init__.py
Normal file
761
modelscope/trainers/nlp/space/trainers/gen_trainer.py
Normal file
761
modelscope/trainers/nlp/space/trainers/gen_trainer.py
Normal file
@@ -0,0 +1,761 @@
|
||||
"""
|
||||
Trainer class.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
import modelscope.utils.nlp.space.ontology as ontology
|
||||
from ..metrics.metrics_tracker import MetricsTracker
|
||||
|
||||
|
||||
def get_logger(log_path, name='default'):
|
||||
logger = logging.getLogger(name)
|
||||
logger.propagate = False
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter('%(message)s')
|
||||
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
sh.setFormatter(formatter)
|
||||
logger.addHandler(sh)
|
||||
|
||||
fh = logging.FileHandler(log_path, mode='w')
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
to_tensor,
|
||||
config,
|
||||
logger=None,
|
||||
lr_scheduler=None,
|
||||
optimizer=None,
|
||||
reader=None,
|
||||
evaluator=None):
|
||||
self.to_tensor = to_tensor
|
||||
|
||||
self.do_train = config.do_train
|
||||
self.do_infer = config.do_infer
|
||||
self.is_decreased_valid_metric = config.Trainer.valid_metric_name[
|
||||
0] == '-'
|
||||
self.valid_metric_name = config.Trainer.valid_metric_name[1:]
|
||||
self.num_epochs = config.Trainer.num_epochs
|
||||
# self.save_dir = config.Trainer.save_dir
|
||||
self.log_steps = config.Trainer.log_steps
|
||||
self.valid_steps = config.Trainer.valid_steps
|
||||
self.save_checkpoint = config.Trainer.save_checkpoint
|
||||
self.save_summary = config.Trainer.save_summary
|
||||
self.lr = config.Model.lr
|
||||
self.weight_decay = config.Model.weight_decay
|
||||
self.batch_size = config.Trainer.batch_size
|
||||
self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps
|
||||
self.warmup_steps = config.Model.warmup_steps
|
||||
self.gpu = config.Trainer.gpu
|
||||
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.model = model
|
||||
self.func_model = self.model.module if self.gpu > 1 else self.model
|
||||
self.reader = reader
|
||||
self.evaluator = evaluator
|
||||
self.tokenizer = reader.tokenizer
|
||||
|
||||
# if not os.path.exists(self.save_dir):
|
||||
# os.makedirs(self.save_dir)
|
||||
|
||||
# self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer")
|
||||
self.logger = logger or get_logger('trainer.log', 'trainer')
|
||||
|
||||
self.batch_metrics_tracker = MetricsTracker()
|
||||
self.token_metrics_tracker = MetricsTracker()
|
||||
|
||||
self.best_valid_metric = float(
|
||||
'inf' if self.is_decreased_valid_metric else '-inf')
|
||||
self.epoch = 0
|
||||
|
||||
def decode_generated_bspn_resp(self, generated):
|
||||
"""
|
||||
decode generated
|
||||
return decoded ('bspn', 'resp')
|
||||
"""
|
||||
decoded = {}
|
||||
eos_r_id = self.reader.eos_r_id
|
||||
eos_b_id = self.reader.eos_b_id
|
||||
|
||||
# eos_r may not exists if gpt2 generated repetitive words.
|
||||
if eos_r_id in generated:
|
||||
eos_r_idx = generated.index(eos_r_id)
|
||||
else:
|
||||
eos_r_idx = len(generated) - 1
|
||||
# self.logger.info('eos_r not in generated: ' + self.tokenizer.decode(generated))
|
||||
|
||||
# predicted bspn, resp
|
||||
eos_b_idx = generated.index(eos_b_id)
|
||||
decoded['bspn'] = generated[:eos_b_idx + 1]
|
||||
decoded['resp'] = generated[eos_b_idx + 1:eos_r_idx + 1]
|
||||
return decoded
|
||||
|
||||
def decode_generated_act_resp(self, generated):
|
||||
"""
|
||||
decode generated
|
||||
return decoded['resp'] ('bspn', 'aspn')
|
||||
"""
|
||||
decoded = {}
|
||||
eos_a_id = self.reader.eos_a_id
|
||||
eos_r_id = self.reader.eos_r_id
|
||||
# eos_b_id = self.reader.eos_b_id
|
||||
|
||||
# eos_r may not exists if gpt2 generated repetitive words.
|
||||
if eos_r_id in generated:
|
||||
eos_r_idx = generated.index(eos_r_id)
|
||||
else:
|
||||
eos_r_idx = len(generated) - 1
|
||||
msg = 'eos_r not in generated: ' + self.tokenizer.decode(generated)
|
||||
self.logger.info(msg)
|
||||
|
||||
if self.reader.use_true_curr_aspn: # only predict resp
|
||||
decoded['resp'] = generated[:eos_r_idx + 1]
|
||||
else: # predicted aspn, resp
|
||||
eos_a_idx = generated.index(eos_a_id)
|
||||
decoded['aspn'] = generated[:eos_a_idx + 1]
|
||||
decoded['resp'] = generated[eos_a_idx + 1:eos_r_idx + 1]
|
||||
return decoded
|
||||
|
||||
def decode_generated_bspn(self, generated):
|
||||
eos_b_id = self.reader.eos_b_id
|
||||
if eos_b_id in generated:
|
||||
eos_b_idx = generated.index(eos_b_id)
|
||||
else:
|
||||
eos_b_idx = len(generated) - 1
|
||||
return generated[:eos_b_idx + 1]
|
||||
|
||||
def set_optimizers(self):
|
||||
"""
|
||||
Setup the optimizer and the learning rate scheduler.
|
||||
|
||||
from transformers.Trainer
|
||||
|
||||
parameters from cfg: lr (1e-3); warmup_steps
|
||||
"""
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'norm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
'params': [
|
||||
p for n, p in self.model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
'weight_decay':
|
||||
self.weight_decay,
|
||||
},
|
||||
{
|
||||
'params': [
|
||||
p for n, p in self.model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
'weight_decay':
|
||||
0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr)
|
||||
|
||||
num_training_steps = \
|
||||
self.reader.set_stats['train']['num_training_steps_per_epoch'] \
|
||||
* self.num_epochs \
|
||||
// self.gradient_accumulation_steps
|
||||
num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int(
|
||||
num_training_steps * 0.1)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps)
|
||||
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
def train(self, train_data, dev_data):
|
||||
# log info
|
||||
set_stats = self.reader.set_stats['train']
|
||||
self.logger.info('***** Running training *****')
|
||||
self.logger.info(
|
||||
' Num Training steps(one turn in a batch of dialogs) per epoch = %d',
|
||||
set_stats['num_training_steps_per_epoch'])
|
||||
self.logger.info(' Num Turns = %d', set_stats['num_turns'])
|
||||
self.logger.info(' Num Dialogs = %d', set_stats['num_dials'])
|
||||
self.logger.info(' Num Epochs = %d', self.num_epochs)
|
||||
self.logger.info(' Batch size = %d', self.batch_size)
|
||||
self.logger.info(' Gradient Accumulation steps = %d',
|
||||
self.gradient_accumulation_steps)
|
||||
steps = set_stats[
|
||||
'num_training_steps_per_epoch'] * self.num_epochs // self.gradient_accumulation_steps
|
||||
msg = ' Total optimization steps = %d' % steps
|
||||
self.logger.info(msg)
|
||||
|
||||
# begin training
|
||||
num_epochs = self.num_epochs - self.epoch
|
||||
for epoch in range(num_epochs):
|
||||
self.train_epoch(train_data=train_data, dev_data=dev_data)
|
||||
|
||||
def train_epoch(self, train_data, dev_data):
|
||||
"""
|
||||
Train an epoch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def infer(self, data_type):
|
||||
"""
|
||||
Inference interface.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, turn, old_pv_turn):
|
||||
"""
|
||||
one turn inference
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, is_best=False):
|
||||
""" save """
|
||||
train_state = {
|
||||
'epoch': self.epoch,
|
||||
'best_valid_metric': self.best_valid_metric,
|
||||
'optimizer': self.optimizer.state_dict()
|
||||
}
|
||||
if self.lr_scheduler is not None:
|
||||
train_state['lr_scheduler'] = self.lr_scheduler.state_dict()
|
||||
|
||||
# Save checkpoint
|
||||
if self.save_checkpoint:
|
||||
model_file = os.path.join(self.save_dir,
|
||||
f'state_epoch_{self.epoch}.model')
|
||||
torch.save(self.model.state_dict(), model_file)
|
||||
self.logger.info(f"Saved model state to '{model_file}'")
|
||||
|
||||
train_file = os.path.join(self.save_dir,
|
||||
f'state_epoch_{self.epoch}.train')
|
||||
torch.save(train_state, train_file)
|
||||
self.logger.info(f"Saved train state to '{train_file}'")
|
||||
|
||||
# Save current best model
|
||||
if is_best:
|
||||
best_model_file = os.path.join(self.save_dir, 'best.model')
|
||||
torch.save(self.model.state_dict(), best_model_file)
|
||||
best_train_file = os.path.join(self.save_dir, 'best.train')
|
||||
torch.save(train_state, best_train_file)
|
||||
self.logger.info(
|
||||
f"Saved best model state to '{best_model_file}' with new best valid metric "
|
||||
f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}'
|
||||
)
|
||||
|
||||
def load(self):
|
||||
""" load """
|
||||
|
||||
def _load_model_state():
|
||||
model_state_dict = torch.load(
|
||||
f'{self.func_model.init_checkpoint}',
|
||||
map_location=lambda storage, loc: storage)
|
||||
|
||||
if 'module.' in list(model_state_dict.keys())[0]:
|
||||
new_model_state_dict = OrderedDict()
|
||||
for k, v in model_state_dict.items():
|
||||
assert k[:7] == 'module.'
|
||||
new_model_state_dict[k[7:]] = v
|
||||
model_state_dict = new_model_state_dict
|
||||
|
||||
new_model_state_dict = OrderedDict()
|
||||
parameters = {
|
||||
name: param
|
||||
for name, param in self.func_model.named_parameters()
|
||||
}
|
||||
for name, param in model_state_dict.items():
|
||||
if name in parameters:
|
||||
if param.shape != parameters[name].shape:
|
||||
assert hasattr(param, 'numpy')
|
||||
arr = param.numpy()
|
||||
z = np.random.normal(
|
||||
scale=self.func_model.initializer_range,
|
||||
size=parameters[name].shape).astype('float32')
|
||||
if name == 'embedder.token_embedding.weight':
|
||||
z[-param.shape[0]:] = arr
|
||||
print(
|
||||
f'part of parameter({name}) random normlize initialize'
|
||||
)
|
||||
else:
|
||||
if z.shape[0] < param.shape[0]:
|
||||
z = arr[:z.shape[0]]
|
||||
print(f'part of parameter({name}) are dropped')
|
||||
else:
|
||||
z[:param.shape[0]] = arr
|
||||
print(
|
||||
f'part of parameter({name}) random normlize initialize'
|
||||
)
|
||||
dtype, device = param.dtype, param.device
|
||||
z = torch.tensor(z, dtype=dtype, device=device)
|
||||
new_model_state_dict[name] = z
|
||||
else:
|
||||
new_model_state_dict[name] = param
|
||||
else:
|
||||
print(f'parameter({name}) are dropped')
|
||||
model_state_dict = new_model_state_dict
|
||||
|
||||
for name in parameters:
|
||||
if name not in model_state_dict:
|
||||
if parameters[name].requires_grad:
|
||||
print(f'parameter({name}) random normlize initialize')
|
||||
z = np.random.normal(
|
||||
scale=self.func_model.initializer_range,
|
||||
size=parameters[name].shape).astype('float32')
|
||||
dtype, device = parameters[name].dtype, parameters[
|
||||
name].device
|
||||
model_state_dict[name] = torch.tensor(
|
||||
z, dtype=dtype, device=device)
|
||||
else:
|
||||
model_state_dict[name] = parameters[name]
|
||||
|
||||
self.func_model.load_state_dict(model_state_dict)
|
||||
self.logger.info(
|
||||
f"Loaded model state from '{self.func_model.init_checkpoint}.model'"
|
||||
)
|
||||
|
||||
def _load_train_state():
|
||||
train_file = f'{self.func_model.init_checkpoint}.train'
|
||||
if os.path.exists(train_file):
|
||||
train_state_dict = torch.load(
|
||||
train_file, map_location=lambda storage, loc: storage)
|
||||
self.epoch = train_state_dict['epoch']
|
||||
self.best_valid_metric = train_state_dict['best_valid_metric']
|
||||
if self.optimizer is not None and 'optimizer' in train_state_dict:
|
||||
self.optimizer.load_state_dict(
|
||||
train_state_dict['optimizer'])
|
||||
if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict:
|
||||
self.lr_scheduler.load_state_dict(
|
||||
train_state_dict['lr_scheduler'])
|
||||
self.logger.info(
|
||||
f"Loaded train state from '{train_file}' with (epoch-{self.epoch} "
|
||||
f'best_valid_metric={self.best_valid_metric:.3f})')
|
||||
else:
|
||||
self.logger.info('Loaded no train state')
|
||||
|
||||
if self.func_model.init_checkpoint is None:
|
||||
self.logger.info('Loaded no model !!!')
|
||||
return
|
||||
|
||||
if self.do_train:
|
||||
_load_model_state()
|
||||
return
|
||||
|
||||
if self.do_infer:
|
||||
_load_model_state()
|
||||
_load_train_state()
|
||||
|
||||
|
||||
class MultiWOZTrainer(Trainer):
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
to_tensor,
|
||||
config,
|
||||
logger=None,
|
||||
lr_scheduler=None,
|
||||
optimizer=None,
|
||||
reader=None,
|
||||
evaluator=None):
|
||||
super(MultiWOZTrainer,
|
||||
self).__init__(model, to_tensor, config, logger, lr_scheduler,
|
||||
optimizer, reader, evaluator)
|
||||
|
||||
def train_epoch(self, train_data, dev_data):
|
||||
"""
|
||||
Train an epoch.
|
||||
"""
|
||||
times = []
|
||||
epoch_step = 0
|
||||
global_step = 0
|
||||
tr_batch_loss = 0.0
|
||||
tr_token_loss = 0.0
|
||||
self.epoch += 1
|
||||
self.batch_metrics_tracker.clear()
|
||||
self.token_metrics_tracker.clear()
|
||||
num_training_steps = \
|
||||
self.reader.set_stats['train']['num_training_steps_per_epoch'] // \
|
||||
self.gradient_accumulation_steps # similar to the original num_batches
|
||||
|
||||
self.model.zero_grad()
|
||||
data_iterator = self.reader.get_data_iterator(all_batches=train_data)
|
||||
|
||||
for batch_idx, dial_batch in enumerate(data_iterator):
|
||||
pv_batch = []
|
||||
for turn_num, turn_batch in enumerate(dial_batch):
|
||||
first_turn = (turn_num == 0)
|
||||
samples, pv_batch = self.reader.convert_batch_turn(
|
||||
turn_batch, pv_batch, first_turn)
|
||||
batch, batch_size = self.reader.collate_fn_multi_turn(
|
||||
samples=samples)
|
||||
batch = type(batch)(
|
||||
map(lambda kv: (kv[0], self.to_tensor(kv[1])),
|
||||
batch.items()))
|
||||
|
||||
# Do a training iteration
|
||||
start_time = time.time()
|
||||
metrics = self.model(batch, is_training=True)
|
||||
if self.gpu > 1:
|
||||
for metric in metrics:
|
||||
if metric is not None:
|
||||
assert len(metric) == self.gpu
|
||||
nll, token_nll, token_num = metrics
|
||||
metrics = {}
|
||||
|
||||
token_num = torch.sum(token_num)
|
||||
token_nll = \
|
||||
torch.sum(nll) * (batch_size / self.gpu) / \
|
||||
token_num
|
||||
nll = torch.mean(nll)
|
||||
metrics['token_num'] = token_num
|
||||
metrics['token_nll'] = token_nll
|
||||
metrics['nll'] = nll
|
||||
loss = token_nll if self.func_model.token_loss else nll
|
||||
|
||||
metrics['loss'] = loss
|
||||
else:
|
||||
loss = metrics['loss']
|
||||
self.func_model._optimize(
|
||||
loss, do_update=False, optimizer=self.optimizer)
|
||||
metrics = {
|
||||
k: v.cpu().detach().numpy()
|
||||
if isinstance(v, torch.Tensor) else v
|
||||
for k, v in metrics.items()
|
||||
}
|
||||
token_num = metrics.pop('token_num', None)
|
||||
# bow_num = metrics.pop("bow_num", None)
|
||||
elapsed = time.time() - start_time
|
||||
times.append(elapsed)
|
||||
epoch_step += 1
|
||||
|
||||
tr_batch_loss += metrics['nll']
|
||||
tr_token_loss += metrics['token_nll']
|
||||
batch_metrics = {
|
||||
k: v
|
||||
for k, v in metrics.items() if 'token' not in k
|
||||
}
|
||||
token_metrics = {
|
||||
k: v
|
||||
for k, v in metrics.items() if 'token' in k
|
||||
}
|
||||
self.batch_metrics_tracker.update(batch_metrics, batch_size)
|
||||
self.token_metrics_tracker.update(token_metrics, token_num)
|
||||
|
||||
if (epoch_step % self.gradient_accumulation_steps == 0) or \
|
||||
(epoch_step == self.reader.set_stats['train']['num_training_steps_per_epoch']):
|
||||
self.optimizer.step()
|
||||
self.lr_scheduler.step()
|
||||
self.optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if self.log_steps > 0 and global_step % self.log_steps == 0:
|
||||
batch_metrics_message = self.batch_metrics_tracker.value(
|
||||
)
|
||||
token_metrics_message = self.token_metrics_tracker.value(
|
||||
)
|
||||
message_prefix = f'[Train][{self.epoch}][{global_step}/{num_training_steps}]'
|
||||
avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}'
|
||||
message = ' '.join([
|
||||
message_prefix, batch_metrics_message,
|
||||
token_metrics_message, avg_time
|
||||
])
|
||||
self.logger.info(message)
|
||||
|
||||
self.logger.info('-' * 150)
|
||||
avg_batch_loss = tr_batch_loss / epoch_step
|
||||
avg_token_loss = tr_token_loss / epoch_step
|
||||
batch_metrics_message = self.batch_metrics_tracker.summary()
|
||||
token_metrics_message = self.token_metrics_tracker.summary()
|
||||
message_prefix = f'[Valid][{self.epoch}]'
|
||||
message = ' '.join([
|
||||
message_prefix, batch_metrics_message, token_metrics_message,
|
||||
str(avg_batch_loss),
|
||||
str(avg_token_loss)
|
||||
])
|
||||
self.logger.info(message)
|
||||
|
||||
cur_valid_metric = self.batch_metrics_tracker.get(
|
||||
self.valid_metric_name)
|
||||
if self.is_decreased_valid_metric:
|
||||
is_best = cur_valid_metric < self.best_valid_metric
|
||||
else:
|
||||
is_best = cur_valid_metric > self.best_valid_metric
|
||||
if is_best:
|
||||
self.best_valid_metric = cur_valid_metric
|
||||
self.save(is_best)
|
||||
self.logger.info('-' * 150)
|
||||
|
||||
return
|
||||
|
||||
def infer(self, data_type='test'):
|
||||
"""
|
||||
Inference interface.
|
||||
"""
|
||||
self.logger.info('Generation starts ...')
|
||||
infer_save_file = os.path.join(self.save_dir,
|
||||
f'infer_{self.epoch}.result.json')
|
||||
infer_samples_save_file = os.path.join(
|
||||
self.save_dir, f'infer_samples_{self.epoch}.result.json')
|
||||
|
||||
# Inference
|
||||
result_collection = {}
|
||||
begin_time = time.time()
|
||||
|
||||
eval_data = self.reader.get_eval_data(data_type)
|
||||
set_stats = self.reader.set_stats[data_type]
|
||||
self.logger.info('***** Running Evaluation *****')
|
||||
self.logger.info(' Num Turns = %d', set_stats['num_turns'])
|
||||
|
||||
with torch.no_grad():
|
||||
pbar = tqdm(eval_data)
|
||||
for dial_idx, dialog in enumerate(pbar):
|
||||
pv_turn = {}
|
||||
for turn_idx, turn in enumerate(dialog):
|
||||
first_turn = (turn_idx == 0)
|
||||
inputs, prompt_id = self.reader.convert_turn_eval(
|
||||
turn, pv_turn, first_turn)
|
||||
batch, batch_size = self.reader.collate_fn_multi_turn(
|
||||
samples=[inputs])
|
||||
batch = type(batch)(
|
||||
map(lambda kv: (kv[0], self.to_tensor(kv[1])),
|
||||
batch.items()))
|
||||
if self.reader.use_true_curr_bspn: # generate act, response
|
||||
max_len = 60
|
||||
if not self.reader.use_true_curr_aspn:
|
||||
max_len = 80
|
||||
outputs = self.func_model.infer(
|
||||
inputs=batch,
|
||||
start_id=prompt_id,
|
||||
eos_id=self.reader.eos_r_id,
|
||||
max_gen_len=max_len)
|
||||
# resp_gen, need to trim previous context
|
||||
generated = outputs[0].cpu().numpy().tolist()
|
||||
try:
|
||||
decoded = self.decode_generated_act_resp(generated)
|
||||
except ValueError as exception:
|
||||
self.logger.info(str(exception))
|
||||
self.logger.info(self.tokenizer.decode(generated))
|
||||
decoded = {'resp': [], 'bspn': [], 'aspn': []}
|
||||
else: # predict bspn, access db, then generate act and resp
|
||||
outputs = self.func_model.infer(
|
||||
inputs=batch,
|
||||
start_id=prompt_id,
|
||||
eos_id=self.reader.eos_b_id,
|
||||
max_gen_len=60)
|
||||
generated_bs = outputs[0].cpu().numpy().tolist()
|
||||
bspn_gen = self.decode_generated_bspn(generated_bs)
|
||||
# check DB result
|
||||
if self.reader.use_true_db_pointer: # 控制当前轮的db是否为ground truth
|
||||
db = turn['db']
|
||||
else:
|
||||
db_result = self.reader.bspan_to_DBpointer(
|
||||
self.tokenizer.decode(bspn_gen),
|
||||
turn['turn_domain'])
|
||||
assert len(turn['db']) == 4
|
||||
book_result = turn['db'][2]
|
||||
assert isinstance(db_result, str)
|
||||
db = \
|
||||
[self.reader.sos_db_id] + \
|
||||
self.tokenizer.convert_tokens_to_ids([db_result]) + \
|
||||
[book_result] + \
|
||||
[self.reader.eos_db_id]
|
||||
prompt_id = self.reader.sos_a_id
|
||||
|
||||
prev_input = torch.tensor(bspn_gen + db)
|
||||
if self.func_model.use_gpu:
|
||||
prev_input = prev_input.cuda()
|
||||
outputs_db = self.func_model.infer(
|
||||
inputs=batch,
|
||||
start_id=prompt_id,
|
||||
eos_id=self.reader.eos_r_id,
|
||||
max_gen_len=80,
|
||||
prev_input=prev_input)
|
||||
generated_ar = outputs_db[0].cpu().numpy().tolist()
|
||||
try:
|
||||
decoded = self.decode_generated_act_resp(
|
||||
generated_ar)
|
||||
decoded['bspn'] = bspn_gen
|
||||
except ValueError as exception:
|
||||
self.logger.info(str(exception))
|
||||
self.logger.info(
|
||||
self.tokenizer.decode(generated_ar))
|
||||
decoded = {'resp': [], 'bspn': [], 'aspn': []}
|
||||
|
||||
turn['resp_gen'] = decoded['resp']
|
||||
turn['bspn_gen'] = turn[
|
||||
'bspn'] if self.reader.use_true_curr_bspn else decoded[
|
||||
'bspn']
|
||||
turn['aspn_gen'] = turn[
|
||||
'aspn'] if self.reader.use_true_curr_aspn else decoded[
|
||||
'aspn']
|
||||
turn['dspn_gen'] = turn['dspn']
|
||||
|
||||
pv_turn['labels'] = inputs[
|
||||
'labels'] # all true previous context
|
||||
pv_turn['resp'] = turn[
|
||||
'resp'] if self.reader.use_true_prev_resp else decoded[
|
||||
'resp']
|
||||
if not self.reader.use_true_curr_bspn:
|
||||
pv_turn['bspn'] = turn[
|
||||
'bspn'] if self.reader.use_true_prev_bspn else decoded[
|
||||
'bspn']
|
||||
pv_turn['db'] = turn[
|
||||
'db'] if self.reader.use_true_prev_bspn else db
|
||||
pv_turn['aspn'] = turn[
|
||||
'aspn'] if self.reader.use_true_prev_aspn else decoded[
|
||||
'aspn']
|
||||
|
||||
tmp_dialog_result = self.reader.inverse_transpose_turn(dialog)
|
||||
result_collection.update(tmp_dialog_result)
|
||||
|
||||
# compute tmp scores
|
||||
results, _ = self.reader.wrap_result_lm(tmp_dialog_result)
|
||||
bleu, success, match = self.evaluator.validation_metric(
|
||||
results)
|
||||
score = 0.5 * (success + match) + bleu
|
||||
pbar.set_description(
|
||||
'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %
|
||||
(match, success, bleu, score))
|
||||
|
||||
# compute scores
|
||||
results, _ = self.reader.wrap_result_lm(result_collection)
|
||||
bleu, success, match = self.evaluator.validation_metric(results)
|
||||
score = 0.5 * (success + match) + bleu
|
||||
|
||||
# log results
|
||||
metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %\
|
||||
(match, success, bleu, score)
|
||||
message_prefix = f'[Infer][{self.epoch}]'
|
||||
time_cost = f'TIME-{time.time() - begin_time:.3f}'
|
||||
message = ' '.join([message_prefix, metrics_message, time_cost])
|
||||
self.logger.info(message)
|
||||
|
||||
# save results
|
||||
eval_results = {
|
||||
'bleu': bleu,
|
||||
'success': success,
|
||||
'match': match,
|
||||
'score': score,
|
||||
'result': message
|
||||
}
|
||||
with open(infer_save_file, 'w') as fp:
|
||||
json.dump(eval_results, fp, indent=2)
|
||||
self.logger.info(f'Saved inference results to {infer_save_file}')
|
||||
with open(infer_samples_save_file, 'w') as fp:
|
||||
for sample in results:
|
||||
line = json.dumps(sample)
|
||||
fp.write(line)
|
||||
fp.write('\n')
|
||||
self.logger.info(
|
||||
f'Saved inference samples to {infer_samples_save_file}')
|
||||
|
||||
return
|
||||
|
||||
def _get_turn_domain(self, old_pv_turn, bspn_gen_ids, first_turn):
|
||||
|
||||
def _get_slots(constraint):
|
||||
domain_name = ''
|
||||
slots = {}
|
||||
for item in constraint:
|
||||
if item in ontology.placeholder_tokens:
|
||||
continue
|
||||
if item in ontology.all_domains_with_bracket:
|
||||
domain_name = item
|
||||
slots[domain_name] = set()
|
||||
else:
|
||||
assert domain_name in ontology.all_domains_with_bracket
|
||||
slots[domain_name].add(item)
|
||||
return slots
|
||||
|
||||
turn_domain = []
|
||||
if first_turn and len(bspn_gen_ids) == 0:
|
||||
turn_domain = ['[general]']
|
||||
return turn_domain
|
||||
|
||||
bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen_ids)
|
||||
turn_slots = _get_slots(bspn_token)
|
||||
if first_turn:
|
||||
return list(turn_slots.keys())
|
||||
|
||||
assert 'bspn' in old_pv_turn
|
||||
pv_bspn_token = self.tokenizer.convert_ids_to_tokens(
|
||||
old_pv_turn['bspn'])
|
||||
pv_turn_slots = _get_slots(pv_bspn_token)
|
||||
for domain, value in turn_slots.items():
|
||||
pv_value = pv_turn_slots[
|
||||
domain] if domain in pv_turn_slots else set()
|
||||
if len(value - pv_value) > 0 or len(pv_value - value):
|
||||
turn_domain.append(domain)
|
||||
if len(turn_domain) == 0:
|
||||
turn_domain = list(turn_slots.keys())
|
||||
|
||||
return turn_domain
|
||||
|
||||
def forward(self, turn, old_pv_turn):
|
||||
with torch.no_grad():
|
||||
first_turn = True if len(old_pv_turn) == 0 else False
|
||||
inputs, prompt_id = self.reader.convert_turn_eval(
|
||||
turn, old_pv_turn, first_turn)
|
||||
batch, batch_size = self.reader.collate_fn_multi_turn(
|
||||
samples=[inputs])
|
||||
batch = type(batch)(
|
||||
map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items()))
|
||||
pv_turn = {}
|
||||
|
||||
outputs = self.func_model.infer(
|
||||
inputs=batch,
|
||||
start_id=prompt_id,
|
||||
eos_id=self.reader.eos_b_id,
|
||||
max_gen_len=60)
|
||||
generated_bs = outputs[0].cpu().numpy().tolist()
|
||||
bspn_gen = self.decode_generated_bspn(generated_bs)
|
||||
|
||||
turn_domain = self._get_turn_domain(old_pv_turn, bspn_gen,
|
||||
first_turn)
|
||||
|
||||
db_result = self.reader.bspan_to_DBpointer(
|
||||
self.tokenizer.decode(bspn_gen), turn_domain)
|
||||
assert isinstance(db_result, str)
|
||||
db = \
|
||||
[self.reader.sos_db_id] + \
|
||||
self.tokenizer.convert_tokens_to_ids([db_result]) + \
|
||||
[self.reader.eos_db_id]
|
||||
prompt_id = self.reader.sos_a_id
|
||||
prev_input = torch.tensor(bspn_gen + db)
|
||||
if self.func_model.use_gpu:
|
||||
prev_input = prev_input.cuda()
|
||||
outputs_db = self.func_model.infer(
|
||||
inputs=batch,
|
||||
start_id=prompt_id,
|
||||
eos_id=self.reader.eos_r_id,
|
||||
max_gen_len=80,
|
||||
prev_input=prev_input)
|
||||
generated_ar = outputs_db[0].cpu().numpy().tolist()
|
||||
decoded = self.decode_generated_act_resp(generated_ar)
|
||||
decoded['bspn'] = bspn_gen
|
||||
|
||||
pv_turn['labels'] = inputs['labels']
|
||||
pv_turn['resp'] = decoded['resp']
|
||||
pv_turn['bspn'] = decoded['bspn']
|
||||
pv_turn['db'] = db
|
||||
pv_turn['aspn'] = decoded['aspn']
|
||||
|
||||
return pv_turn
|
||||
824
modelscope/trainers/nlp/space/trainers/intent_trainer.py
Normal file
824
modelscope/trainers/nlp/space/trainers/intent_trainer.py
Normal file
@@ -0,0 +1,824 @@
|
||||
"""
|
||||
Trainer class.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
from modelscope.trainers.nlp.space.metrics.metrics_tracker import \
|
||||
MetricsTracker
|
||||
from modelscope.utils.nlp.space.args import str2bool
|
||||
|
||||
|
||||
def get_logger(log_path, name='default'):
|
||||
logger = logging.getLogger(name)
|
||||
logger.propagate = False
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter('%(message)s')
|
||||
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
sh.setFormatter(formatter)
|
||||
logger.addHandler(sh)
|
||||
|
||||
fh = logging.FileHandler(log_path, mode='w')
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
to_tensor,
|
||||
config,
|
||||
reader=None,
|
||||
logger=None,
|
||||
lr_scheduler=None,
|
||||
optimizer=None):
|
||||
self.model = model
|
||||
self.to_tensor = to_tensor
|
||||
self.do_train = config.do_train
|
||||
self.do_infer = config.do_infer
|
||||
|
||||
self.is_decreased_valid_metric = config.Trainer.valid_metric_name[
|
||||
0] == '-'
|
||||
self.valid_metric_name = config.Trainer.valid_metric_name[1:]
|
||||
self.num_epochs = config.Trainer.num_epochs
|
||||
self.save_dir = config.Trainer.save_dir
|
||||
self.log_steps = config.Trainer.log_steps
|
||||
self.valid_steps = config.Trainer.valid_steps
|
||||
self.save_checkpoint = config.Trainer.save_checkpoint
|
||||
self.save_summary = config.Trainer.save_summary
|
||||
self.learning_method = config.Dataset.learning_method
|
||||
self.weight_decay = config.Model.weight_decay
|
||||
self.warmup_steps = config.Model.warmup_steps
|
||||
self.batch_size_label = config.Trainer.batch_size_label
|
||||
self.batch_size_nolabel = config.Trainer.batch_size_nolabel
|
||||
self.gpu = config.Trainer.gpu
|
||||
self.lr = config.Model.lr
|
||||
|
||||
self.model = model
|
||||
self.func_model = self.model.module if self.gpu > 1 else self.model
|
||||
self.reader = reader
|
||||
self.tokenizer = reader.tokenizer
|
||||
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.optimizer = optimizer
|
||||
|
||||
# if not os.path.exists(self.save_dir):
|
||||
# os.makedirs(self.save_dir)
|
||||
|
||||
# self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer")
|
||||
self.logger = logger or get_logger('trainer.log', 'trainer')
|
||||
|
||||
self.batch_metrics_tracker_label = MetricsTracker()
|
||||
self.token_metrics_tracker_label = MetricsTracker()
|
||||
self.batch_metrics_tracker_nolabel = MetricsTracker()
|
||||
self.token_metrics_tracker_nolabel = MetricsTracker()
|
||||
|
||||
self.best_valid_metric = float(
|
||||
'inf' if self.is_decreased_valid_metric else '-inf')
|
||||
self.epoch = 0
|
||||
self.batch_num = 0
|
||||
|
||||
def set_optimizers(self, num_training_steps_per_epoch):
|
||||
"""
|
||||
Setup the optimizer and the learning rate scheduler.
|
||||
|
||||
from transformers.Trainer
|
||||
|
||||
parameters from cfg: lr (1e-3); warmup_steps
|
||||
"""
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'norm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
'params': [
|
||||
p for n, p in self.model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
'weight_decay':
|
||||
self.weight_decay,
|
||||
},
|
||||
{
|
||||
'params': [
|
||||
p for n, p in self.model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
'weight_decay':
|
||||
0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr)
|
||||
|
||||
num_training_steps = num_training_steps_per_epoch * self.num_epochs
|
||||
num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int(
|
||||
num_training_steps * 0.1)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps)
|
||||
|
||||
# reset optimizer and lr_scheduler
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
# log info
|
||||
self.logger.info(
|
||||
f'***** Running training: {self.learning_method} *****')
|
||||
self.logger.info(' Num Epochs = %d', self.num_epochs)
|
||||
self.logger.info(
|
||||
' Num Training steps(one turn in a batch of dialogs) per epoch = %d',
|
||||
num_training_steps_per_epoch)
|
||||
self.logger.info(' Batch size for labeled data = %d',
|
||||
self.batch_size_label)
|
||||
self.logger.info(' Batch size for unlabeled data = %d',
|
||||
self.batch_size_nolabel)
|
||||
self.logger.info(' Total optimization steps = %d', num_training_steps)
|
||||
self.logger.info(' Total warmup steps = %d', num_warmup_steps)
|
||||
self.logger.info('************************************')
|
||||
|
||||
def train(self,
|
||||
train_label_iter,
|
||||
train_nolabel_iter=None,
|
||||
valid_label_iter=None,
|
||||
valid_nolabel_iter=None):
|
||||
# begin training
|
||||
num_epochs = self.num_epochs - self.epoch
|
||||
for epoch in range(num_epochs):
|
||||
self.train_epoch(
|
||||
train_label_iter=train_label_iter,
|
||||
train_nolabel_iter=train_nolabel_iter,
|
||||
valid_label_iter=valid_label_iter,
|
||||
valid_nolabel_iter=valid_nolabel_iter)
|
||||
|
||||
def train_epoch(self, train_label_iter, train_nolabel_iter,
|
||||
valid_label_iter, valid_nolabel_iter):
|
||||
"""
|
||||
Train an epoch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def evaluate(self, data_label_iter, data_nolabel_iter, need_save=True):
|
||||
raise NotImplementedError
|
||||
|
||||
def infer(self, data_iter, num_batches=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, is_best=False):
|
||||
""" save """
|
||||
train_state = {
|
||||
'epoch': self.epoch,
|
||||
'batch_num': self.batch_num,
|
||||
'best_valid_metric': self.best_valid_metric,
|
||||
'optimizer': self.optimizer.state_dict()
|
||||
}
|
||||
if self.lr_scheduler is not None:
|
||||
train_state['lr_scheduler'] = self.lr_scheduler.state_dict()
|
||||
|
||||
# Save checkpoint
|
||||
if self.save_checkpoint:
|
||||
model_file = os.path.join(self.save_dir,
|
||||
f'state_epoch_{self.epoch}.model')
|
||||
torch.save(self.model.state_dict(), model_file)
|
||||
self.logger.info(f"Saved model state to '{model_file}'")
|
||||
|
||||
train_file = os.path.join(self.save_dir,
|
||||
f'state_epoch_{self.epoch}.train')
|
||||
torch.save(train_state, train_file)
|
||||
self.logger.info(f"Saved train state to '{train_file}'")
|
||||
|
||||
# Save current best model
|
||||
if is_best:
|
||||
best_model_file = os.path.join(self.save_dir, 'best.model')
|
||||
torch.save(self.model.state_dict(), best_model_file)
|
||||
best_train_file = os.path.join(self.save_dir, 'best.train')
|
||||
torch.save(train_state, best_train_file)
|
||||
self.logger.info(
|
||||
f"Saved best model state to '{best_model_file}' with new best valid metric "
|
||||
f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}'
|
||||
)
|
||||
|
||||
def load(self):
|
||||
""" load """
|
||||
|
||||
def _load_model_state():
|
||||
model_state_dict = torch.load(
|
||||
f'{self.func_model.init_checkpoint}.model',
|
||||
map_location=lambda storage, loc: storage)
|
||||
|
||||
if 'module.' in list(model_state_dict.keys())[0]:
|
||||
new_model_state_dict = OrderedDict()
|
||||
for k, v in model_state_dict.items():
|
||||
assert k[:7] == 'module.'
|
||||
new_model_state_dict[k[7:]] = v
|
||||
model_state_dict = new_model_state_dict
|
||||
|
||||
new_model_state_dict = OrderedDict()
|
||||
parameters = {
|
||||
name: param
|
||||
for name, param in self.func_model.named_parameters()
|
||||
}
|
||||
for name, param in model_state_dict.items():
|
||||
if name in parameters:
|
||||
if param.shape != parameters[name].shape:
|
||||
assert hasattr(param, 'numpy')
|
||||
arr = param.numpy()
|
||||
z = np.random.normal(
|
||||
scale=self.func_model.initializer_range,
|
||||
size=parameters[name].shape).astype('float32')
|
||||
if name == 'embedder.token_embedding.weight':
|
||||
z[-param.shape[0]:] = arr
|
||||
print(
|
||||
f'part of parameter({name}) random normlize initialize'
|
||||
)
|
||||
else:
|
||||
if z.shape[0] < param.shape[0]:
|
||||
z = arr[:z.shape[0]]
|
||||
print(f'part of parameter({name}) are dropped')
|
||||
else:
|
||||
z[:param.shape[0]] = arr
|
||||
print(
|
||||
f'part of parameter({name}) random normlize initialize'
|
||||
)
|
||||
dtype, device = param.dtype, param.device
|
||||
z = torch.tensor(z, dtype=dtype, device=device)
|
||||
new_model_state_dict[name] = z
|
||||
else:
|
||||
new_model_state_dict[name] = param
|
||||
else:
|
||||
print(f'parameter({name}) are dropped')
|
||||
model_state_dict = new_model_state_dict
|
||||
|
||||
for name in parameters:
|
||||
if name not in model_state_dict:
|
||||
if parameters[name].requires_grad:
|
||||
print(f'parameter({name}) random normlize initialize')
|
||||
z = np.random.normal(
|
||||
scale=self.func_model.initializer_range,
|
||||
size=parameters[name].shape).astype('float32')
|
||||
dtype, device = parameters[name].dtype, parameters[
|
||||
name].device
|
||||
model_state_dict[name] = torch.tensor(
|
||||
z, dtype=dtype, device=device)
|
||||
else:
|
||||
model_state_dict[name] = parameters[name]
|
||||
|
||||
self.func_model.load_state_dict(model_state_dict)
|
||||
self.logger.info(
|
||||
f"Loaded model state from '{self.func_model.init_checkpoint}.model'"
|
||||
)
|
||||
|
||||
def _load_train_state():
|
||||
train_file = f'{self.func_model.init_checkpoint}.train'
|
||||
if os.path.exists(train_file):
|
||||
train_state_dict = torch.load(
|
||||
train_file, map_location=lambda storage, loc: storage)
|
||||
self.epoch = train_state_dict['epoch']
|
||||
self.best_valid_metric = train_state_dict['best_valid_metric']
|
||||
if self.optimizer is not None and 'optimizer' in train_state_dict:
|
||||
self.optimizer.load_state_dict(
|
||||
train_state_dict['optimizer'])
|
||||
if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict:
|
||||
self.lr_scheduler.load_state_dict(
|
||||
train_state_dict['lr_scheduler'])
|
||||
self.logger.info(
|
||||
f"Loaded train state from '{train_file}' with (epoch-{self.epoch} "
|
||||
f'best_valid_metric={self.best_valid_metric:.3f})')
|
||||
else:
|
||||
self.logger.info('Loaded no train state')
|
||||
|
||||
if self.func_model.init_checkpoint is None:
|
||||
self.logger.info('Loaded no model !!!')
|
||||
return
|
||||
|
||||
_load_model_state()
|
||||
_load_train_state()
|
||||
|
||||
|
||||
class IntentTrainer(Trainer):
|
||||
|
||||
def __init__(self, model, to_tensor, config, reader=None):
|
||||
super(IntentTrainer, self).__init__(model, to_tensor, config, reader)
|
||||
self.example = config.Model.example
|
||||
self.can_norm = config.Trainer.can_norm
|
||||
|
||||
def can_normalization(self, y_pred, y_true, ex_data_iter):
|
||||
# 预测结果,计算修正前准确率
|
||||
acc_original = np.mean([y_pred.argmax(1) == y_true])
|
||||
message = 'original acc: %s' % acc_original
|
||||
|
||||
# 评价每个预测结果的不确定性
|
||||
k = 3
|
||||
y_pred_topk = np.sort(y_pred, axis=1)[:, -k:]
|
||||
y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True)
|
||||
y_pred_uncertainty =\
|
||||
-(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k)
|
||||
|
||||
# 选择阈值,划分高、低置信度两部分
|
||||
# print(np.sort(y_pred_uncertainty)[-100:].tolist())
|
||||
threshold = 0.7
|
||||
y_pred_confident = y_pred[y_pred_uncertainty < threshold]
|
||||
y_pred_unconfident = y_pred[y_pred_uncertainty >= threshold]
|
||||
y_true_confident = y_true[y_pred_uncertainty < threshold]
|
||||
y_true_unconfident = y_true[y_pred_uncertainty >= threshold]
|
||||
|
||||
# 显示两部分各自的准确率
|
||||
# 一般而言,高置信度集准确率会远高于低置信度的
|
||||
acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean() \
|
||||
if len(y_true_confident) else 0.
|
||||
acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean() \
|
||||
if len(y_true_unconfident) else 0.
|
||||
message += ' (%s) confident acc: %s' % (len(y_true_confident),
|
||||
acc_confident)
|
||||
message += ' (%s) unconfident acc: %s' % (len(y_true_unconfident),
|
||||
acc_unconfident)
|
||||
|
||||
# 从训练集统计先验分布
|
||||
prior = np.zeros(self.func_model.num_intent)
|
||||
for _, (batch, batch_size) in ex_data_iter:
|
||||
for intent_label in batch['intent_label']:
|
||||
prior[intent_label] += 1.
|
||||
|
||||
prior /= prior.sum()
|
||||
|
||||
# 逐个修改低置信度样本,并重新评价准确率
|
||||
right, alpha, iters = 0, 1, 1
|
||||
for i, y in enumerate(y_pred_unconfident):
|
||||
Y = np.concatenate([y_pred_confident, y[None]], axis=0)
|
||||
for j in range(iters):
|
||||
Y = Y**alpha
|
||||
Y /= Y.mean(axis=0, keepdims=True)
|
||||
Y *= prior[None]
|
||||
Y /= Y.sum(axis=1, keepdims=True)
|
||||
y = Y[-1]
|
||||
if y.argmax() == y_true_unconfident[i]:
|
||||
right += 1
|
||||
|
||||
# 输出修正后的准确率
|
||||
acc_final = \
|
||||
(acc_confident * len(y_pred_confident) + right) / \
|
||||
len(y_pred)
|
||||
if len(y_pred_unconfident):
|
||||
message += ' new unconfident acc: %s' % (
|
||||
right / len(y_pred_unconfident))
|
||||
else:
|
||||
message += ' no unconfident predictions'
|
||||
message += ' final acc: %s' % acc_final
|
||||
return acc_original, acc_final, message
|
||||
|
||||
def train_epoch(self, train_label_iter, train_nolabel_iter,
|
||||
valid_label_iter, valid_nolabel_iter):
|
||||
"""
|
||||
Train an epoch.
|
||||
"""
|
||||
times = []
|
||||
self.epoch += 1
|
||||
self.batch_metrics_tracker_label.clear()
|
||||
self.token_metrics_tracker_label.clear()
|
||||
self.batch_metrics_tracker_nolabel.clear()
|
||||
self.token_metrics_tracker_nolabel.clear()
|
||||
|
||||
num_label_batches = len(train_label_iter)
|
||||
num_nolabel_batches = len(
|
||||
train_nolabel_iter) if train_nolabel_iter is not None else 0
|
||||
num_batches = max(num_label_batches, num_nolabel_batches)
|
||||
|
||||
train_label_iter_loop = iter(train_label_iter)
|
||||
train_nolabel_iter_loop = iter(
|
||||
train_nolabel_iter) if train_nolabel_iter is not None else None
|
||||
report_for_unlabeled_data = True if train_nolabel_iter is not None else False
|
||||
|
||||
for batch_id in range(1, num_batches + 1):
|
||||
# Do a training iteration
|
||||
start_time = time.time()
|
||||
batch_list, batch_size_list, with_label_list, loss_list, metrics_list = [], [], [], [], []
|
||||
data_file_list = []
|
||||
|
||||
# collect batch for labeled data
|
||||
try:
|
||||
data_file_label, (
|
||||
batch_label,
|
||||
batch_size_label) = next(train_label_iter_loop)
|
||||
except StopIteration:
|
||||
train_label_iter_loop = iter(train_label_iter)
|
||||
data_file_label, (
|
||||
batch_label,
|
||||
batch_size_label) = next(train_label_iter_loop)
|
||||
batch_list.append(batch_label)
|
||||
batch_size_list.append(batch_size_label)
|
||||
with_label_list.append(True)
|
||||
data_file_list.append(data_file_label)
|
||||
|
||||
# collect batch for unlabeled data
|
||||
if train_nolabel_iter is not None:
|
||||
try:
|
||||
data_file_nolabel, (
|
||||
batch_nolabel,
|
||||
batch_size_nolabel) = next(train_nolabel_iter_loop)
|
||||
except StopIteration:
|
||||
train_nolabel_iter_loop = iter(train_nolabel_iter)
|
||||
data_file_nolabel, (
|
||||
batch_nolabel,
|
||||
batch_size_nolabel) = next(train_nolabel_iter_loop)
|
||||
batch_list.append(batch_nolabel)
|
||||
batch_size_list.append(batch_size_nolabel)
|
||||
with_label_list.append(False)
|
||||
data_file_list.append(data_file_nolabel)
|
||||
|
||||
# forward labeled batch and unlabeled batch and collect outputs, respectively
|
||||
for (batch, batch_size, with_label, data_file) in \
|
||||
zip(batch_list, batch_size_list, with_label_list, data_file_list):
|
||||
batch = type(batch)(
|
||||
map(lambda kv: (kv[0], self.to_tensor(kv[1])),
|
||||
batch.items()))
|
||||
if self.example and with_label:
|
||||
current_dataset = train_label_iter.data_file_to_dataset[
|
||||
data_file]
|
||||
example_batch = self.reader.retrieve_examples(
|
||||
dataset=current_dataset,
|
||||
labels=batch['intent_label'],
|
||||
inds=batch['ids'],
|
||||
task='intent')
|
||||
example_batch = type(example_batch)(
|
||||
map(lambda kv: (kv[0], self.to_tensor(kv[1])),
|
||||
example_batch.items()))
|
||||
for k, v in example_batch.items():
|
||||
batch[k] = v
|
||||
batch['epoch'] = self.epoch
|
||||
batch['num_steps'] = self.batch_num
|
||||
metrics = self.model(
|
||||
batch,
|
||||
is_training=True,
|
||||
with_label=with_label,
|
||||
data_file=data_file)
|
||||
loss, metrics = self.balance_metrics(
|
||||
metrics=metrics, batch_size=batch_size)
|
||||
loss_list.append(loss)
|
||||
metrics_list.append(metrics)
|
||||
|
||||
# combine loss for labeled data and unlabeled data
|
||||
# TODO change the computation of combined loss of labeled batch and unlabeled batch
|
||||
loss = loss_list[0] if len(
|
||||
loss_list) == 1 else loss_list[0] + loss_list[1]
|
||||
|
||||
# optimization procedure
|
||||
self.func_model._optimize(
|
||||
loss, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler)
|
||||
elapsed = time.time() - start_time
|
||||
times.append(elapsed)
|
||||
self.batch_num += 1
|
||||
|
||||
# track metrics and log temporary message
|
||||
for (batch_size, metrics,
|
||||
with_label) in zip(batch_size_list, metrics_list,
|
||||
with_label_list):
|
||||
self.track_and_log_message(
|
||||
metrics=metrics,
|
||||
batch_id=batch_id,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
times=times,
|
||||
with_label=with_label)
|
||||
|
||||
# evaluate
|
||||
if self.valid_steps > 0 and valid_label_iter is not None and valid_nolabel_iter is not None \
|
||||
and batch_id % self.valid_steps == 0:
|
||||
self.evaluate(
|
||||
data_label_iter=valid_label_iter,
|
||||
data_nolabel_iter=valid_nolabel_iter)
|
||||
|
||||
# compute accuracy for valid dataset
|
||||
accuracy = self.infer(
|
||||
data_iter=valid_label_iter, ex_data_iter=train_label_iter)
|
||||
|
||||
# report summary message and save checkpoints
|
||||
self.save_and_log_message(
|
||||
report_for_unlabeled_data, cur_valid_metric=-accuracy)
|
||||
|
||||
def forward(self, batch):
|
||||
pred = []
|
||||
|
||||
with torch.no_grad():
|
||||
batch = type(batch)(
|
||||
map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items()))
|
||||
result = self.model.infer(inputs=batch)
|
||||
result = {
|
||||
name: result[name].cpu().detach().numpy()
|
||||
for name in result
|
||||
}
|
||||
intent_probs = result['intent_probs']
|
||||
if self.can_norm:
|
||||
pred += [intent_probs]
|
||||
else:
|
||||
pred += np.argmax(intent_probs, axis=1).tolist()
|
||||
|
||||
return pred
|
||||
|
||||
def infer(self, data_iter, num_batches=None, ex_data_iter=None):
|
||||
"""
|
||||
Inference interface.
|
||||
"""
|
||||
self.logger.info('Generation starts ...')
|
||||
infer_save_file = os.path.join(self.save_dir,
|
||||
f'infer_{self.epoch}.result.json')
|
||||
|
||||
# Inference
|
||||
batch_cnt = 0
|
||||
pred, true = [], []
|
||||
outputs, labels = [], []
|
||||
begin_time = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
if self.example:
|
||||
for _, (batch, batch_size) in tqdm(
|
||||
ex_data_iter, desc='Building train memory.'):
|
||||
batch = type(batch)(
|
||||
map(lambda kv: (kv[0], self.to_tensor(kv[1])),
|
||||
batch.items()))
|
||||
result = self.model.infer(inputs=batch)
|
||||
result = {
|
||||
name: result[name].cpu().detach().numpy()
|
||||
for name in result
|
||||
}
|
||||
outputs.append(torch.from_numpy(result['features']))
|
||||
labels += batch['intent_label'].tolist()
|
||||
|
||||
mem = torch.cat(outputs, dim=0)
|
||||
mem = mem.cuda() if self.func_model.use_gpu else mem
|
||||
labels = torch.LongTensor(labels).unsqueeze(0)
|
||||
labels = labels.cuda() if self.func_model.use_gpu else labels
|
||||
self.logger.info(f'Memory size: {mem.size()}')
|
||||
|
||||
for _, (batch, batch_size) in tqdm(data_iter, total=num_batches):
|
||||
batch = type(batch)(
|
||||
map(lambda kv: (kv[0], self.to_tensor(kv[1])),
|
||||
batch.items()))
|
||||
result = self.model.infer(inputs=batch)
|
||||
result = {
|
||||
name: result[name].cpu().detach().numpy()
|
||||
for name in result
|
||||
}
|
||||
|
||||
if self.example:
|
||||
features = torch.from_numpy(result['features'])
|
||||
features = features.cuda(
|
||||
) if self.func_model.use_gpu else features
|
||||
probs = torch.softmax(features.mm(mem.t()), dim=-1)
|
||||
intent_probs = torch.zeros(
|
||||
probs.size(0), self.func_model.num_intent)
|
||||
intent_probs = intent_probs.cuda(
|
||||
) if self.func_model.use_gpu else intent_probs
|
||||
intent_probs = intent_probs.scatter_add(
|
||||
-1, labels.repeat(probs.size(0), 1), probs)
|
||||
intent_probs = intent_probs.cpu().detach().numpy()
|
||||
else:
|
||||
intent_probs = result['intent_probs']
|
||||
|
||||
if self.can_norm:
|
||||
pred += [intent_probs]
|
||||
true += batch['intent_label'].cpu().detach().tolist()
|
||||
else:
|
||||
pred += np.argmax(intent_probs, axis=1).tolist()
|
||||
true += batch['intent_label'].cpu().detach().tolist()
|
||||
|
||||
batch_cnt += 1
|
||||
if batch_cnt == num_batches:
|
||||
break
|
||||
|
||||
if self.can_norm:
|
||||
true = np.array(true)
|
||||
pred = np.concatenate(pred, axis=0)
|
||||
acc_original, acc_final, message = self.can_normalization(
|
||||
y_pred=pred, y_true=true, ex_data_iter=ex_data_iter)
|
||||
accuracy = max(acc_original, acc_final)
|
||||
infer_results = {
|
||||
'accuracy': accuracy,
|
||||
'pred_labels': pred.tolist(),
|
||||
'message': message
|
||||
}
|
||||
metrics_message = f'Accuracy: {accuracy} {message}'
|
||||
else:
|
||||
accuracy = sum(p == t for p, t in zip(pred, true)) / len(pred)
|
||||
infer_results = {'accuracy': accuracy, 'pred_labels': pred}
|
||||
metrics_message = f'Accuracy: {accuracy}'
|
||||
|
||||
self.logger.info(f'Saved inference results to {infer_save_file}')
|
||||
with open(infer_save_file, 'w') as fp:
|
||||
json.dump(infer_results, fp, indent=2)
|
||||
message_prefix = f'[Infer][{self.epoch}]'
|
||||
time_cost = f'TIME-{time.time() - begin_time:.3f}'
|
||||
message = ' '.join([message_prefix, metrics_message, time_cost])
|
||||
self.logger.info(message)
|
||||
return accuracy
|
||||
|
||||
def track_and_log_message(self, metrics, batch_id, batch_size, num_batches,
|
||||
times, with_label):
|
||||
# track metrics
|
||||
batch_metrics_tracker = self.batch_metrics_tracker_label if with_label else self.batch_metrics_tracker_nolabel
|
||||
token_metrics_tracker = self.token_metrics_tracker_label if with_label else self.token_metrics_tracker_nolabel
|
||||
|
||||
metrics = {
|
||||
k: v.cpu().detach().numpy() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in metrics.items()
|
||||
}
|
||||
mlm_num = metrics.pop('mlm_num', 0)
|
||||
|
||||
batch_metrics = {k: v for k, v in metrics.items() if 'token' not in k}
|
||||
token_metrics = {k: v for k, v in metrics.items() if 'token' in k}
|
||||
batch_metrics_tracker.update(batch_metrics, batch_size)
|
||||
token_metrics_tracker.update(token_metrics, mlm_num)
|
||||
|
||||
# log message
|
||||
if self.log_steps > 0 and batch_id % self.log_steps == 0:
|
||||
batch_metrics_message = batch_metrics_tracker.value()
|
||||
token_metrics_message = token_metrics_tracker.value()
|
||||
label_prefix = 'Labeled' if with_label else 'Unlabeled'
|
||||
message_prefix = f'[Train][{self.epoch}][{batch_id}/{num_batches}][{label_prefix}]'
|
||||
avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}'
|
||||
message = ' '.join([
|
||||
message_prefix, batch_metrics_message, token_metrics_message,
|
||||
avg_time
|
||||
])
|
||||
self.logger.info(message)
|
||||
|
||||
def save_and_log_message(self,
|
||||
report_for_unlabeled_data,
|
||||
cur_valid_metric=None):
|
||||
# report message
|
||||
batch_metrics_message = self.batch_metrics_tracker_label.summary()
|
||||
token_metrics_message = self.token_metrics_tracker_label.summary()
|
||||
message_prefix = f'[Valid][{self.epoch}][Labeled]'
|
||||
message = ' '.join(
|
||||
[message_prefix, batch_metrics_message, token_metrics_message])
|
||||
self.logger.info(message)
|
||||
if report_for_unlabeled_data:
|
||||
batch_metrics_message = self.batch_metrics_tracker_nolabel.summary(
|
||||
)
|
||||
token_metrics_message = self.token_metrics_tracker_nolabel.summary(
|
||||
)
|
||||
message_prefix = f'[Valid][{self.epoch}][Unlabeled]'
|
||||
message = ' '.join(
|
||||
[message_prefix, batch_metrics_message, token_metrics_message])
|
||||
self.logger.info(message)
|
||||
|
||||
# save checkpoints
|
||||
assert cur_valid_metric is not None
|
||||
if self.is_decreased_valid_metric:
|
||||
is_best = cur_valid_metric < self.best_valid_metric
|
||||
else:
|
||||
is_best = cur_valid_metric > self.best_valid_metric
|
||||
if is_best:
|
||||
self.best_valid_metric = cur_valid_metric
|
||||
self.save(is_best)
|
||||
|
||||
def balance_metrics(self, metrics, batch_size):
|
||||
if self.gpu > 1:
|
||||
for metric in metrics:
|
||||
if metric is not None:
|
||||
assert len(metric) == self.gpu
|
||||
|
||||
intent_loss, mlm, token_mlm, mlm_num, kl, con = metrics
|
||||
metrics = {}
|
||||
|
||||
intent_loss = torch.mean(intent_loss)
|
||||
metrics['intent_loss'] = intent_loss
|
||||
loss = intent_loss
|
||||
|
||||
if mlm is not None:
|
||||
mlm_num = torch.sum(mlm_num)
|
||||
token_mlm = torch.sum(mlm) * (batch_size / self.gpu) / mlm_num
|
||||
mlm = torch.mean(mlm)
|
||||
metrics['mlm_num'] = mlm_num
|
||||
metrics['token_mlm'] = token_mlm
|
||||
metrics['mlm'] = mlm
|
||||
loss = loss + (token_mlm if self.func_model.token_loss else
|
||||
mlm) * self.func_model.mlm_ratio
|
||||
|
||||
if kl is not None:
|
||||
kl = torch.mean(kl)
|
||||
metrics['kl'] = kl
|
||||
loss = loss + kl * self.func_model.kl_ratio
|
||||
|
||||
if con is not None:
|
||||
con = torch.mean(con)
|
||||
metrics['con'] = con
|
||||
loss = loss + con
|
||||
|
||||
metrics['loss'] = loss
|
||||
|
||||
assert 'loss' in metrics
|
||||
return metrics['loss'], metrics
|
||||
|
||||
def load(self):
|
||||
""" load """
|
||||
|
||||
def _load_model_state():
|
||||
model_state_dict = torch.load(
|
||||
f'{self.func_model.init_checkpoint}',
|
||||
map_location=lambda storage, loc: storage)
|
||||
|
||||
if 'module.' in list(model_state_dict.keys())[0]:
|
||||
new_model_state_dict = OrderedDict()
|
||||
for k, v in model_state_dict.items():
|
||||
assert k[:7] == 'module.'
|
||||
new_model_state_dict[k[7:]] = v
|
||||
model_state_dict = new_model_state_dict
|
||||
|
||||
new_model_state_dict = OrderedDict()
|
||||
parameters = {
|
||||
name: param
|
||||
for name, param in self.func_model.named_parameters()
|
||||
}
|
||||
for name, param in model_state_dict.items():
|
||||
if name in parameters:
|
||||
if param.shape != parameters[name].shape:
|
||||
assert hasattr(param, 'numpy')
|
||||
arr = param.numpy()
|
||||
z = np.random.normal(
|
||||
scale=self.func_model.initializer_range,
|
||||
size=parameters[name].shape).astype('float32')
|
||||
if name == 'embedder.token_embedding.weight':
|
||||
z[-param.shape[0]:] = arr
|
||||
print(
|
||||
f'part of parameter({name}) random normlize initialize'
|
||||
)
|
||||
else:
|
||||
if z.shape[0] < param.shape[0]:
|
||||
z = arr[:z.shape[0]]
|
||||
print(f'part of parameter({name}) are dropped')
|
||||
else:
|
||||
z[:param.shape[0]] = arr
|
||||
print(
|
||||
f'part of parameter({name}) random normlize initialize'
|
||||
)
|
||||
dtype, device = param.dtype, param.device
|
||||
z = torch.tensor(z, dtype=dtype, device=device)
|
||||
new_model_state_dict[name] = z
|
||||
else:
|
||||
new_model_state_dict[name] = param
|
||||
else:
|
||||
print(f'parameter({name}) are dropped')
|
||||
model_state_dict = new_model_state_dict
|
||||
|
||||
for name in parameters:
|
||||
if name not in model_state_dict:
|
||||
if parameters[name].requires_grad:
|
||||
print(f'parameter({name}) random normlize initialize')
|
||||
z = np.random.normal(
|
||||
scale=self.func_model.initializer_range,
|
||||
size=parameters[name].shape).astype('float32')
|
||||
dtype, device = parameters[name].dtype, parameters[
|
||||
name].device
|
||||
model_state_dict[name] = torch.tensor(
|
||||
z, dtype=dtype, device=device)
|
||||
else:
|
||||
model_state_dict[name] = parameters[name]
|
||||
|
||||
self.func_model.load_state_dict(model_state_dict)
|
||||
self.logger.info(
|
||||
f"Loaded model state from '{self.func_model.init_checkpoint}.model'"
|
||||
)
|
||||
|
||||
def _load_train_state():
|
||||
train_file = f'{self.func_model.init_checkpoint}.train'
|
||||
if os.path.exists(train_file):
|
||||
train_state_dict = torch.load(
|
||||
train_file, map_location=lambda storage, loc: storage)
|
||||
self.epoch = train_state_dict['epoch']
|
||||
self.best_valid_metric = train_state_dict['best_valid_metric']
|
||||
if self.optimizer is not None and 'optimizer' in train_state_dict:
|
||||
self.optimizer.load_state_dict(
|
||||
train_state_dict['optimizer'])
|
||||
if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict:
|
||||
self.lr_scheduler.load_state_dict(
|
||||
train_state_dict['lr_scheduler'])
|
||||
self.logger.info(
|
||||
f"Loaded train state from '{train_file}' with (epoch-{self.epoch} "
|
||||
f'best_valid_metric={self.best_valid_metric:.3f})')
|
||||
else:
|
||||
self.logger.info('Loaded no train state')
|
||||
|
||||
if self.func_model.init_checkpoint is None:
|
||||
self.logger.info('Loaded no model !!!')
|
||||
return
|
||||
|
||||
if self.do_train:
|
||||
_load_model_state()
|
||||
return
|
||||
|
||||
if self.do_infer:
|
||||
_load_model_state()
|
||||
_load_train_state()
|
||||
@@ -44,6 +44,8 @@ class Tasks(object):
|
||||
token_classification = 'token-classification'
|
||||
conversational = 'conversational'
|
||||
text_generation = 'text-generation'
|
||||
dialog_modeling = 'dialog-modeling'
|
||||
dialog_intent_prediction = 'dialog-intent-prediction'
|
||||
table_question_answering = 'table-question-answering'
|
||||
feature_extraction = 'feature-extraction'
|
||||
sentence_similarity = 'sentence-similarity'
|
||||
|
||||
0
modelscope/utils/nlp/__init__.py
Normal file
0
modelscope/utils/nlp/__init__.py
Normal file
0
modelscope/utils/nlp/space/__init__.py
Normal file
0
modelscope/utils/nlp/space/__init__.py
Normal file
66
modelscope/utils/nlp/space/args.py
Normal file
66
modelscope/utils/nlp/space/args.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Parse argument.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import json
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||
return True
|
||||
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError('Unsupported value encountered.')
|
||||
|
||||
|
||||
class HParams(dict):
|
||||
""" Hyper-parameters class
|
||||
|
||||
Store hyper-parameters in training / infer / ... scripts.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self.keys():
|
||||
return self[name]
|
||||
for v in self.values():
|
||||
if isinstance(v, HParams):
|
||||
if name in v:
|
||||
return v[name]
|
||||
raise AttributeError(f"'HParams' object has no attribute '{name}'")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
self[name] = value
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename, 'w', encoding='utf-8') as fp:
|
||||
json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False)
|
||||
|
||||
def load(self, filename):
|
||||
with open(filename, 'r', encoding='utf-8') as fp:
|
||||
params_dict = json.load(fp)
|
||||
for k, v in params_dict.items():
|
||||
if isinstance(v, dict):
|
||||
self[k].update(HParams(v))
|
||||
else:
|
||||
self[k] = v
|
||||
|
||||
|
||||
def parse_args(parser):
|
||||
""" Parse hyper-parameters from cmdline. """
|
||||
parsed = parser.parse_args()
|
||||
args = HParams()
|
||||
optional_args = parser._action_groups[1]
|
||||
for action in optional_args._group_actions[1:]:
|
||||
arg_name = action.dest
|
||||
args[arg_name] = getattr(parsed, arg_name)
|
||||
for group in parser._action_groups[2:]:
|
||||
group_args = HParams()
|
||||
for action in group._group_actions:
|
||||
arg_name = action.dest
|
||||
group_args[arg_name] = getattr(parsed, arg_name)
|
||||
if len(group_args) > 0:
|
||||
args[group.title] = group_args
|
||||
return args
|
||||
52
modelscope/utils/nlp/space/criterions.py
Normal file
52
modelscope/utils/nlp/space/criterions.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
def compute_kl_loss(p, q, filter_scores=None):
|
||||
p_loss = F.kl_div(
|
||||
F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
|
||||
q_loss = F.kl_div(
|
||||
F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')
|
||||
|
||||
# You can choose whether to use function "sum" and "mean" depending on your task
|
||||
p_loss = p_loss.sum(dim=-1)
|
||||
q_loss = q_loss.sum(dim=-1)
|
||||
|
||||
# mask is for filter mechanism
|
||||
if filter_scores is not None:
|
||||
p_loss = filter_scores * p_loss
|
||||
q_loss = filter_scores * q_loss
|
||||
|
||||
p_loss = p_loss.mean()
|
||||
q_loss = q_loss.mean()
|
||||
|
||||
loss = (p_loss + q_loss) / 2
|
||||
return loss
|
||||
|
||||
|
||||
class CatKLLoss(_Loss):
|
||||
"""
|
||||
CatKLLoss
|
||||
"""
|
||||
|
||||
def __init__(self, reduction='mean'):
|
||||
super(CatKLLoss, self).__init__()
|
||||
assert reduction in ['none', 'sum', 'mean']
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, log_qy, log_py):
|
||||
"""
|
||||
KL(qy|py) = Eq[qy * log(q(y) / p(y))]
|
||||
|
||||
log_qy: (batch_size, latent_size)
|
||||
log_py: (batch_size, latent_size)
|
||||
"""
|
||||
qy = torch.exp(log_qy)
|
||||
kl = torch.sum(qy * (log_qy - log_py), dim=1)
|
||||
|
||||
if self.reduction == 'mean':
|
||||
kl = kl.mean()
|
||||
elif self.reduction == 'sum':
|
||||
kl = kl.sum()
|
||||
return kl
|
||||
321
modelscope/utils/nlp/space/db_ops.py
Normal file
321
modelscope/utils/nlp/space/db_ops.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import os
|
||||
import random
|
||||
import sqlite3
|
||||
|
||||
import json
|
||||
|
||||
from .ontology import all_domains, db_domains
|
||||
|
||||
|
||||
class MultiWozDB(object):
|
||||
|
||||
def __init__(self, db_dir, db_paths):
|
||||
self.dbs = {}
|
||||
self.sql_dbs = {}
|
||||
for domain in all_domains:
|
||||
with open(os.path.join(db_dir, db_paths[domain]), 'r') as f:
|
||||
self.dbs[domain] = json.loads(f.read().lower())
|
||||
|
||||
def oneHotVector(self, domain, num):
|
||||
"""Return number of available entities for particular domain."""
|
||||
vector = [0, 0, 0, 0]
|
||||
if num == '':
|
||||
return vector
|
||||
if domain != 'train':
|
||||
if num == 0:
|
||||
vector = [1, 0, 0, 0]
|
||||
elif num == 1:
|
||||
vector = [0, 1, 0, 0]
|
||||
elif num <= 3:
|
||||
vector = [0, 0, 1, 0]
|
||||
else:
|
||||
vector = [0, 0, 0, 1]
|
||||
else:
|
||||
if num == 0:
|
||||
vector = [1, 0, 0, 0]
|
||||
elif num <= 5:
|
||||
vector = [0, 1, 0, 0]
|
||||
elif num <= 10:
|
||||
vector = [0, 0, 1, 0]
|
||||
else:
|
||||
vector = [0, 0, 0, 1]
|
||||
return vector
|
||||
|
||||
def addBookingPointer(self, turn_da):
|
||||
"""Add information about availability of the booking option."""
|
||||
# Booking pointer
|
||||
# Do not consider booking two things in a single turn.
|
||||
vector = [0, 0]
|
||||
if turn_da.get('booking-nobook'):
|
||||
vector = [1, 0]
|
||||
if turn_da.get('booking-book') or turn_da.get('train-offerbooked'):
|
||||
vector = [0, 1]
|
||||
return vector
|
||||
|
||||
def addDBPointer(self, domain, match_num, return_num=False):
|
||||
"""Create database pointer for all related domains."""
|
||||
# if turn_domains is None:
|
||||
# turn_domains = db_domains
|
||||
if domain in db_domains:
|
||||
vector = self.oneHotVector(domain, match_num)
|
||||
else:
|
||||
vector = [0, 0, 0, 0]
|
||||
return vector
|
||||
|
||||
def addDBIndicator(self, domain, match_num, return_num=False):
|
||||
"""Create database indicator for all related domains."""
|
||||
# if turn_domains is None:
|
||||
# turn_domains = db_domains
|
||||
if domain in db_domains:
|
||||
vector = self.oneHotVector(domain, match_num)
|
||||
else:
|
||||
vector = [0, 0, 0, 0]
|
||||
|
||||
# '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]'
|
||||
if vector == [0, 0, 0, 0]:
|
||||
indicator = '[db_nores]'
|
||||
else:
|
||||
indicator = '[db_%s]' % vector.index(1)
|
||||
return indicator
|
||||
|
||||
def get_match_num(self, constraints, return_entry=False):
|
||||
"""Create database pointer for all related domains."""
|
||||
match = {'general': ''}
|
||||
entry = {}
|
||||
# if turn_domains is None:
|
||||
# turn_domains = db_domains
|
||||
for domain in all_domains:
|
||||
match[domain] = ''
|
||||
if domain in db_domains and constraints.get(domain):
|
||||
matched_ents = self.queryJsons(domain, constraints[domain])
|
||||
match[domain] = len(matched_ents)
|
||||
if return_entry:
|
||||
entry[domain] = matched_ents
|
||||
if return_entry:
|
||||
return entry
|
||||
return match
|
||||
|
||||
def pointerBack(self, vector, domain):
|
||||
# multi domain implementation
|
||||
# domnum = cfg.domain_num
|
||||
if domain.endswith(']'):
|
||||
domain = domain[1:-1]
|
||||
if domain != 'train':
|
||||
nummap = {0: '0', 1: '1', 2: '2-3', 3: '>3'}
|
||||
else:
|
||||
nummap = {0: '0', 1: '1-5', 2: '6-10', 3: '>10'}
|
||||
if vector[:4] == [0, 0, 0, 0]:
|
||||
report = ''
|
||||
else:
|
||||
num = vector.index(1)
|
||||
report = domain + ': ' + nummap[num] + '; '
|
||||
|
||||
if vector[-2] == 0 and vector[-1] == 1:
|
||||
report += 'booking: ok'
|
||||
if vector[-2] == 1 and vector[-1] == 0:
|
||||
report += 'booking: unable'
|
||||
|
||||
return report
|
||||
|
||||
def queryJsons(self,
|
||||
domain,
|
||||
constraints,
|
||||
exactly_match=True,
|
||||
return_name=False):
|
||||
"""Returns the list of entities for a given domain
|
||||
based on the annotation of the belief state
|
||||
constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'}
|
||||
"""
|
||||
# query the db
|
||||
if domain == 'taxi':
|
||||
return [{
|
||||
'taxi_colors':
|
||||
random.choice(self.dbs[domain]['taxi_colors']),
|
||||
'taxi_types':
|
||||
random.choice(self.dbs[domain]['taxi_types']),
|
||||
'taxi_phone': [random.randint(1, 9) for _ in range(10)]
|
||||
}]
|
||||
if domain == 'police':
|
||||
return self.dbs['police']
|
||||
if domain == 'hospital':
|
||||
if constraints.get('department'):
|
||||
for entry in self.dbs['hospital']:
|
||||
if entry.get('department') == constraints.get(
|
||||
'department'):
|
||||
return [entry]
|
||||
else:
|
||||
return []
|
||||
|
||||
valid_cons = False
|
||||
for v in constraints.values():
|
||||
if v not in ['not mentioned', '']:
|
||||
valid_cons = True
|
||||
if not valid_cons:
|
||||
return []
|
||||
|
||||
match_result = []
|
||||
|
||||
if 'name' in constraints:
|
||||
for db_ent in self.dbs[domain]:
|
||||
if 'name' in db_ent:
|
||||
cons = constraints['name']
|
||||
dbn = db_ent['name']
|
||||
if cons == dbn:
|
||||
db_ent = db_ent if not return_name else db_ent['name']
|
||||
match_result.append(db_ent)
|
||||
return match_result
|
||||
|
||||
for db_ent in self.dbs[domain]:
|
||||
match = True
|
||||
for s, v in constraints.items():
|
||||
if s == 'name':
|
||||
continue
|
||||
if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \
|
||||
(domain == 'restaurant' and s in ['day', 'time']):
|
||||
# 因为这些inform slot属于book info,而数据库中没有这些slot;
|
||||
# 能否book是根据user goal中的信息判断,而非通过数据库查询;
|
||||
continue
|
||||
|
||||
skip_case = {
|
||||
"don't care": 1,
|
||||
"do n't care": 1,
|
||||
'dont care': 1,
|
||||
'not mentioned': 1,
|
||||
'dontcare': 1,
|
||||
'': 1
|
||||
}
|
||||
if skip_case.get(v):
|
||||
continue
|
||||
|
||||
if s not in db_ent:
|
||||
# logging.warning('Searching warning: slot %s not in %s db'%(s, domain))
|
||||
match = False
|
||||
break
|
||||
|
||||
# v = 'guesthouse' if v == 'guest house' else v
|
||||
# v = 'swimmingpool' if v == 'swimming pool' else v
|
||||
v = 'yes' if v == 'free' else v
|
||||
|
||||
if s in ['arrive', 'leave']:
|
||||
try:
|
||||
h, m = v.split(
|
||||
':'
|
||||
) # raise error if time value is not xx:xx format
|
||||
v = int(h) * 60 + int(m)
|
||||
except Exception:
|
||||
match = False
|
||||
break
|
||||
time = int(db_ent[s].split(':')[0]) * 60 + int(
|
||||
db_ent[s].split(':')[1])
|
||||
if s == 'arrive' and v > time:
|
||||
match = False
|
||||
if s == 'leave' and v < time:
|
||||
match = False
|
||||
else:
|
||||
if exactly_match and v != db_ent[s]:
|
||||
match = False
|
||||
break
|
||||
elif v not in db_ent[s]:
|
||||
match = False
|
||||
break
|
||||
|
||||
if match:
|
||||
match_result.append(db_ent)
|
||||
|
||||
if not return_name:
|
||||
return match_result
|
||||
else:
|
||||
if domain == 'train':
|
||||
match_result = [e['id'] for e in match_result]
|
||||
else:
|
||||
match_result = [e['name'] for e in match_result]
|
||||
return match_result
|
||||
|
||||
def querySQL(self, domain, constraints):
|
||||
if not self.sql_dbs:
|
||||
for dom in db_domains:
|
||||
db = 'db/{}-dbase.db'.format(dom)
|
||||
conn = sqlite3.connect(db)
|
||||
c = conn.cursor()
|
||||
self.sql_dbs[dom] = c
|
||||
|
||||
sql_query = 'select * from {}'.format(domain)
|
||||
|
||||
flag = True
|
||||
for key, val in constraints.items():
|
||||
if val == '' \
|
||||
or val == 'dontcare' \
|
||||
or val == 'not mentioned' \
|
||||
or val == "don't care" \
|
||||
or val == 'dont care' \
|
||||
or val == "do n't care":
|
||||
pass
|
||||
else:
|
||||
if flag:
|
||||
sql_query += ' where '
|
||||
val2 = val.replace("'", "''")
|
||||
# val2 = normalize(val2)
|
||||
if key == 'leaveAt':
|
||||
sql_query += r' ' + key + ' > ' + r"'" + val2 + r"'"
|
||||
elif key == 'arriveBy':
|
||||
sql_query += r' ' + key + ' < ' + r"'" + val2 + r"'"
|
||||
else:
|
||||
sql_query += r' ' + key + '=' + r"'" + val2 + r"'"
|
||||
flag = False
|
||||
else:
|
||||
val2 = val.replace("'", "''")
|
||||
# val2 = normalize(val2)
|
||||
if key == 'leaveAt':
|
||||
sql_query += r' and ' + key + ' > ' + r"'" + val2 + r"'"
|
||||
elif key == 'arriveBy':
|
||||
sql_query += r' and ' + key + ' < ' + r"'" + val2 + r"'"
|
||||
else:
|
||||
sql_query += r' and ' + key + '=' + r"'" + val2 + r"'"
|
||||
|
||||
try: # "select * from attraction where name = 'queens college'"
|
||||
print(sql_query)
|
||||
return self.sql_dbs[domain].execute(sql_query).fetchall()
|
||||
except Exception:
|
||||
return [] # TODO test it
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dbPATHs = {
|
||||
'attraction': 'db/attraction_db_processed.json',
|
||||
'hospital': 'db/hospital_db_processed.json',
|
||||
'hotel': 'db/hotel_db_processed.json',
|
||||
'police': 'db/police_db_processed.json',
|
||||
'restaurant': 'db/restaurant_db_processed.json',
|
||||
'taxi': 'db/taxi_db_processed.json',
|
||||
'train': 'db/train_db_processed.json',
|
||||
}
|
||||
db = MultiWozDB(dbPATHs)
|
||||
while True:
|
||||
constraints = {}
|
||||
inp = input(
|
||||
'input belief state in fomat: domain-slot1=value1;slot2=value2...\n'
|
||||
)
|
||||
domain, cons = inp.split('-')
|
||||
for sv in cons.split(';'):
|
||||
s, v = sv.split('=')
|
||||
constraints[s] = v
|
||||
# res = db.querySQL(domain, constraints)
|
||||
res = db.queryJsons(domain, constraints, return_name=True)
|
||||
report = []
|
||||
reidx = {
|
||||
'hotel': 8,
|
||||
'restaurant': 6,
|
||||
'attraction': 5,
|
||||
'train': 1,
|
||||
}
|
||||
# for ent in res:
|
||||
# if reidx.get(domain):
|
||||
# report.append(ent[reidx[domain]])
|
||||
# for ent in res:
|
||||
# if 'name' in ent:
|
||||
# report.append(ent['name'])
|
||||
# if 'trainid' in ent:
|
||||
# report.append(ent['trainid'])
|
||||
print(constraints)
|
||||
print(res)
|
||||
print('count:', len(res), '\nnames:', report)
|
||||
217
modelscope/utils/nlp/space/ontology.py
Normal file
217
modelscope/utils/nlp/space/ontology.py
Normal file
@@ -0,0 +1,217 @@
|
||||
all_domains = [
|
||||
'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital'
|
||||
]
|
||||
all_domains_with_bracket = ['[{}]'.format(item) for item in all_domains]
|
||||
db_domains = ['restaurant', 'hotel', 'attraction', 'train']
|
||||
placeholder_tokens = [
|
||||
'<go_r>', '<go_b>', '<go_a>', '<go_d>', '<eos_u>', '<eos_r>', '<eos_b>',
|
||||
'<eos_a>', '<eos_d>', '<eos_q>', '<sos_u>', '<sos_r>', '<sos_b>',
|
||||
'<sos_a>', '<sos_d>', '<sos_q>'
|
||||
]
|
||||
|
||||
normlize_slot_names = {
|
||||
'car type': 'car',
|
||||
'entrance fee': 'price',
|
||||
'duration': 'time',
|
||||
'leaveat': 'leave',
|
||||
'arriveby': 'arrive',
|
||||
'trainid': 'id'
|
||||
}
|
||||
|
||||
requestable_slots = {
|
||||
'taxi': ['car', 'phone'],
|
||||
'police': ['postcode', 'address', 'phone'],
|
||||
'hospital': ['address', 'phone', 'postcode'],
|
||||
'hotel': [
|
||||
'address', 'postcode', 'internet', 'phone', 'parking', 'type',
|
||||
'pricerange', 'stars', 'area', 'reference'
|
||||
],
|
||||
'attraction':
|
||||
['price', 'type', 'address', 'postcode', 'phone', 'area', 'reference'],
|
||||
'train': ['time', 'leave', 'price', 'arrive', 'id', 'reference'],
|
||||
'restaurant': [
|
||||
'phone', 'postcode', 'address', 'pricerange', 'food', 'area',
|
||||
'reference'
|
||||
]
|
||||
}
|
||||
all_reqslot = [
|
||||
'car', 'address', 'postcode', 'phone', 'internet', 'parking', 'type',
|
||||
'pricerange', 'food', 'stars', 'area', 'reference', 'time', 'leave',
|
||||
'price', 'arrive', 'id'
|
||||
]
|
||||
|
||||
informable_slots = {
|
||||
'taxi': ['leave', 'destination', 'departure', 'arrive'],
|
||||
'police': [],
|
||||
'hospital': ['department'],
|
||||
'hotel': [
|
||||
'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people',
|
||||
'area', 'stars', 'name'
|
||||
],
|
||||
'attraction': ['area', 'type', 'name'],
|
||||
'train': ['destination', 'day', 'arrive', 'departure', 'people', 'leave'],
|
||||
'restaurant':
|
||||
['food', 'pricerange', 'area', 'name', 'time', 'day', 'people']
|
||||
}
|
||||
all_infslot = [
|
||||
'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people',
|
||||
'area', 'stars', 'name', 'leave', 'destination', 'departure', 'arrive',
|
||||
'department', 'food', 'time'
|
||||
]
|
||||
|
||||
all_slots = all_reqslot + [
|
||||
'stay', 'day', 'people', 'name', 'destination', 'departure', 'department'
|
||||
]
|
||||
get_slot = {}
|
||||
for s in all_slots:
|
||||
get_slot[s] = 1
|
||||
|
||||
# mapping slots in dialogue act to original goal slot names
|
||||
da_abbr_to_slot_name = {
|
||||
'addr': 'address',
|
||||
'fee': 'price',
|
||||
'post': 'postcode',
|
||||
'ref': 'reference',
|
||||
'ticket': 'price',
|
||||
'depart': 'departure',
|
||||
'dest': 'destination',
|
||||
}
|
||||
|
||||
dialog_acts = {
|
||||
'restaurant': [
|
||||
'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook',
|
||||
'offerbooked', 'nobook'
|
||||
],
|
||||
'hotel': [
|
||||
'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook',
|
||||
'offerbooked', 'nobook'
|
||||
],
|
||||
'attraction': ['inform', 'request', 'nooffer', 'recommend', 'select'],
|
||||
'train':
|
||||
['inform', 'request', 'nooffer', 'offerbook', 'offerbooked', 'select'],
|
||||
'taxi': ['inform', 'request'],
|
||||
'police': ['inform', 'request'],
|
||||
'hospital': ['inform', 'request'],
|
||||
# 'booking': ['book', 'inform', 'nobook', 'request'],
|
||||
'general': ['bye', 'greet', 'reqmore', 'welcome'],
|
||||
}
|
||||
all_acts = []
|
||||
for acts in dialog_acts.values():
|
||||
for act in acts:
|
||||
if act not in all_acts:
|
||||
all_acts.append(act)
|
||||
|
||||
dialog_act_params = {
|
||||
'inform': all_slots + ['choice', 'open'],
|
||||
'request': all_infslot + ['choice', 'price'],
|
||||
'nooffer': all_slots + ['choice'],
|
||||
'recommend': all_reqslot + ['choice', 'open'],
|
||||
'select': all_slots + ['choice'],
|
||||
# 'book': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'],
|
||||
'nobook': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'],
|
||||
'offerbook': all_slots + ['choice'],
|
||||
'offerbooked': all_slots + ['choice'],
|
||||
'reqmore': [],
|
||||
'welcome': [],
|
||||
'bye': [],
|
||||
'greet': [],
|
||||
}
|
||||
|
||||
dialog_act_all_slots = all_slots + ['choice', 'open']
|
||||
|
||||
# special slot tokens in belief span
|
||||
# no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange]
|
||||
slot_name_to_slot_token = {}
|
||||
|
||||
# special slot tokens in responses
|
||||
# not use at the momoent
|
||||
slot_name_to_value_token = {
|
||||
# 'entrance fee': '[value_price]',
|
||||
# 'pricerange': '[value_price]',
|
||||
# 'arriveby': '[value_time]',
|
||||
# 'leaveat': '[value_time]',
|
||||
# 'departure': '[value_place]',
|
||||
# 'destination': '[value_place]',
|
||||
# 'stay': 'count',
|
||||
# 'people': 'count'
|
||||
}
|
||||
|
||||
# eos tokens definition
|
||||
eos_tokens = {
|
||||
'user': '<eos_u>',
|
||||
'user_delex': '<eos_u>',
|
||||
'resp': '<eos_r>',
|
||||
'resp_gen': '<eos_r>',
|
||||
'pv_resp': '<eos_r>',
|
||||
'bspn': '<eos_b>',
|
||||
'bspn_gen': '<eos_b>',
|
||||
'pv_bspn': '<eos_b>',
|
||||
'bsdx': '<eos_b>',
|
||||
'bsdx_gen': '<eos_b>',
|
||||
'pv_bsdx': '<eos_b>',
|
||||
'qspn': '<eos_q>',
|
||||
'qspn_gen': '<eos_q>',
|
||||
'pv_qspn': '<eos_q>',
|
||||
'aspn': '<eos_a>',
|
||||
'aspn_gen': '<eos_a>',
|
||||
'pv_aspn': '<eos_a>',
|
||||
'dspn': '<eos_d>',
|
||||
'dspn_gen': '<eos_d>',
|
||||
'pv_dspn': '<eos_d>'
|
||||
}
|
||||
|
||||
# sos tokens definition
|
||||
sos_tokens = {
|
||||
'user': '<sos_u>',
|
||||
'user_delex': '<sos_u>',
|
||||
'resp': '<sos_r>',
|
||||
'resp_gen': '<sos_r>',
|
||||
'pv_resp': '<sos_r>',
|
||||
'bspn': '<sos_b>',
|
||||
'bspn_gen': '<sos_b>',
|
||||
'pv_bspn': '<sos_b>',
|
||||
'bsdx': '<sos_b>',
|
||||
'bsdx_gen': '<sos_b>',
|
||||
'pv_bsdx': '<sos_b>',
|
||||
'qspn': '<sos_q>',
|
||||
'qspn_gen': '<sos_q>',
|
||||
'pv_qspn': '<sos_q>',
|
||||
'aspn': '<sos_a>',
|
||||
'aspn_gen': '<sos_a>',
|
||||
'pv_aspn': '<sos_a>',
|
||||
'dspn': '<sos_d>',
|
||||
'dspn_gen': '<sos_d>',
|
||||
'pv_dspn': '<sos_d>'
|
||||
}
|
||||
|
||||
# db tokens definition
|
||||
db_tokens = [
|
||||
'<sos_db>', '<eos_db>', '[book_nores]', '[book_fail]', '[book_success]',
|
||||
'[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]'
|
||||
]
|
||||
|
||||
|
||||
# understand tokens definition
|
||||
def get_understand_tokens(prompt_num_for_understand):
|
||||
understand_tokens = []
|
||||
for i in range(prompt_num_for_understand):
|
||||
understand_tokens.append(f'<understand_{i}>')
|
||||
return understand_tokens
|
||||
|
||||
|
||||
# policy tokens definition
|
||||
def get_policy_tokens(prompt_num_for_policy):
|
||||
policy_tokens = []
|
||||
for i in range(prompt_num_for_policy):
|
||||
policy_tokens.append(f'<policy_{i}>')
|
||||
return policy_tokens
|
||||
|
||||
|
||||
# all special tokens definition
|
||||
def get_special_tokens(other_tokens):
|
||||
special_tokens = [
|
||||
'<go_r>', '<go_b>', '<go_a>', '<go_d>', '<eos_u>', '<eos_r>',
|
||||
'<eos_b>', '<eos_a>', '<eos_d>', '<eos_q>', '<sos_u>', '<sos_r>',
|
||||
'<sos_b>', '<sos_a>', '<sos_d>', '<sos_q>'
|
||||
] + db_tokens + other_tokens
|
||||
return special_tokens
|
||||
6
modelscope/utils/nlp/space/scores.py
Normal file
6
modelscope/utils/nlp/space/scores.py
Normal file
@@ -0,0 +1,6 @@
|
||||
def hierarchical_set_score(frame1, frame2):
|
||||
# deal with empty frame
|
||||
if not (frame1 and frame2):
|
||||
return 0.
|
||||
pass
|
||||
return 0.
|
||||
206
modelscope/utils/nlp/space/utils.py
Normal file
206
modelscope/utils/nlp/space/utils.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from . import ontology
|
||||
|
||||
|
||||
def max_lens(X):
|
||||
lens = [len(X)]
|
||||
while isinstance(X[0], list):
|
||||
lens.append(max(map(len, X)))
|
||||
X = [x for xs in X for x in xs]
|
||||
return lens
|
||||
|
||||
|
||||
def list2np(X: object, padding: object = 0, dtype: object = 'int64') -> object:
|
||||
shape = max_lens(X)
|
||||
ret = np.full(shape, padding, dtype=np.int32)
|
||||
|
||||
if len(shape) == 1:
|
||||
ret = np.array(X)
|
||||
elif len(shape) == 2:
|
||||
for i, x in enumerate(X):
|
||||
ret[i, :len(x)] = np.array(x)
|
||||
elif len(shape) == 3:
|
||||
for i, xs in enumerate(X):
|
||||
for j, x in enumerate(xs):
|
||||
ret[i, j, :len(x)] = np.array(x)
|
||||
return ret.astype(dtype)
|
||||
|
||||
|
||||
def clean_replace(s, r, t, forward=True, backward=False):
|
||||
|
||||
def clean_replace_single(s, r, t, forward, backward, sidx=0):
|
||||
# idx = s[sidx:].find(r)
|
||||
idx = s.find(r)
|
||||
if idx == -1:
|
||||
return s, -1
|
||||
idx_r = idx + len(r)
|
||||
if backward:
|
||||
while idx > 0 and s[idx - 1]:
|
||||
idx -= 1
|
||||
elif idx > 0 and s[idx - 1] != ' ':
|
||||
return s, -1
|
||||
|
||||
if forward:
|
||||
while \
|
||||
idx_r < len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()):
|
||||
idx_r += 1
|
||||
elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()):
|
||||
return s, -1
|
||||
return s[:idx] + t + s[idx_r:], idx_r
|
||||
|
||||
# source, replace, target = s, r, t
|
||||
# count = 0
|
||||
sidx = 0
|
||||
while sidx != -1:
|
||||
s, sidx = clean_replace_single(s, r, t, forward, backward, sidx)
|
||||
# count += 1
|
||||
# print(s, sidx)
|
||||
# if count == 20:
|
||||
# print(source, '\n', replace, '\n', target)
|
||||
# quit()
|
||||
return s
|
||||
|
||||
|
||||
def py2np(list):
|
||||
return np.array(list)
|
||||
|
||||
|
||||
def write_dict(fn, dic):
|
||||
with open(fn, 'w') as f:
|
||||
json.dump(dic, f, indent=2)
|
||||
|
||||
|
||||
def f1_score(label_list, pred_list):
|
||||
tp = len([t for t in pred_list if t in label_list])
|
||||
fp = max(0, len(pred_list) - tp)
|
||||
fn = max(0, len(label_list) - tp)
|
||||
precision = tp / (tp + fp + 1e-10)
|
||||
recall = tp / (tp + fn + 1e-10)
|
||||
f1 = 2 * precision * recall / (precision + recall + 1e-10)
|
||||
return f1
|
||||
|
||||
|
||||
class MultiWOZVocab(object):
|
||||
|
||||
def __init__(self, vocab_size=0):
|
||||
"""
|
||||
vocab for multiwoz dataset
|
||||
"""
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_size_oov = 0 # get after construction
|
||||
self._idx2word = {} # word + oov
|
||||
self._word2idx = {} # word
|
||||
self._freq_dict = {} # word + oov
|
||||
for w in [
|
||||
'[PAD]', '<go_r>', '[UNK]', '<go_b>', '<go_a>', '<eos_u>',
|
||||
'<eos_r>', '<eos_b>', '<eos_a>', '<go_d>', '<eos_d>'
|
||||
]:
|
||||
self._absolute_add_word(w)
|
||||
|
||||
def _absolute_add_word(self, w):
|
||||
idx = len(self._idx2word)
|
||||
self._idx2word[idx] = w
|
||||
self._word2idx[w] = idx
|
||||
|
||||
def add_word(self, word):
|
||||
if word not in self._freq_dict:
|
||||
self._freq_dict[word] = 0
|
||||
self._freq_dict[word] += 1
|
||||
|
||||
def has_word(self, word):
|
||||
return self._freq_dict.get(word)
|
||||
|
||||
def _add_to_vocab(self, word):
|
||||
if word not in self._word2idx:
|
||||
idx = len(self._idx2word)
|
||||
self._idx2word[idx] = word
|
||||
self._word2idx[word] = idx
|
||||
|
||||
def construct(self):
|
||||
freq_dict_sorted = sorted(
|
||||
self._freq_dict.keys(), key=lambda x: -self._freq_dict[x])
|
||||
print('Vocabulary size including oov: %d' %
|
||||
(len(freq_dict_sorted) + len(self._idx2word)))
|
||||
if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size:
|
||||
logging.warning(
|
||||
'actual label set smaller than that configured: {}/{}'.format(
|
||||
len(freq_dict_sorted) + len(self._idx2word),
|
||||
self.vocab_size))
|
||||
for word in ontology.all_domains + ['general']:
|
||||
word = '[' + word + ']'
|
||||
self._add_to_vocab(word)
|
||||
for word in ontology.all_acts:
|
||||
word = '[' + word + ']'
|
||||
self._add_to_vocab(word)
|
||||
for word in ontology.all_slots:
|
||||
self._add_to_vocab(word)
|
||||
for word in freq_dict_sorted:
|
||||
if word.startswith('[value_') and word.endswith(']'):
|
||||
self._add_to_vocab(word)
|
||||
for word in freq_dict_sorted:
|
||||
self._add_to_vocab(word)
|
||||
self.vocab_size_oov = len(self._idx2word)
|
||||
|
||||
def load_vocab(self, vocab_path):
|
||||
self._freq_dict = json.loads(
|
||||
open(vocab_path + '.freq.json', 'r').read())
|
||||
self._word2idx = json.loads(
|
||||
open(vocab_path + '.word2idx.json', 'r').read())
|
||||
self._idx2word = {}
|
||||
for w, idx in self._word2idx.items():
|
||||
self._idx2word[idx] = w
|
||||
self.vocab_size_oov = len(self._idx2word)
|
||||
print('vocab file loaded from "' + vocab_path + '"')
|
||||
print('Vocabulary size including oov: %d' % (self.vocab_size_oov))
|
||||
|
||||
def save_vocab(self, vocab_path):
|
||||
_freq_dict = OrderedDict(
|
||||
sorted(
|
||||
self._freq_dict.items(), key=lambda kv: kv[1], reverse=True))
|
||||
write_dict(vocab_path + '.word2idx.json', self._word2idx)
|
||||
write_dict(vocab_path + '.freq.json', _freq_dict)
|
||||
|
||||
def encode(self, word, include_oov=True):
|
||||
if include_oov:
|
||||
if self._word2idx.get(word, None) is None:
|
||||
raise ValueError(
|
||||
'Unknown word: %s. Vocabulary should include oovs here.'
|
||||
% word)
|
||||
return self._word2idx[word]
|
||||
else:
|
||||
word = '<unk>' if word not in self._word2idx else word
|
||||
return self._word2idx[word]
|
||||
|
||||
def sentence_encode(self, word_list):
|
||||
return [self.encode(_) for _ in word_list]
|
||||
|
||||
def oov_idx_map(self, idx):
|
||||
return 2 if idx > self.vocab_size else idx
|
||||
|
||||
def sentence_oov_map(self, index_list):
|
||||
return [self.oov_idx_map(_) for _ in index_list]
|
||||
|
||||
def decode(self, idx, indicate_oov=False):
|
||||
if not self._idx2word.get(idx):
|
||||
raise ValueError(
|
||||
'Error idx: %d. Vocabulary should include oovs here.' % idx)
|
||||
if not indicate_oov or idx < self.vocab_size:
|
||||
return self._idx2word[idx]
|
||||
else:
|
||||
return self._idx2word[idx] + '(o)'
|
||||
|
||||
# def sentence_decode(self, index_list, eos=None, indicate_oov=False):
|
||||
# l = [self.decode(_, indicate_oov) for _ in index_list]
|
||||
# if not eos or eos not in l:
|
||||
# return ' '.join(l)
|
||||
# else:
|
||||
# idx = l.index(eos)
|
||||
# return ' '.join(l[:idx])
|
||||
#
|
||||
# def nl_decode(self, l, eos=None):
|
||||
# return [self.sentence_decode(_, eos) + '\n' for _ in l]
|
||||
@@ -1 +1,4 @@
|
||||
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.3-py3-none-any.whl
|
||||
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz
|
||||
spacy>=2.3.5
|
||||
# python -m spacy download en_core_web_sm
|
||||
|
||||
0
tests/pipelines/nlp/__init__.py
Normal file
0
tests/pipelines/nlp/__init__.py
Normal file
60
tests/pipelines/nlp/test_dialog_intent_prediction.py
Normal file
60
tests/pipelines/nlp/test_dialog_intent_prediction.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import DialogIntentModel
|
||||
from modelscope.pipelines import DialogIntentPredictionPipeline, pipeline
|
||||
from modelscope.preprocessors import DialogIntentPredictionPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
class DialogIntentPredictionTest(unittest.TestCase):
|
||||
model_id = 'damo/nlp_space_dialog-intent-prediction'
|
||||
test_case = [
|
||||
'How do I locate my card?',
|
||||
'I still have not received my new card, I ordered over a week ago.'
|
||||
]
|
||||
|
||||
@unittest.skip('test with snapshot_download')
|
||||
def test_run(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path)
|
||||
model = DialogIntentModel(
|
||||
model_dir=cache_path,
|
||||
text_field=preprocessor.text_field,
|
||||
config=preprocessor.config)
|
||||
|
||||
pipelines = [
|
||||
DialogIntentPredictionPipeline(
|
||||
model=model, preprocessor=preprocessor),
|
||||
pipeline(
|
||||
task=Tasks.dialog_intent_prediction,
|
||||
model=model,
|
||||
preprocessor=preprocessor)
|
||||
]
|
||||
|
||||
for my_pipeline, item in list(zip(pipelines, self.test_case)):
|
||||
print(my_pipeline(item))
|
||||
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
preprocessor = DialogIntentPredictionPreprocessor(
|
||||
model_dir=model.model_dir)
|
||||
|
||||
pipelines = [
|
||||
DialogIntentPredictionPipeline(
|
||||
model=model, preprocessor=preprocessor),
|
||||
pipeline(
|
||||
task=Tasks.dialog_intent_prediction,
|
||||
model=model,
|
||||
preprocessor=preprocessor)
|
||||
]
|
||||
|
||||
for my_pipeline, item in list(zip(pipelines, self.test_case)):
|
||||
print(my_pipeline(item))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
149
tests/pipelines/nlp/test_dialog_modeling.py
Normal file
149
tests/pipelines/nlp/test_dialog_modeling.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import DialogModelingModel
|
||||
from modelscope.pipelines import DialogModelingPipeline, pipeline
|
||||
from modelscope.preprocessors import DialogModelingPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
class DialogModelingTest(unittest.TestCase):
|
||||
model_id = 'damo/nlp_space_dialog-modeling'
|
||||
test_case = {
|
||||
'sng0073': {
|
||||
'goal': {
|
||||
'taxi': {
|
||||
'info': {
|
||||
'leaveat': '17:15',
|
||||
'destination': 'pizza hut fen ditton',
|
||||
'departure': "saint john's college"
|
||||
},
|
||||
'reqt': ['car', 'phone'],
|
||||
'fail_info': {}
|
||||
}
|
||||
},
|
||||
'log': [{
|
||||
'user':
|
||||
"i would like a taxi from saint john 's college to pizza hut fen ditton .",
|
||||
'user_delex':
|
||||
'i would like a taxi from [value_departure] to [value_destination] .',
|
||||
'resp':
|
||||
'what time do you want to leave and what time do you want to arrive by ?',
|
||||
'sys':
|
||||
'what time do you want to leave and what time do you want to arrive by ?',
|
||||
'pointer': '0,0,0,0,0,0',
|
||||
'match': '',
|
||||
'constraint':
|
||||
"[taxi] destination pizza hut fen ditton departure saint john 's college",
|
||||
'cons_delex': '[taxi] destination departure',
|
||||
'sys_act': '[taxi] [request] leave arrive',
|
||||
'turn_num': 0,
|
||||
'turn_domain': '[taxi]'
|
||||
}, {
|
||||
'user': 'i want to leave after 17:15 .',
|
||||
'user_delex': 'i want to leave after [value_leave] .',
|
||||
'resp':
|
||||
'booking completed ! your taxi will be [value_car] contact number is [value_phone]',
|
||||
'sys':
|
||||
'booking completed ! your taxi will be blue honda contact number is 07218068540',
|
||||
'pointer': '0,0,0,0,0,0',
|
||||
'match': '',
|
||||
'constraint':
|
||||
"[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15",
|
||||
'cons_delex': '[taxi] destination departure leave',
|
||||
'sys_act': '[taxi] [inform] car phone',
|
||||
'turn_num': 1,
|
||||
'turn_domain': '[taxi]'
|
||||
}, {
|
||||
'user': 'thank you for all the help ! i appreciate it .',
|
||||
'user_delex': 'thank you for all the help ! i appreciate it .',
|
||||
'resp':
|
||||
'you are welcome . is there anything else i can help you with today ?',
|
||||
'sys':
|
||||
'you are welcome . is there anything else i can help you with today ?',
|
||||
'pointer': '0,0,0,0,0,0',
|
||||
'match': '',
|
||||
'constraint':
|
||||
"[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15",
|
||||
'cons_delex': '[taxi] destination departure leave',
|
||||
'sys_act': '[general] [reqmore]',
|
||||
'turn_num': 2,
|
||||
'turn_domain': '[general]'
|
||||
}, {
|
||||
'user': 'no , i am all set . have a nice day . bye .',
|
||||
'user_delex': 'no , i am all set . have a nice day . bye .',
|
||||
'resp': 'you too ! thank you',
|
||||
'sys': 'you too ! thank you',
|
||||
'pointer': '0,0,0,0,0,0',
|
||||
'match': '',
|
||||
'constraint':
|
||||
"[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15",
|
||||
'cons_delex': '[taxi] destination departure leave',
|
||||
'sys_act': '[general] [bye]',
|
||||
'turn_num': 3,
|
||||
'turn_domain': '[general]'
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
@unittest.skip('test with snapshot_download')
|
||||
def test_run(self):
|
||||
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
|
||||
preprocessor = DialogModelingPreprocessor(model_dir=cache_path)
|
||||
model = DialogModelingModel(
|
||||
model_dir=cache_path,
|
||||
text_field=preprocessor.text_field,
|
||||
config=preprocessor.config)
|
||||
pipelines = [
|
||||
DialogModelingPipeline(model=model, preprocessor=preprocessor),
|
||||
pipeline(
|
||||
task=Tasks.dialog_modeling,
|
||||
model=model,
|
||||
preprocessor=preprocessor)
|
||||
]
|
||||
|
||||
result = {}
|
||||
for step, item in enumerate(self.test_case['sng0073']['log']):
|
||||
user = item['user']
|
||||
print('user: {}'.format(user))
|
||||
|
||||
result = pipelines[step % 2]({
|
||||
'user_input': user,
|
||||
'history': result
|
||||
})
|
||||
print('sys : {}'.format(result['sys']))
|
||||
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir)
|
||||
|
||||
pipelines = [
|
||||
DialogModelingPipeline(model=model, preprocessor=preprocessor),
|
||||
pipeline(
|
||||
task=Tasks.dialog_modeling,
|
||||
model=model,
|
||||
preprocessor=preprocessor)
|
||||
]
|
||||
|
||||
result = {}
|
||||
for step, item in enumerate(self.test_case['sng0073']['log']):
|
||||
user = item['user']
|
||||
print('user: {}'.format(user))
|
||||
|
||||
result = pipelines[step % 2]({
|
||||
'user_input': user,
|
||||
'history': result
|
||||
})
|
||||
print('sys : {}'.format(result['sys']))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user