diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 1dd32621..ad343957 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -5,4 +5,6 @@ from .palm_for_text_generation import * # noqa F403 from .sbert_for_sentence_similarity import * # noqa F403 from .sbert_for_token_classification import * # noqa F403 from .sentiment_classification_model import * # noqa F403 +from .space.dialog_intent_prediction_model import * # noqa F403 +from .space.dialog_modeling_model import * # noqa F403 from .zero_shot_classification_model import * # noqa F403 diff --git a/modelscope/models/nlp/space/__init__.py b/modelscope/models/nlp/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/nlp/space/dialog_intent_prediction_model.py b/modelscope/models/nlp/space/dialog_intent_prediction_model.py new file mode 100644 index 00000000..3ea500e5 --- /dev/null +++ b/modelscope/models/nlp/space/dialog_intent_prediction_model.py @@ -0,0 +1,81 @@ +import os +from typing import Any, Dict + +from modelscope.preprocessors.space.fields.intent_field import \ + IntentBPETextField +from modelscope.trainers.nlp.space.trainers.intent_trainer import IntentTrainer +from modelscope.utils.config import Config +from modelscope.utils.constant import Tasks +from ...base import Model, Tensor +from ...builder import MODELS +from .model.generator import Generator +from .model.model_base import ModelBase + +__all__ = ['DialogIntentModel'] + + +@MODELS.register_module( + Tasks.dialog_intent_prediction, module_name=r'space-intent') +class DialogIntentModel(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the test generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + model_cls (Optional[Any], optional): model loader, if None, use the + default loader to load model weights, by default None. + """ + + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + self.config = kwargs.pop( + 'config', + Config.from_file( + os.path.join(self.model_dir, 'configuration.json'))) + self.text_field = kwargs.pop( + 'text_field', + IntentBPETextField(self.model_dir, config=self.config)) + + self.generator = Generator.create(self.config, reader=self.text_field) + self.model = ModelBase.create( + model_dir=model_dir, + config=self.config, + reader=self.text_field, + generator=self.generator) + + def to_tensor(array): + """ + numpy array -> tensor + """ + import torch + array = torch.tensor(array) + return array.cuda() if self.config.use_gpu else array + + self.trainer = IntentTrainer( + model=self.model, + to_tensor=to_tensor, + config=self.config, + reader=self.text_field) + self.trainer.load() + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + import numpy as np + pred = self.trainer.forward(input) + pred = np.squeeze(pred[0], 0) + + return {'pred': pred} diff --git a/modelscope/models/nlp/space/dialog_modeling_model.py b/modelscope/models/nlp/space/dialog_modeling_model.py new file mode 100644 index 00000000..bae8a822 --- /dev/null +++ b/modelscope/models/nlp/space/dialog_modeling_model.py @@ -0,0 +1,82 @@ +import os +from typing import Any, Dict, Optional + +from modelscope.preprocessors.space.fields.gen_field import \ + MultiWOZBPETextField +from modelscope.trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer +from modelscope.utils.config import Config +from modelscope.utils.constant import Tasks +from ...base import Model, Tensor +from ...builder import MODELS +from .model.generator import Generator +from .model.model_base import ModelBase + +__all__ = ['DialogModelingModel'] + + +@MODELS.register_module(Tasks.dialog_modeling, module_name=r'space-modeling') +class DialogModelingModel(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the test generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + model_cls (Optional[Any], optional): model loader, if None, use the + default loader to load model weights, by default None. + """ + + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + self.config = kwargs.pop( + 'config', + Config.from_file( + os.path.join(self.model_dir, 'configuration.json'))) + self.text_field = kwargs.pop( + 'text_field', + MultiWOZBPETextField(self.model_dir, config=self.config)) + self.generator = Generator.create(self.config, reader=self.text_field) + self.model = ModelBase.create( + model_dir=model_dir, + config=self.config, + reader=self.text_field, + generator=self.generator) + + def to_tensor(array): + """ + numpy array -> tensor + """ + import torch + array = torch.tensor(array) + return array.cuda() if self.config.use_gpu else array + + self.trainer = MultiWOZTrainer( + model=self.model, + to_tensor=to_tensor, + config=self.config, + reader=self.text_field, + evaluator=None) + self.trainer.load() + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + + turn = {'user': input['user']} + old_pv_turn = input['history'] + + pv_turn = self.trainer.forward(turn=turn, old_pv_turn=old_pv_turn) + + return pv_turn diff --git a/modelscope/models/nlp/space/model/__init__.py b/modelscope/models/nlp/space/model/__init__.py new file mode 100644 index 00000000..7e1b5264 --- /dev/null +++ b/modelscope/models/nlp/space/model/__init__.py @@ -0,0 +1,3 @@ +from .gen_unified_transformer import GenUnifiedTransformer +from .intent_unified_transformer import IntentUnifiedTransformer +from .unified_transformer import UnifiedTransformer diff --git a/modelscope/models/nlp/space/model/gen_unified_transformer.py b/modelscope/models/nlp/space/model/gen_unified_transformer.py new file mode 100644 index 00000000..c076cce4 --- /dev/null +++ b/modelscope/models/nlp/space/model/gen_unified_transformer.py @@ -0,0 +1,285 @@ +""" +IntentUnifiedTransformer +""" +import torch + +from modelscope.models.nlp.space.model.unified_transformer import \ + UnifiedTransformer + + +class GenUnifiedTransformer(UnifiedTransformer): + """ + Implement generation unified transformer. + """ + + def __init__(self, model_dir, config, reader, generator): + super(GenUnifiedTransformer, self).__init__(model_dir, config, reader, + generator) + self.understand = config.BPETextField.understand + + if self.use_gpu: + self.cuda() + return + + def _forward(self, inputs, is_training, with_label): + """ Real forward process of model in different mode(train/test). """ + + def cat(x, y, dim=1): + return torch.cat([x, y], dim=dim) + + outputs = {} + + if self.understand or self.policy: + if self.understand: + prompt_token = inputs['understand_token'] + prompt_mask = inputs['understand_mask'] + if self.policy: + prompt_token = cat(prompt_token, inputs['policy_token']) + prompt_mask = cat(prompt_mask, inputs['policy_mask']) + else: + prompt_token = inputs['policy_token'] + prompt_mask = inputs['policy_mask'] + + enc_embed, dec_embed, prompt_embed = self._encoder_prompt_decoder_network( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + tgt_token=inputs['tgt_token'][:, :-1], + tgt_mask=inputs['tgt_mask'][:, :-1], + prompt_token=prompt_token, + prompt_mask=prompt_mask, + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn'], + tgt_pos=inputs['tgt_pos'][:, :-1], + tgt_type=inputs['tgt_type'][:, :-1], + tgt_turn=inputs['tgt_turn'][:, :-1]) + else: + enc_embed, dec_embed = self._encoder_decoder_network( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + tgt_token=inputs['tgt_token'][:, :-1], + tgt_mask=inputs['tgt_mask'][:, :-1], + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn'], + tgt_pos=inputs['tgt_pos'][:, :-1], + tgt_type=inputs['tgt_type'][:, :-1], + tgt_turn=inputs['tgt_turn'][:, :-1]) + + outputs['dec_probs'] = self._dec_head(dec_embed=dec_embed) + return outputs + + def _collect_metrics(self, inputs, outputs, with_label, data_file): + + metrics = {} + loss = 0. + + label = inputs['tgt_token'][:, 1:] + token_num = torch.sum(torch.sum(inputs['tgt_mask'], dim=1) - 1) + nll = self.nll_loss( + torch.log(outputs['dec_probs'] + 1e-12).permute(0, 2, 1), label) + nll = torch.sum(nll, dim=1) + token_nll = torch.sum(nll) / token_num + nll = torch.mean(nll) + metrics['nll'] = nll + metrics['token_nll'] = token_nll + metrics['token_num'] = token_num + loss = loss + (token_nll if self.token_loss else nll) + + metrics['loss'] = loss + if self.gpu > 1: + return nll, token_nll, token_num + else: + return metrics + + def _optimize(self, loss, do_update=False, optimizer=None): + """ Optimize loss function and update model. """ + assert optimizer is not None + + if self.gradient_accumulation_steps > 1: + loss = loss / self.gradient_accumulation_steps + + loss.backward() + + if self.grad_clip is not None and self.grad_clip > 0: + torch.nn.utils.clip_grad_norm_( + parameters=self.parameters(), max_norm=self.grad_clip) + + if do_update: + optimizer.step() + optimizer.zero_grad() + + return + + def _init_state(self, + src_token, + src_mask, + src_pos=None, + src_type=None, + src_turn=None): + """ Initialize decode state. """ + state = {} + batch_size = src_token.shape[0] + + src_embed = self.embedder(src_token, src_pos, src_type, src_turn) + src_embed = self.embed_layer_norm(src_embed) + + mask = self._create_mask(src_mask, append_head=False) + + enc_out = src_embed + + cache = {} + for _l, layer in enumerate(self.layers): + cache[f'layer_{_l}'] = {} + enc_out = layer(enc_out, mask, cache[f'layer_{_l}']) + + state['cache'] = cache + state['mask'] = mask[:, :1] + state['batch_size'] = batch_size + shape = [batch_size, 1, 1] + state['pred_mask'] = torch.ones(shape, dtype=torch.float32) + state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) + state['pred_type'] = torch.zeros(shape, dtype=torch.int64) + state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) + if self.use_gpu: + state['pred_mask'] = state['pred_mask'].cuda() + state['pred_pos'] = state['pred_pos'].cuda() + state['pred_type'] = state['pred_type'].cuda() + state['pred_turn'] = state['pred_turn'].cuda() + + return state + + def _init_prompt_state(self, + src_token, + src_mask, + prompt_token, + prompt_mask, + src_pos=None, + src_type=None, + src_turn=None, + prompt_pos=None, + prompt_type=None, + prompt_turn=None): + """ Initialize decode state. """ + state = {} + batch_size = src_token.shape[0] + + src_embed = self.embedder(src_token, src_pos, src_type, src_turn) + prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, + prompt_turn) + embed = torch.cat([src_embed, prompt_embed], dim=1) + embed = self.embed_layer_norm(embed) + enc_out = embed + + enc_mask = self._create_mask(src_mask, auto_regressive=False) + dec_mask = self._create_mask(prompt_mask, auto_regressive=True) + mask = self._join_mask(enc_mask, dec_mask) + + cache = {} + for _l, layer in enumerate(self.layers): + cache[f'layer_{_l}'] = {} + enc_out = layer(enc_out, mask, cache[f'layer_{_l}']) + + state['cache'] = cache + state['mask'] = mask[:, -1:] # state["mask"] = mask[:, :1] + state['batch_size'] = batch_size + shape = [batch_size, 1, 1] + state['pred_mask'] = torch.ones(shape, dtype=torch.float32) + state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) + state['pred_type'] = torch.zeros(shape, dtype=torch.int64) + state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) + if self.use_gpu: + state['pred_mask'] = state['pred_mask'].cuda() + state['pred_pos'] = state['pred_pos'].cuda() + state['pred_type'] = state['pred_type'].cuda() + state['pred_turn'] = state['pred_turn'].cuda() + + return state + + def _decode(self, state): + """ Decoding one time stamp. """ + + # shape: [batch_size, 1, seq_len] + mask = state['mask'] + + # shape: [batch_size, 1, 1] + pred_token = state['pred_token'] + pred_mask = state['pred_mask'] + pred_pos = state['pred_pos'] + pred_type = state['pred_type'] + pred_turn = state['pred_turn'] + + # list of shape(len: num_layers): [batch_size, seq_len, hidden_dim] + cache = state['cache'] + + pred_embed = self.embedder(pred_token, pred_pos, pred_type, + pred_turn).squeeze(-2) + pred_embed = self.embed_layer_norm(pred_embed) + + # shape: [batch_size, 1, seq_len + 1] + mask = torch.cat([mask, 1 - pred_mask], dim=2) + + # shape: [batch_size, 1, hidden_dim] + for _l, layer in enumerate(self.layers): + pred_embed = layer(pred_embed, mask, cache[f'layer_{_l}']) + + # shape: [batch_size, vocab_size] + pred_probs = self._dec_head(dec_embed=pred_embed[:, 0]) + pred_logits = torch.log(pred_probs) + + state['mask'] = mask + return pred_logits, state + + def _infer(self, + inputs, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ Real inference process of model. """ + + def cat(x, y, dim=1): + return torch.cat([x, y], dim=dim) + + # Initial decode state. + if self.understand or self.policy: + if self.understand: + prompt_token = inputs['understand_token'] + prompt_mask = inputs['understand_mask'] + if self.policy: + prompt_token = cat(prompt_token, inputs['policy_token']) + prompt_mask = cat(prompt_mask, inputs['policy_mask']) + else: + prompt_token = inputs['policy_token'] + prompt_mask = inputs['policy_mask'] + + state = self._init_prompt_state( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + prompt_token=prompt_token, + prompt_mask=prompt_mask, + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn']) + else: + state = self._init_state( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn']) + + # Generation process. + gen_results = self.generator( + step_fn=self._decode, + state=state, + start_id=start_id, + eos_id=eos_id, + max_gen_len=max_gen_len, + prev_input=prev_input) + + outputs = gen_results['preds'] + return outputs + + +GenUnifiedTransformer.register('GenUnifiedTransformer') diff --git a/modelscope/models/nlp/space/model/generator.py b/modelscope/models/nlp/space/model/generator.py new file mode 100644 index 00000000..bdf6b135 --- /dev/null +++ b/modelscope/models/nlp/space/model/generator.py @@ -0,0 +1,290 @@ +""" +Generator class. +""" + +import math + +import numpy as np +import torch + + +def repeat(var, times): + if isinstance(var, list): + return [repeat(x, times) for x in var] + elif isinstance(var, dict): + return {k: repeat(v, times) for k, v in var.items()} + elif isinstance(var, torch.Tensor): + var = var.unsqueeze(1) + expand_times = [1] * len(var.shape) + expand_times[1] = times + dtype = var.dtype + var = var.float() + var = var.repeat(*expand_times) + shape = [var.shape[0] * var.shape[1]] + list(var.shape[2:]) + var = var.reshape(*shape) + var = torch.tensor(var, dtype=dtype) + return var + else: + return var + + +def gather(var, idx): + if isinstance(var, list): + return [gather(x, idx) for x in var] + elif isinstance(var, dict): + return {k: gather(v, idx) for k, v in var.items()} + elif isinstance(var, torch.Tensor): + out = var.index_select(dim=0, index=idx) + return out + else: + return var + + +class Generator(object): + """ Genrator class. """ + + _registry = dict() + + @classmethod + def register(cls, name): + Generator._registry[name] = cls + return + + @staticmethod + def by_name(name): + return Generator._registry[name] + + @staticmethod + def create(config, *args, **kwargs): + """ Create generator. """ + generator_cls = Generator.by_name(config.Generator.generator) + return generator_cls(config, *args, **kwargs) + + def __init__(self, config, reader): + self.vocab_size = reader.vocab_size + self.bos_id = reader.bos_id + self.eos_id = reader.eos_id + self.unk_id = reader.unk_id + self.pad_id = reader.pad_id + self.min_gen_len = config.Generator.min_gen_len + self.max_gen_len = config.Generator.max_gen_len + self.use_gpu = config.use_gpu + assert 1 <= self.min_gen_len <= self.max_gen_len + return + + def __call__(self, step_fn, state): + """ + Running generation. + + @param : step_fn : decoding one step + @type : function + + @param : state : initial state + @type : dict + """ + raise NotImplementedError + + +class BeamSearch(Generator): + """ BeamSearch generator. """ + + def __init__(self, config, reader): + super().__init__(config, reader) + self.beam_size = config.Generator.beam_size + self.length_average = config.Generator.length_average + self.length_penalty = config.Generator.length_penalty + self.ignore_unk = config.Generator.ignore_unk + return + + def __call__(self, + step_fn, + state, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ + Running beam search. + + @param : step_fn : decoding one step + @type : function + + @param : state : initial state + @type : dict + """ + if prev_input is not None: + + if isinstance(prev_input, list): + length = max(list(map(lambda x: len(x), prev_input))) + prev_input_numpy = np.full((len(prev_input), length), + self.pad_id) + for i, x in enumerate(prev_input): + prev_input_numpy[i, :len(x)] = x + prev_input_tensor = torch.from_numpy(prev_input_numpy) + if self.use_gpu: + prev_input_tensor = prev_input_tensor.cuda() + + for i in range(length): + state['pred_token'] = prev_input_tensor[:, i].unsqueeze( + -1).unsqueeze(-1) + if i != 0: + state['pred_mask'] = torch.not_equal( + state['pred_token'], self.pad_id).float() + state['pred_pos'] = state['pred_pos'] + state[ + 'pred_mask'].int() + _, state = step_fn(state) + else: + assert isinstance(prev_input, torch.Tensor) + for i, input in enumerate(prev_input): + state['pred_token'] = input.expand(1, 1, 1) + if i != 0: + state['pred_mask'] = torch.not_equal( + state['pred_token'], self.pad_id).float() + state['pred_pos'] = state['pred_pos'] + 1 + _, state = step_fn(state) + + batch_size = state['batch_size'] + beam_size = self.beam_size + + # shape: [batch_size, 1] + pos_index = torch.arange( + 0, batch_size, 1, dtype=torch.int64) * beam_size + pos_index = pos_index.unsqueeze(1) + + # shape: [batch_size, beam_size, 1] + if start_id is None: + start_id = self.bos_id + if eos_id is None: + eos_id = self.eos_id + predictions = torch.ones([batch_size, beam_size, 1], + dtype=torch.int64) * start_id + + if self.use_gpu: + pos_index = pos_index.cuda() + predictions = predictions.cuda() + + # initial input (start_id) + state['pred_token'] = predictions[:, :1] + if prev_input is not None: + state['pred_mask'] = torch.not_equal(state['pred_token'], + self.pad_id).float() + state['pred_pos'] = state['pred_pos'] + 1 + + # shape: [batch_size, vocab_size] + scores, state = step_fn(state) + + unk_penalty = np.zeros(self.vocab_size, dtype='float32') + unk_penalty[self.unk_id] = -1e10 + unk_penalty = torch.from_numpy(unk_penalty) + + eos_penalty = np.zeros(self.vocab_size, dtype='float32') + eos_penalty[eos_id] = -1e10 + eos_penalty = torch.from_numpy(eos_penalty) + + scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32') + scores_after_end[ + self.pad_id] = 0 # 希望之后只生成,故使词表中log(p())最高(0) + scores_after_end = torch.from_numpy(scores_after_end) + + if self.use_gpu: + unk_penalty = unk_penalty.cuda() + eos_penalty = eos_penalty.cuda() + scores_after_end = scores_after_end.cuda() + + if self.ignore_unk: + scores = scores + unk_penalty + scores = scores + eos_penalty + + # shape: [batch_size, beam_size] + sequence_scores, preds = torch.topk(scores, self.beam_size) + + predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) + state = repeat(state, beam_size) + + if max_gen_len is None: + max_gen_len = self.max_gen_len + for step in range(2, max_gen_len + 1): + pre_ids = predictions[:, :, -1:] + state['pred_token'] = pre_ids.reshape(batch_size * beam_size, 1, 1) + state['pred_mask'] = torch.not_equal(state['pred_token'], + self.pad_id).float() + state['pred_pos'] = state['pred_pos'] + 1 + scores, state = step_fn(state) + + # Generate next + # scores shape: [batch_size * beam_size, vocab_size] + if self.ignore_unk: + scores = scores + unk_penalty + + if step <= self.min_gen_len: + scores = scores + eos_penalty + + # scores shape: [batch_size, beam_size, vocab_size] + scores = scores.reshape(batch_size, beam_size, self.vocab_size) + + # previous token is [PAD] or [EOS] + pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ + (1 - torch.not_equal(pre_ids, self.pad_id).float()) + + scores = scores * (1 - pre_eos_mask) + pre_eos_mask.repeat( + 1, 1, self.vocab_size) * scores_after_end + if self.length_average: + scaled_value = \ + pre_eos_mask + (1 - pre_eos_mask) * (1 - 1 / step) + sequence_scores = sequence_scores.unsqueeze(2) * scaled_value + scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step) + scores = scores * scaled_value + elif self.length_penalty >= 0.0: + scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ + (math.pow((4 + step) / (5 + step), self.length_penalty)) + sequence_scores = scaled_value * sequence_scores + scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ + (math.pow(1 / (5 + step), self.length_penalty)) + scores = scores * scaled_value + scores = scores + sequence_scores.unsqueeze(-1) + scores = scores.reshape(batch_size, beam_size * self.vocab_size) + + topk_scores, topk_indices = torch.topk(scores, beam_size) + # topk_indices: [batch_size, beam_size * self.vocab_size] (已reshape) + # 判断当前时间步产生词的前一个词在哪个beam中,对vocab_size取商 + parent_idx = topk_indices.floor_divide(self.vocab_size) + # 对vocab_size取余 + preds = topk_indices % self.vocab_size + + # Gather state / sequence_scores + parent_idx = parent_idx + pos_index + parent_idx = parent_idx.reshape(batch_size * beam_size) + state = gather(state, parent_idx) + sequence_scores = topk_scores + + predictions = predictions.reshape(batch_size * beam_size, step) + predictions = gather(predictions, parent_idx) + predictions = predictions.reshape(batch_size, beam_size, step) + predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) + + # 希望生成的整个句子已完结,所以要求最后一个token为或者(跟在之后),否则惩罚 + pre_ids = predictions[:, :, -1] + pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ + (1 - torch.not_equal(pre_ids, self.pad_id).float()) + sequence_scores = sequence_scores * pre_eos_mask + ( + 1 - pre_eos_mask) * (-1e10) + + # 先获得ascending排序的index,便于之后对predictions和sequence_scores排序(针对beam size轴) + indices = torch.argsort(sequence_scores, dim=1) + indices = indices + pos_index + indices = indices.reshape(-1) + sequence_scores = sequence_scores.reshape(batch_size * beam_size) + predictions = predictions.reshape(batch_size * beam_size, -1) + sequence_scores = gather(sequence_scores, indices) + predictions = gather(predictions, indices) + sequence_scores = sequence_scores.reshape(batch_size, beam_size) + predictions = predictions.reshape(batch_size, beam_size, -1) + + results = { + 'preds': predictions[:, -1], + 'scores': sequence_scores[:, -1] + } + return results + + +BeamSearch.register('BeamSearch') diff --git a/modelscope/models/nlp/space/model/intent_unified_transformer.py b/modelscope/models/nlp/space/model/intent_unified_transformer.py new file mode 100644 index 00000000..646a8044 --- /dev/null +++ b/modelscope/models/nlp/space/model/intent_unified_transformer.py @@ -0,0 +1,198 @@ +""" +IntentUnifiedTransformer +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.utils.nlp.space.criterions import compute_kl_loss +from .unified_transformer import UnifiedTransformer + + +class IntentUnifiedTransformer(UnifiedTransformer): + """ + Implement intent unified transformer. + """ + + def __init__(self, model_dir, config, reader, generator): + super(IntentUnifiedTransformer, self).__init__(model_dir, config, + reader, generator) + self.example = config.Model.example + self.num_intent = config.Model.num_intent + self.with_rdrop = config.Model.with_rdrop + self.kl_ratio = config.Model.kl_ratio + self.loss_fct = nn.CrossEntropyLoss() + if self.example: + self.loss_fct = nn.NLLLoss() + else: + self.intent_classifier = nn.Linear(self.hidden_dim, + self.num_intent) + self.loss_fct = nn.CrossEntropyLoss() + + if self.use_gpu: + self.cuda() + return + + def _forward(self, inputs, is_training, with_label): + """ Real forward process of model in different mode(train/test). """ + + def aug(v): + assert isinstance(v, torch.Tensor) + return torch.cat([v, v], dim=0) + + outputs = {} + + if self.with_mlm: + mlm_embed = self._encoder_network( + input_token=inputs['mlm_token'], + input_mask=inputs['src_mask'], + input_pos=inputs['src_pos'], + input_type=inputs['src_type'], + input_turn=inputs['src_turn']) + outputs['mlm_probs'] = self._mlm_head(mlm_embed=mlm_embed) + + if self.with_rdrop or self.with_contrastive: + enc_embed, dec_embed = self._encoder_decoder_network( + src_token=aug(inputs['src_token']), + src_mask=aug(inputs['src_mask']), + tgt_token=aug(inputs['tgt_token']), + tgt_mask=aug(inputs['tgt_mask']), + src_pos=aug(inputs['src_pos']), + src_type=aug(inputs['src_type']), + src_turn=aug(inputs['src_turn'])) + else: + enc_embed, dec_embed = self._encoder_decoder_network( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + tgt_token=inputs['tgt_token'], + tgt_mask=inputs['tgt_mask'], + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn']) + features = dec_embed[:, -1] + features = self.pooler(features) if self.with_pool else features + + if self.example: + assert not self.with_rdrop + ex_enc_embed, ex_dec_embed = self._encoder_decoder_network( + src_token=inputs['example_src_token'], + src_mask=inputs['example_src_mask'], + tgt_token=inputs['example_tgt_token'], + tgt_mask=inputs['example_tgt_mask'], + src_pos=inputs['example_src_pos'], + src_type=inputs['example_src_type'], + src_turn=inputs['example_src_turn']) + ex_features = ex_dec_embed[:, -1] + ex_features = self.pooler( + ex_features) if self.with_pool else ex_features + + probs = self.softmax(features.mm(ex_features.t())) + example_intent = inputs['example_intent'].unsqueeze(0) + intent_probs = torch.zeros(probs.size(0), self.num_intent) + intent_probs = intent_probs.cuda( + ) if self.use_gpu else intent_probs + intent_probs = intent_probs.scatter_add( + -1, example_intent.repeat(probs.size(0), 1), probs) + outputs['intent_probs'] = intent_probs + else: + intent_logits = self.intent_classifier(features) + outputs['intent_logits'] = intent_logits + + if self.with_contrastive: + features = features if self.with_pool else self.pooler(features) + batch_size = features.size(0) // 2 + features = \ + torch.cat( + [features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)], + dim=1 + ) + features = F.normalize(features, dim=-1, p=2) + outputs['features'] = features + + return outputs + + def _collect_metrics(self, inputs, outputs, with_label, data_file): + + metrics = {} + batch_size = inputs['src_token'].size(0) + + intent_label = torch.cat([inputs['intent_label'], inputs['intent_label']], dim=0) \ + if self.with_rdrop or self.with_contrastive else inputs['intent_label'] + + if self.example: + intent_loss = self.loss_fct( + torch.log(outputs['intent_probs'] + 1e-12).view( + -1, self.num_intent), intent_label.type(torch.long)) + else: + intent_loss = self.loss_fct( + outputs['intent_logits'].view(-1, self.num_intent), + intent_label.type(torch.long)) + metrics['intent_loss'] = intent_loss + loss = intent_loss + + if self.with_mlm: + mlm_num = torch.sum(torch.sum(inputs['mlm_mask'], dim=1)) + mlm = self.nll_loss( + torch.log(outputs['mlm_probs'] + 1e-12).permute(0, 2, 1), + inputs['mlm_label']) + mlm = torch.sum(mlm, dim=1) + token_mlm = torch.sum(mlm) / mlm_num + mlm = torch.mean(mlm) + metrics['mlm'] = mlm + metrics['token_mlm'] = token_mlm + metrics['mlm_num'] = mlm_num + loss = loss + (token_mlm + if self.token_loss else mlm) * self.mlm_ratio + else: + mlm, token_mlm, mlm_num = None, None, None + + if self.with_rdrop: + kl = compute_kl_loss( + p=outputs['intent_logits'][:batch_size], + q=outputs['intent_logits'][batch_size:]) + metrics['kl'] = kl + loss = loss + kl * self.kl_ratio + else: + kl = None + + if self.with_contrastive: + pass + con = None + else: + con = None + + metrics['loss'] = loss + + if self.gpu > 1: + return intent_loss, mlm, token_mlm, mlm_num, kl, con + else: + return metrics + + def _infer(self, + inputs, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ Real inference process of model. """ + results = {} + enc_embed, dec_embed = self._encoder_decoder_network( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + tgt_token=inputs['tgt_token'], + tgt_mask=inputs['tgt_mask'], + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn']) + features = dec_embed[:, -1] + features = self.pooler(features) if self.with_pool else features + if self.example: + results['features'] = features + else: + intent_logits = self.intent_classifier(features) + intent_probs = self.softmax(intent_logits) + results['intent_probs'] = intent_probs + return results + + +IntentUnifiedTransformer.register('IntentUnifiedTransformer') diff --git a/modelscope/models/nlp/space/model/model_base.py b/modelscope/models/nlp/space/model/model_base.py new file mode 100644 index 00000000..cdd355a5 --- /dev/null +++ b/modelscope/models/nlp/space/model/model_base.py @@ -0,0 +1,99 @@ +""" +Model base +""" +import os + +import torch.nn as nn + + +class ModelBase(nn.Module): + """ + Basic model wrapper for static graph and dygrpah. + """ + _registry = dict() + + @classmethod + def register(cls, name): + ModelBase._registry[name] = cls + return + + @staticmethod + def by_name(name): + return ModelBase._registry[name] + + @staticmethod + def create(model_dir, config, *args, **kwargs): + model_cls = ModelBase.by_name(config.Model.model) + return model_cls(model_dir, config, *args, **kwargs) + + def __init__(self, model_dir, config): + super(ModelBase, self).__init__() + self.init_checkpoint = os.path.join(model_dir, 'pytorch_model.bin') + self.abandon_label = config.Dataset.abandon_label + self.use_gpu = config.use_gpu + self.gpu = config.Trainer.gpu + return + + def _create_parameters(self): + """ Create model's paramters. """ + raise NotImplementedError + + def _forward(self, inputs, is_training, with_label): + """ NO LABEL: Real forward process of model in different mode(train/test). """ + raise NotImplementedError + + def _collect_metrics(self, inputs, outputs, with_label, data_file): + """ NO LABEL: Calculate loss function by using inputs and outputs. """ + raise NotImplementedError + + def _optimize(self, loss, optimizer, lr_scheduler): + """ Optimize loss function and update model. """ + raise NotImplementedError + + def _infer(self, inputs, start_id, eos_id, max_gen_len, prev_input): + """ Real inference process of model. """ + raise NotImplementedError + + def forward(self, + inputs, + is_training=False, + with_label=False, + data_file=None): + """ + Forward process, include real forward, collect metrices and optimize(optional) + + @params : inputs : input data + @type : dict of numpy.ndarray/int/float/... + """ + if is_training: + self.train() + else: + self.eval() + + with_label = False if self.abandon_label else with_label + outputs = self._forward(inputs, is_training, with_label=with_label) + metrics = self._collect_metrics( + inputs, outputs, with_label=with_label, data_file=data_file) + + return metrics + + def infer(self, + inputs, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ + Inference process. + + @params : inputs : input data + @type : dict of numpy.ndarray/int/float/... + """ + self.eval() + results = self._infer( + inputs, + start_id=start_id, + eos_id=eos_id, + max_gen_len=max_gen_len, + prev_input=prev_input) + return results diff --git a/modelscope/models/nlp/space/model/unified_transformer.py b/modelscope/models/nlp/space/model/unified_transformer.py new file mode 100644 index 00000000..a25bc7f4 --- /dev/null +++ b/modelscope/models/nlp/space/model/unified_transformer.py @@ -0,0 +1,322 @@ +""" +UnifiedTransformer +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.nlp.space.model.model_base import ModelBase +from modelscope.models.nlp.space.modules.embedder import Embedder +from modelscope.models.nlp.space.modules.transformer_block import \ + TransformerBlock + + +class UnifiedTransformer(ModelBase): + """ + Implement unified transformer. + """ + + def __init__(self, model_dir, config, reader, generator, dtype='float32'): + super(UnifiedTransformer, self).__init__(model_dir, config) + self.reader = reader + self.generator = generator + self.policy = config.BPETextField.policy + self.generation = config.BPETextField.generation + self.num_token_embeddings = config.Model.num_token_embeddings + self.num_pos_embeddings = config.Model.num_pos_embeddings + self.num_type_embeddings = config.Model.num_type_embeddings + self.num_turn_embeddings = config.Model.num_turn_embeddings + self.temperature = config.Model.temperature + self.hidden_dim = config.Model.hidden_dim + self.num_heads = config.Model.num_heads + self.num_layers = config.Model.num_layers + self.padding_idx = config.Model.padding_idx + self.dropout = config.Model.dropout + self.embed_dropout = config.Model.embed_dropout + self.attn_dropout = config.Model.attn_dropout + self.ff_dropout = config.Model.ff_dropout + self.mlm_ratio = config.Model.mlm_ratio + self.mmd_ratio = config.Model.mmd_ratio + self.pos_trainable = config.Model.pos_trainable + self.label_smooth = config.Model.label_smooth + self.initializer_range = config.Model.initializer_range + self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps + self.token_loss = config.Trainer.token_loss + self.learning_method = config.Dataset.learning_method + self.with_contrastive = config.Dataset.with_contrastive + self.with_query_bow = config.BPETextField.with_query_bow + self.with_resp_bow = config.BPETextField.with_resp_bow + self.with_pool = config.Model.with_pool + self.with_mlm = config.Dataset.with_mlm + self._dtype = dtype + + self.embedder = Embedder( + self.hidden_dim, + self.num_token_embeddings, + self.num_pos_embeddings, + self.num_type_embeddings, + self.num_turn_embeddings, + padding_idx=self.padding_idx, + dropout=self.embed_dropout, + pos_trainable=self.pos_trainable) + self.embed_layer_norm = nn.LayerNorm( + normalized_shape=self.hidden_dim, + eps=1e-12, + elementwise_affine=True) + + self.layers = nn.ModuleList([ + TransformerBlock(self.hidden_dim, self.num_heads, self.dropout, + self.attn_dropout, self.ff_dropout) + for _ in range(config.Model.num_layers) + ]) + + if self.with_mlm: + self.mlm_transform = nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), nn.GELU(), + nn.LayerNorm( + normalized_shape=self.hidden_dim, + eps=1e-12, + elementwise_affine=True)) + self.mlm_bias = nn.Parameter( + torch.zeros(self.num_token_embeddings)) + + self.pooler = nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), nn.Tanh()) + + if self.with_query_bow or self.with_resp_bow: + self.bow_predictor = nn.Linear( + self.hidden_dim, self.num_token_embeddings, bias=False) + + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(dim=-1) + self.bce_loss = nn.BCELoss(reduction='none') + self.nll_loss = nn.NLLLoss( + ignore_index=self.padding_idx, reduction='none') + self._create_parameters() + + self.max_grad_norm = config.Model.max_grad_norm + if self.max_grad_norm is not None: + self.grad_clip = self.max_grad_norm + else: + self.grad_clip = None + self.weight_decay = config.Model.weight_decay + + if self.use_gpu: + self.cuda() + + return + + def _create_parameters(self): + """ Create model's paramters. """ + sequence_mask = np.tri( + self.num_pos_embeddings, + self.num_pos_embeddings, + dtype=self._dtype) + self.sequence_mask = torch.tensor(sequence_mask) + return + + def _create_mask(self, + input_mask, + append_head=False, + auto_regressive=False): + """ + Create attention mask. + 创建从序列形式到矩阵形式的mask:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] + mask除了要考虑attention mask(自回归),还需要考虑pad的mask(自回归和双向) + 注: + 1. 一个句子中的非词看整个句子,该句中只有词才被mask + 2. 一个句子中的词看整个句子,该句的所有词都应该被mask + + @param : input_mask + @type : Variable(shape: [batch_size, max_seq_len]) + + @param : auto_regressive + @type : bool + """ + seq_len = input_mask.shape[1] + + input_mask = input_mask.float() + mask1 = input_mask.unsqueeze(-1).repeat(1, 1, seq_len) + mask2 = mask1.permute(0, 2, 1) + mask = mask1 * mask2 + + if append_head: + # 拼接上句首位置([M]/z)的mask + mask = torch.cat([mask[:, :1, :], mask], dim=1) + mask = torch.cat([mask[:, :, :1], mask], dim=2) + seq_len += 1 + + if auto_regressive: + # 将tgt端的 mask和自回归attention mask融合 + seq_mask = self.sequence_mask[:seq_len, :seq_len] + seq_mask = seq_mask.to(mask.device) + mask = mask * seq_mask + + mask = 1 - mask + return mask + + def _join_mask(self, mask1, mask2): + """ + Merge source attention mask and target attention mask. + 合并后的整个mask矩阵可以分为四个部分:左上lu/右上ru/左下lb/右下rb + + @param : mask1 : source attention mask + @type : Variable(shape: [batch_size, max_src_len, max_src_len]) + + @param : mask1 : target attention mask + @type : Variable(shape: [batch_size, max_tgt_len, max_tgt_len]) + """ + batch_size = mask1.shape[0] + seq_len1 = mask1.shape[1] + seq_len2 = mask2.shape[1] + # seq_len = seq_len1 + seq_len2 + + mask_lu = mask1 + mask_ru = torch.ones(batch_size, seq_len1, seq_len2) + if self.use_gpu: + mask_ru = mask_ru.cuda() + mask3 = mask2[:, :, :1].repeat(1, 1, seq_len1) + mask4 = mask1[:, :1].repeat(1, seq_len2, 1) + mask_lb = mask3 + mask4 - mask3 * mask4 + mask_rb = mask2 + mask_u = torch.cat([mask_lu, mask_ru], dim=2) + mask_b = torch.cat([mask_lb, mask_rb], dim=2) + mask = torch.cat([mask_u, mask_b], dim=1) + return mask + + def _mlm_head(self, mlm_embed): + mlm_embed = self.mlm_transform(mlm_embed) + mlm_logits = torch.matmul( + mlm_embed, self.embedder.token_embedding.weight.T) + self.mlm_bias + mlm_probs = self.softmax(mlm_logits) + return mlm_probs + + def _dec_head(self, dec_embed): + dec_logits = torch.matmul(dec_embed, + self.embedder.token_embedding.weight.T) + dec_probs = self.softmax(dec_logits) + return dec_probs + + def _refactor_feature(self, features): + features = self.pooler(features) if self.with_pool else features + batch_size = features.size(0) // 2 + features = \ + torch.cat( + [features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)], + dim=1 + ) + features = F.normalize(features, dim=-1, p=2) + return features + + def _encoder_network(self, + input_token, + input_mask, + input_pos=None, + input_type=None, + input_turn=None): + embed = self.embedder(input_token, input_pos, input_type, input_turn) + embed = self.embed_layer_norm(embed) + mask = self._create_mask(input_mask, auto_regressive=False) + + for layer in self.layers: + embed = layer(embed, mask, None) + + return embed + + def _encoder_decoder_network(self, + src_token, + src_mask, + tgt_token, + tgt_mask, + src_pos=None, + src_type=None, + src_turn=None, + tgt_pos=None, + tgt_type=None, + tgt_turn=None): + src_embed = self.embedder(src_token, src_pos, src_type, src_turn) + tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) + embed = torch.cat([src_embed, tgt_embed], dim=1) + embed = self.embed_layer_norm(embed) + + enc_mask = self._create_mask(src_mask, auto_regressive=False) + dec_mask = self._create_mask(tgt_mask, auto_regressive=True) + mask = self._join_mask(enc_mask, dec_mask) + + for layer in self.layers: + embed = layer(embed, mask, None) + + tgt_len = tgt_token.shape[1] + enc_embed = embed[:, :-tgt_len] + dec_embed = embed[:, -tgt_len:] + + return enc_embed, dec_embed + + def _encoder_prompt_decoder_network(self, + src_token, + src_mask, + tgt_token, + tgt_mask, + prompt_token, + prompt_mask, + src_pos=None, + src_type=None, + src_turn=None, + tgt_pos=None, + tgt_type=None, + tgt_turn=None, + prompt_pos=None, + prompt_type=None, + prompt_turn=None): + src_embed = self.embedder(src_token, src_pos, src_type, src_turn) + tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) + prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, + prompt_turn) + + embed = torch.cat([src_embed, prompt_embed, tgt_embed], dim=1) + embed = self.embed_layer_norm(embed) + + enc_mask = self._create_mask(src_mask, auto_regressive=False) + dec_mask = self._create_mask( + torch.cat([prompt_mask, tgt_mask], dim=1), auto_regressive=True) + mask = self._join_mask(enc_mask, dec_mask) + + for layer in self.layers: + embed = layer(embed, mask, None) + + src_len = src_token.shape[1] + tgt_len = tgt_token.shape[1] + enc_embed = embed[:, :src_len] + dec_embed = embed[:, -tgt_len:] + prompt_embed = embed[:, src_len:-tgt_len] + + return enc_embed, dec_embed, prompt_embed + + def _optimize(self, loss, optimizer=None, lr_scheduler=None): + """ Optimize loss function and update model. """ + assert optimizer is not None + optimizer.zero_grad() + loss.backward() + + if self.grad_clip is not None and self.grad_clip > 0: + torch.nn.utils.clip_grad_norm_( + parameters=self.parameters(), max_norm=self.grad_clip) + optimizer.step() + if lr_scheduler is not None: + lr_scheduler.step() + return + + def _infer(self, + inputs, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ Real inference process of model. """ + results = {} + return results + + +UnifiedTransformer.register('UnifiedTransformer') diff --git a/modelscope/models/nlp/space/modules/__init__.py b/modelscope/models/nlp/space/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/nlp/space/modules/embedder.py b/modelscope/models/nlp/space/modules/embedder.py new file mode 100644 index 00000000..4fb592ef --- /dev/null +++ b/modelscope/models/nlp/space/modules/embedder.py @@ -0,0 +1,67 @@ +""" +Embedder class. +""" + +import torch +import torch.nn as nn + + +class Embedder(nn.Module): + """ + Composite embedding layer. + """ + + def __init__(self, + hidden_dim, + num_token_embeddings, + num_pos_embeddings, + num_type_embeddings, + num_turn_embeddings, + padding_idx=None, + dropout=0.1, + pos_trainable=False): + super(Embedder, self).__init__() + + self.token_embedding = nn.Embedding(num_token_embeddings, hidden_dim) + self.pos_embedding = nn.Embedding(num_pos_embeddings, hidden_dim) + self.pos_embedding.weight.requires_grad = pos_trainable + self.type_embedding = nn.Embedding(num_type_embeddings, hidden_dim) + self.turn_embedding = nn.Embedding(num_turn_embeddings, hidden_dim) + self.dropout_layer = nn.Dropout(p=dropout) + + # follow the default xavier_uniform initializer in paddle version + # otherwise, there are bugs for dec_probs computation in weight typing setting + # default norm initializer in nn.Embedding in pytorch, which samples larger values + nn.init.xavier_uniform_(self.token_embedding.weight) + nn.init.xavier_uniform_(self.pos_embedding.weight) + nn.init.xavier_uniform_(self.type_embedding.weight) + nn.init.xavier_uniform_(self.turn_embedding.weight) + return + + def forward(self, token_inp, pos_inp=None, type_inp=None, turn_inp=None): + embed = self.token_embedding(token_inp) + if pos_inp is not None: + embed += self.pos_embedding(pos_inp) + if type_inp is not None: + embed += self.type_embedding(type_inp) + if turn_inp is not None: + embed += self.turn_embedding(turn_inp) + embed = self.dropout_layer(embed) + return embed + + +def main(): + import numpy as np + + model = Embedder(10, 20, 20, 20, 20) + token_inp = torch.tensor( + np.random.randint(0, 19, [10, 10]).astype('int64')) + pos_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) + type_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) + turn_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) + out = model(token_inp, pos_inp, type_inp, turn_inp) + print(out) + + +if __name__ == '__main__': + main() diff --git a/modelscope/models/nlp/space/modules/feedforward.py b/modelscope/models/nlp/space/modules/feedforward.py new file mode 100644 index 00000000..e9a5f4c7 --- /dev/null +++ b/modelscope/models/nlp/space/modules/feedforward.py @@ -0,0 +1,43 @@ +""" +FeedForward class. +""" + +import torch +import torch.nn as nn + + +class FeedForward(nn.Module): + """ + Positional feed forward layer. + """ + + def __init__(self, hidden_dim, inner_dim, dropout): + super(FeedForward, self).__init__() + + self.hidden_dim = hidden_dim + self.inner_dim = inner_dim + self.linear_hidden = nn.Sequential( + nn.Linear(hidden_dim, inner_dim), nn.GELU()) + self.linear_out = nn.Linear(inner_dim, hidden_dim) + self.dropout_layer = nn.Dropout(p=dropout) + return + + def forward(self, x): + out = self.linear_hidden(x) + out = self.dropout_layer(out) + out = self.linear_out(out) + return out + + +def main(): + import numpy as np + + model = FeedForward(10, 20, 0.5) + inp = np.random.rand(2, 3, 10).astype('float32') + inp = torch.tensor(inp) + out = model(inp) + print(out) + + +if __name__ == '__main__': + main() diff --git a/modelscope/models/nlp/space/modules/functions.py b/modelscope/models/nlp/space/modules/functions.py new file mode 100644 index 00000000..45c02e21 --- /dev/null +++ b/modelscope/models/nlp/space/modules/functions.py @@ -0,0 +1,64 @@ +""" +Helpful functions. +""" + +import numpy as np +import torch +import torch.nn.functional as F + + +def unsqueeze(input, dims): + """ Implement multi-dimension unsqueeze function. """ + if isinstance(dims, (list, tuple)): + dims = [ + dim if dim >= 0 else dim + len(input.shape) + 1 for dim in dims + ] + dims = sorted(dims, reverse=True) + shape = list(input.shape) + for dim in dims: + shape.insert(dim, 1) + return torch.reshape(input, shape) + elif isinstance(dims, int): + return input.unsqueeze(dims) + else: + raise ValueError('Warning: type(dims) must in (list, tuple, int)!') + + +def gumbel_softmax(input, tau=1, eps=1e-10): + """ Basic implement of gumbel_softmax. """ + U = torch.tensor(np.random.rand(*input.shape)) + gumbel = 0.0 - torch.log(eps - torch.log(U + eps)) + y = input + gumbel + return F.softmax(y / tau) + + +def equal(x, y, dtype=None): + """ Implement equal in dygraph mode. (paddle) """ + if dtype is None: + dtype = 'float32' + if isinstance(x, torch.Tensor): + x = x.numpy() + if isinstance(y, torch.Tensor): + y = y.numpy() + out = np.equal(x, y).astype(dtype) + return torch.tensor(out) + + +def not_equal(x, y, dtype=None): + """ Implement not_equal in dygraph mode. (paddle) """ + return 1 - equal(x, y, dtype) + + +if __name__ == '__main__': + a = torch.tensor([[1, 1], [3, 4]]) + b = torch.tensor([[1, 1], [3, 4]]) + c = torch.equal(a, a) + c1 = equal(a, 3) + d = 1 - torch.not_equal(a, 3).float() + print(c) + print(c1) + print(d) + e = F.gumbel_softmax(a) + f = a.unsqueeze(a) + g = unsqueeze(a, dims=[0, 0, 1]) + print(g, g.shape) diff --git a/modelscope/models/nlp/space/modules/multihead_attention.py b/modelscope/models/nlp/space/modules/multihead_attention.py new file mode 100644 index 00000000..209eab5e --- /dev/null +++ b/modelscope/models/nlp/space/modules/multihead_attention.py @@ -0,0 +1,109 @@ +""" +MultiheadAttention class. +""" + +import torch +import torch.nn as nn + + +class MultiheadAttention(nn.Module): + """ + Multi head attention layer. + """ + + def __init__(self, hidden_dim, num_heads, dropout): + assert hidden_dim % num_heads == 0 + super(MultiheadAttention, self).__init__() + + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + self.scale = self.head_dim**-0.5 + self.linear_qkv = nn.Linear(hidden_dim, hidden_dim * 3) + self.linear_out = nn.Linear(hidden_dim, hidden_dim) + self.dropout_layer = nn.Dropout(p=dropout) + self.softmax = nn.Softmax(dim=-1) + return + + def _split_heads(self, x, is_key=False): + x = x.reshape(x.size(0), x.size(1), self.num_heads, self.head_dim) + x = x.permute(0, 2, 3, 1) if is_key else x.permute(0, 2, 1, 3) + return x + + def _merge_heads(self, x): + x = x.permute(0, 2, 1, 3) + x = x.reshape(x.size(0), x.size(1), self.hidden_dim) + return x + + def _attn(self, query, key, value, mask): + # shape: [batch_size, num_head, seq_len, seq_len] + scores = torch.matmul(query, key) + scores = scores * self.scale + + if mask is not None: + mask = mask.unsqueeze(1) + mask = mask.repeat(1, self.num_heads, 1, 1) + scores.masked_fill_( + mask.bool(), + float('-inf')) # scores = (1 - mask) * scores + mask * (-1e10) + + attn = self.softmax(scores) + attn = self.dropout_layer(attn) + + if mask is not None: + ''' + mask: [batch size, num_heads, seq_len, seq_len] + mask后两维(seq_len, seq_len)矩阵来看,其中有的行可能都是true(1),对应句子中位看的行 + 导致softmax后该行的每个位置的attn prob都为1/n而非0,所以此处需重置为0 + + >>> F.softmax([-1e10, -100, -100]) + >>> [0.00, 0.50, 0.50] + >>> F.softmax([-1e10, -1e10, -1e10]) + >>> [0.33, 0.33, 0.33] + ==> [0.00, 0.00, 0.00] + ''' + attn.masked_fill_(mask.bool(), 0.) # attn = (1 - mask) * attn + + out = torch.matmul(attn, value) + return out + + def forward(self, inp, mask=None, cache=None): + """ Forward process of self attention. """ + # shape: [batch_size, seq_len, 3 * hidden_dim] + qkv = self.linear_qkv(inp) + query, key, value = torch.split(qkv, self.hidden_dim, dim=2) + + # shape: [batch_size, num_head, seq_len, head_dim] + query = self._split_heads(query) + # shape: [batch_size, num_head, head_dim, seq_len] + key = self._split_heads(key, is_key=True) + # shape: [batch_size, num_head, seq_len, head_dim] + value = self._split_heads(value) + + if cache is not None: + if 'key' in cache and 'value' in cache: + key = torch.cat([cache['key'], key], dim=3) + value = torch.cat([cache['value'], value], dim=2) + cache['key'] = key + cache['value'] = value + + out = self._attn(query, key, value, mask) + out = self._merge_heads(out) + out = self.linear_out(out) + return out + + +def main(): + import numpy as np + + model = MultiheadAttention(10, 2, 0.5) + inp = np.random.rand(2, 3, 10).astype('float32') + inp = torch.tensor(inp) + mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') + mask = torch.tensor(mask) + out = model(inp, mask=mask, cache=None) + print(out) + + +if __name__ == '__main__': + main() diff --git a/modelscope/models/nlp/space/modules/transformer_block.py b/modelscope/models/nlp/space/modules/transformer_block.py new file mode 100644 index 00000000..1a0565d6 --- /dev/null +++ b/modelscope/models/nlp/space/modules/transformer_block.py @@ -0,0 +1,73 @@ +""" +TransformerBlock class. +""" + +import torch +import torch.nn as nn + +from modelscope.models.nlp.space.modules.feedforward import FeedForward +from modelscope.models.nlp.space.modules.multihead_attention import \ + MultiheadAttention + + +class TransformerBlock(nn.Module): + """ + Transformer block module. + """ + + def __init__(self, hidden_dim, num_heads, dropout, attn_dropout, + ff_dropout): + super(TransformerBlock, self).__init__() + + self.attn = MultiheadAttention( + hidden_dim=hidden_dim, num_heads=num_heads, dropout=attn_dropout) + self.attn_norm = nn.LayerNorm( + normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) + self.ff = FeedForward( + hidden_dim=hidden_dim, + inner_dim=4 * hidden_dim, + dropout=ff_dropout) + self.ff_norm = nn.LayerNorm( + normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) + self.dropout_layer = nn.Dropout(p=dropout) + return + + def forward(self, inp, mask=None, cache=None): + """ + Forward process on one transformer layer. + + @param : x + @type : Variable(shape: [batch_size, seq_len, hidden_size]) + + @param : memory + @type : Variable(shape: [batch_size, seq_len, hidden_size]) + + @param : mask + + @param : cache + """ + attn_out = self.attn(inp, mask, cache) + attn_out = self.dropout_layer(attn_out) + attn_out = self.attn_norm(attn_out + inp) + + ff_out = self.ff(attn_out) + ff_out = self.dropout_layer(ff_out) + ff_out = self.ff_norm(ff_out + attn_out) + + return ff_out + + +def main(): + import numpy as np + + model = TransformerBlock(10, 2, 0.5, 0.5, 0.5) + inp = np.random.rand(2, 3, 10).astype('float32') + inp = torch.tensor(inp) + mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') + mask = torch.tensor(mask) + out = model(inp, mask=mask, cache=None) + print(out) + + +if __name__ == '__main__': + main() diff --git a/modelscope/pipelines/__init__.py b/modelscope/pipelines/__init__.py index 14865872..6e2645de 100644 --- a/modelscope/pipelines/__init__.py +++ b/modelscope/pipelines/__init__.py @@ -4,3 +4,4 @@ from .builder import pipeline from .cv import * # noqa F403 from .multi_modal import * # noqa F403 from .nlp import * # noqa F403 +from .nlp.space import * # noqa F403 diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 7a21d5d9..ac6a1d32 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -14,7 +14,7 @@ from .outputs import TASK_OUTPUTS from .util import is_model_name Tensor = Union['torch.Tensor', 'tf.Tensor'] -Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] +Input = Union[str, tuple, dict, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] InputModel = Union[str, Model] output_keys = [ @@ -120,6 +120,7 @@ class Pipeline(ABC): out = self.preprocess(input, **preprocess_params) out = self.forward(out, **forward_params) out = self.postprocess(out, **postprocess_params) + self._check_output(out) return out diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 090e1384..dc99a157 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -3,6 +3,8 @@ from .nli_pipeline import * # noqa F403 from .sentence_similarity_pipeline import * # noqa F403 from .sentiment_classification_pipeline import * # noqa F403 from .sequence_classification_pipeline import * # noqa F403 +from .space.dialog_intent_prediction_pipeline import * # noqa F403 +from .space.dialog_modeling_pipeline import * # noqa F403 from .text_generation_pipeline import * # noqa F403 from .word_segmentation_pipeline import * # noqa F403 from .zero_shot_classification_pipeline import * # noqa F403 diff --git a/modelscope/pipelines/nlp/space/__init__.py b/modelscope/pipelines/nlp/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/nlp/space/dialog_intent_prediction_pipeline.py b/modelscope/pipelines/nlp/space/dialog_intent_prediction_pipeline.py new file mode 100644 index 00000000..57245bdf --- /dev/null +++ b/modelscope/pipelines/nlp/space/dialog_intent_prediction_pipeline.py @@ -0,0 +1,44 @@ +from typing import Any, Dict, Optional + +from modelscope.models.nlp import DialogIntentModel +from modelscope.preprocessors import DialogIntentPredictionPreprocessor +from modelscope.utils.constant import Tasks +from ...base import Input, Pipeline +from ...builder import PIPELINES + +__all__ = ['DialogIntentPredictionPipeline'] + + +@PIPELINES.register_module( + Tasks.dialog_intent_prediction, module_name=r'space-intent') +class DialogIntentPredictionPipeline(Pipeline): + + def __init__(self, model: DialogIntentModel, + preprocessor: DialogIntentPredictionPreprocessor, **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model (SequenceClassificationModel): a model instance + preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + """ + + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.model = model + # self.tokenizer = preprocessor.tokenizer + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + import numpy as np + pred = inputs['pred'] + pos = np.where(pred == np.max(pred)) + + result = {'pred': pred, 'label': pos[0]} + + return result diff --git a/modelscope/pipelines/nlp/space/dialog_modeling_pipeline.py b/modelscope/pipelines/nlp/space/dialog_modeling_pipeline.py new file mode 100644 index 00000000..afa352b6 --- /dev/null +++ b/modelscope/pipelines/nlp/space/dialog_modeling_pipeline.py @@ -0,0 +1,46 @@ +from typing import Any, Dict, Optional + +from modelscope.models.nlp import DialogModelingModel +from modelscope.preprocessors import DialogModelingPreprocessor +from modelscope.utils.constant import Tasks +from ...base import Pipeline, Tensor +from ...builder import PIPELINES + +__all__ = ['DialogModelingPipeline'] + + +@PIPELINES.register_module( + Tasks.dialog_modeling, module_name=r'space-modeling') +class DialogModelingPipeline(Pipeline): + + def __init__(self, model: DialogModelingModel, + preprocessor: DialogModelingPreprocessor, **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model (SequenceClassificationModel): a model instance + preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + """ + + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.model = model + self.preprocessor = preprocessor + + def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + sys_rsp = self.preprocessor.text_field.tokenizer.convert_ids_to_tokens( + inputs['resp']) + assert len(sys_rsp) > 2 + sys_rsp = sys_rsp[1:len(sys_rsp) - 1] + # sys_rsp = self.preprocessor.text_field.tokenizer. + + inputs['sys'] = sys_rsp + + return inputs diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 50860514..a94cbca1 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -7,4 +7,6 @@ from .common import Compose from .image import LoadImage, load_image from .multi_model import OfaImageCaptionPreprocessor from .nlp import * # noqa F403 +from .space.dialog_intent_prediction_preprocessor import * # noqa F403 +from .space.dialog_modeling_preprocessor import * # noqa F403 from .text_to_speech import * # noqa F403 diff --git a/modelscope/preprocessors/space/__init__.py b/modelscope/preprocessors/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py b/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py new file mode 100644 index 00000000..c5a6b34c --- /dev/null +++ b/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py @@ -0,0 +1,49 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict + +from modelscope.preprocessors.space.fields.intent_field import \ + IntentBPETextField +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields +from modelscope.utils.type_assert import type_assert +from ..base import Preprocessor +from ..builder import PREPROCESSORS + +__all__ = ['DialogIntentPredictionPreprocessor'] + + +@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-intent') +class DialogIntentPredictionPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + self.config = Config.from_file( + os.path.join(self.model_dir, 'configuration.json')) + self.text_field = IntentBPETextField( + self.model_dir, config=self.config) + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + samples = self.text_field.preprocessor([data]) + samples, _ = self.text_field.collate_fn_multi_turn(samples) + + return samples diff --git a/modelscope/preprocessors/space/dialog_modeling_preprocessor.py b/modelscope/preprocessors/space/dialog_modeling_preprocessor.py new file mode 100644 index 00000000..5061ba35 --- /dev/null +++ b/modelscope/preprocessors/space/dialog_modeling_preprocessor.py @@ -0,0 +1,51 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import uuid +from typing import Any, Dict, Union + +from modelscope.preprocessors.space.fields.gen_field import \ + MultiWOZBPETextField +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, InputFields +from modelscope.utils.type_assert import type_assert +from ..base import Preprocessor +from ..builder import PREPROCESSORS + +__all__ = ['DialogModelingPreprocessor'] + + +@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-modeling') +class DialogModelingPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + self.config = Config.from_file( + os.path.join(self.model_dir, 'configuration.json')) + self.text_field = MultiWOZBPETextField( + self.model_dir, config=self.config) + + @type_assert(object, Dict) + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + + user_ids = self.text_field.get_ids(data['user_input']) + data['user'] = user_ids + + return data diff --git a/modelscope/preprocessors/space/fields/__init__.py b/modelscope/preprocessors/space/fields/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/preprocessors/space/fields/dst_processors.py b/modelscope/preprocessors/space/fields/dst_processors.py new file mode 100644 index 00000000..6d888bff --- /dev/null +++ b/modelscope/preprocessors/space/fields/dst_processors.py @@ -0,0 +1,1522 @@ +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +import json +import numpy as np +import six +from tqdm import tqdm + +logger = logging.getLogger(__name__) +USER_NAME = 'User' +SYSTEM_NAME = 'System' +DIALOG_ACT = 'Dialog_Act' + +utter1 = { + 'User-1': + "I'd really like to take my client out to a nice restaurant that serves indian food." +} +history_states1 = [ + {}, +] +utter2 = { + 'User-1': + "I'd really like to take my client out to a nice restaurant that serves indian food.", + 'System-1': + 'I show many restaurants that serve Indian food in that price range. What area would you like to travel to?', + 'Dialog_Act-1': { + 'Restaurant-Inform': [['choice', 'many'], ['food', 'Indian'], + ['pricerange', 'that price range']] + }, + 'User-2': + 'I am looking for an expensive indian restaurant in the area of centre.', +} + +history_states2 = [{}, { + 'attraction': { + 'book': { + 'booked': [] + }, + 'semi': { + 'area': '', + 'name': '', + 'type': '' + } + }, + 'hospital': { + 'book': { + 'booked': [] + }, + 'semi': { + 'department': '' + } + }, + 'hotel': { + 'book': { + 'booked': [{ + 'name': 'alexander bed and breakfast', + 'reference': 'JXVKZ7KV' + }], + 'day': + 'sunday', + 'people': + '6', + 'stay': + '4' + }, + 'semi': { + 'area': '', + 'internet': 'yes', + 'name': 'alexander bed and breakfast', + 'parking': 'yes', + 'pricerange': 'cheap', + 'stars': '', + 'type': 'guesthouse' + } + }, + 'police': { + 'book': { + 'booked': [] + }, + 'semi': {} + }, + 'restaurant': { + 'book': { + 'booked': [{ + 'name': 'ask', + 'reference': 'Y2Y8QYBY' + }], + 'day': 'sunday', + 'people': '6', + 'time': '18:45' + }, + 'semi': { + 'area': 'centre', + 'food': 'italian', + 'name': 'ask', + 'pricerange': 'cheap' + } + }, + 'taxi': { + 'book': { + 'booked': [] + }, + 'semi': { + 'arriveBy': '', + 'departure': '', + 'destination': '', + 'leaveAt': '' + } + }, + 'train': { + 'book': { + 'booked': [], + 'people': '' + }, + 'semi': { + 'arriveBy': '', + 'day': '', + 'departure': '', + 'destination': '', + 'leaveAt': '' + } + } +}, {}] + +utter3 = { + 'User-1': + "I'd really like to take my client out to a nice restaurant that serves indian food.", + 'System-1': + 'I show many restaurants that serve Indian food in that price range. What area would you like to travel to?', + 'Dialog_Act-1': { + 'Restaurant-Inform': [['choice', 'many'], ['food', 'Indian'], + ['pricerange', 'that price range']] + }, + 'User-2': + 'I am looking for an expensive indian restaurant in the area of centre.', + 'System-2': + 'Might I recommend Saffron Brasserie? That is an expensive Indian restaurant in the center of town. I can book a table for you, if you like.', + 'Dialog_Act-2': { + 'Restaurant-Recommend': [['area', 'center of town'], + ['food', 'Indian'], + ['name', 'Saffron Brasserie'], + ['pricerange', 'expensive']] + }, + 'User-3': 'Sure thing, please book for 6 people at 19:30 on Saturday.' +} + +history_states3 = [{}, { + 'attraction': { + 'book': { + 'booked': [] + }, + 'semi': { + 'area': '', + 'name': '', + 'type': '' + } + }, + 'hospital': { + 'book': { + 'booked': [] + }, + 'semi': { + 'department': '' + } + }, + 'hotel': { + 'book': { + 'booked': [{ + 'name': 'alexander bed and breakfast', + 'reference': 'JXVKZ7KV' + }], + 'day': + 'sunday', + 'people': + '6', + 'stay': + '4' + }, + 'semi': { + 'area': '', + 'internet': 'yes', + 'name': 'alexander bed and breakfast', + 'parking': 'yes', + 'pricerange': 'cheap', + 'stars': '', + 'type': 'guesthouse' + } + }, + 'police': { + 'book': { + 'booked': [] + }, + 'semi': {} + }, + 'restaurant': { + 'book': { + 'booked': [{ + 'name': 'ask', + 'reference': 'Y2Y8QYBY' + }], + 'day': 'sunday', + 'people': '6', + 'time': '18:45' + }, + 'semi': { + 'area': 'centre', + 'food': 'italian', + 'name': 'ask', + 'pricerange': 'cheap' + } + }, + 'taxi': { + 'book': { + 'booked': [] + }, + 'semi': { + 'arriveBy': '', + 'departure': '', + 'destination': '', + 'leaveAt': '' + } + }, + 'train': { + 'book': { + 'booked': [], + 'people': '' + }, + 'semi': { + 'arriveBy': '', + 'day': '', + 'departure': '', + 'destination': '', + 'leaveAt': '' + } + } +}, {}, { + 'attraction': { + 'book': { + 'booked': [] + }, + 'semi': { + 'area': '', + 'name': '', + 'type': '' + } + }, + 'hospital': { + 'book': { + 'booked': [] + }, + 'semi': { + 'department': '' + } + }, + 'hotel': { + 'book': { + 'booked': [{ + 'name': 'alexander bed and breakfast', + 'reference': 'JXVKZ7KV' + }], + 'day': + 'sunday', + 'people': + '6', + 'stay': + '4' + }, + 'semi': { + 'area': '', + 'internet': 'yes', + 'name': 'alexander bed and breakfast', + 'parking': 'yes', + 'pricerange': 'cheap', + 'stars': '', + 'type': 'guesthouse' + } + }, + 'police': { + 'book': { + 'booked': [] + }, + 'semi': {} + }, + 'restaurant': { + 'book': { + 'booked': [{ + 'name': 'ask', + 'reference': 'Y2Y8QYBY' + }], + 'day': 'sunday', + 'people': '6', + 'time': '18:45' + }, + 'semi': { + 'area': 'centre', + 'food': 'italian', + 'name': 'ask', + 'pricerange': 'cheap' + } + }, + 'taxi': { + 'book': { + 'booked': [] + }, + 'semi': { + 'arriveBy': '', + 'departure': '', + 'destination': '', + 'leaveAt': '' + } + }, + 'train': { + 'book': { + 'booked': [], + 'people': '' + }, + 'semi': { + 'arriveBy': '', + 'day': '', + 'departure': '', + 'destination': '', + 'leaveAt': '' + } + } +}, {}] + + +class DSTProcessor(object): + + ACTS_DICT = { + 'taxi-depart': 'taxi-departure', + 'taxi-dest': 'taxi-destination', + 'taxi-leaveat': 'taxi-leaveAt', + 'taxi-arriveby': 'taxi-arriveBy', + 'train-depart': 'train-departure', + 'train-dest': 'train-destination', + 'train-leaveat': 'train-leaveAt', + 'train-arriveby': 'train-arriveBy', + 'train-bookpeople': 'train-book_people', + 'restaurant-price': 'restaurant-pricerange', + 'restaurant-bookpeople': 'restaurant-book_people', + 'restaurant-bookday': 'restaurant-book_day', + 'restaurant-booktime': 'restaurant-book_time', + 'hotel-price': 'hotel-pricerange', + 'hotel-bookpeople': 'hotel-book_people', + 'hotel-bookday': 'hotel-book_day', + 'hotel-bookstay': 'hotel-book_stay', + 'booking-bookpeople': 'booking-book_people', + 'booking-bookday': 'booking-book_day', + 'booking-bookstay': 'booking-book_stay', + 'booking-booktime': 'booking-book_time', + } + + LABEL_MAPS = {} # Loaded from file + + def __init__(self): + # Required for mapping slot names in dialogue_acts.json file + # to proper designations. + pass + + def _convert_inputs_to_utterances(self, inputs: dict, + history_states: list): + """This method is to generate the utterances with user, sys, dialog_acts and metadata, while metadata is from the history_states or the output from the inference pipline""" + + utterances = [] + user_inputs = [] + sys_gen_inputs = [] + dialog_acts_inputs = [] + for i, item in enumerate(inputs): + name, turn = item.split('-') + if name == USER_NAME: + user_inputs.insert(int(turn) - 1, inputs[item]) + elif name == SYSTEM_NAME: + sys_gen_inputs.insert(int(turn) - 1, inputs[item]) + else: + dialog_acts_inputs.insert(int(turn) - 1, inputs[item]) + + # user is leading the topic should aways larger than sys and dialog acts + assert len(user_inputs) - 1 == len(sys_gen_inputs) + assert len(user_inputs) - 1 == len(dialog_acts_inputs) + # the history states record both user and sys states + assert len(history_states) == len(user_inputs) + len(sys_gen_inputs) + + # the dialog_act at user turn is useless + for i, item in enumerate(history_states): + utterance = {} + # the dialog_act at user turn is useless + utterance['dialog_act'] = dialog_acts_inputs[ + i // 2] if i % 2 == 1 else {} + utterance['text'] = sys_gen_inputs[ + i // 2] if i % 2 == 1 else user_inputs[i // 2] + utterance['metadata'] = item + utterance['span_info'] = [] + utterances.append(utterance) + + return utterances + + def _load_acts(self, inputs: dict, dialog_id='example.json'): + dialog_acts_inputs = [] + for i, item in enumerate(inputs): + name, turn = item.split('-') + if name == DIALOG_ACT: + dialog_acts_inputs.insert(int(turn) - 1, inputs[item]) + s_dict = {} + + for j, item in enumerate(dialog_acts_inputs): + if isinstance(item, dict): + for a in item: + aa = a.lower().split('-') + if aa[1] == 'inform' or aa[1] == 'recommend' or aa[ + 1] == 'select' or aa[1] == 'book': + for i in item[a]: + s = i[0].lower() + v = i[1].lower().strip() + if s == 'none' or v == '?' or v == 'none': + continue + slot = aa[0] + '-' + s + if slot in self.ACTS_DICT: + slot = self.ACTS_DICT[slot] + key = dialog_id, str(int(j) + 1), slot + # In case of multiple mentioned values... + # ... Option 1: Keep first informed value + if key not in s_dict: + s_dict[key] = list([v]) + # ... Option 2: Keep last informed value + #s_dict[key] = list([v]) + + return s_dict + + +class multiwoz22Processor(DSTProcessor): + + def __init__(self): + super().__init__() + + def normalize_time(self, text): + text = re.sub('(\d{1})(a\.?m\.?|p\.?m\.?)', r'\1 \2', + text) # am/pm without space + text = re.sub('(^| )(\d{1,2}) (a\.?m\.?|p\.?m\.?)', r'\1\2:00 \3', + text) # am/pm short to long form + text = re.sub( + '(^| )(at|from|by|until|after) ?(\d{1,2}) ?(\d{2})([^0-9]|$)', + r'\1\2 \3:\4\5', text) # Missing separator + text = re.sub('(^| )(\d{2})[;.,](\d{2})', r'\1\2:\3', + text) # Wrong separator + text = re.sub('(^| )(at|from|by|until|after) ?(\d{1,2})([;., ]|$)', + r'\1\2 \3:00\4', text) # normalize simple full hour time + text = re.sub('(^| )(\d{1}:\d{2})', r'\g<1>0\2', + text) # Add missing leading 0 + # Map 12 hour times to 24 hour times + text = re.sub( + '(\d{2})(:\d{2}) ?p\.?m\.?', lambda x: str( + int(x.groups()[0]) + 12 + if int(x.groups()[0]) < 12 else int(x.groups()[0])) + x.groups( + )[1], text) + text = re.sub('(^| )24:(\d{2})', r'\g<1>00:\2', + text) # Correct times that use 24 as hour + return text + + def normalize_text(self, text): + text = self.normalize_time(text) + text = re.sub("n't", ' not', text) + text = re.sub('(^| )zero(-| )star([s.,? ]|$)', r'\g<1>0 star\3', text) + text = re.sub('(^| )one(-| )star([s.,? ]|$)', r'\g<1>1 star\3', text) + text = re.sub('(^| )two(-| )star([s.,? ]|$)', r'\g<1>2 star\3', text) + text = re.sub('(^| )three(-| )star([s.,? ]|$)', r'\g<1>3 star\3', text) + text = re.sub('(^| )four(-| )star([s.,? ]|$)', r'\g<1>4 star\3', text) + text = re.sub('(^| )five(-| )star([s.,? ]|$)', r'\g<1>5 star\3', text) + text = re.sub('archaelogy', 'archaeology', text) # Systematic typo + text = re.sub('guesthouse', 'guest house', text) # Normalization + text = re.sub('(^| )b ?& ?b([.,? ]|$)', r'\1bed and breakfast\2', + text) # Normalization + text = re.sub('bed & breakfast', 'bed and breakfast', + text) # Normalization + return text + + # Loads the dialogue_acts.json and returns a list + # of slot-value pairs. + def load_acts(self, input_file): + with open(input_file) as f: + acts = json.load(f) + s_dict = {} + for d in acts: + for t in acts[d]: + if int(t) % 2 == 0: + continue + # Only process, if turn has annotation + if isinstance(acts[d][t]['dialog_act'], dict): + for a in acts[d][t]['dialog_act']: + aa = a.lower().split('-') + if aa[1] == 'inform' or aa[1] == 'recommend' or aa[ + 1] == 'select' or aa[1] == 'book': + for i in acts[d][t]['dialog_act'][a]: + s = i[0].lower() + v = i[1].lower().strip() + if s == 'none' or v == '?' or v == 'none': + continue + slot = aa[0] + '-' + s + if slot in self.ACTS_DICT: + slot = self.ACTS_DICT[slot] + key = d, str(int(t) // 2 + 1), slot + # In case of multiple mentioned values... + # ... Option 1: Keep first informed value + if key not in s_dict: + s_dict[key] = list([v]) + # ... Option 2: Keep last informed value + #s_dict[key] = list([v]) + return s_dict + + # This should only contain label normalizations. All other mappings should + # be defined in LABEL_MAPS. + def normalize_label(self, slot, value_label): + # Normalization of empty slots + if value_label == '' or value_label == 'not mentioned': + return 'none' + + # Normalization of time slots + if 'leaveAt' in slot or 'arriveBy' in slot or slot == 'restaurant-book_time': + return self.normalize_time(value_label) + + # Normalization + if 'type' in slot or 'name' in slot or 'destination' in slot or 'departure' in slot: + value_label = re.sub('guesthouse', 'guest house', value_label) + + # Map to boolean slots + if slot == 'hotel-parking' or slot == 'hotel-internet': + if value_label == 'yes' or value_label == 'free': + return 'true' + if value_label == 'no': + return 'false' + if slot == 'hotel-type': + if value_label == 'hotel': + return 'true' + if value_label == 'guest house': + return 'false' + + return value_label + + def tokenize(self, utt): + utt_lower = convert_to_unicode(utt).lower() + utt_lower = self.normalize_text(utt_lower) + utt_tok = [ + tok for tok in map(str.strip, re.split('(\W+)', utt_lower)) + if len(tok) > 0 + ] + return utt_tok + + def delex_utt(self, utt, values, unk_token='[UNK]'): + utt_norm = self.tokenize(utt) + for s, vals in values.items(): + # TODO vals可能不是数组形式,而是初始化的字符串"none" + for v in vals: + if v != 'none': + v_norm = self.tokenize(v) + v_len = len(v_norm) + for i in range(len(utt_norm) + 1 - v_len): + if utt_norm[i:i + v_len] == v_norm: + utt_norm[i:i + v_len] = [unk_token] * v_len + return utt_norm + + def get_token_pos(self, tok_list, value_label): + find_pos = [] + found = False + label_list = [ + item for item in map(str.strip, re.split('(\W+)', value_label)) + if len(item) > 0 + ] + len_label = len(label_list) + for i in range(len(tok_list) + 1 - len_label): + if tok_list[i:i + len_label] == label_list: + find_pos.append((i, i + len_label)) # start, exclusive_end + found = True + return found, find_pos + + def check_label_existence(self, value_label, usr_utt_tok): + in_usr, usr_pos = self.get_token_pos(usr_utt_tok, value_label) + # If no hit even though there should be one, check for value label variants + if not in_usr and value_label in self.LABEL_MAPS: + for value_label_variant in self.LABEL_MAPS[value_label]: + in_usr, usr_pos = self.get_token_pos(usr_utt_tok, + value_label_variant) + if in_usr: + break + return in_usr, usr_pos + + def check_slot_referral(self, value_label, slot, seen_slots): + referred_slot = 'none' + if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking': + return referred_slot + for s in seen_slots: + # Avoid matches for slots that share values with different meaning. + # hotel-internet and -parking are handled separately as Boolean slots. + if s == 'hotel-stars' or s == 'hotel-internet' or s == 'hotel-parking': + continue + if re.match('(hotel|restaurant)-book_people', + s) and slot == 'hotel-book_stay': + continue + if re.match('(hotel|restaurant)-book_people', + slot) and s == 'hotel-book_stay': + continue + if slot != s and (slot not in seen_slots + or seen_slots[slot] != value_label): + if seen_slots[s] == value_label: + referred_slot = s + break + elif value_label in self.LABEL_MAPS: + for value_label_variant in self.LABEL_MAPS[value_label]: + if seen_slots[s] == value_label_variant: + referred_slot = s + break + return referred_slot + + def is_in_list(self, tok, value): + found = False + tok_list = [ + item for item in map(str.strip, re.split('(\W+)', tok)) + if len(item) > 0 + ] + value_list = [ + item for item in map(str.strip, re.split('(\W+)', value)) + if len(item) > 0 + ] + tok_len = len(tok_list) + value_len = len(value_list) + for i in range(tok_len + 1 - value_len): + if tok_list[i:i + value_len] == value_list: + found = True + break + return found + + # Fuzzy matching to label informed slot values + def check_slot_inform(self, value_label, inform_label): + result = False + informed_value = 'none' + vl = ' '.join(self.tokenize(value_label)) + for il in inform_label: + if vl == il: + result = True + elif self.is_in_list(il, vl): + result = True + elif self.is_in_list(vl, il): + result = True + elif il in self.LABEL_MAPS: + for il_variant in self.LABEL_MAPS[il]: + if vl == il_variant: + result = True + break + elif self.is_in_list(il_variant, vl): + result = True + break + elif self.is_in_list(vl, il_variant): + result = True + break + elif vl in self.LABEL_MAPS: + for value_label_variant in self.LABEL_MAPS[vl]: + if value_label_variant == il: + result = True + break + elif self.is_in_list(il, value_label_variant): + result = True + break + elif self.is_in_list(value_label_variant, il): + result = True + break + if result: + informed_value = il + break + return result, informed_value + + def get_turn_label(self, value_label, inform_label, sys_utt_tok, + usr_utt_tok, slot, seen_slots, slot_last_occurrence): + usr_utt_tok_label = [0 for _ in usr_utt_tok] + informed_value = 'none' + referred_slot = 'none' + if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false': + class_type = value_label + else: + in_usr, usr_pos = self.check_label_existence( + value_label, usr_utt_tok) + is_informed, informed_value = self.check_slot_inform( + value_label, inform_label) + if in_usr: + class_type = 'copy_value' + if slot_last_occurrence: + (s, e) = usr_pos[-1] + for i in range(s, e): + usr_utt_tok_label[i] = 1 + else: + for (s, e) in usr_pos: + for i in range(s, e): + usr_utt_tok_label[i] = 1 + elif is_informed: + class_type = 'inform' + else: + referred_slot = self.check_slot_referral( + value_label, slot, seen_slots) + if referred_slot != 'none': + class_type = 'refer' + else: + class_type = 'unpointable' + return informed_value, referred_slot, usr_utt_tok_label, class_type + + def _create_example(self, + utterances, + sys_inform_dict, + set_type, + slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + unk_token='[UNK]', + analyze=False, + dialog_id='example.json'): + + # Collects all slot changes throughout the dialog + cumulative_labels = {slot: 'none' for slot in slot_list} + + # First system utterance is empty, since multiwoz starts with user input + utt_tok_list = [[]] + mod_slots_list = [] + + # Collect all utterances and their metadata + usr_sys_switch = True + turn_itr = 0 + + for utt in utterances: + # Assert that system and user utterances alternate + is_sys_utt = utt['metadata'] != {} + if usr_sys_switch == is_sys_utt: + print( + 'WARN: Wrong order of system and user utterances. Skipping rest of the dialog %s' + % (dialog_id)) + break + usr_sys_switch = is_sys_utt + + if is_sys_utt: + turn_itr += 1 + + # Delexicalize sys utterance + if delexicalize_sys_utts and is_sys_utt: + inform_dict = {slot: 'none' for slot in slot_list} + for slot in slot_list: + if (str(dialog_id), str(turn_itr), + slot) in sys_inform_dict: + inform_dict[slot] = sys_inform_dict[(str(dialog_id), + str(turn_itr), + slot)] + utt_tok_list.append( + self.delex_utt(utt['text'], inform_dict, + unk_token)) # normalize utterances + else: + utt_tok_list.append(self.tokenize( + utt['text'])) # normalize utterances + + modified_slots = {} + + # If sys utt, extract metadata (identify and collect modified slots) + if is_sys_utt: + for d in utt['metadata']: + booked = utt['metadata'][d]['book']['booked'] + booked_slots = {} + # Check the booked section + if booked != []: + for s in booked[0]: + booked_slots[s] = self.normalize_label( + '%s-%s' % (d, s), + booked[0][s]) # normalize labels + # Check the semi and the inform slots + for category in ['book', 'semi']: + for s in utt['metadata'][d][category]: + cs = '%s-book_%s' % ( + d, s) if category == 'book' else '%s-%s' % (d, + s) + value_label = self.normalize_label( + cs, utt['metadata'][d][category] + [s]) # normalize labels + # Prefer the slot value as stored in the booked section + if s in booked_slots: + value_label = booked_slots[s] + # Remember modified slots and entire dialog state + if cs in slot_list and cumulative_labels[ + cs] != value_label: + modified_slots[cs] = value_label + cumulative_labels[cs] = value_label + + mod_slots_list.append(modified_slots.copy()) + + # Form proper (usr, sys) turns + turn_itr = 0 + diag_seen_slots_dict = {} + diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list} + diag_state = {slot: 'none' for slot in slot_list} + sys_utt_tok = [] + usr_utt_tok = [] + hst_utt_tok = [] + hst_utt_tok_label_dict = {slot: [] for slot in slot_list} + new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() + new_diag_state = diag_state.copy() + + for i in range(0, len(utt_tok_list) - 1, 2): + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + referral_dict = {} + class_type_dict = {} + + # Collect turn data + if append_history: + if swap_utterances: + hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok + else: + hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok + sys_utt_tok = utt_tok_list[i] + usr_utt_tok = utt_tok_list[i + 1] + turn_slots = mod_slots_list[ + i + 1] if len(mod_slots_list) > 1 else {} + + guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr)) + + if analyze: + print('%15s %2s %s ||| %s' % + (dialog_id, turn_itr, ' '.join(sys_utt_tok), + ' '.join(usr_utt_tok))) + print('%15s %2s [' % (dialog_id, turn_itr), end='') + + new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() + new_diag_state = diag_state.copy() + for slot in slot_list: + value_label = 'none' + if slot in turn_slots: + value_label = turn_slots[slot] + # We keep the original labels so as to not + # overlook unpointable values, as well as to not + # modify any of the original labels for test sets, + # since this would make comparison difficult. + value_dict[slot] = value_label + elif label_value_repetitions and slot in diag_seen_slots_dict: + value_label = diag_seen_slots_value_dict[slot] + + # Get dialog act annotations + inform_label = list(['none']) + inform_slot_dict[slot] = 0 + if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict: + inform_label = list([ + self.normalize_label(slot, i) + for i in sys_inform_dict[(str(dialog_id), + str(turn_itr), slot)] + ]) + inform_slot_dict[slot] = 1 + elif (str(dialog_id), str(turn_itr), + 'booking-' + slot.split('-')[1]) in sys_inform_dict: + inform_label = list([ + self.normalize_label(slot, i) + for i in sys_inform_dict[(str(dialog_id), + str(turn_itr), 'booking-' + + slot.split('-')[1])] + ]) + inform_slot_dict[slot] = 1 + + (informed_value, referred_slot, usr_utt_tok_label, + class_type) = self.get_turn_label( + value_label, + inform_label, + sys_utt_tok, + usr_utt_tok, + slot, + diag_seen_slots_value_dict, + slot_last_occurrence=True) + + inform_dict[slot] = informed_value + + # Generally don't use span prediction on sys utterance (but inform prediction instead). + sys_utt_tok_label = [0 for _ in sys_utt_tok] + + # Determine what to do with value repetitions. + # If value is unique in seen slots, then tag it, otherwise not, + # since correct slot assignment can not be guaranteed anymore. + if label_value_repetitions and slot in diag_seen_slots_dict: + if class_type == 'copy_value' and list( + diag_seen_slots_value_dict.values()).count( + value_label) > 1: + class_type = 'none' + usr_utt_tok_label = [0 for _ in usr_utt_tok_label] + + sys_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label_dict[slot] = usr_utt_tok_label + + if append_history: + if use_history_labels: + if swap_utterances: + new_hst_utt_tok_label_dict[ + slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[ + slot] + else: + new_hst_utt_tok_label_dict[ + slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[ + slot] + else: + new_hst_utt_tok_label_dict[slot] = [ + 0 for _ in sys_utt_tok_label + usr_utt_tok_label + + new_hst_utt_tok_label_dict[slot] + ] + + # For now, we map all occurences of unpointable slot values + # to none. However, since the labels will still suggest + # a presence of unpointable slot values, the task of the + # DST is still to find those values. It is just not + # possible to do that via span prediction on the current input. + if class_type == 'unpointable': + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + if analyze: + if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[ + slot]: + print('(%s): %s, ' % (slot, value_label), end='') + elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[ + slot] and class_type != 'copy_value' and class_type != 'inform': + # If slot has seen before and its class type did not change, label this slot a not present, + # assuming that the slot has not actually been mentioned in this turn. + # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, + # this must mean there is evidence in the original labels, therefore consider + # them as mentioned again. + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + else: + class_type_dict[slot] = class_type + referral_dict[slot] = referred_slot + # Remember that this slot was mentioned during this dialog already. + if class_type != 'none': + diag_seen_slots_dict[slot] = class_type + diag_seen_slots_value_dict[slot] = value_label + new_diag_state[slot] = class_type + # Unpointable is not a valid class, therefore replace with + # some valid class for now... + if class_type == 'unpointable': + new_diag_state[slot] = 'copy_value' + + if analyze: + print(']') + + if swap_utterances: + txt_a = usr_utt_tok + txt_b = sys_utt_tok + txt_a_lbl = usr_utt_tok_label_dict + txt_b_lbl = sys_utt_tok_label_dict + else: + txt_a = sys_utt_tok + txt_b = usr_utt_tok + txt_a_lbl = sys_utt_tok_label_dict + txt_b_lbl = usr_utt_tok_label_dict + + example = DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + history=hst_utt_tok, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + history_label=hst_utt_tok_label_dict, + values=diag_seen_slots_value_dict.copy(), + inform_label=inform_dict, + inform_slot_label=inform_slot_dict, + refer_label=referral_dict, + diag_state=diag_state, + class_label=class_type_dict) + # Update some variables. + hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy() + diag_state = new_diag_state.copy() + + turn_itr += 1 + return example + + def create_example(self, + inputs, + history_states, + set_type, + slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + unk_token='[UNK]', + analyze=False, + dialog_id='0'): + utterances = self._convert_inputs_to_utterances(inputs, history_states) + sys_inform_dict = self._load_acts(inputs) + self.LABEL_MAPS = label_maps + example = self._create_example(utterances, sys_inform_dict, set_type, + slot_list, label_maps, append_history, + use_history_labels, swap_utterances, + label_value_repetitions, + delexicalize_sys_utts, unk_token, + analyze) + + return example + + def create_examples(self, + input_file, + acts_file, + set_type, + slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + unk_token='[UNK]', + analyze=False): + """Read a DST json file into a list of DSTExample.""" + + sys_inform_dict = self.load_acts(acts_file) + + with open(input_file, 'r', encoding='utf-8') as reader: + input_data = json.load(reader) + + self.LABEL_MAPS = label_maps + + examples = [] + for dialog_id in tqdm(input_data): + entry = input_data[dialog_id] + utterances = entry['log'] + + example = self._create_example( + utterances, sys_inform_dict, set_type, slot_list, label_maps, + append_history, use_history_labels, swap_utterances, + label_value_repetitions, delexicalize_sys_utts, unk_token, + analyze) + examples.append(example) + + return examples + + +class DSTExample(object): + """ + A single training/test example for the DST dataset. + """ + + def __init__(self, + guid, + text_a, + text_b, + history, + text_a_label=None, + text_b_label=None, + history_label=None, + values=None, + inform_label=None, + inform_slot_label=None, + refer_label=None, + diag_state=None, + class_label=None): + self.guid = guid + self.text_a = text_a + self.text_b = text_b + self.history = history + self.text_a_label = text_a_label + self.text_b_label = text_b_label + self.history_label = history_label + self.values = values + self.inform_label = inform_label + self.inform_slot_label = inform_slot_label + self.refer_label = refer_label + self.diag_state = diag_state + self.class_label = class_label + + def __str__(self): + return self.__repr__() + + def __repr__(self): + s = '' + s += 'guid: %s' % (self.guid) + s += ', text_a: %s' % (self.text_a) + s += ', text_b: %s' % (self.text_b) + s += ', history: %s' % (self.history) + if self.text_a_label: + s += ', text_a_label: %d' % (self.text_a_label) + if self.text_b_label: + s += ', text_b_label: %d' % (self.text_b_label) + if self.history_label: + s += ', history_label: %d' % (self.history_label) + if self.values: + s += ', values: %d' % (self.values) + if self.inform_label: + s += ', inform_label: %d' % (self.inform_label) + if self.inform_slot_label: + s += ', inform_slot_label: %d' % (self.inform_slot_label) + if self.refer_label: + s += ', refer_label: %d' % (self.refer_label) + if self.diag_state: + s += ', diag_state: %d' % (self.diag_state) + if self.class_label: + s += ', class_label: %d' % (self.class_label) + return s + + +class InputFeatures(object): + """A single set of features of data.""" + + def __init__(self, + input_ids, + input_ids_unmasked, + input_mask, + segment_ids, + start_pos=None, + end_pos=None, + values=None, + inform=None, + inform_slot=None, + refer_id=None, + diag_state=None, + class_label_id=None, + guid='NONE'): + self.guid = guid + self.input_ids = input_ids + self.input_ids_unmasked = input_ids_unmasked + self.input_mask = input_mask + self.segment_ids = segment_ids + self.start_pos = start_pos + self.end_pos = end_pos + self.values = values + self.inform = inform + self.inform_slot = inform_slot + self.refer_id = refer_id + self.diag_state = diag_state + self.class_label_id = class_label_id + + +def convert_examples_to_features(examples, + slot_list, + class_types, + model_type, + tokenizer, + max_seq_length, + slot_value_dropout=0.0): + """Loads a data file into a list of `InputBatch`s.""" + + if model_type == 'bert': + model_specs = { + 'MODEL_TYPE': 'bert', + 'CLS_TOKEN': '[CLS]', + 'UNK_TOKEN': '[UNK]', + 'SEP_TOKEN': '[SEP]', + 'TOKEN_CORRECTION': 4 + } + else: + logger.error('Unknown model type (%s). Aborting.' % (model_type)) + exit(1) + + def _tokenize_text_and_label(text, text_label_dict, slot, tokenizer, + model_specs, slot_value_dropout): + joint_text_label = [0 for _ in text_label_dict[slot] + ] # joint all slots' label + for slot_text_label in text_label_dict.values(): + for idx, label in enumerate(slot_text_label): + if label == 1: + joint_text_label[idx] = 1 + + text_label = text_label_dict[slot] + tokens = [] + tokens_unmasked = [] + token_labels = [] + for token, token_label, joint_label in zip(text, text_label, + joint_text_label): + token = convert_to_unicode(token) + sub_tokens = tokenizer.tokenize(token) # Most time intensive step + tokens_unmasked.extend(sub_tokens) + if slot_value_dropout == 0.0 or joint_label == 0: + tokens.extend(sub_tokens) + else: + rn_list = np.random.random_sample((len(sub_tokens), )) + for rn, sub_token in zip(rn_list, sub_tokens): + if rn > slot_value_dropout: + tokens.append(sub_token) + else: + tokens.append(model_specs['UNK_TOKEN']) + token_labels.extend([token_label for _ in sub_tokens]) + assert len(tokens) == len(token_labels) + assert len(tokens_unmasked) == len(token_labels) + return tokens, tokens_unmasked, token_labels + + def _truncate_seq_pair(tokens_a, tokens_b, history, max_length): + """Truncates a sequence pair in place to the maximum length. + Copied from bert/run_classifier.py + """ + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + total_length = len(tokens_a) + len(tokens_b) + len(history) + if total_length <= max_length: + break + if len(history) > 0: + history.pop() + elif len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + + def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, + model_specs, guid): + # Modifies `tokens_a` and `tokens_b` in place so that the total + # length is less than the specified length. + # Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT) + if len(tokens_a) + len(tokens_b) + len( + history) > max_seq_length - model_specs['TOKEN_CORRECTION']: + logger.info('Truncate Example %s. Total len=%d.' % + (guid, len(tokens_a) + len(tokens_b) + len(history))) + input_text_too_long = True + else: + input_text_too_long = False + _truncate_seq_pair(tokens_a, tokens_b, history, + max_seq_length - model_specs['TOKEN_CORRECTION']) + return input_text_too_long + + def _get_token_label_ids(token_labels_a, token_labels_b, + token_labels_history, max_seq_length, + model_specs): + token_label_ids = [] + token_label_ids.append(0) # [CLS] + for token_label in token_labels_a: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + for token_label in token_labels_b: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + for token_label in token_labels_history: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + while len(token_label_ids) < max_seq_length: + token_label_ids.append(0) # padding + assert len(token_label_ids) == max_seq_length + return token_label_ids + + def _get_start_end_pos(class_type, token_label_ids, max_seq_length): + if class_type == 'copy_value' and 1 not in token_label_ids: + #logger.warn("copy_value label, but token_label not detected. Setting label to 'none'.") + class_type = 'none' + start_pos = 0 + end_pos = 0 + if 1 in token_label_ids: + start_pos = token_label_ids.index(1) + # Parsing is supposed to find only first location of wanted value + if 0 not in token_label_ids[start_pos:]: + end_pos = len(token_label_ids[start_pos:]) + start_pos - 1 + else: + end_pos = token_label_ids[start_pos:].index(0) + start_pos - 1 + for i in range(max_seq_length): + if i >= start_pos and i <= end_pos: + assert token_label_ids[i] == 1 + return class_type, start_pos, end_pos + + def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, + tokenizer, model_specs): + # The convention in BERT is: + # (a) For sequence pairs: + # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] + # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + # (b) For single sequences: + # tokens: [CLS] the dog is hairy . [SEP] + # type_ids: 0 0 0 0 0 0 0 + # + # Where "type_ids" are used to indicate whether this is the first + # sequence or the second sequence. The embedding vectors for `type=0` and + # `type=1` were learned during pre-training and are added to the wordpiece + # embedding vector (and position vector). This is not *strictly* necessary + # since the [SEP] token unambiguously separates the sequences, but it makes + # it easier for the model to learn the concept of sequences. + # + # For classification tasks, the first vector (corresponding to [CLS]) is + # used as the "sentence vector". Note that this only makes sense because + # the entire model is fine-tuned. + tokens = [] + segment_ids = [] + tokens.append(model_specs['CLS_TOKEN']) + segment_ids.append(0) + for token in tokens_a: + tokens.append(token) + segment_ids.append(0) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(0) + for token in tokens_b: + tokens.append(token) + segment_ids.append(1) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(1) + for token in history: + tokens.append(token) + segment_ids.append(1) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(1) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1] * len(input_ids) + # Zero-pad up to the sequence length. + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + return tokens, input_ids, input_mask, segment_ids + + total_cnt = 0 + too_long_cnt = 0 + + refer_list = ['none'] + slot_list + + features = [] + # Convert single example + for (example_index, example) in enumerate(examples): + if example_index % 1000 == 0: + logger.info('Writing example %d of %d' % + (example_index, len(examples))) + + total_cnt += 1 + + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + refer_id_dict = {} + diag_state_dict = {} + class_label_id_dict = {} + start_pos_dict = {} + end_pos_dict = {} + for slot in slot_list: + tokens_a, tokens_a_unmasked, token_labels_a = _tokenize_text_and_label( + example.text_a, example.text_a_label, slot, tokenizer, + model_specs, slot_value_dropout) + tokens_b, tokens_b_unmasked, token_labels_b = _tokenize_text_and_label( + example.text_b, example.text_b_label, slot, tokenizer, + model_specs, slot_value_dropout) + tokens_history, tokens_history_unmasked, token_labels_history = _tokenize_text_and_label( + example.history, example.history_label, slot, tokenizer, + model_specs, slot_value_dropout) + + input_text_too_long = _truncate_length_and_warn( + tokens_a, tokens_b, tokens_history, max_seq_length, + model_specs, example.guid) + + if input_text_too_long: + if example_index < 10: + if len(token_labels_a) > len(tokens_a): + logger.info(' tokens_a truncated labels: %s' + % str(token_labels_a[len(tokens_a):])) + if len(token_labels_b) > len(tokens_b): + logger.info(' tokens_b truncated labels: %s' + % str(token_labels_b[len(tokens_b):])) + if len(token_labels_history) > len(tokens_history): + logger.info( + ' tokens_history truncated labels: %s' + % str(token_labels_history[len(tokens_history):])) + + token_labels_a = token_labels_a[:len(tokens_a)] + token_labels_b = token_labels_b[:len(tokens_b)] + token_labels_history = token_labels_history[:len(tokens_history + )] + tokens_a_unmasked = tokens_a_unmasked[:len(tokens_a)] + tokens_b_unmasked = tokens_b_unmasked[:len(tokens_b)] + tokens_history_unmasked = tokens_history_unmasked[:len( + tokens_history)] + + assert len(token_labels_a) == len(tokens_a) + assert len(token_labels_b) == len(tokens_b) + assert len(token_labels_history) == len(tokens_history) + assert len(token_labels_a) == len(tokens_a_unmasked) + assert len(token_labels_b) == len(tokens_b_unmasked) + assert len(token_labels_history) == len(tokens_history_unmasked) + token_label_ids = _get_token_label_ids(token_labels_a, + token_labels_b, + token_labels_history, + max_seq_length, model_specs) + + value_dict[slot] = example.values[slot] + inform_dict[slot] = example.inform_label[slot] + + class_label_mod, start_pos_dict[slot], end_pos_dict[ + slot] = _get_start_end_pos(example.class_label[slot], + token_label_ids, max_seq_length) + if class_label_mod != example.class_label[slot]: + example.class_label[slot] = class_label_mod + inform_slot_dict[slot] = example.inform_slot_label[slot] + refer_id_dict[slot] = refer_list.index(example.refer_label[slot]) + diag_state_dict[slot] = class_types.index(example.diag_state[slot]) + class_label_id_dict[slot] = class_types.index( + example.class_label[slot]) + + if input_text_too_long: + too_long_cnt += 1 + + tokens, input_ids, input_mask, segment_ids = _get_transformer_input( + tokens_a, tokens_b, tokens_history, max_seq_length, tokenizer, + model_specs) + if slot_value_dropout > 0.0: + _, input_ids_unmasked, _, _ = _get_transformer_input( + tokens_a_unmasked, tokens_b_unmasked, tokens_history_unmasked, + max_seq_length, tokenizer, model_specs) + else: + input_ids_unmasked = input_ids + + assert (len(input_ids) == len(input_ids_unmasked)) + + if example_index < 10: + logger.info('*** Example ***') + logger.info('guid: %s' % (example.guid)) + logger.info('tokens: %s' % ' '.join(tokens)) + logger.info('input_ids: %s' % ' '.join([str(x) + for x in input_ids])) + logger.info('input_mask: %s' + % ' '.join([str(x) for x in input_mask])) + logger.info('segment_ids: %s' + % ' '.join([str(x) for x in segment_ids])) + logger.info('start_pos: %s' % str(start_pos_dict)) + logger.info('end_pos: %s' % str(end_pos_dict)) + logger.info('values: %s' % str(value_dict)) + logger.info('inform: %s' % str(inform_dict)) + logger.info('inform_slot: %s' % str(inform_slot_dict)) + logger.info('refer_id: %s' % str(refer_id_dict)) + logger.info('diag_state: %s' % str(diag_state_dict)) + logger.info('class_label_id: %s' % str(class_label_id_dict)) + + features.append( + InputFeatures( + guid=example.guid, + input_ids=input_ids, + input_ids_unmasked=input_ids_unmasked, + input_mask=input_mask, + segment_ids=segment_ids, + start_pos=start_pos_dict, + end_pos=end_pos_dict, + values=value_dict, + inform=inform_dict, + inform_slot=inform_slot_dict, + refer_id=refer_id_dict, + diag_state=diag_state_dict, + class_label_id=class_label_id_dict)) + + logger.info('========== %d out of %d examples have text too long' % + (too_long_cnt, total_cnt)) + + return features + + +# From bert.tokenization (TF code) +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode('utf-8', 'ignore') + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode('utf-8', 'ignore') + elif isinstance(text, unicode): + return text + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + else: + raise ValueError('Not running on Python2 or Python 3?') + + +if __name__ == '__main__': + processor = multiwoz22Processor() + set_type = 'test' + slot_list = [ + 'taxi-leaveAt', 'taxi-destination', 'taxi-departure', 'taxi-arriveBy', + 'restaurant-book_people', 'restaurant-book_day', + 'restaurant-book_time', 'restaurant-food', 'restaurant-pricerange', + 'restaurant-name', 'restaurant-area', 'hotel-book_people', + 'hotel-book_day', 'hotel-book_stay', 'hotel-name', 'hotel-area', + 'hotel-parking', 'hotel-pricerange', 'hotel-stars', 'hotel-internet', + 'hotel-type', 'attraction-type', 'attraction-name', 'attraction-area', + 'train-book_people', 'train-leaveAt', 'train-destination', 'train-day', + 'train-arriveBy', 'train-departure' + ] + append_history = True + use_history_labels = True + swap_utterances = True + label_value_repetitions = True + delexicalize_sys_utts = True, + unk_token = '[UNK]' + analyze = False + example = processor.create_example(utter1, history_states1, set_type, + slot_list, {}, append_history, + use_history_labels, swap_utterances, + label_value_repetitions, + delexicalize_sys_utts, unk_token, + analyze) + print(f'utterances is {example}') diff --git a/modelscope/preprocessors/space/fields/gen_field.py b/modelscope/preprocessors/space/fields/gen_field.py new file mode 100644 index 00000000..7012697f --- /dev/null +++ b/modelscope/preprocessors/space/fields/gen_field.py @@ -0,0 +1,687 @@ +""" +Field class +""" +import os +import random +from collections import OrderedDict +from itertools import chain + +import numpy as np + +from modelscope.preprocessors.space.tokenizer import Tokenizer +from modelscope.utils.nlp.space import ontology, utils +from modelscope.utils.nlp.space.db_ops import MultiWozDB +from modelscope.utils.nlp.space.utils import list2np + + +class BPETextField(object): + + pad_token = '[PAD]' + bos_token = '[BOS]' + eos_token = '[EOS]' + unk_token = '[UNK]' + sos_u_token = '' + eos_u_token = '' + sos_b_token = '' + eos_b_token = '' + sos_d_token = '' + eos_d_token = '' + sos_a_token = '' + eos_a_token = '' + sos_db_token = '' + eos_db_token = '' + sos_r_token = '' + eos_r_token = '' + + @property + def bot_id(self): + """ + 用于区分user和bot两个角色 + 1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings' + """ + return 0 + + @property + def user_id(self): + """ + 用于区分user和bot两个角色 + 1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings' + """ + return 1 + + @property + def vocab_size(self): + return self.tokenizer.vocab_size + + @property + def num_specials(self): + return len(self.tokenizer.special_tokens) + + @property + def pad_id(self): + return self.tokenizer.convert_tokens_to_ids([self.pad_token])[0] + + @property + def bos_id(self): + return self.tokenizer.convert_tokens_to_ids([self.bos_token])[0] + + @property + def eos_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_token])[0] + + @property + def unk_id(self): + return self.tokenizer.convert_tokens_to_ids([self.unk_token])[0] + + @property + def sos_u_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_u_token])[0] + + @property + def eos_u_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_u_token])[0] + + @property + def sos_b_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_b_token])[0] + + @property + def eos_b_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_b_token])[0] + + @property + def sos_db_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_db_token])[0] + + @property + def eos_db_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_db_token])[0] + + @property + def sos_a_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_a_token])[0] + + @property + def eos_a_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_a_token])[0] + + @property + def sos_r_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_r_token])[0] + + @property + def eos_r_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_r_token])[0] + + @property + def sos_d_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_d_token])[0] + + @property + def eos_d_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_d_token])[0] + + def __init__(self, config): + self.gpu = 0 + self.tokenizer = None + self.vocab = None + self.db = None + self.set_stats = {} + + self.prompt_num_for_understand = config.BPETextField.prompt_num_for_understand + self.prompt_num_for_policy = config.BPETextField.prompt_num_for_policy + self.understand_tokens = ontology.get_understand_tokens( + self.prompt_num_for_understand) + self.policy_tokens = ontology.get_policy_tokens( + self.prompt_num_for_policy) + + self.with_query_bow = config.BPETextField.with_query_bow + self.understand = config.BPETextField.understand + self.policy = config.BPETextField.policy + + self.batch_size = config.Trainer.batch_size + self.filtered = config.BPETextField.filtered + self.max_len = config.BPETextField.max_len + self.min_utt_len = config.BPETextField.min_utt_len + self.max_utt_len = config.BPETextField.max_utt_len + self.min_ctx_turn = config.BPETextField.min_ctx_turn + self.max_ctx_turn = config.BPETextField.max_ctx_turn - 1 # subtract reply turn + + self.use_true_prev_bspn = config.Generator.use_true_prev_bspn + self.use_true_prev_aspn = config.Generator.use_true_prev_aspn + self.use_true_db_pointer = config.Generator.use_true_db_pointer + self.use_true_prev_resp = config.Generator.use_true_prev_resp + self.use_true_curr_bspn = config.Generator.use_true_curr_bspn + self.use_true_curr_aspn = config.Generator.use_true_curr_aspn + self.use_all_previous_context = config.Generator.use_all_previous_context + self.use_true_bspn_for_ctr_eval = config.Generator.use_true_bspn_for_ctr_eval + self.use_true_domain_for_ctr_eval = config.Generator.use_true_domain_for_ctr_eval + + def collate_fn_multi_turn(self, samples): + batch_size = len(samples) + batch = {} + + src = [sp['src'][-self.max_ctx_turn:] for sp in samples] + query_token, src_token, src_pos, src_turn, src_role = [], [], [], [], [] + for utts in src: + query_token.append(utts[-1]) + utt_lens = [len(utt) for utt in utts] + + # Token ids + src_token.append(list(chain(*utts))[-self.max_len:]) + + # Position ids + pos = [list(range(utt_len)) for utt_len in utt_lens] + src_pos.append(list(chain(*pos))[-self.max_len:]) + + # Turn ids + turn = [[len(utts) - i] * l for i, l in enumerate(utt_lens)] + src_turn.append(list(chain(*turn))[-self.max_len:]) + + # Role ids + role = [ + [self.bot_id if (len(utts) - i) % 2 == 0 else self.user_id] * l + for i, l in enumerate(utt_lens) + ] + src_role.append(list(chain(*role))[-self.max_len:]) + + # src端序列和tgt端序列需要分开pad,以保证解码时第一个词对齐 + src_token = list2np(src_token, padding=self.pad_id) + src_pos = list2np(src_pos, padding=self.pad_id) + src_turn = list2np(src_turn, padding=self.pad_id) + src_role = list2np(src_role, padding=self.pad_id) + batch['src_token'] = src_token + batch['src_pos'] = src_pos + batch['src_type'] = src_role + batch['src_turn'] = src_turn + batch['src_mask'] = (src_token != self.pad_id).astype('int64') + + if self.with_query_bow: + query_token = list2np(query_token, padding=self.pad_id) + batch['query_token'] = query_token + batch['query_mask'] = (query_token != self.pad_id).astype('int64') + + if self.understand_ids and self.understand: + understand = [self.understand_ids for _ in samples] + understand_token = np.array(understand).astype('int64') + batch['understand_token'] = understand_token + batch['understand_mask'] = \ + (understand_token != self.pad_id).astype('int64') + + if self.policy_ids and self.policy: + policy = [self.policy_ids for _ in samples] + policy_token = np.array(policy).astype('int64') + batch['policy_token'] = policy_token + batch['policy_mask'] = \ + (policy_token != self.pad_id).astype('int64') + + if 'tgt' in samples[0]: + tgt = [sp['tgt'] for sp in samples] + + # Token ids & Label ids + tgt_token = list2np(tgt, padding=self.pad_id) + + # Position ids + tgt_pos = np.zeros_like(tgt_token) + tgt_pos[:] = np.arange(tgt_token.shape[1], dtype=tgt_token.dtype) + + # Turn ids + tgt_turn = np.zeros_like(tgt_token) + + # Role ids + tgt_role = np.full_like(tgt_token, self.bot_id) + + batch['tgt_token'] = tgt_token + batch['tgt_pos'] = tgt_pos + batch['tgt_type'] = tgt_role + batch['tgt_turn'] = tgt_turn + batch['tgt_mask'] = (tgt_token != self.pad_id).astype('int64') + + return batch, batch_size + + def _bucket_by_turn(self, encoded_data): + turn_bucket = {} + for dial in encoded_data: + turn_len = len(dial) + if turn_len not in turn_bucket: + turn_bucket[turn_len] = [] + turn_bucket[turn_len].append(dial) + return OrderedDict(sorted(turn_bucket.items(), key=lambda i: i[0])) + + def _construct_mini_batch(self, data): + all_batches = [] + batch = [] + for dial in data: + batch.append(dial) + if len(batch) == self.batch_size: + # print('batch size: %d, batch num +1'%(len(batch))) + all_batches.append(batch) + batch = [] + # if remainder > 1/2 batch_size, just put them in the previous batch, otherwise form a new batch + # print('last batch size: %d, batch num +1'%(len(batch))) + # if (len(batch) % len(cfg.cuda_device)) != 0: + # batch = batch[:-(len(batch) % len(cfg.cuda_device))] + # TODO deal with deleted data + if self.gpu <= 1: + if len(batch) > 0.5 * self.batch_size: + all_batches.append(batch) + elif len(all_batches): + all_batches[-1].extend(batch) + else: + all_batches.append(batch) + + return all_batches + + def transpose_batch(self, batch): + dial_batch = [] + turn_num = len(batch[0]) + for turn in range(turn_num): + turn_l = {} + for dial in batch: + this_turn = dial[turn] + for k in this_turn: + if k not in turn_l: + turn_l[k] = [] + turn_l[k].append(this_turn[k]) + dial_batch.append(turn_l) + return dial_batch + + def get_eval_data(self, set_name='dev'): + name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} + dial = name_to_set[set_name] + + if set_name not in self.set_stats: + self.set_stats[set_name] = {} + num_turns = 0 + num_dials = len(dial) + for d in dial: + num_turns += len(d) + + self.set_stats[set_name]['num_turns'] = num_turns + self.set_stats[set_name]['num_dials'] = num_dials + + return dial + + def get_nontranspose_data_iterator(self, all_batches): + for i, batch in enumerate(all_batches): + yield batch + + def get_data_iterator(self, all_batches): + for i, batch in enumerate(all_batches): + yield self.transpose_batch(batch) + + +class MultiWOZBPETextField(BPETextField): + + def __init__(self, model_dir, config): + super(MultiWOZBPETextField, self).__init__(config) + import spacy + self.nlp = spacy.load('en_core_web_sm') + + self.db = MultiWozDB( + model_dir, { + 'attraction': 'db/attraction_db_processed.json', + 'hospital': 'db/hospital_db_processed.json', + 'hotel': 'db/hotel_db_processed.json', + 'police': 'db/police_db_processed.json', + 'restaurant': 'db/restaurant_db_processed.json', + 'taxi': 'db/taxi_db_processed.json', + 'train': 'db/train_db_processed.json', + }) + self._build_vocab(model_dir) + + special_tokens = [ + self.pad_token, self.bos_token, self.eos_token, self.unk_token + ] + special_tokens.extend(self.add_sepcial_tokens()) + self.tokenizer = Tokenizer( + vocab_path=os.path.join(model_dir, 'vocab.txt'), + special_tokens=special_tokens, + tokenizer_type=config.BPETextField.tokenizer_type) + self.understand_ids = self.tokenizer.convert_tokens_to_ids( + self.understand_tokens) + self.policy_ids = self.tokenizer.convert_tokens_to_ids( + self.policy_tokens) + + return + + def get_ids(self, data: str): + result = [self.sos_u_id] + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize( + self._get_convert_str(data))) + [self.eos_u_id] + return result + + def inverse_transpose_turn(self, turn_list): + """ + eval, one dialog at a time + """ + dialogs = {} + turn_num = len(turn_list) + dial_id = turn_list[0]['dial_id'] + dialogs[dial_id] = [] + for turn_idx in range(turn_num): + dial_turn = {} + turn = turn_list[turn_idx] + for key, value in turn.items(): + if key == 'dial_id': + continue + if key == 'pointer' and self.db is not None: + turn_domain = turn['turn_domain'][-1] + value = self.db.pointerBack(value, turn_domain) + dial_turn[key] = value + dialogs[dial_id].append(dial_turn) + return dialogs + + def inverse_transpose_batch(self, turn_batch_list): + """ + :param turn_batch_list: list of transpose dial batch + """ + dialogs = {} + total_turn_num = len(turn_batch_list) + # initialize + for idx_in_batch, dial_id in enumerate(turn_batch_list[0]['dial_id']): + dialogs[dial_id] = [] + for turn_n in range(total_turn_num): + dial_turn = {} + turn_batch = turn_batch_list[turn_n] + for key, v_list in turn_batch.items(): + if key == 'dial_id': + continue + value = v_list[idx_in_batch] + if key == 'pointer' and self.db is not None: + turn_domain = turn_batch['turn_domain'][idx_in_batch][ + -1] + value = self.db.pointerBack(value, turn_domain) + dial_turn[key] = value + dialogs[dial_id].append(dial_turn) + return dialogs + + def get_batches(self, set_name): + """ + compute dataset stats. + """ + global dia_count + log_str = '' + name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} + dial = name_to_set[set_name] + turn_bucket = self._bucket_by_turn(dial) + # self._shuffle_turn_bucket(turn_bucket) + all_batches = [] + + if set_name not in self.set_stats: + self.set_stats[set_name] = {} + num_training_steps = 0 + num_turns = 0 + num_dials = 0 + + for k in turn_bucket: + if set_name != 'test' and k == 1 or k >= 17: + continue + batches = self._construct_mini_batch(turn_bucket[k]) + try: + log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( + k, len(turn_bucket[k]), len(batches), len(batches[-1])) + except Exception: + log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( + k, len(turn_bucket[k]), len(batches), 0.0) + # print("turn num:%d, dial num:v%d, batch num: %d, "%(k, len(turn_bucket[k]), len(batches))) + num_training_steps += k * len(batches) + num_turns += k * len(turn_bucket[k]) + num_dials += len(turn_bucket[k]) + all_batches += batches + log_str += 'total batch num: %d\n' % len(all_batches) + # print('total batch num: %d'%len(all_batches)) + # print('dialog count: %d'%dia_count) + # return all_batches + + # log stats + # logging.info(log_str) + # cfg.num_training_steps = num_training_steps * cfg.epoch_num + self.set_stats[set_name][ + 'num_training_steps_per_epoch'] = num_training_steps # turn-level的steps + self.set_stats[set_name]['num_turns'] = num_turns + self.set_stats[set_name]['num_dials'] = num_dials + + if set_name == 'train': + random.shuffle(all_batches) + return all_batches + + def add_sepcial_tokens(self): + """ + add special tokens to gpt tokenizer + serves a similar role of Vocab.construt() + make a dict of special tokens + """ + special_tokens = [] + prompt_tokens = self.understand_tokens + self.policy_tokens + special_tokens.extend( + ontology.get_special_tokens(other_tokens=prompt_tokens)) + + for word in ontology.all_domains + ['general']: + word = '[' + word + ']' + special_tokens.append(word) + for word in ontology.all_acts: + word = '[' + word + ']' + special_tokens.append(word) + for word in self.vocab._word2idx.keys(): + if word.startswith('[value_') and word.endswith(']'): + special_tokens.append(word) + + return special_tokens + + def _build_vocab(self, model_dir: str): + self.vocab = utils.MultiWOZVocab(3000) + vp = os.path.join('{}/vocab'.format(model_dir)) + self.vocab.load_vocab(vp) + return self.vocab.vocab_size + + def _get_convert_str(self, sent): + assert isinstance(sent, str) + return ' '.join([ + self.tokenizer.spec_convert_dict.get(tok, tok) + for tok in sent.split() + ]) + + def bspan_to_DBpointer(self, bspan, turn_domain): + constraint_dict = self.bspan_to_constraint_dict(bspan) + # print(constraint_dict) + matnums = self.db.get_match_num(constraint_dict) + match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1] + match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom + match = matnums[match_dom] + # vector = self.db.addDBPointer(match_dom, match) + vector = self.db.addDBIndicator(match_dom, match) + return vector + + def bspan_to_constraint_dict(self, bspan, bspn_mode='bspn'): + """ + ['[hotel]', 'pricerange', 'cheap', 'type', 'hotel'] -> {'hotel': {'pricerange': 'cheap', 'type': 'hotel'}} + """ + bspan = bspan.split() if isinstance(bspan, str) else bspan + constraint_dict = {} + domain = None + conslen = len(bspan) + for idx, cons in enumerate(bspan): + cons = self.vocab.decode(cons) if type(cons) is not str else cons + if cons == '': + break + if '[' in cons: + if cons[1:-1] not in ontology.all_domains: + continue + domain = cons[1:-1] + elif cons in ontology.get_slot: + if domain is None: + continue + if cons == 'people': + # handle confusion of value name "people's portraits..." and slot people + try: + ns = bspan[idx + 1] + ns = self.vocab.decode(ns) if type( + ns) is not str else ns + if ns == "'s": + continue + except Exception: + continue + if not constraint_dict.get(domain): + constraint_dict[domain] = {} + if bspn_mode == 'bsdx': + constraint_dict[domain][cons] = 1 + continue + vidx = idx + 1 + if vidx == conslen: + break + vt_collect = [] + vt = bspan[vidx] + vt = self.vocab.decode(vt) if type(vt) is not str else vt + while vidx < conslen and vt != '' and '[' not in vt and vt not in ontology.get_slot: + vt_collect.append(vt) + vidx += 1 + if vidx == conslen: + break + vt = bspan[vidx] + vt = self.vocab.decode(vt) if type(vt) is not str else vt + if vt_collect: + constraint_dict[domain][cons] = ' '.join(vt_collect) + + return constraint_dict + + def convert_batch_turn(self, turn_batch, pv_batch, first_turn=False): + """ + URURU:这里的含义是指轮级别的训练(数据整理),区别于session级别的训练方式(convert_batch_session); + 但不同于eval时的含义,eval时二者都是逐轮依次生成的,那时URURU的含义请见相关的函数注释; + + convert the current and the last turn + concat [U_0,R_0,...,U_{t-1}, R_{t-1}, U_t, B_t, A_t, R_t] + firts turn: [U_t, B_t, A_t, R_t] + try: [user, bspn, db, aspn, resp] + + """ + inputs = [] + if first_turn: + batch_zipped = zip(turn_batch['user'], turn_batch['bspn'], + turn_batch['db'], turn_batch['aspn'], + turn_batch['resp']) + for u, b, db, a, r in batch_zipped: + if self.use_true_curr_bspn: + src = [u + b + db] + tgt = a + r + else: + src = [u] + tgt = b + db + a + r + inputs.append({'src': src, 'tgt': tgt}) + pv = [src[-1], tgt] + pv_batch.append(pv) + else: + batch_zipped = zip(pv_batch, turn_batch['user'], + turn_batch['bspn'], turn_batch['db'], + turn_batch['aspn'], turn_batch['resp']) + for i, (pv, u, b, db, a, r) in enumerate(batch_zipped): + if self.use_true_curr_bspn: + src = pv + [u + b + db] + tgt = a + r + else: + src = pv + [u] + tgt = b + db + a + r + inputs.append({'src': src, 'tgt': tgt}) + pv = [src[-1], tgt] + pv_batch[i].extend(pv) + + return inputs, pv_batch + + def wrap_result_lm(self, result_dict, eos_syntax=None): + results = [] + eos_syntax = ontology.eos_tokens if not eos_syntax else eos_syntax + sos_syntax = ontology.sos_tokens + # ground truth bs, as, ds.. generate response + field = [ + 'dial_id', 'turn_num', 'user', 'bspn_gen', 'bsdx', 'resp_gen', + 'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'bspn', 'pointer', + 'qspn_gen', 'qspn' + ] + + for dial_id, turns in result_dict.items(): + entry = {'dial_id': dial_id, 'trun_num': len(turns)} + for f in field[2:]: + entry[f] = '' # TODO ??? + results.append(entry) + for turn_idx, turn in enumerate(turns): + entry = {'dial_id': dial_id} + for key in field: + if key in ['dial_id']: + continue + v = turn.get(key, '') + if key == 'turn_domain': + v = ' '.join(v) + + if key in eos_syntax and v != '': + # remove eos tokens + v = self.tokenizer.decode(v) + v = v.split() + # remove eos/sos in span + if eos_syntax[key] in v: + v.remove(eos_syntax[key]) + if sos_syntax[key] in v: + v.remove(sos_syntax[key]) + v = ' '.join(v) + else: + pass # v = v + entry[key] = v + + results.append(entry) + + return results, field + + def convert_turn_eval(self, turn, pv_turn, first_turn=False): + """ + input: [all previous ubar, U_t, B_t, A_t] predict R_t + firts turn: [U_t, B_t, A_t] predict R_t + + regarding the context, all previous ubar is too slow, try the previous ubar + """ + inputs = {} + + context_list = [] + prompt_id = None + if self.use_true_curr_bspn: + if self.use_true_curr_aspn: # only predict resp + context_list = ['user', 'bspn', 'db', 'aspn'] + prompt_id = self.sos_r_id + else: # predicted aspn + context_list = ['user', 'bspn', 'db'] + prompt_id = self.sos_a_id + else: # predict bspn aspn resp. db are not predicted. this part tbd. + context_list = ['user'] + prompt_id = self.sos_b_id + + if first_turn: + context = [] + for c in context_list: + context += turn[c] + + inputs['src'] = [context] + inputs['labels'] = [context] + else: + context = [] + for c in context_list: + context += turn[c] + + if self.use_true_curr_bspn: + pv_context = pv_turn['labels'] + [ + pv_turn['aspn'] + pv_turn['resp'] + ] + else: + pv_info = pv_turn['bspn'] + pv_turn['db'] + pv_turn[ + 'aspn'] + pv_turn['resp'] + pv_context = pv_turn['labels'] + [pv_info] + + # prompt response, add sos_r + inputs['src'] = pv_context + [context] + + if self.use_all_previous_context: + inputs['labels'] = pv_context + [ + context + ] # use all previous ubar history + else: + inputs['labels'] = [context] # use previous turn + + return inputs, prompt_id diff --git a/modelscope/preprocessors/space/fields/intent_field.py b/modelscope/preprocessors/space/fields/intent_field.py new file mode 100644 index 00000000..9907165e --- /dev/null +++ b/modelscope/preprocessors/space/fields/intent_field.py @@ -0,0 +1,1093 @@ +""" +Intent Field class +""" +import glob +import multiprocessing +import os +import random +import re +import time +from collections import defaultdict +from itertools import chain + +import json +import numpy as np +from tqdm import tqdm + +from modelscope.preprocessors.space.tokenizer import Tokenizer +from modelscope.utils.nlp.space import ontology, utils +from modelscope.utils.nlp.space.scores import hierarchical_set_score +from modelscope.utils.nlp.space.utils import list2np + + +class BPETextField(object): + + pad_token = '[PAD]' + bos_token = '[BOS]' + eos_token = '[EOS]' + unk_token = '[UNK]' + mask_token = '[MASK]' + sos_u_token = '' + eos_u_token = '' + sos_b_token = '' + eos_b_token = '' + sos_db_token = '' + eos_db_token = '' + sos_a_token = '' + eos_a_token = '' + sos_r_token = '' + eos_r_token = '' + + def __init__(self, model_dir, config): + self.score_matrixs = {} + self.prompt_num_for_understand = config.BPETextField.prompt_num_for_understand + self.prompt_num_for_policy = config.BPETextField.prompt_num_for_policy + self.understand_tokens = ontology.get_understand_tokens( + self.prompt_num_for_understand) + self.policy_tokens = ontology.get_policy_tokens( + self.prompt_num_for_policy) + special_tokens = [ + self.pad_token, self.bos_token, self.eos_token, self.unk_token + ] + special_tokens.extend(self.add_sepcial_tokens()) + self.tokenizer = Tokenizer( + vocab_path=os.path.join(model_dir, 'vocab.txt'), + special_tokens=special_tokens, + tokenizer_type=config.BPETextField.tokenizer_type) + self.understand_ids = self.numericalize(self.understand_tokens) + self.policy_ids = self.numericalize(self.policy_tokens) + + self.tokenizer_type = config.BPETextField.tokenizer_type + self.filtered = config.BPETextField.filtered + self.max_len = config.BPETextField.max_len + self.min_utt_len = config.BPETextField.min_utt_len + self.max_utt_len = config.BPETextField.max_utt_len + self.min_ctx_turn = config.BPETextField.min_ctx_turn + self.max_ctx_turn = config.BPETextField.max_ctx_turn + self.policy = config.BPETextField.policy + self.generation = config.BPETextField.generation + self.with_mlm = config.Dataset.with_mlm + self.with_query_bow = config.BPETextField.with_query_bow + self.with_contrastive = config.Dataset.with_contrastive + self.num_process = config.Dataset.num_process + self.dynamic_score = config.Dataset.dynamic_score + self.abandon_label = config.Dataset.abandon_label + self.trigger_role = config.Dataset.trigger_role + self.trigger_data = config.Dataset.trigger_data.split( + ',') if config.Dataset.trigger_data else [] + + # data_paths = list(os.path.dirname(c) for c in sorted( + # glob.glob(hparams.data_dir + '/**/' + f'train.{hparams.tokenizer_type}.jsonl', recursive=True))) + # self.data_paths = self.filter_data_path(data_paths=data_paths) + # self.labeled_data_paths = [data_path for data_path in self.data_paths if 'UniDA' in data_path] + # self.unlabeled_data_paths = [data_path for data_path in self.data_paths if 'UnDial' in data_path] + # assert len(self.unlabeled_data_paths) + len(self.labeled_data_paths) == len(self.data_paths) + # assert len(self.labeled_data_paths) or len(self.unlabeled_data_paths), 'No dataset is loaded' + + @property + def vocab_size(self): + return self.tokenizer.vocab_size + + @property + def num_specials(self): + return len(self.tokenizer.special_tokens) + + @property + def pad_id(self): + return self.tokenizer.convert_tokens_to_ids([self.pad_token])[0] + + @property + def bos_id(self): + return self.tokenizer.convert_tokens_to_ids([self.bos_token])[0] + + @property + def eos_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_token])[0] + + @property + def unk_id(self): + return self.tokenizer.convert_tokens_to_ids([self.unk_token])[0] + + @property + def mask_id(self): + return self.tokenizer.convert_tokens_to_ids([self.mask_token])[0] + + @property + def sos_u_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_u_token])[0] + + @property + def eos_u_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_u_token])[0] + + @property + def sos_b_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_b_token])[0] + + @property + def eos_b_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_b_token])[0] + + @property + def sos_db_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_db_token])[0] + + @property + def eos_db_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_db_token])[0] + + @property + def sos_a_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_a_token])[0] + + @property + def eos_a_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_a_token])[0] + + @property + def sos_r_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_r_token])[0] + + @property + def eos_r_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_r_token])[0] + + @property + def bot_id(self): + """ + 用于区分user和bot两个角色 + 1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings' + """ + return 0 + + @property + def user_id(self): + """ + 用于区分user和bot两个角色 + 1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings' + """ + return 1 + + def add_sepcial_tokens(self): + prompt_tokens = self.understand_tokens + self.policy_tokens + return ontology.get_special_tokens(other_tokens=prompt_tokens) + + def filter_data_path(self, data_paths): + if self.trigger_data: + filtered_data_paths = [] + for data_path in data_paths: + for data_name in self.trigger_data: + if data_path.endswith(f'/{data_name}'): + filtered_data_paths.append(data_path) + break + else: + filtered_data_paths = data_paths + return filtered_data_paths + + def load_score_matrix(self, data_type, data_iter=None): + """ + load score matrix for all labeled datasets + """ + for data_path in self.labeled_data_paths: + file_index = os.path.join( + data_path, f'{data_type}.{self.tokenizer_type}.jsonl') + file = os.path.join(data_path, f'{data_type}.Score.npy') + if self.dynamic_score: + score_matrix = {} + print(f"Created 1 score cache dict for data in '{file_index}'") + else: + # TODO add post score matrix + assert os.path.exists(file), f"{file} isn't exist" + print(f"Loading 1 score matrix from '{file}' ...") + fp = np.memmap(file, dtype='float32', mode='r') + assert len(fp.shape) == 1 + num = int(np.sqrt(fp.shape[0])) + score_matrix = fp.reshape(num, num) + print(f"Loaded 1 score matrix for data in '{file_index}'") + self.score_matrixs[file_index] = score_matrix + + def random_word(self, chars): + output_label = [] + output_chars = [] + + for i, char in enumerate(chars): + # TODO delete this part to learn special tokens + if char in [ + self.sos_u_id, self.eos_u_id, self.sos_r_id, self.eos_r_id + ]: + output_chars.append(char) + output_label.append(self.pad_id) + continue + + prob = random.random() + if prob < 0.15: + prob /= 0.15 + + # 80% randomly change token to mask token + if prob < 0.8: + output_chars.append(self.mask_id) + + # 10% randomly change token to random token + elif prob < 0.9: + tmp = random.randint(1, self.vocab_size - 1) + output_chars.append(tmp) # start from 1, to exclude pad_id + + # 10% randomly change token to current token + else: + output_chars.append(char) + + output_label.append(char) + + else: + output_chars.append(char) + output_label.append(self.pad_id) + + return output_chars, output_label + + def create_masked_lm_predictions(self, sample): + src = sample['src'] + src_span_mask = sample['src_span_mask'] + mlm_inputs = [] + mlm_labels = [] + for chars, chars_span_mask in zip(src, src_span_mask): + if sum(chars_span_mask): + mlm_input, mlm_label = [], [] + for char, char_mask in zip(chars, chars_span_mask): + if char_mask: + mlm_input.append(self.mask_id) + mlm_label.append(char) + else: + mlm_input.append(char) + mlm_label.append(self.pad_id) + else: + mlm_input, mlm_label = self.random_word(chars) + mlm_inputs.append(mlm_input) + mlm_labels.append(mlm_label) + + sample['mlm_inputs'] = mlm_inputs + sample['mlm_labels'] = mlm_labels + return sample + + def create_span_masked_lm_predictions(self, sample): + src = sample['src'] + src_span_mask = sample['src_span_mask'] + mlm_inputs = [] + mlm_labels = [] + for chars, chars_span_mask in zip(src, src_span_mask): + mlm_input, mlm_label = [], [] + for char, char_mask in zip(chars, chars_span_mask): + if char_mask: + mlm_input.append(self.mask_id) + mlm_label.append(char) + else: + mlm_input.append(char) + mlm_label.append(self.pad_id) + mlm_inputs.append(mlm_input) + mlm_labels.append(mlm_label) + + sample['mlm_inputs'] = mlm_inputs + sample['mlm_labels'] = mlm_labels + return sample + + def create_token_masked_lm_predictions(self, sample): + mlm_inputs = sample['mlm_inputs'] + mlm_labels = sample['mlm_labels'] + + for i, span_mlm_label in enumerate(mlm_labels): + if not sum(span_mlm_label): + mlm_input, mlm_label = self.random_word(mlm_inputs[i]) + mlm_inputs[i] = mlm_input + mlm_labels[i] = mlm_label + + return sample + + def numericalize(self, tokens): + """ + here only "convert_tokens_to_ids", + which need be tokenized into tokens(sub-words) by "tokenizer.tokenize" before + """ + assert isinstance(tokens, list) + if len(tokens) == 0: + return [] + element = tokens[0] + if isinstance(element, list): + return [self.numericalize(s) for s in tokens] + else: + return self.tokenizer.convert_tokens_to_ids(tokens) + + def denumericalize(self, numbers): + """ + here first "convert_ids_to_tokens", then combine sub-words into origin words + """ + assert isinstance(numbers, list) + if len(numbers) == 0: + return [] + element = numbers[0] + if isinstance(element, list): + return [self.denumericalize(x) for x in numbers] + else: + return self.tokenizer.decode( + numbers, + ignore_tokens=[self.bos_token, self.eos_token, self.pad_token]) + + def save_examples(self, examples, filename): + start = time.time() + if filename.endswith('npy'): + print(f"Saving 1 object to '{filename}' ...") + assert len( + examples.shape) == 2 and examples.shape[0] == examples.shape[1] + num = examples.shape[0] + fp = np.memmap( + filename, dtype='float32', mode='w+', shape=(num, num)) + fp[:] = examples[:] + fp.flush() + elapsed = time.time() - start + print(f'Saved 1 object (elapsed {elapsed:.2f}s)') + elif filename.endswith('jsonl'): + print(f"Saving examples to '{filename}' ...") + with open(filename, 'w', encoding='utf-8') as fp: + for ex in examples: + fp.write(json.dumps(ex) + '\n') + elapsed = time.time() - start + print(f'Saved {len(examples)} examples (elapsed {elapsed:.2f}s)') + else: + print(f"Saving examples to '{filename}' ...") + raise ValueError(f'Unsport file format: {filename}') + + def load_examples(self, filename): + start = time.time() + if filename.endswith('npy'): + print(f"Loading 1 object from '{filename}' ...") + fp = np.memmap(filename, dtype='float32', mode='r') + assert len(fp.shape) == 1 + num = int(np.sqrt(fp.shape[0])) + examples = fp.reshape(num, num) + elapsed = time.time() - start + print(f'Loaded 1 object (elapsed {elapsed:.2f}s)') + else: + print(f"Loading examples from '{filename}' ...") + with open(filename, 'r', encoding='utf-8') as fp: + examples = list(map(lambda s: json.loads(s.strip()), fp)) + elapsed = time.time() - start + print(f'Loaded {len(examples)} examples (elapsed {elapsed:.2f}s)') + return examples + + def utt_filter_pred(self, utt): + return self.min_utt_len <= len(utt) \ + and (not self.filtered or len(utt) <= self.max_utt_len) + + def utts_filter_pred(self, utts): + return self.min_ctx_turn <= len(utts) \ + and (not self.filtered or len(utts) <= self.max_ctx_turn) + + def get_token_pos(self, tok_list, value_label): + find_pos = [] + found = False + label_list = [ + item + for item in map(str.strip, re.split('(\\W+)', value_label.lower())) + if len(item) > 0 + ] + len_label = len(label_list) + for i in range(len(tok_list) + 1 - len_label): + if tok_list[i:i + len_label] == label_list: + find_pos.append((i, i + len_label)) # start, exclusive_end + found = True + return found, find_pos + + def build_score_matrix(self, examples): + """ + build symmetric score matrix + """ + assert self.num_process == 1 + print('Building score matrix from examples ...') + num = len(examples) + score_matrix = np.eye( + num, num, dtype='float32' + ) # in case of empty label of self, resulting in score 0. + + for i in tqdm(range(num)): + for j in range(i): + # TODO change the score method + score = hierarchical_set_score( + frame1=examples[i]['label'], frame2=examples[j]['label']) + score_matrix[i][j] = score + score_matrix[j][i] = score + + print('Built score matrix') + return score_matrix + + def build_score_matrix_on_the_fly(self, + ids, + labels, + data_file, + is_post=False): + """ + build symmetric score matrix on the fly + @is_post: True for resp label of sample i and j, False for query label of sample i and j + """ + num = len(labels) + tag = 'r' if is_post else 'q' + assert len(ids) == len(labels) + score_matrix = np.eye( + num, num, dtype='float32' + ) # in case of empty label of self, resulting in score 0. + + for i in range(num): + for j in range(i): + score = self.score_matrixs[data_file].get( + f'{ids[i]}-{ids[j]}-{tag}', None) + if score is None: + score = self.score_matrixs[data_file].get( + f'{ids[j]}-{ids[i]}-{tag}', None) + if score is None: + # TODO change the score method + score = hierarchical_set_score( + frame1=labels[i], frame2=labels[j]) + self.score_matrixs[data_file][ + f'{ids[i]}-{ids[j]}-{tag}'] = score + score_matrix[i][j] = score + score_matrix[j][i] = score + + return score_matrix + + def build_score_matrix_func(self, examples, start, exclusive_end): + """ + build sub score matrix + """ + num = len(examples) + process_id = os.getpid() + description = f'PID: {process_id} Start: {start} End: {exclusive_end}' + print( + f'PID-{process_id}: Building {start} to {exclusive_end} lines score matrix from examples ...' + ) + score_matrix = np.zeros((exclusive_end - start, num), dtype='float32') + + for abs_i, i in enumerate( + tqdm(range(start, exclusive_end), desc=description)): + for j in range(num): + # TODO change the score method + score = hierarchical_set_score( + frame1=examples[i]['label'], frame2=examples[j]['label']) + score_matrix[abs_i][j] = score + + print( + f'PID-{process_id}: Built {start} to {exclusive_end} lines score matrix' + ) + return {'start': start, 'score_matrix': score_matrix} + + def build_score_matrix_multiprocessing(self, examples): + """ + build score matrix + """ + assert self.num_process >= 2 and multiprocessing.cpu_count() >= 2 + print('Building score matrix from examples ...') + results = [] + num = len(examples) + sub_num, res_num = num // self.num_process, num % self.num_process + patches = [sub_num] * (self.num_process - 1) + [sub_num + res_num] + + start = 0 + pool = multiprocessing.Pool(processes=self.num_process) + for patch in patches: + exclusive_end = start + patch + results.append( + pool.apply_async(self.build_score_matrix_func, + (examples, start, exclusive_end))) + start = exclusive_end + pool.close() + pool.join() + + sub_score_matrixs = [result.get() for result in results] + sub_score_matrixs = sorted( + sub_score_matrixs, key=lambda sub: sub['start']) + sub_score_matrixs = [ + sub_score_matrix['score_matrix'] + for sub_score_matrix in sub_score_matrixs + ] + score_matrix = np.concatenate(sub_score_matrixs, axis=0) + assert score_matrix.shape == (num, num) + np.fill_diagonal( + score_matrix, + 1.) # in case of empty label of self, resulting in score 0. + + print('Built score matrix') + return score_matrix + + def extract_span_texts(self, text, label): + span_texts = [] + for domain, frame in label.items(): + for act, slot_values in frame.items(): + for slot, values in slot_values.items(): + for value in values: + if value['span']: + span_texts.append( + text[value['span'][0]:value['span'][1]]) + elif str(value['value']).strip().lower() in text.strip( + ).lower(): + span_texts.append(str(value['value'])) + return span_texts + + def fix_label(self, label): + for domain, frame in label.items(): + if not frame: + return {} + for act, slot_values in frame.items(): + if act == 'DEFAULT_INTENT' and not slot_values: + return {} + return label + + def build_examples_multi_turn(self, data_file, data_type='train'): + print(f"Reading examples from '{data_file}' ...") + examples = [] + ignored = 0 + + with open(data_file, 'r', encoding='utf-8') as f: + input_data = json.load(f) + for dialog_id in tqdm(input_data): + turns = input_data[dialog_id]['turns'] + history, history_role, history_span_mask, history_label = [], [], [], [] + for t, turn in enumerate(turns): + label = turn['label'] + role = turn['role'] + text = turn['text'] + utterance, span_mask = [], [] + + token_list = [ + tok for tok in map(str.strip, + re.split('(\\W+)', text.lower())) + if len(tok) > 0 + ] + span_list = np.zeros(len(token_list), dtype=np.int32) + span_texts = self.extract_span_texts( + text=text, label=label) + + for span_text in span_texts: + found, find_pos = self.get_token_pos( + tok_list=token_list, value_label=span_text) + if found: + for start, exclusive_end in find_pos: + span_list[start:exclusive_end] = 1 + + token_list = [ + self.tokenizer.tokenize(token) for token in token_list + ] + span_list = [[tag] * len(token_list[i]) + for i, tag in enumerate(span_list)] + for sub_tokens in token_list: + utterance.extend(sub_tokens) + for sub_spans in span_list: + span_mask.extend(sub_spans) + assert len(utterance) == len(span_mask) + + history.append(utterance) + history_role.append(role) + history_span_mask.append(span_mask) + history_label.append(self.fix_label(label)) + + tmp = self.utts_filter_pred(history[:-1]) and all( + map(self.utt_filter_pred, history)) + if ( + tmp or data_type == 'test' + ) and role in self.trigger_role and t: # TODO consider test + src = [ + s[-self.max_utt_len:] + for s in history[:-1][-self.max_ctx_turn:] + ] + src_span_mask = [ + s[-self.max_utt_len:] for s in + history_span_mask[:-1][-self.max_ctx_turn:] + ] + roles = [ + role + for role in history_role[:-1][-self.max_ctx_turn:] + ] + + new_src = [] + for i, s in enumerate(src): + if roles[i] == 'user': + user_or_sys = [self.eos_u_id] + else: + user_or_sys = [self.sos_r_id] + tmp = [self.sos_u_id + ] + self.numericalize(s) + user_or_sys + tmp = tmp + self.numericalize(s) + [self.eos_r_id] + new_src.append(tmp) + + src_span_mask = [[0] + list(map(int, s)) + [0] + for s in src_span_mask] + + tgt = [self.sos_r_id] + self.numericalize( + history[-1]) + [self.eos_r_id] + if data_type != 'test': + tgt = tgt[:self.max_utt_len + 2] + + ex = { + 'dialog_id': dialog_id, + 'turn_id': turn['turn_id'], + 'src': new_src, + 'src_span_mask': src_span_mask, + 'tgt': tgt, + 'query_label': history_label[-2], + 'resp_label': history_label[-1], + 'extra_info': turn.get('extra_info', '') + } + examples.append(ex) + else: + ignored += 1 + + # add span mlm inputs and span mlm labels in advance + if self.with_mlm: + examples = [ + self.create_span_masked_lm_predictions(example) + for example in examples + ] + + # add absolute id of the dataset for indexing scores in its score matrix + for i, example in enumerate(examples): + example['id'] = i + + print( + f'Built {len(examples)} {data_type.upper()} examples ({ignored} filtered)' + ) + return examples + + def preprocessor(self, text_list): + role = 'user' + examples = [] + + for text in text_list: + history, history_role, history_span_mask = [], [], [] + utterance, span_mask = [], [] + token_list = [ + tok for tok in map(str.strip, re.split('(\\W+)', text.lower())) + if len(tok) > 0 + ] + span_list = np.zeros(len(token_list), dtype=np.int32) + token_list = [ + self.tokenizer.tokenize(token) for token in token_list + ] + span_list = [[tag] * len(token_list[i]) + for i, tag in enumerate(span_list)] + + for sub_tokens in token_list: + utterance.extend(sub_tokens) + for sub_spans in span_list: + span_mask.extend(sub_spans) + assert len(utterance) == len(span_mask) + + history.append(utterance) + history_role.append(role) + history_span_mask.append(span_mask) + + src = [s[-self.max_utt_len:] for s in history[-self.max_ctx_turn:]] + src_span_mask = [ + s[-self.max_utt_len:] + for s in history_span_mask[-self.max_ctx_turn:] + ] + roles = [role for role in history_role[-self.max_ctx_turn:]] + + new_src = [] + for i, s in enumerate(src): + if roles[i] == 'user': + user_or_sys = [self.eos_u_id] + else: + user_or_sys = [self.sos_r_id] + tmp = [self.sos_u_id] + self.numericalize(s) + user_or_sys + tmp = tmp + self.numericalize(s) + [self.eos_r_id] + new_src.append(tmp) + + src_span_mask = [[0] + list(map(int, s)) + [0] + for s in src_span_mask] + + ex = { + 'dialog_id': 'inference', + 'turn_id': 0, + 'role': role, + 'src': new_src, + 'src_span_mask': src_span_mask, + 'query_label': { + 'DEFAULT_DOMAIN': { + 'card_arrival': {} + } + }, + 'extra_info': { + 'intent_label': -1 + } + } + examples.append(ex) + # add span mlm inputs and span mlm labels in advance + if self.with_mlm: + examples = [ + self.create_span_masked_lm_predictions(example) + for example in examples + ] + + # add absolute id of the dataset for indexing scores in its score matrix + for i, example in enumerate(examples): + example['id'] = i + + return examples + + def build_examples_single_turn(self, data_file, data_type='train'): + print(f"Reading examples from '{data_file}' ...") + examples = [] + ignored = 0 + + with open(data_file, 'r', encoding='utf-8') as f: + input_data = json.load(f) + for dialog_id in tqdm(input_data): + turns = input_data[dialog_id]['turns'] + history, history_role, history_span_mask = [], [], [] + for turn in turns: + label = turn['label'] + role = turn['role'] + text = turn['text'] + utterance, span_mask = [], [] + + token_list = [ + tok for tok in map(str.strip, + re.split('(\\W+)', text.lower())) + if len(tok) > 0 + ] + span_list = np.zeros(len(token_list), dtype=np.int32) + span_texts = self.extract_span_texts( + text=text, label=label) + + for span_text in span_texts: + found, find_pos = self.get_token_pos( + tok_list=token_list, value_label=span_text) + if found: + for start, exclusive_end in find_pos: + span_list[start:exclusive_end] = 1 + + token_list = [ + self.tokenizer.tokenize(token) for token in token_list + ] + span_list = [[tag] * len(token_list[i]) + for i, tag in enumerate(span_list)] + for sub_tokens in token_list: + utterance.extend(sub_tokens) + for sub_spans in span_list: + span_mask.extend(sub_spans) + assert len(utterance) == len(span_mask) + + history.append(utterance) + history_role.append(role) + history_span_mask.append(span_mask) + + tmp = self.utts_filter_pred(history) and all( + map(self.utt_filter_pred, history)) + tmp = tmp or data_type == 'test' + if tmp and role in self.trigger_role: # TODO consider test + src = [ + s[-self.max_utt_len:] + for s in history[-self.max_ctx_turn:] + ] + src_span_mask = [ + s[-self.max_utt_len:] + for s in history_span_mask[-self.max_ctx_turn:] + ] + roles = [ + role for role in history_role[-self.max_ctx_turn:] + ] + new_src = [] + for i, s in enumerate(src): + if roles[i] == 'user': + user_or_sys = [self.eos_u_id] + else: + user_or_sys = [self.sos_r_id] + tmp = [self.sos_u_id + ] + self.numericalize(s) + user_or_sys + tmp = tmp + self.numericalize(s) + [self.eos_r_id] + new_src.append(tmp) + + src_span_mask = [[0] + list(map(int, s)) + [0] + for s in src_span_mask] + + ex = { + 'dialog_id': dialog_id, + 'turn_id': turn['turn_id'], + 'role': role, + 'src': new_src, + 'src_span_mask': src_span_mask, + 'query_label': self.fix_label(label), + 'extra_info': turn.get('extra_info', '') + } + examples.append(ex) + else: + ignored += 1 + + # add span mlm inputs and span mlm labels in advance + if self.with_mlm: + examples = [ + self.create_span_masked_lm_predictions(example) + for example in examples + ] + + # add absolute id of the dataset for indexing scores in its score matrix + for i, example in enumerate(examples): + example['id'] = i + + print( + f'Built {len(examples)} {data_type.upper()} examples ({ignored} filtered)' + ) + return examples + + def collate_fn_multi_turn(self, samples): + batch_size = len(samples) + batch = {} + + src = [sp['src'] for sp in samples] + query_token, src_token, src_pos, src_turn, src_role = [], [], [], [], [] + for utts in src: + query_token.append(utts[-1]) + utt_lens = [len(utt) for utt in utts] + + # Token ids + src_token.append(list(chain(*utts))[-self.max_len:]) + + # Position ids + pos = [list(range(utt_len)) for utt_len in utt_lens] + src_pos.append(list(chain(*pos))[-self.max_len:]) + + # Turn ids + turn = [[len(utts) - i] * l for i, l in enumerate(utt_lens)] + src_turn.append(list(chain(*turn))[-self.max_len:]) + + # Role ids + role = [ + [self.bot_id if (len(utts) - i) % 2 == 0 else self.user_id] * l + for i, l in enumerate(utt_lens) + ] + src_role.append(list(chain(*role))[-self.max_len:]) + + # src端序列和tgt端序列需要分开pad,以保证解码时第一个词对齐 + src_token = list2np(src_token, padding=self.pad_id) + src_pos = list2np(src_pos, padding=self.pad_id) + src_turn = list2np(src_turn, padding=self.pad_id) + src_role = list2np(src_role, padding=self.pad_id) + batch['src_token'] = src_token + batch['src_pos'] = src_pos + batch['src_type'] = src_role + batch['src_turn'] = src_turn + batch['src_mask'] = (src_token != self.pad_id).astype('int64') + + if self.with_query_bow: + query_token = list2np(query_token, padding=self.pad_id) + batch['query_token'] = query_token + batch['query_mask'] = (query_token != self.pad_id).astype('int64') + + if self.with_mlm: + mlm_token, mlm_label = [], [] + raw_mlm_input = [sp['mlm_inputs'] for sp in samples] + raw_mlm_label = [sp['mlm_labels'] for sp in samples] + for inputs in raw_mlm_input: + mlm_token.append(list(chain(*inputs))[-self.max_len:]) + for labels in raw_mlm_label: + mlm_label.append(list(chain(*labels))[-self.max_len:]) + + mlm_token = list2np(mlm_token, padding=self.pad_id) + mlm_label = list2np(mlm_label, padding=self.pad_id) + batch['mlm_token'] = mlm_token + batch['mlm_label'] = mlm_label + batch['mlm_mask'] = (mlm_label != self.pad_id).astype('int64') + + if self.dynamic_score and self.with_contrastive and not self.abandon_label: + query_labels = [sp['query_label'] for sp in samples] + batch['query_labels'] = query_labels + if self.trigger_role == 'system': + resp_labels = [sp['resp_label'] for sp in samples] + batch['resp_labels'] = resp_labels + batch['label_ids'] = np.arange( + batch_size) # to identify labels for each GPU when multi-gpu + + if self.understand_ids: + understand = [self.understand_ids for _ in samples] + understand_token = np.array(understand).astype('int64') + batch['understand_token'] = understand_token + batch['understand_mask'] = \ + (understand_token != self.pad_id).astype('int64') + + if self.policy_ids and self.policy: + policy = [self.policy_ids for _ in samples] + policy_token = np.array(policy).astype('int64') + batch['policy_token'] = policy_token + batch['policy_mask'] = \ + (policy_token != self.pad_id).astype('int64') + + if 'tgt' in samples[0]: + tgt = [sp['tgt'] for sp in samples] + + # Token ids & Label ids + tgt_token = list2np(tgt, padding=self.pad_id) + + # Position ids + tgt_pos = np.zeros_like(tgt_token) + tgt_pos[:] = np.arange(tgt_token.shape[1], dtype=tgt_token.dtype) + + # Turn ids + tgt_turn = np.zeros_like(tgt_token) + + # Role ids + tgt_role = np.full_like(tgt_token, self.bot_id) + + batch['tgt_token'] = tgt_token + batch['tgt_pos'] = tgt_pos + batch['tgt_type'] = tgt_role + batch['tgt_turn'] = tgt_turn + batch['tgt_mask'] = (tgt_token != self.pad_id).astype('int64') + + if 'id' in samples[0]: + ids = [sp['id'] for sp in samples] + ids = np.array(ids).astype('int64') + batch['ids'] = ids + + return batch, batch_size + + +class IntentBPETextField(BPETextField): + + def __init__(self, model_dir, config): + super(IntentBPETextField, self).__init__(model_dir, config) + + def retrieve_examples(self, + dataset, + labels, + inds, + task, + num=None, + cache=None): + assert task == 'intent', 'Example-driven may only be used with intent prediction' + if num is None and labels is not None: + num = len(labels) * 2 + + # Populate cache + if cache is None: + cache = defaultdict(list) + for i, example in enumerate(dataset): + assert i == example['id'] + cache[example['extra_info']['intent_label']].append(i) + + # One example for each label + example_inds = [] + for lable in set(labels.tolist()): + if lable == -1: + continue + + ind = random.choice(cache[l]) + retries = 0 + while ind in inds.tolist() or type(ind) is not int: + ind = random.choice(cache[l]) + retries += 1 + if retries > len(dataset): + break + + example_inds.append(ind) + + # Sample randomly until we hit batch size + while len(example_inds) < min(len(dataset), num): + ind = random.randint(0, len(dataset) - 1) + if ind not in example_inds and ind not in inds.tolist(): + example_inds.append(ind) + + # Create examples + example_batch = {} + examples = [dataset[i] for i in example_inds] + examples, _ = self.collate_fn_multi_turn(examples) + example_batch['example_src_token'] = examples['src_token'] + example_batch['example_src_pos'] = examples['src_pos'] + example_batch['example_src_type'] = examples['src_type'] + example_batch['example_src_turn'] = examples['src_turn'] + example_batch['example_src_mask'] = examples['src_mask'] + example_batch['example_tgt_token'] = examples['tgt_token'] + example_batch['example_tgt_mask'] = examples['tgt_mask'] + example_batch['example_intent'] = examples['intent_label'] + + return example_batch + + def collate_fn_multi_turn(self, samples): + batch_size = len(samples) + batch = {} + + cur_roles = [sp['role'] for sp in samples] + src = [sp['src'] for sp in samples] + src_token, src_pos, src_turn, src_role = [], [], [], [] + for utts, cur_role in zip(src, cur_roles): + utt_lens = [len(utt) for utt in utts] + + # Token ids + src_token.append(list(chain(*utts))[-self.max_len:]) + + # Position ids + pos = [list(range(utt_len)) for utt_len in utt_lens] + src_pos.append(list(chain(*pos))[-self.max_len:]) + + # Turn ids + turn = [[len(utts) - i] * l for i, l in enumerate(utt_lens)] + src_turn.append(list(chain(*turn))[-self.max_len:]) + + # Role ids + if cur_role == 'user': + role = [[ + self.bot_id if (len(utts) - i) % 2 == 0 else self.user_id + ] * l for i, l in enumerate(utt_lens)] + else: + role = [[ + self.user_id if (len(utts) - i) % 2 == 0 else self.bot_id + ] * l for i, l in enumerate(utt_lens)] + src_role.append(list(chain(*role))[-self.max_len:]) + + # src端序列和tgt端序列需要分开pad,以保证解码时第一个词对齐 + src_token = list2np(src_token, padding=self.pad_id) + src_pos = list2np(src_pos, padding=self.pad_id) + src_turn = list2np(src_turn, padding=self.pad_id) + src_role = list2np(src_role, padding=self.pad_id) + batch['src_token'] = src_token + batch['src_pos'] = src_pos + batch['src_type'] = src_role + batch['src_turn'] = src_turn + batch['src_mask'] = (src_token != self.pad_id).astype( + 'int64') # input mask + + if self.with_mlm: + mlm_token, mlm_label = [], [] + raw_mlm_input = [sp['mlm_inputs'] for sp in samples] + raw_mlm_label = [sp['mlm_labels'] for sp in samples] + for inputs in raw_mlm_input: + mlm_token.append(list(chain(*inputs))[-self.max_len:]) + for labels in raw_mlm_label: + mlm_label.append(list(chain(*labels))[-self.max_len:]) + + mlm_token = list2np(mlm_token, padding=self.pad_id) + mlm_label = list2np(mlm_label, padding=self.pad_id) + batch['mlm_token'] = mlm_token + batch['mlm_label'] = mlm_label + batch['mlm_mask'] = (mlm_label != self.pad_id).astype( + 'int64') # label mask + + if self.understand_ids: + tgt = [self.understand_ids for _ in samples] + tgt_token = np.array(tgt).astype('int64') + batch['tgt_token'] = tgt_token + batch['tgt_mask'] = (tgt_token != self.pad_id).astype( + 'int64') # input mask + + if 'id' in samples[0]: + ids = [sp['id'] for sp in samples] + ids = np.array(ids).astype('int64') + batch['ids'] = ids + + if self.dynamic_score and self.with_contrastive: + query_labels = [sp['query_label'] for sp in samples] + batch['query_labels'] = query_labels + batch['label_ids'] = np.arange(batch_size) + + if 'intent_label' in samples[0]['extra_info']: + intent_label = [ + sample['extra_info']['intent_label'] for sample in samples + ] + intent_label = np.array(intent_label).astype('int64') + batch['intent_label'] = intent_label + + return batch, batch_size diff --git a/modelscope/preprocessors/space/tokenizer.py b/modelscope/preprocessors/space/tokenizer.py new file mode 100644 index 00000000..764552cd --- /dev/null +++ b/modelscope/preprocessors/space/tokenizer.py @@ -0,0 +1,672 @@ +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import collections +import logging +import os +import sys +import unicodedata + +import json +import regex as re + + +def clean_string(string): + replace_mp = { + ' - ': '-', + " ' ": "'", + " n't": "n't", + " 'm": "'m", + ' do not': " don't", + " 's": "'s", + " 've": "'ve", + " 're": "'re" + } + for k, v in replace_mp.items(): + string = string.replace(k, v) + return string + + +class Tokenizer(object): + + def __init__(self, vocab_path, special_tokens=[], tokenizer_type='Bert'): + self.tokenizer_type = tokenizer_type + if tokenizer_type == 'Bert': + self.spec_convert_dict = { + '[BOS]': '[unused0]', + '[EOS]': '[unused1]' + } + for token in special_tokens: + if token not in self.spec_convert_dict and token not in [ + '[PAD]', '[UNK]' + ]: + self.spec_convert_dict[ + token] = f'[unused{len(self.spec_convert_dict)}]' + self.spec_revert_dict = { + v: k + for k, v in self.spec_convert_dict.items() + } + special_tokens = [ + self.spec_convert_dict.get(tok, tok) for tok in special_tokens + ] + self.special_tokens = ('[UNK]', '[SEP]', '[PAD]', '[CLS]', + '[MASK]') + self.special_tokens += tuple(x for x in special_tokens + if x not in self.special_tokens) + + self._tokenizer = BertTokenizer( + vocab_path, never_split=self.special_tokens) + for tok in self.special_tokens: + ''' + 需要先保证special_tokens在词表中,这里设置special_tokens的目的是为了这些词能够完整占位,不再切分为子词; + 若不在词表中,可以使用词表中的[unused]符号进行转换:spec_convert_dict; + ''' + assert tok in self._tokenizer.vocab, f"special token '{tok}' is not in the vocabulary" + self.vocab_size = len(self._tokenizer.vocab) + elif tokenizer_type == 'GPT2': + self.spec_convert_dict = {'[UNK]': ''} + self.spec_revert_dict = { + v: k + for k, v in self.spec_convert_dict.items() + } + special_tokens = [ + tok for tok in special_tokens + if tok not in self.spec_convert_dict + ] + vocab_file = os.path.join(vocab_path, 'vocab.json') + merges_file = os.path.join(vocab_path, 'merges.txt') + self._tokenizer = GPT2Tokenizer( + vocab_file, merges_file, special_tokens=special_tokens) + self.num_specials = len(special_tokens) + self.vocab_size = len(self._tokenizer) + else: + raise ValueError + + def tokenize(self, text): + return self._tokenizer.tokenize(text) + + def convert_tokens_to_ids(self, tokens): + if self.tokenizer_type == 'Bert': + tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens] + ids = self._tokenizer.convert_tokens_to_ids(tokens) + return ids + else: + tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens] + ids = self._tokenizer.convert_tokens_to_ids(tokens) + ids = [(i + self.num_specials) % self.vocab_size for i in ids] + return ids + + def convert_ids_to_tokens(self, ids): + if self.tokenizer_type == 'Bert': + tokens = self._tokenizer.convert_ids_to_tokens(ids) + tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens] + return tokens + else: + ids = [(i - self.num_specials) % self.vocab_size for i in ids] + tokens = self._tokenizer.convert_ids_to_tokens(ids) + tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens] + return tokens + + def decode(self, ids, ignore_tokens=[]): + tokens = self.convert_ids_to_tokens(ids) + if len(ignore_tokens) > 0: + ignore_tokens = set(ignore_tokens) + tokens = [tok for tok in tokens if tok not in ignore_tokens] + if self.tokenizer_type == 'Bert': + string = ' '.join(tokens).replace(' ##', '') + else: + string = ''.join(tokens) + string = bytearray([ + self._tokenizer.byte_decoder[c] for c in string + ]).decode('utf-8') + string = clean_string(string) + return string + + +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +logger = logging.getLogger(__name__) + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, 'r', encoding='utf-8') as reader: + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + + def __init__(self, + vocab_file, + do_lower_case=True, + max_len=None, + do_basic_tokenize=True, + never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): + """Constructs a BertTokenizer. + + Args: + vocab_file: Path to a one-wordpiece-per-line vocabulary file + do_lower_case: Whether to lower case the input + Only has an effect when do_wordpiece_only=False + do_basic_tokenize: Whether to do basic tokenization before wordpiece. + max_len: An artificial maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of this + value (if specified) and the underlying BERT model's + sequence length. + never_split: List of tokens which will never be split during tokenization. + Only has an effect when do_wordpiece_only=False + """ + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' + .format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([ + (ids, tok) for tok, ids in self.vocab.items() + ]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, never_split=never_split) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.max_len = max_len if max_len is not None else int(1e12) + + def tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + if len(ids) > self.max_len: + logger.warning( + 'Token indices sequence length is longer than the specified maximum ' + ' sequence length for this BERT model ({} > {}). Running this' + ' sequence through BERT will result in indexing errors'.format( + len(ids), self.max_len)) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, + do_lower_case=True, + never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = self._clean_text(text) + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in self.never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(' '.join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize('NFD', text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == 'Mn': + continue + output.append(char) + return ''.join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return [''.join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(' ') + output.append(char) + output.append(' ') + else: + output.append(char) + return ''.join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + tmp = (cp >= 0x4E00 and cp <= 0x9FFF) + tmp = tmp or (cp >= 0x3400 and cp <= 0x4DBF) + tmp = tmp or (cp >= 0x20000 and cp <= 0x2A6DF) + tmp = tmp or (cp >= 0x2A700 and cp <= 0x2B73F) + tmp = tmp or (cp >= 0x2B740 and cp <= 0x2B81F) + tmp = tmp or (cp >= 0x2B820 and cp <= 0x2CEAF) + tmp = tmp or (cp >= 0xF900 and cp <= 0xFAFF) + tmp = tmp or (cp >= 0x2F800 and cp <= 0x2FA1F) + if tmp: + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(' ') + else: + output.append(char) + return ''.join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = ''.join(chars[start:end]) + if start > 0: + substr = '##' + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == ' ' or char == '\t' or char == '\n' or char == '\r': + return True + cat = unicodedata.category(char) + if cat == 'Zs': + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == '\t' or char == '\n' or char == '\r': + return False + cat = unicodedata.category(char) + if cat.startswith('C'): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + tmp = (cp >= 33 and cp <= 47) + tmp = tmp or (cp >= 58 and cp <= 64) + tmp = tmp or (cp >= 91 and cp <= 96) + tmp = tmp or (cp >= 123 and cp <= 126) + if tmp: + return True + cat = unicodedata.category(char) + if cat.startswith('P'): + return True + return False + + +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OpenAI GPT.""" + +try: + from functools import lru_cache +except ImportError: + # Just a dummy decorator to get the checks to run on python2 + # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. + def lru_cache(): + return lambda func: func + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + _chr = unichr if sys.version_info[0] == 2 else chr + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [_chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class GPT2Tokenizer(object): + """ + GPT-2 BPE tokenizer. Peculiarities: + - Byte-level BPE + """ + + def __init__(self, + vocab_file, + merges_file, + errors='replace', + special_tokens=None, + max_len=None): + self.max_len = max_len if max_len is not None else int(1e12) + self.encoder = json.load(open(vocab_file)) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_data] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) + + self.special_tokens = {} + self.special_tokens_decoder = {} + self.set_special_tokens(special_tokens) + + def __len__(self): + return len(self.encoder) + len(self.special_tokens) + + def set_special_tokens(self, special_tokens): + """ Add a list of additional tokens to the encoder. + The additional tokens are indexed starting from the last index of the + current vocabulary in the order of the `special_tokens` list. + """ + if not special_tokens: + self.special_tokens = {} + self.special_tokens_decoder = {} + return + self.special_tokens = dict((tok, len(self.encoder) + i) + for i, tok in enumerate(special_tokens)) + self.special_tokens_decoder = { + v: k + for k, v in self.special_tokens.items() + } + logger.info('Special tokens {}'.format(self.special_tokens)) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def tokenize(self, text): + """ Tokenize a string. """ + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[ord(b)] for b in token + if ord(b) in self.byte_encoder) + if token == '': + continue + bpe_tokens.extend( + bpe_token for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def convert_tokens_to_ids(self, tokens): + """ Converts a sequence of tokens into ids using the vocab. """ + ids = [] + python_version_3 = isinstance(tokens, str) + python_version_2 = ( + sys.version_info[0] == 2 and isinstance(tokens, unicode)) + if python_version_3 or python_version_2: + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.encoder.get(tokens, 0) + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.encoder.get(token, 0)) + if len(ids) > self.max_len: + logger.warning( + 'Token indices sequence length is longer than the specified maximum ' + ' sequence length for this OpenAI GPT model ({} > {}). Running this' + ' sequence through the model will result in indexing errors'. + format(len(ids), self.max_len)) + return ids + + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """Converts a sequence of ids in BPE tokens using the vocab.""" + tokens = [] + for i in ids: + if i in self.special_tokens_decoder: + if not skip_special_tokens: + tokens.append(self.special_tokens_decoder[i]) + else: + tokens.append(self.decoder[i]) + return tokens + + def encode(self, text): + return self.convert_tokens_to_ids(self.tokenize(text)) + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors=self.errors) + return text diff --git a/modelscope/trainers/nlp/space/__init__.py b/modelscope/trainers/nlp/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/trainers/nlp/space/metrics/__init__.py b/modelscope/trainers/nlp/space/metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/trainers/nlp/space/metrics/metrics_tracker.py b/modelscope/trainers/nlp/space/metrics/metrics_tracker.py new file mode 100644 index 00000000..c08eba68 --- /dev/null +++ b/modelscope/trainers/nlp/space/metrics/metrics_tracker.py @@ -0,0 +1,73 @@ +""" +MetricsTracker class +""" + +import math +from collections import defaultdict + + +class MetricsTracker(object): + """ Tracking metrics. """ + + def __init__(self): + self.metrics_val = defaultdict(float) # 记录最新一个batch返回的指标 + self.metrics_avg = defaultdict(float) # 维护一个epoch内已训练batches的平均指标 + self.num_samples = 0 + + def update(self, metrics, num_samples): + for key, val in metrics.items(): + if val is not None: + val = float(val) # [val] -> val + self.metrics_val[key] = val + avg_val = \ + (self.metrics_avg.get(key, 0) * self.num_samples + val * num_samples) / \ + (self.num_samples + num_samples) + self.metrics_avg[key] = avg_val + self.num_samples += num_samples + + def clear(self): + self.metrics_val = defaultdict(float) + self.metrics_avg = defaultdict(float) + self.num_samples = 0 + + def items(self): + return self.metrics_avg.items() + + def get(self, name): + if self.num_samples == 0: + raise ValueError('There is no data in Metrics.') + return self.metrics_avg.get(name) + + def state_dict(self): + return { + 'metrics_val': self.metrics_val, + 'metrics_avg': self.metrics_avg, + 'num_samples': self.num_samples, + } + + def load_state_dict(self, state_dict): + self.metrics_val = state_dict['metrics_val'] + self.metrics_avg = state_dict['metrics_avg'] + self.num_samples = state_dict['num_samples'] + + def value(self): + metric_strs = [] + for key, val in self.metrics_val.items(): + metric_str = f'{key.upper()}-{val:.3f}' + metric_strs.append(metric_str) + if 'token_nll' in self.metrics_val: + metric_str = f"TOKEN_PPL-{math.exp(self.metrics_val['token_nll']):.3f}" + metric_strs.append(metric_str) + metric_strs = ' '.join(metric_strs) + return metric_strs + + def summary(self): + metric_strs = [] + for key, val in self.metrics_avg.items(): + metric_str = f'{key.upper()}-{val:.3f}' + metric_strs.append(metric_str) + if 'token_nll' in self.metrics_avg: + metric_str = f"TOKEN_PPL-{math.exp(self.metrics_avg['token_nll']):.3f}" + metric_strs.append(metric_str) + metric_strs = ' '.join(metric_strs) + return metric_strs diff --git a/modelscope/trainers/nlp/space/trainers/__init__.py b/modelscope/trainers/nlp/space/trainers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/trainers/nlp/space/trainers/gen_trainer.py b/modelscope/trainers/nlp/space/trainers/gen_trainer.py new file mode 100644 index 00000000..a0cda25c --- /dev/null +++ b/modelscope/trainers/nlp/space/trainers/gen_trainer.py @@ -0,0 +1,761 @@ +""" +Trainer class. +""" +import logging +import os +import sys +import time +from collections import OrderedDict + +import json +import numpy as np +import torch +from tqdm import tqdm +from transformers.optimization import AdamW, get_linear_schedule_with_warmup + +import modelscope.utils.nlp.space.ontology as ontology +from ..metrics.metrics_tracker import MetricsTracker + + +def get_logger(log_path, name='default'): + logger = logging.getLogger(name) + logger.propagate = False + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter('%(message)s') + + sh = logging.StreamHandler(sys.stdout) + sh.setFormatter(formatter) + logger.addHandler(sh) + + fh = logging.FileHandler(log_path, mode='w') + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger + + +class Trainer(object): + + def __init__(self, + model, + to_tensor, + config, + logger=None, + lr_scheduler=None, + optimizer=None, + reader=None, + evaluator=None): + self.to_tensor = to_tensor + + self.do_train = config.do_train + self.do_infer = config.do_infer + self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ + 0] == '-' + self.valid_metric_name = config.Trainer.valid_metric_name[1:] + self.num_epochs = config.Trainer.num_epochs + # self.save_dir = config.Trainer.save_dir + self.log_steps = config.Trainer.log_steps + self.valid_steps = config.Trainer.valid_steps + self.save_checkpoint = config.Trainer.save_checkpoint + self.save_summary = config.Trainer.save_summary + self.lr = config.Model.lr + self.weight_decay = config.Model.weight_decay + self.batch_size = config.Trainer.batch_size + self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps + self.warmup_steps = config.Model.warmup_steps + self.gpu = config.Trainer.gpu + + self.lr_scheduler = lr_scheduler + self.optimizer = optimizer + + self.model = model + self.func_model = self.model.module if self.gpu > 1 else self.model + self.reader = reader + self.evaluator = evaluator + self.tokenizer = reader.tokenizer + + # if not os.path.exists(self.save_dir): + # os.makedirs(self.save_dir) + + # self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") + self.logger = logger or get_logger('trainer.log', 'trainer') + + self.batch_metrics_tracker = MetricsTracker() + self.token_metrics_tracker = MetricsTracker() + + self.best_valid_metric = float( + 'inf' if self.is_decreased_valid_metric else '-inf') + self.epoch = 0 + + def decode_generated_bspn_resp(self, generated): + """ + decode generated + return decoded ('bspn', 'resp') + """ + decoded = {} + eos_r_id = self.reader.eos_r_id + eos_b_id = self.reader.eos_b_id + + # eos_r may not exists if gpt2 generated repetitive words. + if eos_r_id in generated: + eos_r_idx = generated.index(eos_r_id) + else: + eos_r_idx = len(generated) - 1 + # self.logger.info('eos_r not in generated: ' + self.tokenizer.decode(generated)) + + # predicted bspn, resp + eos_b_idx = generated.index(eos_b_id) + decoded['bspn'] = generated[:eos_b_idx + 1] + decoded['resp'] = generated[eos_b_idx + 1:eos_r_idx + 1] + return decoded + + def decode_generated_act_resp(self, generated): + """ + decode generated + return decoded['resp'] ('bspn', 'aspn') + """ + decoded = {} + eos_a_id = self.reader.eos_a_id + eos_r_id = self.reader.eos_r_id + # eos_b_id = self.reader.eos_b_id + + # eos_r may not exists if gpt2 generated repetitive words. + if eos_r_id in generated: + eos_r_idx = generated.index(eos_r_id) + else: + eos_r_idx = len(generated) - 1 + msg = 'eos_r not in generated: ' + self.tokenizer.decode(generated) + self.logger.info(msg) + + if self.reader.use_true_curr_aspn: # only predict resp + decoded['resp'] = generated[:eos_r_idx + 1] + else: # predicted aspn, resp + eos_a_idx = generated.index(eos_a_id) + decoded['aspn'] = generated[:eos_a_idx + 1] + decoded['resp'] = generated[eos_a_idx + 1:eos_r_idx + 1] + return decoded + + def decode_generated_bspn(self, generated): + eos_b_id = self.reader.eos_b_id + if eos_b_id in generated: + eos_b_idx = generated.index(eos_b_id) + else: + eos_b_idx = len(generated) - 1 + return generated[:eos_b_idx + 1] + + def set_optimizers(self): + """ + Setup the optimizer and the learning rate scheduler. + + from transformers.Trainer + + parameters from cfg: lr (1e-3); warmup_steps + """ + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ['bias', 'norm.weight'] + optimizer_grouped_parameters = [ + { + 'params': [ + p for n, p in self.model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + 'weight_decay': + self.weight_decay, + }, + { + 'params': [ + p for n, p in self.model.named_parameters() + if any(nd in n for nd in no_decay) + ], + 'weight_decay': + 0.0, + }, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) + + num_training_steps = \ + self.reader.set_stats['train']['num_training_steps_per_epoch'] \ + * self.num_epochs \ + // self.gradient_accumulation_steps + num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int( + num_training_steps * 0.1) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps) + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + def train(self, train_data, dev_data): + # log info + set_stats = self.reader.set_stats['train'] + self.logger.info('***** Running training *****') + self.logger.info( + ' Num Training steps(one turn in a batch of dialogs) per epoch = %d', + set_stats['num_training_steps_per_epoch']) + self.logger.info(' Num Turns = %d', set_stats['num_turns']) + self.logger.info(' Num Dialogs = %d', set_stats['num_dials']) + self.logger.info(' Num Epochs = %d', self.num_epochs) + self.logger.info(' Batch size = %d', self.batch_size) + self.logger.info(' Gradient Accumulation steps = %d', + self.gradient_accumulation_steps) + steps = set_stats[ + 'num_training_steps_per_epoch'] * self.num_epochs // self.gradient_accumulation_steps + msg = ' Total optimization steps = %d' % steps + self.logger.info(msg) + + # begin training + num_epochs = self.num_epochs - self.epoch + for epoch in range(num_epochs): + self.train_epoch(train_data=train_data, dev_data=dev_data) + + def train_epoch(self, train_data, dev_data): + """ + Train an epoch. + """ + raise NotImplementedError + + def infer(self, data_type): + """ + Inference interface. + """ + raise NotImplementedError + + def forward(self, turn, old_pv_turn): + """ + one turn inference + """ + raise NotImplementedError + + def save(self, is_best=False): + """ save """ + train_state = { + 'epoch': self.epoch, + 'best_valid_metric': self.best_valid_metric, + 'optimizer': self.optimizer.state_dict() + } + if self.lr_scheduler is not None: + train_state['lr_scheduler'] = self.lr_scheduler.state_dict() + + # Save checkpoint + if self.save_checkpoint: + model_file = os.path.join(self.save_dir, + f'state_epoch_{self.epoch}.model') + torch.save(self.model.state_dict(), model_file) + self.logger.info(f"Saved model state to '{model_file}'") + + train_file = os.path.join(self.save_dir, + f'state_epoch_{self.epoch}.train') + torch.save(train_state, train_file) + self.logger.info(f"Saved train state to '{train_file}'") + + # Save current best model + if is_best: + best_model_file = os.path.join(self.save_dir, 'best.model') + torch.save(self.model.state_dict(), best_model_file) + best_train_file = os.path.join(self.save_dir, 'best.train') + torch.save(train_state, best_train_file) + self.logger.info( + f"Saved best model state to '{best_model_file}' with new best valid metric " + f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}' + ) + + def load(self): + """ load """ + + def _load_model_state(): + model_state_dict = torch.load( + f'{self.func_model.init_checkpoint}', + map_location=lambda storage, loc: storage) + + if 'module.' in list(model_state_dict.keys())[0]: + new_model_state_dict = OrderedDict() + for k, v in model_state_dict.items(): + assert k[:7] == 'module.' + new_model_state_dict[k[7:]] = v + model_state_dict = new_model_state_dict + + new_model_state_dict = OrderedDict() + parameters = { + name: param + for name, param in self.func_model.named_parameters() + } + for name, param in model_state_dict.items(): + if name in parameters: + if param.shape != parameters[name].shape: + assert hasattr(param, 'numpy') + arr = param.numpy() + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + if name == 'embedder.token_embedding.weight': + z[-param.shape[0]:] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + else: + if z.shape[0] < param.shape[0]: + z = arr[:z.shape[0]] + print(f'part of parameter({name}) are dropped') + else: + z[:param.shape[0]] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + dtype, device = param.dtype, param.device + z = torch.tensor(z, dtype=dtype, device=device) + new_model_state_dict[name] = z + else: + new_model_state_dict[name] = param + else: + print(f'parameter({name}) are dropped') + model_state_dict = new_model_state_dict + + for name in parameters: + if name not in model_state_dict: + if parameters[name].requires_grad: + print(f'parameter({name}) random normlize initialize') + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + dtype, device = parameters[name].dtype, parameters[ + name].device + model_state_dict[name] = torch.tensor( + z, dtype=dtype, device=device) + else: + model_state_dict[name] = parameters[name] + + self.func_model.load_state_dict(model_state_dict) + self.logger.info( + f"Loaded model state from '{self.func_model.init_checkpoint}.model'" + ) + + def _load_train_state(): + train_file = f'{self.func_model.init_checkpoint}.train' + if os.path.exists(train_file): + train_state_dict = torch.load( + train_file, map_location=lambda storage, loc: storage) + self.epoch = train_state_dict['epoch'] + self.best_valid_metric = train_state_dict['best_valid_metric'] + if self.optimizer is not None and 'optimizer' in train_state_dict: + self.optimizer.load_state_dict( + train_state_dict['optimizer']) + if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: + self.lr_scheduler.load_state_dict( + train_state_dict['lr_scheduler']) + self.logger.info( + f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " + f'best_valid_metric={self.best_valid_metric:.3f})') + else: + self.logger.info('Loaded no train state') + + if self.func_model.init_checkpoint is None: + self.logger.info('Loaded no model !!!') + return + + if self.do_train: + _load_model_state() + return + + if self.do_infer: + _load_model_state() + _load_train_state() + + +class MultiWOZTrainer(Trainer): + + def __init__(self, + model, + to_tensor, + config, + logger=None, + lr_scheduler=None, + optimizer=None, + reader=None, + evaluator=None): + super(MultiWOZTrainer, + self).__init__(model, to_tensor, config, logger, lr_scheduler, + optimizer, reader, evaluator) + + def train_epoch(self, train_data, dev_data): + """ + Train an epoch. + """ + times = [] + epoch_step = 0 + global_step = 0 + tr_batch_loss = 0.0 + tr_token_loss = 0.0 + self.epoch += 1 + self.batch_metrics_tracker.clear() + self.token_metrics_tracker.clear() + num_training_steps = \ + self.reader.set_stats['train']['num_training_steps_per_epoch'] // \ + self.gradient_accumulation_steps # similar to the original num_batches + + self.model.zero_grad() + data_iterator = self.reader.get_data_iterator(all_batches=train_data) + + for batch_idx, dial_batch in enumerate(data_iterator): + pv_batch = [] + for turn_num, turn_batch in enumerate(dial_batch): + first_turn = (turn_num == 0) + samples, pv_batch = self.reader.convert_batch_turn( + turn_batch, pv_batch, first_turn) + batch, batch_size = self.reader.collate_fn_multi_turn( + samples=samples) + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + + # Do a training iteration + start_time = time.time() + metrics = self.model(batch, is_training=True) + if self.gpu > 1: + for metric in metrics: + if metric is not None: + assert len(metric) == self.gpu + nll, token_nll, token_num = metrics + metrics = {} + + token_num = torch.sum(token_num) + token_nll = \ + torch.sum(nll) * (batch_size / self.gpu) / \ + token_num + nll = torch.mean(nll) + metrics['token_num'] = token_num + metrics['token_nll'] = token_nll + metrics['nll'] = nll + loss = token_nll if self.func_model.token_loss else nll + + metrics['loss'] = loss + else: + loss = metrics['loss'] + self.func_model._optimize( + loss, do_update=False, optimizer=self.optimizer) + metrics = { + k: v.cpu().detach().numpy() + if isinstance(v, torch.Tensor) else v + for k, v in metrics.items() + } + token_num = metrics.pop('token_num', None) + # bow_num = metrics.pop("bow_num", None) + elapsed = time.time() - start_time + times.append(elapsed) + epoch_step += 1 + + tr_batch_loss += metrics['nll'] + tr_token_loss += metrics['token_nll'] + batch_metrics = { + k: v + for k, v in metrics.items() if 'token' not in k + } + token_metrics = { + k: v + for k, v in metrics.items() if 'token' in k + } + self.batch_metrics_tracker.update(batch_metrics, batch_size) + self.token_metrics_tracker.update(token_metrics, token_num) + + if (epoch_step % self.gradient_accumulation_steps == 0) or \ + (epoch_step == self.reader.set_stats['train']['num_training_steps_per_epoch']): + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + global_step += 1 + + if self.log_steps > 0 and global_step % self.log_steps == 0: + batch_metrics_message = self.batch_metrics_tracker.value( + ) + token_metrics_message = self.token_metrics_tracker.value( + ) + message_prefix = f'[Train][{self.epoch}][{global_step}/{num_training_steps}]' + avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}' + message = ' '.join([ + message_prefix, batch_metrics_message, + token_metrics_message, avg_time + ]) + self.logger.info(message) + + self.logger.info('-' * 150) + avg_batch_loss = tr_batch_loss / epoch_step + avg_token_loss = tr_token_loss / epoch_step + batch_metrics_message = self.batch_metrics_tracker.summary() + token_metrics_message = self.token_metrics_tracker.summary() + message_prefix = f'[Valid][{self.epoch}]' + message = ' '.join([ + message_prefix, batch_metrics_message, token_metrics_message, + str(avg_batch_loss), + str(avg_token_loss) + ]) + self.logger.info(message) + + cur_valid_metric = self.batch_metrics_tracker.get( + self.valid_metric_name) + if self.is_decreased_valid_metric: + is_best = cur_valid_metric < self.best_valid_metric + else: + is_best = cur_valid_metric > self.best_valid_metric + if is_best: + self.best_valid_metric = cur_valid_metric + self.save(is_best) + self.logger.info('-' * 150) + + return + + def infer(self, data_type='test'): + """ + Inference interface. + """ + self.logger.info('Generation starts ...') + infer_save_file = os.path.join(self.save_dir, + f'infer_{self.epoch}.result.json') + infer_samples_save_file = os.path.join( + self.save_dir, f'infer_samples_{self.epoch}.result.json') + + # Inference + result_collection = {} + begin_time = time.time() + + eval_data = self.reader.get_eval_data(data_type) + set_stats = self.reader.set_stats[data_type] + self.logger.info('***** Running Evaluation *****') + self.logger.info(' Num Turns = %d', set_stats['num_turns']) + + with torch.no_grad(): + pbar = tqdm(eval_data) + for dial_idx, dialog in enumerate(pbar): + pv_turn = {} + for turn_idx, turn in enumerate(dialog): + first_turn = (turn_idx == 0) + inputs, prompt_id = self.reader.convert_turn_eval( + turn, pv_turn, first_turn) + batch, batch_size = self.reader.collate_fn_multi_turn( + samples=[inputs]) + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + if self.reader.use_true_curr_bspn: # generate act, response + max_len = 60 + if not self.reader.use_true_curr_aspn: + max_len = 80 + outputs = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_r_id, + max_gen_len=max_len) + # resp_gen, need to trim previous context + generated = outputs[0].cpu().numpy().tolist() + try: + decoded = self.decode_generated_act_resp(generated) + except ValueError as exception: + self.logger.info(str(exception)) + self.logger.info(self.tokenizer.decode(generated)) + decoded = {'resp': [], 'bspn': [], 'aspn': []} + else: # predict bspn, access db, then generate act and resp + outputs = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_b_id, + max_gen_len=60) + generated_bs = outputs[0].cpu().numpy().tolist() + bspn_gen = self.decode_generated_bspn(generated_bs) + # check DB result + if self.reader.use_true_db_pointer: # 控制当前轮的db是否为ground truth + db = turn['db'] + else: + db_result = self.reader.bspan_to_DBpointer( + self.tokenizer.decode(bspn_gen), + turn['turn_domain']) + assert len(turn['db']) == 4 + book_result = turn['db'][2] + assert isinstance(db_result, str) + db = \ + [self.reader.sos_db_id] + \ + self.tokenizer.convert_tokens_to_ids([db_result]) + \ + [book_result] + \ + [self.reader.eos_db_id] + prompt_id = self.reader.sos_a_id + + prev_input = torch.tensor(bspn_gen + db) + if self.func_model.use_gpu: + prev_input = prev_input.cuda() + outputs_db = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_r_id, + max_gen_len=80, + prev_input=prev_input) + generated_ar = outputs_db[0].cpu().numpy().tolist() + try: + decoded = self.decode_generated_act_resp( + generated_ar) + decoded['bspn'] = bspn_gen + except ValueError as exception: + self.logger.info(str(exception)) + self.logger.info( + self.tokenizer.decode(generated_ar)) + decoded = {'resp': [], 'bspn': [], 'aspn': []} + + turn['resp_gen'] = decoded['resp'] + turn['bspn_gen'] = turn[ + 'bspn'] if self.reader.use_true_curr_bspn else decoded[ + 'bspn'] + turn['aspn_gen'] = turn[ + 'aspn'] if self.reader.use_true_curr_aspn else decoded[ + 'aspn'] + turn['dspn_gen'] = turn['dspn'] + + pv_turn['labels'] = inputs[ + 'labels'] # all true previous context + pv_turn['resp'] = turn[ + 'resp'] if self.reader.use_true_prev_resp else decoded[ + 'resp'] + if not self.reader.use_true_curr_bspn: + pv_turn['bspn'] = turn[ + 'bspn'] if self.reader.use_true_prev_bspn else decoded[ + 'bspn'] + pv_turn['db'] = turn[ + 'db'] if self.reader.use_true_prev_bspn else db + pv_turn['aspn'] = turn[ + 'aspn'] if self.reader.use_true_prev_aspn else decoded[ + 'aspn'] + + tmp_dialog_result = self.reader.inverse_transpose_turn(dialog) + result_collection.update(tmp_dialog_result) + + # compute tmp scores + results, _ = self.reader.wrap_result_lm(tmp_dialog_result) + bleu, success, match = self.evaluator.validation_metric( + results) + score = 0.5 * (success + match) + bleu + pbar.set_description( + 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' % + (match, success, bleu, score)) + + # compute scores + results, _ = self.reader.wrap_result_lm(result_collection) + bleu, success, match = self.evaluator.validation_metric(results) + score = 0.5 * (success + match) + bleu + + # log results + metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %\ + (match, success, bleu, score) + message_prefix = f'[Infer][{self.epoch}]' + time_cost = f'TIME-{time.time() - begin_time:.3f}' + message = ' '.join([message_prefix, metrics_message, time_cost]) + self.logger.info(message) + + # save results + eval_results = { + 'bleu': bleu, + 'success': success, + 'match': match, + 'score': score, + 'result': message + } + with open(infer_save_file, 'w') as fp: + json.dump(eval_results, fp, indent=2) + self.logger.info(f'Saved inference results to {infer_save_file}') + with open(infer_samples_save_file, 'w') as fp: + for sample in results: + line = json.dumps(sample) + fp.write(line) + fp.write('\n') + self.logger.info( + f'Saved inference samples to {infer_samples_save_file}') + + return + + def _get_turn_domain(self, old_pv_turn, bspn_gen_ids, first_turn): + + def _get_slots(constraint): + domain_name = '' + slots = {} + for item in constraint: + if item in ontology.placeholder_tokens: + continue + if item in ontology.all_domains_with_bracket: + domain_name = item + slots[domain_name] = set() + else: + assert domain_name in ontology.all_domains_with_bracket + slots[domain_name].add(item) + return slots + + turn_domain = [] + if first_turn and len(bspn_gen_ids) == 0: + turn_domain = ['[general]'] + return turn_domain + + bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen_ids) + turn_slots = _get_slots(bspn_token) + if first_turn: + return list(turn_slots.keys()) + + assert 'bspn' in old_pv_turn + pv_bspn_token = self.tokenizer.convert_ids_to_tokens( + old_pv_turn['bspn']) + pv_turn_slots = _get_slots(pv_bspn_token) + for domain, value in turn_slots.items(): + pv_value = pv_turn_slots[ + domain] if domain in pv_turn_slots else set() + if len(value - pv_value) > 0 or len(pv_value - value): + turn_domain.append(domain) + if len(turn_domain) == 0: + turn_domain = list(turn_slots.keys()) + + return turn_domain + + def forward(self, turn, old_pv_turn): + with torch.no_grad(): + first_turn = True if len(old_pv_turn) == 0 else False + inputs, prompt_id = self.reader.convert_turn_eval( + turn, old_pv_turn, first_turn) + batch, batch_size = self.reader.collate_fn_multi_turn( + samples=[inputs]) + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) + pv_turn = {} + + outputs = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_b_id, + max_gen_len=60) + generated_bs = outputs[0].cpu().numpy().tolist() + bspn_gen = self.decode_generated_bspn(generated_bs) + + turn_domain = self._get_turn_domain(old_pv_turn, bspn_gen, + first_turn) + + db_result = self.reader.bspan_to_DBpointer( + self.tokenizer.decode(bspn_gen), turn_domain) + assert isinstance(db_result, str) + db = \ + [self.reader.sos_db_id] + \ + self.tokenizer.convert_tokens_to_ids([db_result]) + \ + [self.reader.eos_db_id] + prompt_id = self.reader.sos_a_id + prev_input = torch.tensor(bspn_gen + db) + if self.func_model.use_gpu: + prev_input = prev_input.cuda() + outputs_db = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_r_id, + max_gen_len=80, + prev_input=prev_input) + generated_ar = outputs_db[0].cpu().numpy().tolist() + decoded = self.decode_generated_act_resp(generated_ar) + decoded['bspn'] = bspn_gen + + pv_turn['labels'] = inputs['labels'] + pv_turn['resp'] = decoded['resp'] + pv_turn['bspn'] = decoded['bspn'] + pv_turn['db'] = db + pv_turn['aspn'] = decoded['aspn'] + + return pv_turn diff --git a/modelscope/trainers/nlp/space/trainers/intent_trainer.py b/modelscope/trainers/nlp/space/trainers/intent_trainer.py new file mode 100644 index 00000000..bd43e9a5 --- /dev/null +++ b/modelscope/trainers/nlp/space/trainers/intent_trainer.py @@ -0,0 +1,824 @@ +""" +Trainer class. +""" + +import logging +import os +import sys +import time +from collections import OrderedDict + +import json +import numpy as np +import torch +from tqdm import tqdm +from transformers.optimization import AdamW, get_linear_schedule_with_warmup + +from modelscope.trainers.nlp.space.metrics.metrics_tracker import \ + MetricsTracker +from modelscope.utils.nlp.space.args import str2bool + + +def get_logger(log_path, name='default'): + logger = logging.getLogger(name) + logger.propagate = False + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter('%(message)s') + + sh = logging.StreamHandler(sys.stdout) + sh.setFormatter(formatter) + logger.addHandler(sh) + + fh = logging.FileHandler(log_path, mode='w') + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger + + +class Trainer(object): + + def __init__(self, + model, + to_tensor, + config, + reader=None, + logger=None, + lr_scheduler=None, + optimizer=None): + self.model = model + self.to_tensor = to_tensor + self.do_train = config.do_train + self.do_infer = config.do_infer + + self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ + 0] == '-' + self.valid_metric_name = config.Trainer.valid_metric_name[1:] + self.num_epochs = config.Trainer.num_epochs + self.save_dir = config.Trainer.save_dir + self.log_steps = config.Trainer.log_steps + self.valid_steps = config.Trainer.valid_steps + self.save_checkpoint = config.Trainer.save_checkpoint + self.save_summary = config.Trainer.save_summary + self.learning_method = config.Dataset.learning_method + self.weight_decay = config.Model.weight_decay + self.warmup_steps = config.Model.warmup_steps + self.batch_size_label = config.Trainer.batch_size_label + self.batch_size_nolabel = config.Trainer.batch_size_nolabel + self.gpu = config.Trainer.gpu + self.lr = config.Model.lr + + self.model = model + self.func_model = self.model.module if self.gpu > 1 else self.model + self.reader = reader + self.tokenizer = reader.tokenizer + + self.lr_scheduler = lr_scheduler + self.optimizer = optimizer + + # if not os.path.exists(self.save_dir): + # os.makedirs(self.save_dir) + + # self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") + self.logger = logger or get_logger('trainer.log', 'trainer') + + self.batch_metrics_tracker_label = MetricsTracker() + self.token_metrics_tracker_label = MetricsTracker() + self.batch_metrics_tracker_nolabel = MetricsTracker() + self.token_metrics_tracker_nolabel = MetricsTracker() + + self.best_valid_metric = float( + 'inf' if self.is_decreased_valid_metric else '-inf') + self.epoch = 0 + self.batch_num = 0 + + def set_optimizers(self, num_training_steps_per_epoch): + """ + Setup the optimizer and the learning rate scheduler. + + from transformers.Trainer + + parameters from cfg: lr (1e-3); warmup_steps + """ + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ['bias', 'norm.weight'] + optimizer_grouped_parameters = [ + { + 'params': [ + p for n, p in self.model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + 'weight_decay': + self.weight_decay, + }, + { + 'params': [ + p for n, p in self.model.named_parameters() + if any(nd in n for nd in no_decay) + ], + 'weight_decay': + 0.0, + }, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) + + num_training_steps = num_training_steps_per_epoch * self.num_epochs + num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int( + num_training_steps * 0.1) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps) + + # reset optimizer and lr_scheduler + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + # log info + self.logger.info( + f'***** Running training: {self.learning_method} *****') + self.logger.info(' Num Epochs = %d', self.num_epochs) + self.logger.info( + ' Num Training steps(one turn in a batch of dialogs) per epoch = %d', + num_training_steps_per_epoch) + self.logger.info(' Batch size for labeled data = %d', + self.batch_size_label) + self.logger.info(' Batch size for unlabeled data = %d', + self.batch_size_nolabel) + self.logger.info(' Total optimization steps = %d', num_training_steps) + self.logger.info(' Total warmup steps = %d', num_warmup_steps) + self.logger.info('************************************') + + def train(self, + train_label_iter, + train_nolabel_iter=None, + valid_label_iter=None, + valid_nolabel_iter=None): + # begin training + num_epochs = self.num_epochs - self.epoch + for epoch in range(num_epochs): + self.train_epoch( + train_label_iter=train_label_iter, + train_nolabel_iter=train_nolabel_iter, + valid_label_iter=valid_label_iter, + valid_nolabel_iter=valid_nolabel_iter) + + def train_epoch(self, train_label_iter, train_nolabel_iter, + valid_label_iter, valid_nolabel_iter): + """ + Train an epoch. + """ + raise NotImplementedError + + def evaluate(self, data_label_iter, data_nolabel_iter, need_save=True): + raise NotImplementedError + + def infer(self, data_iter, num_batches=None): + raise NotImplementedError + + def save(self, is_best=False): + """ save """ + train_state = { + 'epoch': self.epoch, + 'batch_num': self.batch_num, + 'best_valid_metric': self.best_valid_metric, + 'optimizer': self.optimizer.state_dict() + } + if self.lr_scheduler is not None: + train_state['lr_scheduler'] = self.lr_scheduler.state_dict() + + # Save checkpoint + if self.save_checkpoint: + model_file = os.path.join(self.save_dir, + f'state_epoch_{self.epoch}.model') + torch.save(self.model.state_dict(), model_file) + self.logger.info(f"Saved model state to '{model_file}'") + + train_file = os.path.join(self.save_dir, + f'state_epoch_{self.epoch}.train') + torch.save(train_state, train_file) + self.logger.info(f"Saved train state to '{train_file}'") + + # Save current best model + if is_best: + best_model_file = os.path.join(self.save_dir, 'best.model') + torch.save(self.model.state_dict(), best_model_file) + best_train_file = os.path.join(self.save_dir, 'best.train') + torch.save(train_state, best_train_file) + self.logger.info( + f"Saved best model state to '{best_model_file}' with new best valid metric " + f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}' + ) + + def load(self): + """ load """ + + def _load_model_state(): + model_state_dict = torch.load( + f'{self.func_model.init_checkpoint}.model', + map_location=lambda storage, loc: storage) + + if 'module.' in list(model_state_dict.keys())[0]: + new_model_state_dict = OrderedDict() + for k, v in model_state_dict.items(): + assert k[:7] == 'module.' + new_model_state_dict[k[7:]] = v + model_state_dict = new_model_state_dict + + new_model_state_dict = OrderedDict() + parameters = { + name: param + for name, param in self.func_model.named_parameters() + } + for name, param in model_state_dict.items(): + if name in parameters: + if param.shape != parameters[name].shape: + assert hasattr(param, 'numpy') + arr = param.numpy() + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + if name == 'embedder.token_embedding.weight': + z[-param.shape[0]:] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + else: + if z.shape[0] < param.shape[0]: + z = arr[:z.shape[0]] + print(f'part of parameter({name}) are dropped') + else: + z[:param.shape[0]] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + dtype, device = param.dtype, param.device + z = torch.tensor(z, dtype=dtype, device=device) + new_model_state_dict[name] = z + else: + new_model_state_dict[name] = param + else: + print(f'parameter({name}) are dropped') + model_state_dict = new_model_state_dict + + for name in parameters: + if name not in model_state_dict: + if parameters[name].requires_grad: + print(f'parameter({name}) random normlize initialize') + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + dtype, device = parameters[name].dtype, parameters[ + name].device + model_state_dict[name] = torch.tensor( + z, dtype=dtype, device=device) + else: + model_state_dict[name] = parameters[name] + + self.func_model.load_state_dict(model_state_dict) + self.logger.info( + f"Loaded model state from '{self.func_model.init_checkpoint}.model'" + ) + + def _load_train_state(): + train_file = f'{self.func_model.init_checkpoint}.train' + if os.path.exists(train_file): + train_state_dict = torch.load( + train_file, map_location=lambda storage, loc: storage) + self.epoch = train_state_dict['epoch'] + self.best_valid_metric = train_state_dict['best_valid_metric'] + if self.optimizer is not None and 'optimizer' in train_state_dict: + self.optimizer.load_state_dict( + train_state_dict['optimizer']) + if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: + self.lr_scheduler.load_state_dict( + train_state_dict['lr_scheduler']) + self.logger.info( + f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " + f'best_valid_metric={self.best_valid_metric:.3f})') + else: + self.logger.info('Loaded no train state') + + if self.func_model.init_checkpoint is None: + self.logger.info('Loaded no model !!!') + return + + _load_model_state() + _load_train_state() + + +class IntentTrainer(Trainer): + + def __init__(self, model, to_tensor, config, reader=None): + super(IntentTrainer, self).__init__(model, to_tensor, config, reader) + self.example = config.Model.example + self.can_norm = config.Trainer.can_norm + + def can_normalization(self, y_pred, y_true, ex_data_iter): + # 预测结果,计算修正前准确率 + acc_original = np.mean([y_pred.argmax(1) == y_true]) + message = 'original acc: %s' % acc_original + + # 评价每个预测结果的不确定性 + k = 3 + y_pred_topk = np.sort(y_pred, axis=1)[:, -k:] + y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True) + y_pred_uncertainty =\ + -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k) + + # 选择阈值,划分高、低置信度两部分 + # print(np.sort(y_pred_uncertainty)[-100:].tolist()) + threshold = 0.7 + y_pred_confident = y_pred[y_pred_uncertainty < threshold] + y_pred_unconfident = y_pred[y_pred_uncertainty >= threshold] + y_true_confident = y_true[y_pred_uncertainty < threshold] + y_true_unconfident = y_true[y_pred_uncertainty >= threshold] + + # 显示两部分各自的准确率 + # 一般而言,高置信度集准确率会远高于低置信度的 + acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean() \ + if len(y_true_confident) else 0. + acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean() \ + if len(y_true_unconfident) else 0. + message += ' (%s) confident acc: %s' % (len(y_true_confident), + acc_confident) + message += ' (%s) unconfident acc: %s' % (len(y_true_unconfident), + acc_unconfident) + + # 从训练集统计先验分布 + prior = np.zeros(self.func_model.num_intent) + for _, (batch, batch_size) in ex_data_iter: + for intent_label in batch['intent_label']: + prior[intent_label] += 1. + + prior /= prior.sum() + + # 逐个修改低置信度样本,并重新评价准确率 + right, alpha, iters = 0, 1, 1 + for i, y in enumerate(y_pred_unconfident): + Y = np.concatenate([y_pred_confident, y[None]], axis=0) + for j in range(iters): + Y = Y**alpha + Y /= Y.mean(axis=0, keepdims=True) + Y *= prior[None] + Y /= Y.sum(axis=1, keepdims=True) + y = Y[-1] + if y.argmax() == y_true_unconfident[i]: + right += 1 + + # 输出修正后的准确率 + acc_final = \ + (acc_confident * len(y_pred_confident) + right) / \ + len(y_pred) + if len(y_pred_unconfident): + message += ' new unconfident acc: %s' % ( + right / len(y_pred_unconfident)) + else: + message += ' no unconfident predictions' + message += ' final acc: %s' % acc_final + return acc_original, acc_final, message + + def train_epoch(self, train_label_iter, train_nolabel_iter, + valid_label_iter, valid_nolabel_iter): + """ + Train an epoch. + """ + times = [] + self.epoch += 1 + self.batch_metrics_tracker_label.clear() + self.token_metrics_tracker_label.clear() + self.batch_metrics_tracker_nolabel.clear() + self.token_metrics_tracker_nolabel.clear() + + num_label_batches = len(train_label_iter) + num_nolabel_batches = len( + train_nolabel_iter) if train_nolabel_iter is not None else 0 + num_batches = max(num_label_batches, num_nolabel_batches) + + train_label_iter_loop = iter(train_label_iter) + train_nolabel_iter_loop = iter( + train_nolabel_iter) if train_nolabel_iter is not None else None + report_for_unlabeled_data = True if train_nolabel_iter is not None else False + + for batch_id in range(1, num_batches + 1): + # Do a training iteration + start_time = time.time() + batch_list, batch_size_list, with_label_list, loss_list, metrics_list = [], [], [], [], [] + data_file_list = [] + + # collect batch for labeled data + try: + data_file_label, ( + batch_label, + batch_size_label) = next(train_label_iter_loop) + except StopIteration: + train_label_iter_loop = iter(train_label_iter) + data_file_label, ( + batch_label, + batch_size_label) = next(train_label_iter_loop) + batch_list.append(batch_label) + batch_size_list.append(batch_size_label) + with_label_list.append(True) + data_file_list.append(data_file_label) + + # collect batch for unlabeled data + if train_nolabel_iter is not None: + try: + data_file_nolabel, ( + batch_nolabel, + batch_size_nolabel) = next(train_nolabel_iter_loop) + except StopIteration: + train_nolabel_iter_loop = iter(train_nolabel_iter) + data_file_nolabel, ( + batch_nolabel, + batch_size_nolabel) = next(train_nolabel_iter_loop) + batch_list.append(batch_nolabel) + batch_size_list.append(batch_size_nolabel) + with_label_list.append(False) + data_file_list.append(data_file_nolabel) + + # forward labeled batch and unlabeled batch and collect outputs, respectively + for (batch, batch_size, with_label, data_file) in \ + zip(batch_list, batch_size_list, with_label_list, data_file_list): + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + if self.example and with_label: + current_dataset = train_label_iter.data_file_to_dataset[ + data_file] + example_batch = self.reader.retrieve_examples( + dataset=current_dataset, + labels=batch['intent_label'], + inds=batch['ids'], + task='intent') + example_batch = type(example_batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + example_batch.items())) + for k, v in example_batch.items(): + batch[k] = v + batch['epoch'] = self.epoch + batch['num_steps'] = self.batch_num + metrics = self.model( + batch, + is_training=True, + with_label=with_label, + data_file=data_file) + loss, metrics = self.balance_metrics( + metrics=metrics, batch_size=batch_size) + loss_list.append(loss) + metrics_list.append(metrics) + + # combine loss for labeled data and unlabeled data + # TODO change the computation of combined loss of labeled batch and unlabeled batch + loss = loss_list[0] if len( + loss_list) == 1 else loss_list[0] + loss_list[1] + + # optimization procedure + self.func_model._optimize( + loss, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler) + elapsed = time.time() - start_time + times.append(elapsed) + self.batch_num += 1 + + # track metrics and log temporary message + for (batch_size, metrics, + with_label) in zip(batch_size_list, metrics_list, + with_label_list): + self.track_and_log_message( + metrics=metrics, + batch_id=batch_id, + batch_size=batch_size, + num_batches=num_batches, + times=times, + with_label=with_label) + + # evaluate + if self.valid_steps > 0 and valid_label_iter is not None and valid_nolabel_iter is not None \ + and batch_id % self.valid_steps == 0: + self.evaluate( + data_label_iter=valid_label_iter, + data_nolabel_iter=valid_nolabel_iter) + + # compute accuracy for valid dataset + accuracy = self.infer( + data_iter=valid_label_iter, ex_data_iter=train_label_iter) + + # report summary message and save checkpoints + self.save_and_log_message( + report_for_unlabeled_data, cur_valid_metric=-accuracy) + + def forward(self, batch): + pred = [] + + with torch.no_grad(): + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) + result = self.model.infer(inputs=batch) + result = { + name: result[name].cpu().detach().numpy() + for name in result + } + intent_probs = result['intent_probs'] + if self.can_norm: + pred += [intent_probs] + else: + pred += np.argmax(intent_probs, axis=1).tolist() + + return pred + + def infer(self, data_iter, num_batches=None, ex_data_iter=None): + """ + Inference interface. + """ + self.logger.info('Generation starts ...') + infer_save_file = os.path.join(self.save_dir, + f'infer_{self.epoch}.result.json') + + # Inference + batch_cnt = 0 + pred, true = [], [] + outputs, labels = [], [] + begin_time = time.time() + + with torch.no_grad(): + if self.example: + for _, (batch, batch_size) in tqdm( + ex_data_iter, desc='Building train memory.'): + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + result = self.model.infer(inputs=batch) + result = { + name: result[name].cpu().detach().numpy() + for name in result + } + outputs.append(torch.from_numpy(result['features'])) + labels += batch['intent_label'].tolist() + + mem = torch.cat(outputs, dim=0) + mem = mem.cuda() if self.func_model.use_gpu else mem + labels = torch.LongTensor(labels).unsqueeze(0) + labels = labels.cuda() if self.func_model.use_gpu else labels + self.logger.info(f'Memory size: {mem.size()}') + + for _, (batch, batch_size) in tqdm(data_iter, total=num_batches): + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + result = self.model.infer(inputs=batch) + result = { + name: result[name].cpu().detach().numpy() + for name in result + } + + if self.example: + features = torch.from_numpy(result['features']) + features = features.cuda( + ) if self.func_model.use_gpu else features + probs = torch.softmax(features.mm(mem.t()), dim=-1) + intent_probs = torch.zeros( + probs.size(0), self.func_model.num_intent) + intent_probs = intent_probs.cuda( + ) if self.func_model.use_gpu else intent_probs + intent_probs = intent_probs.scatter_add( + -1, labels.repeat(probs.size(0), 1), probs) + intent_probs = intent_probs.cpu().detach().numpy() + else: + intent_probs = result['intent_probs'] + + if self.can_norm: + pred += [intent_probs] + true += batch['intent_label'].cpu().detach().tolist() + else: + pred += np.argmax(intent_probs, axis=1).tolist() + true += batch['intent_label'].cpu().detach().tolist() + + batch_cnt += 1 + if batch_cnt == num_batches: + break + + if self.can_norm: + true = np.array(true) + pred = np.concatenate(pred, axis=0) + acc_original, acc_final, message = self.can_normalization( + y_pred=pred, y_true=true, ex_data_iter=ex_data_iter) + accuracy = max(acc_original, acc_final) + infer_results = { + 'accuracy': accuracy, + 'pred_labels': pred.tolist(), + 'message': message + } + metrics_message = f'Accuracy: {accuracy} {message}' + else: + accuracy = sum(p == t for p, t in zip(pred, true)) / len(pred) + infer_results = {'accuracy': accuracy, 'pred_labels': pred} + metrics_message = f'Accuracy: {accuracy}' + + self.logger.info(f'Saved inference results to {infer_save_file}') + with open(infer_save_file, 'w') as fp: + json.dump(infer_results, fp, indent=2) + message_prefix = f'[Infer][{self.epoch}]' + time_cost = f'TIME-{time.time() - begin_time:.3f}' + message = ' '.join([message_prefix, metrics_message, time_cost]) + self.logger.info(message) + return accuracy + + def track_and_log_message(self, metrics, batch_id, batch_size, num_batches, + times, with_label): + # track metrics + batch_metrics_tracker = self.batch_metrics_tracker_label if with_label else self.batch_metrics_tracker_nolabel + token_metrics_tracker = self.token_metrics_tracker_label if with_label else self.token_metrics_tracker_nolabel + + metrics = { + k: v.cpu().detach().numpy() if isinstance(v, torch.Tensor) else v + for k, v in metrics.items() + } + mlm_num = metrics.pop('mlm_num', 0) + + batch_metrics = {k: v for k, v in metrics.items() if 'token' not in k} + token_metrics = {k: v for k, v in metrics.items() if 'token' in k} + batch_metrics_tracker.update(batch_metrics, batch_size) + token_metrics_tracker.update(token_metrics, mlm_num) + + # log message + if self.log_steps > 0 and batch_id % self.log_steps == 0: + batch_metrics_message = batch_metrics_tracker.value() + token_metrics_message = token_metrics_tracker.value() + label_prefix = 'Labeled' if with_label else 'Unlabeled' + message_prefix = f'[Train][{self.epoch}][{batch_id}/{num_batches}][{label_prefix}]' + avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}' + message = ' '.join([ + message_prefix, batch_metrics_message, token_metrics_message, + avg_time + ]) + self.logger.info(message) + + def save_and_log_message(self, + report_for_unlabeled_data, + cur_valid_metric=None): + # report message + batch_metrics_message = self.batch_metrics_tracker_label.summary() + token_metrics_message = self.token_metrics_tracker_label.summary() + message_prefix = f'[Valid][{self.epoch}][Labeled]' + message = ' '.join( + [message_prefix, batch_metrics_message, token_metrics_message]) + self.logger.info(message) + if report_for_unlabeled_data: + batch_metrics_message = self.batch_metrics_tracker_nolabel.summary( + ) + token_metrics_message = self.token_metrics_tracker_nolabel.summary( + ) + message_prefix = f'[Valid][{self.epoch}][Unlabeled]' + message = ' '.join( + [message_prefix, batch_metrics_message, token_metrics_message]) + self.logger.info(message) + + # save checkpoints + assert cur_valid_metric is not None + if self.is_decreased_valid_metric: + is_best = cur_valid_metric < self.best_valid_metric + else: + is_best = cur_valid_metric > self.best_valid_metric + if is_best: + self.best_valid_metric = cur_valid_metric + self.save(is_best) + + def balance_metrics(self, metrics, batch_size): + if self.gpu > 1: + for metric in metrics: + if metric is not None: + assert len(metric) == self.gpu + + intent_loss, mlm, token_mlm, mlm_num, kl, con = metrics + metrics = {} + + intent_loss = torch.mean(intent_loss) + metrics['intent_loss'] = intent_loss + loss = intent_loss + + if mlm is not None: + mlm_num = torch.sum(mlm_num) + token_mlm = torch.sum(mlm) * (batch_size / self.gpu) / mlm_num + mlm = torch.mean(mlm) + metrics['mlm_num'] = mlm_num + metrics['token_mlm'] = token_mlm + metrics['mlm'] = mlm + loss = loss + (token_mlm if self.func_model.token_loss else + mlm) * self.func_model.mlm_ratio + + if kl is not None: + kl = torch.mean(kl) + metrics['kl'] = kl + loss = loss + kl * self.func_model.kl_ratio + + if con is not None: + con = torch.mean(con) + metrics['con'] = con + loss = loss + con + + metrics['loss'] = loss + + assert 'loss' in metrics + return metrics['loss'], metrics + + def load(self): + """ load """ + + def _load_model_state(): + model_state_dict = torch.load( + f'{self.func_model.init_checkpoint}', + map_location=lambda storage, loc: storage) + + if 'module.' in list(model_state_dict.keys())[0]: + new_model_state_dict = OrderedDict() + for k, v in model_state_dict.items(): + assert k[:7] == 'module.' + new_model_state_dict[k[7:]] = v + model_state_dict = new_model_state_dict + + new_model_state_dict = OrderedDict() + parameters = { + name: param + for name, param in self.func_model.named_parameters() + } + for name, param in model_state_dict.items(): + if name in parameters: + if param.shape != parameters[name].shape: + assert hasattr(param, 'numpy') + arr = param.numpy() + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + if name == 'embedder.token_embedding.weight': + z[-param.shape[0]:] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + else: + if z.shape[0] < param.shape[0]: + z = arr[:z.shape[0]] + print(f'part of parameter({name}) are dropped') + else: + z[:param.shape[0]] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + dtype, device = param.dtype, param.device + z = torch.tensor(z, dtype=dtype, device=device) + new_model_state_dict[name] = z + else: + new_model_state_dict[name] = param + else: + print(f'parameter({name}) are dropped') + model_state_dict = new_model_state_dict + + for name in parameters: + if name not in model_state_dict: + if parameters[name].requires_grad: + print(f'parameter({name}) random normlize initialize') + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + dtype, device = parameters[name].dtype, parameters[ + name].device + model_state_dict[name] = torch.tensor( + z, dtype=dtype, device=device) + else: + model_state_dict[name] = parameters[name] + + self.func_model.load_state_dict(model_state_dict) + self.logger.info( + f"Loaded model state from '{self.func_model.init_checkpoint}.model'" + ) + + def _load_train_state(): + train_file = f'{self.func_model.init_checkpoint}.train' + if os.path.exists(train_file): + train_state_dict = torch.load( + train_file, map_location=lambda storage, loc: storage) + self.epoch = train_state_dict['epoch'] + self.best_valid_metric = train_state_dict['best_valid_metric'] + if self.optimizer is not None and 'optimizer' in train_state_dict: + self.optimizer.load_state_dict( + train_state_dict['optimizer']) + if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: + self.lr_scheduler.load_state_dict( + train_state_dict['lr_scheduler']) + self.logger.info( + f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " + f'best_valid_metric={self.best_valid_metric:.3f})') + else: + self.logger.info('Loaded no train state') + + if self.func_model.init_checkpoint is None: + self.logger.info('Loaded no model !!!') + return + + if self.do_train: + _load_model_state() + return + + if self.do_infer: + _load_model_state() + _load_train_state() diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index be9cb403..6ef6b010 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -44,6 +44,8 @@ class Tasks(object): token_classification = 'token-classification' conversational = 'conversational' text_generation = 'text-generation' + dialog_modeling = 'dialog-modeling' + dialog_intent_prediction = 'dialog-intent-prediction' table_question_answering = 'table-question-answering' feature_extraction = 'feature-extraction' sentence_similarity = 'sentence-similarity' diff --git a/modelscope/utils/nlp/__init__.py b/modelscope/utils/nlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/utils/nlp/space/__init__.py b/modelscope/utils/nlp/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/utils/nlp/space/args.py b/modelscope/utils/nlp/space/args.py new file mode 100644 index 00000000..d9e91e74 --- /dev/null +++ b/modelscope/utils/nlp/space/args.py @@ -0,0 +1,66 @@ +""" +Parse argument. +""" + +import argparse + +import json + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Unsupported value encountered.') + + +class HParams(dict): + """ Hyper-parameters class + + Store hyper-parameters in training / infer / ... scripts. + """ + + def __getattr__(self, name): + if name in self.keys(): + return self[name] + for v in self.values(): + if isinstance(v, HParams): + if name in v: + return v[name] + raise AttributeError(f"'HParams' object has no attribute '{name}'") + + def __setattr__(self, name, value): + self[name] = value + + def save(self, filename): + with open(filename, 'w', encoding='utf-8') as fp: + json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False) + + def load(self, filename): + with open(filename, 'r', encoding='utf-8') as fp: + params_dict = json.load(fp) + for k, v in params_dict.items(): + if isinstance(v, dict): + self[k].update(HParams(v)) + else: + self[k] = v + + +def parse_args(parser): + """ Parse hyper-parameters from cmdline. """ + parsed = parser.parse_args() + args = HParams() + optional_args = parser._action_groups[1] + for action in optional_args._group_actions[1:]: + arg_name = action.dest + args[arg_name] = getattr(parsed, arg_name) + for group in parser._action_groups[2:]: + group_args = HParams() + for action in group._group_actions: + arg_name = action.dest + group_args[arg_name] = getattr(parsed, arg_name) + if len(group_args) > 0: + args[group.title] = group_args + return args diff --git a/modelscope/utils/nlp/space/criterions.py b/modelscope/utils/nlp/space/criterions.py new file mode 100644 index 00000000..60f98457 --- /dev/null +++ b/modelscope/utils/nlp/space/criterions.py @@ -0,0 +1,52 @@ +import torch +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss + + +def compute_kl_loss(p, q, filter_scores=None): + p_loss = F.kl_div( + F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none') + q_loss = F.kl_div( + F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none') + + # You can choose whether to use function "sum" and "mean" depending on your task + p_loss = p_loss.sum(dim=-1) + q_loss = q_loss.sum(dim=-1) + + # mask is for filter mechanism + if filter_scores is not None: + p_loss = filter_scores * p_loss + q_loss = filter_scores * q_loss + + p_loss = p_loss.mean() + q_loss = q_loss.mean() + + loss = (p_loss + q_loss) / 2 + return loss + + +class CatKLLoss(_Loss): + """ + CatKLLoss + """ + + def __init__(self, reduction='mean'): + super(CatKLLoss, self).__init__() + assert reduction in ['none', 'sum', 'mean'] + self.reduction = reduction + + def forward(self, log_qy, log_py): + """ + KL(qy|py) = Eq[qy * log(q(y) / p(y))] + + log_qy: (batch_size, latent_size) + log_py: (batch_size, latent_size) + """ + qy = torch.exp(log_qy) + kl = torch.sum(qy * (log_qy - log_py), dim=1) + + if self.reduction == 'mean': + kl = kl.mean() + elif self.reduction == 'sum': + kl = kl.sum() + return kl diff --git a/modelscope/utils/nlp/space/db_ops.py b/modelscope/utils/nlp/space/db_ops.py new file mode 100644 index 00000000..2168c079 --- /dev/null +++ b/modelscope/utils/nlp/space/db_ops.py @@ -0,0 +1,321 @@ +import os +import random +import sqlite3 + +import json + +from .ontology import all_domains, db_domains + + +class MultiWozDB(object): + + def __init__(self, db_dir, db_paths): + self.dbs = {} + self.sql_dbs = {} + for domain in all_domains: + with open(os.path.join(db_dir, db_paths[domain]), 'r') as f: + self.dbs[domain] = json.loads(f.read().lower()) + + def oneHotVector(self, domain, num): + """Return number of available entities for particular domain.""" + vector = [0, 0, 0, 0] + if num == '': + return vector + if domain != 'train': + if num == 0: + vector = [1, 0, 0, 0] + elif num == 1: + vector = [0, 1, 0, 0] + elif num <= 3: + vector = [0, 0, 1, 0] + else: + vector = [0, 0, 0, 1] + else: + if num == 0: + vector = [1, 0, 0, 0] + elif num <= 5: + vector = [0, 1, 0, 0] + elif num <= 10: + vector = [0, 0, 1, 0] + else: + vector = [0, 0, 0, 1] + return vector + + def addBookingPointer(self, turn_da): + """Add information about availability of the booking option.""" + # Booking pointer + # Do not consider booking two things in a single turn. + vector = [0, 0] + if turn_da.get('booking-nobook'): + vector = [1, 0] + if turn_da.get('booking-book') or turn_da.get('train-offerbooked'): + vector = [0, 1] + return vector + + def addDBPointer(self, domain, match_num, return_num=False): + """Create database pointer for all related domains.""" + # if turn_domains is None: + # turn_domains = db_domains + if domain in db_domains: + vector = self.oneHotVector(domain, match_num) + else: + vector = [0, 0, 0, 0] + return vector + + def addDBIndicator(self, domain, match_num, return_num=False): + """Create database indicator for all related domains.""" + # if turn_domains is None: + # turn_domains = db_domains + if domain in db_domains: + vector = self.oneHotVector(domain, match_num) + else: + vector = [0, 0, 0, 0] + + # '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]' + if vector == [0, 0, 0, 0]: + indicator = '[db_nores]' + else: + indicator = '[db_%s]' % vector.index(1) + return indicator + + def get_match_num(self, constraints, return_entry=False): + """Create database pointer for all related domains.""" + match = {'general': ''} + entry = {} + # if turn_domains is None: + # turn_domains = db_domains + for domain in all_domains: + match[domain] = '' + if domain in db_domains and constraints.get(domain): + matched_ents = self.queryJsons(domain, constraints[domain]) + match[domain] = len(matched_ents) + if return_entry: + entry[domain] = matched_ents + if return_entry: + return entry + return match + + def pointerBack(self, vector, domain): + # multi domain implementation + # domnum = cfg.domain_num + if domain.endswith(']'): + domain = domain[1:-1] + if domain != 'train': + nummap = {0: '0', 1: '1', 2: '2-3', 3: '>3'} + else: + nummap = {0: '0', 1: '1-5', 2: '6-10', 3: '>10'} + if vector[:4] == [0, 0, 0, 0]: + report = '' + else: + num = vector.index(1) + report = domain + ': ' + nummap[num] + '; ' + + if vector[-2] == 0 and vector[-1] == 1: + report += 'booking: ok' + if vector[-2] == 1 and vector[-1] == 0: + report += 'booking: unable' + + return report + + def queryJsons(self, + domain, + constraints, + exactly_match=True, + return_name=False): + """Returns the list of entities for a given domain + based on the annotation of the belief state + constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'} + """ + # query the db + if domain == 'taxi': + return [{ + 'taxi_colors': + random.choice(self.dbs[domain]['taxi_colors']), + 'taxi_types': + random.choice(self.dbs[domain]['taxi_types']), + 'taxi_phone': [random.randint(1, 9) for _ in range(10)] + }] + if domain == 'police': + return self.dbs['police'] + if domain == 'hospital': + if constraints.get('department'): + for entry in self.dbs['hospital']: + if entry.get('department') == constraints.get( + 'department'): + return [entry] + else: + return [] + + valid_cons = False + for v in constraints.values(): + if v not in ['not mentioned', '']: + valid_cons = True + if not valid_cons: + return [] + + match_result = [] + + if 'name' in constraints: + for db_ent in self.dbs[domain]: + if 'name' in db_ent: + cons = constraints['name'] + dbn = db_ent['name'] + if cons == dbn: + db_ent = db_ent if not return_name else db_ent['name'] + match_result.append(db_ent) + return match_result + + for db_ent in self.dbs[domain]: + match = True + for s, v in constraints.items(): + if s == 'name': + continue + if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \ + (domain == 'restaurant' and s in ['day', 'time']): + # 因为这些inform slot属于book info,而数据库中没有这些slot; + # 能否book是根据user goal中的信息判断,而非通过数据库查询; + continue + + skip_case = { + "don't care": 1, + "do n't care": 1, + 'dont care': 1, + 'not mentioned': 1, + 'dontcare': 1, + '': 1 + } + if skip_case.get(v): + continue + + if s not in db_ent: + # logging.warning('Searching warning: slot %s not in %s db'%(s, domain)) + match = False + break + + # v = 'guesthouse' if v == 'guest house' else v + # v = 'swimmingpool' if v == 'swimming pool' else v + v = 'yes' if v == 'free' else v + + if s in ['arrive', 'leave']: + try: + h, m = v.split( + ':' + ) # raise error if time value is not xx:xx format + v = int(h) * 60 + int(m) + except Exception: + match = False + break + time = int(db_ent[s].split(':')[0]) * 60 + int( + db_ent[s].split(':')[1]) + if s == 'arrive' and v > time: + match = False + if s == 'leave' and v < time: + match = False + else: + if exactly_match and v != db_ent[s]: + match = False + break + elif v not in db_ent[s]: + match = False + break + + if match: + match_result.append(db_ent) + + if not return_name: + return match_result + else: + if domain == 'train': + match_result = [e['id'] for e in match_result] + else: + match_result = [e['name'] for e in match_result] + return match_result + + def querySQL(self, domain, constraints): + if not self.sql_dbs: + for dom in db_domains: + db = 'db/{}-dbase.db'.format(dom) + conn = sqlite3.connect(db) + c = conn.cursor() + self.sql_dbs[dom] = c + + sql_query = 'select * from {}'.format(domain) + + flag = True + for key, val in constraints.items(): + if val == '' \ + or val == 'dontcare' \ + or val == 'not mentioned' \ + or val == "don't care" \ + or val == 'dont care' \ + or val == "do n't care": + pass + else: + if flag: + sql_query += ' where ' + val2 = val.replace("'", "''") + # val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r' ' + key + ' > ' + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r' ' + key + ' < ' + r"'" + val2 + r"'" + else: + sql_query += r' ' + key + '=' + r"'" + val2 + r"'" + flag = False + else: + val2 = val.replace("'", "''") + # val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r' and ' + key + ' > ' + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r' and ' + key + ' < ' + r"'" + val2 + r"'" + else: + sql_query += r' and ' + key + '=' + r"'" + val2 + r"'" + + try: # "select * from attraction where name = 'queens college'" + print(sql_query) + return self.sql_dbs[domain].execute(sql_query).fetchall() + except Exception: + return [] # TODO test it + + +if __name__ == '__main__': + dbPATHs = { + 'attraction': 'db/attraction_db_processed.json', + 'hospital': 'db/hospital_db_processed.json', + 'hotel': 'db/hotel_db_processed.json', + 'police': 'db/police_db_processed.json', + 'restaurant': 'db/restaurant_db_processed.json', + 'taxi': 'db/taxi_db_processed.json', + 'train': 'db/train_db_processed.json', + } + db = MultiWozDB(dbPATHs) + while True: + constraints = {} + inp = input( + 'input belief state in fomat: domain-slot1=value1;slot2=value2...\n' + ) + domain, cons = inp.split('-') + for sv in cons.split(';'): + s, v = sv.split('=') + constraints[s] = v + # res = db.querySQL(domain, constraints) + res = db.queryJsons(domain, constraints, return_name=True) + report = [] + reidx = { + 'hotel': 8, + 'restaurant': 6, + 'attraction': 5, + 'train': 1, + } + # for ent in res: + # if reidx.get(domain): + # report.append(ent[reidx[domain]]) + # for ent in res: + # if 'name' in ent: + # report.append(ent['name']) + # if 'trainid' in ent: + # report.append(ent['trainid']) + print(constraints) + print(res) + print('count:', len(res), '\nnames:', report) diff --git a/modelscope/utils/nlp/space/ontology.py b/modelscope/utils/nlp/space/ontology.py new file mode 100644 index 00000000..4f27168a --- /dev/null +++ b/modelscope/utils/nlp/space/ontology.py @@ -0,0 +1,217 @@ +all_domains = [ + 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital' +] +all_domains_with_bracket = ['[{}]'.format(item) for item in all_domains] +db_domains = ['restaurant', 'hotel', 'attraction', 'train'] +placeholder_tokens = [ + '', '', '', '', '', '', '', + '', '', '', '', '', '', + '', '', '' +] + +normlize_slot_names = { + 'car type': 'car', + 'entrance fee': 'price', + 'duration': 'time', + 'leaveat': 'leave', + 'arriveby': 'arrive', + 'trainid': 'id' +} + +requestable_slots = { + 'taxi': ['car', 'phone'], + 'police': ['postcode', 'address', 'phone'], + 'hospital': ['address', 'phone', 'postcode'], + 'hotel': [ + 'address', 'postcode', 'internet', 'phone', 'parking', 'type', + 'pricerange', 'stars', 'area', 'reference' + ], + 'attraction': + ['price', 'type', 'address', 'postcode', 'phone', 'area', 'reference'], + 'train': ['time', 'leave', 'price', 'arrive', 'id', 'reference'], + 'restaurant': [ + 'phone', 'postcode', 'address', 'pricerange', 'food', 'area', + 'reference' + ] +} +all_reqslot = [ + 'car', 'address', 'postcode', 'phone', 'internet', 'parking', 'type', + 'pricerange', 'food', 'stars', 'area', 'reference', 'time', 'leave', + 'price', 'arrive', 'id' +] + +informable_slots = { + 'taxi': ['leave', 'destination', 'departure', 'arrive'], + 'police': [], + 'hospital': ['department'], + 'hotel': [ + 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', + 'area', 'stars', 'name' + ], + 'attraction': ['area', 'type', 'name'], + 'train': ['destination', 'day', 'arrive', 'departure', 'people', 'leave'], + 'restaurant': + ['food', 'pricerange', 'area', 'name', 'time', 'day', 'people'] +} +all_infslot = [ + 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', + 'area', 'stars', 'name', 'leave', 'destination', 'departure', 'arrive', + 'department', 'food', 'time' +] + +all_slots = all_reqslot + [ + 'stay', 'day', 'people', 'name', 'destination', 'departure', 'department' +] +get_slot = {} +for s in all_slots: + get_slot[s] = 1 + +# mapping slots in dialogue act to original goal slot names +da_abbr_to_slot_name = { + 'addr': 'address', + 'fee': 'price', + 'post': 'postcode', + 'ref': 'reference', + 'ticket': 'price', + 'depart': 'departure', + 'dest': 'destination', +} + +dialog_acts = { + 'restaurant': [ + 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', + 'offerbooked', 'nobook' + ], + 'hotel': [ + 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', + 'offerbooked', 'nobook' + ], + 'attraction': ['inform', 'request', 'nooffer', 'recommend', 'select'], + 'train': + ['inform', 'request', 'nooffer', 'offerbook', 'offerbooked', 'select'], + 'taxi': ['inform', 'request'], + 'police': ['inform', 'request'], + 'hospital': ['inform', 'request'], + # 'booking': ['book', 'inform', 'nobook', 'request'], + 'general': ['bye', 'greet', 'reqmore', 'welcome'], +} +all_acts = [] +for acts in dialog_acts.values(): + for act in acts: + if act not in all_acts: + all_acts.append(act) + +dialog_act_params = { + 'inform': all_slots + ['choice', 'open'], + 'request': all_infslot + ['choice', 'price'], + 'nooffer': all_slots + ['choice'], + 'recommend': all_reqslot + ['choice', 'open'], + 'select': all_slots + ['choice'], + # 'book': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], + 'nobook': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], + 'offerbook': all_slots + ['choice'], + 'offerbooked': all_slots + ['choice'], + 'reqmore': [], + 'welcome': [], + 'bye': [], + 'greet': [], +} + +dialog_act_all_slots = all_slots + ['choice', 'open'] + +# special slot tokens in belief span +# no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange] +slot_name_to_slot_token = {} + +# special slot tokens in responses +# not use at the momoent +slot_name_to_value_token = { + # 'entrance fee': '[value_price]', + # 'pricerange': '[value_price]', + # 'arriveby': '[value_time]', + # 'leaveat': '[value_time]', + # 'departure': '[value_place]', + # 'destination': '[value_place]', + # 'stay': 'count', + # 'people': 'count' +} + +# eos tokens definition +eos_tokens = { + 'user': '', + 'user_delex': '', + 'resp': '', + 'resp_gen': '', + 'pv_resp': '', + 'bspn': '', + 'bspn_gen': '', + 'pv_bspn': '', + 'bsdx': '', + 'bsdx_gen': '', + 'pv_bsdx': '', + 'qspn': '', + 'qspn_gen': '', + 'pv_qspn': '', + 'aspn': '', + 'aspn_gen': '', + 'pv_aspn': '', + 'dspn': '', + 'dspn_gen': '', + 'pv_dspn': '' +} + +# sos tokens definition +sos_tokens = { + 'user': '', + 'user_delex': '', + 'resp': '', + 'resp_gen': '', + 'pv_resp': '', + 'bspn': '', + 'bspn_gen': '', + 'pv_bspn': '', + 'bsdx': '', + 'bsdx_gen': '', + 'pv_bsdx': '', + 'qspn': '', + 'qspn_gen': '', + 'pv_qspn': '', + 'aspn': '', + 'aspn_gen': '', + 'pv_aspn': '', + 'dspn': '', + 'dspn_gen': '', + 'pv_dspn': '' +} + +# db tokens definition +db_tokens = [ + '', '', '[book_nores]', '[book_fail]', '[book_success]', + '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]' +] + + +# understand tokens definition +def get_understand_tokens(prompt_num_for_understand): + understand_tokens = [] + for i in range(prompt_num_for_understand): + understand_tokens.append(f'') + return understand_tokens + + +# policy tokens definition +def get_policy_tokens(prompt_num_for_policy): + policy_tokens = [] + for i in range(prompt_num_for_policy): + policy_tokens.append(f'') + return policy_tokens + + +# all special tokens definition +def get_special_tokens(other_tokens): + special_tokens = [ + '', '', '', '', '', '', + '', '', '', '', '', '', + '', '', '', '' + ] + db_tokens + other_tokens + return special_tokens diff --git a/modelscope/utils/nlp/space/scores.py b/modelscope/utils/nlp/space/scores.py new file mode 100644 index 00000000..fe0a8a17 --- /dev/null +++ b/modelscope/utils/nlp/space/scores.py @@ -0,0 +1,6 @@ +def hierarchical_set_score(frame1, frame2): + # deal with empty frame + if not (frame1 and frame2): + return 0. + pass + return 0. diff --git a/modelscope/utils/nlp/space/utils.py b/modelscope/utils/nlp/space/utils.py new file mode 100644 index 00000000..ba956b7d --- /dev/null +++ b/modelscope/utils/nlp/space/utils.py @@ -0,0 +1,206 @@ +import logging +from collections import OrderedDict + +import json +import numpy as np + +from . import ontology + + +def max_lens(X): + lens = [len(X)] + while isinstance(X[0], list): + lens.append(max(map(len, X))) + X = [x for xs in X for x in xs] + return lens + + +def list2np(X: object, padding: object = 0, dtype: object = 'int64') -> object: + shape = max_lens(X) + ret = np.full(shape, padding, dtype=np.int32) + + if len(shape) == 1: + ret = np.array(X) + elif len(shape) == 2: + for i, x in enumerate(X): + ret[i, :len(x)] = np.array(x) + elif len(shape) == 3: + for i, xs in enumerate(X): + for j, x in enumerate(xs): + ret[i, j, :len(x)] = np.array(x) + return ret.astype(dtype) + + +def clean_replace(s, r, t, forward=True, backward=False): + + def clean_replace_single(s, r, t, forward, backward, sidx=0): + # idx = s[sidx:].find(r) + idx = s.find(r) + if idx == -1: + return s, -1 + idx_r = idx + len(r) + if backward: + while idx > 0 and s[idx - 1]: + idx -= 1 + elif idx > 0 and s[idx - 1] != ' ': + return s, -1 + + if forward: + while \ + idx_r < len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): + idx_r += 1 + elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): + return s, -1 + return s[:idx] + t + s[idx_r:], idx_r + + # source, replace, target = s, r, t + # count = 0 + sidx = 0 + while sidx != -1: + s, sidx = clean_replace_single(s, r, t, forward, backward, sidx) + # count += 1 + # print(s, sidx) + # if count == 20: + # print(source, '\n', replace, '\n', target) + # quit() + return s + + +def py2np(list): + return np.array(list) + + +def write_dict(fn, dic): + with open(fn, 'w') as f: + json.dump(dic, f, indent=2) + + +def f1_score(label_list, pred_list): + tp = len([t for t in pred_list if t in label_list]) + fp = max(0, len(pred_list) - tp) + fn = max(0, len(label_list) - tp) + precision = tp / (tp + fp + 1e-10) + recall = tp / (tp + fn + 1e-10) + f1 = 2 * precision * recall / (precision + recall + 1e-10) + return f1 + + +class MultiWOZVocab(object): + + def __init__(self, vocab_size=0): + """ + vocab for multiwoz dataset + """ + self.vocab_size = vocab_size + self.vocab_size_oov = 0 # get after construction + self._idx2word = {} # word + oov + self._word2idx = {} # word + self._freq_dict = {} # word + oov + for w in [ + '[PAD]', '', '[UNK]', '', '', '', + '', '', '', '', '' + ]: + self._absolute_add_word(w) + + def _absolute_add_word(self, w): + idx = len(self._idx2word) + self._idx2word[idx] = w + self._word2idx[w] = idx + + def add_word(self, word): + if word not in self._freq_dict: + self._freq_dict[word] = 0 + self._freq_dict[word] += 1 + + def has_word(self, word): + return self._freq_dict.get(word) + + def _add_to_vocab(self, word): + if word not in self._word2idx: + idx = len(self._idx2word) + self._idx2word[idx] = word + self._word2idx[word] = idx + + def construct(self): + freq_dict_sorted = sorted( + self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) + print('Vocabulary size including oov: %d' % + (len(freq_dict_sorted) + len(self._idx2word))) + if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size: + logging.warning( + 'actual label set smaller than that configured: {}/{}'.format( + len(freq_dict_sorted) + len(self._idx2word), + self.vocab_size)) + for word in ontology.all_domains + ['general']: + word = '[' + word + ']' + self._add_to_vocab(word) + for word in ontology.all_acts: + word = '[' + word + ']' + self._add_to_vocab(word) + for word in ontology.all_slots: + self._add_to_vocab(word) + for word in freq_dict_sorted: + if word.startswith('[value_') and word.endswith(']'): + self._add_to_vocab(word) + for word in freq_dict_sorted: + self._add_to_vocab(word) + self.vocab_size_oov = len(self._idx2word) + + def load_vocab(self, vocab_path): + self._freq_dict = json.loads( + open(vocab_path + '.freq.json', 'r').read()) + self._word2idx = json.loads( + open(vocab_path + '.word2idx.json', 'r').read()) + self._idx2word = {} + for w, idx in self._word2idx.items(): + self._idx2word[idx] = w + self.vocab_size_oov = len(self._idx2word) + print('vocab file loaded from "' + vocab_path + '"') + print('Vocabulary size including oov: %d' % (self.vocab_size_oov)) + + def save_vocab(self, vocab_path): + _freq_dict = OrderedDict( + sorted( + self._freq_dict.items(), key=lambda kv: kv[1], reverse=True)) + write_dict(vocab_path + '.word2idx.json', self._word2idx) + write_dict(vocab_path + '.freq.json', _freq_dict) + + def encode(self, word, include_oov=True): + if include_oov: + if self._word2idx.get(word, None) is None: + raise ValueError( + 'Unknown word: %s. Vocabulary should include oovs here.' + % word) + return self._word2idx[word] + else: + word = '' if word not in self._word2idx else word + return self._word2idx[word] + + def sentence_encode(self, word_list): + return [self.encode(_) for _ in word_list] + + def oov_idx_map(self, idx): + return 2 if idx > self.vocab_size else idx + + def sentence_oov_map(self, index_list): + return [self.oov_idx_map(_) for _ in index_list] + + def decode(self, idx, indicate_oov=False): + if not self._idx2word.get(idx): + raise ValueError( + 'Error idx: %d. Vocabulary should include oovs here.' % idx) + if not indicate_oov or idx < self.vocab_size: + return self._idx2word[idx] + else: + return self._idx2word[idx] + '(o)' + + # def sentence_decode(self, index_list, eos=None, indicate_oov=False): + # l = [self.decode(_, indicate_oov) for _ in index_list] + # if not eos or eos not in l: + # return ' '.join(l) + # else: + # idx = l.index(eos) + # return ' '.join(l[:idx]) + # + # def nl_decode(self, l, eos=None): + # return [self.sentence_decode(_, eos) + '\n' for _ in l] diff --git a/requirements/nlp.txt b/requirements/nlp.txt index 261b9ec5..84a57b5c 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -1 +1,4 @@ https://alinlp.alibaba-inc.com/pypi/sofa-1.0.3-py3-none-any.whl +https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz +spacy>=2.3.5 +# python -m spacy download en_core_web_sm diff --git a/tests/pipelines/nlp/__init__.py b/tests/pipelines/nlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pipelines/nlp/test_dialog_intent_prediction.py b/tests/pipelines/nlp/test_dialog_intent_prediction.py new file mode 100644 index 00000000..0ec4e1e7 --- /dev/null +++ b/tests/pipelines/nlp/test_dialog_intent_prediction.py @@ -0,0 +1,60 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from maas_hub.snapshot_download import snapshot_download + +from modelscope.models import Model +from modelscope.models.nlp import DialogIntentModel +from modelscope.pipelines import DialogIntentPredictionPipeline, pipeline +from modelscope.preprocessors import DialogIntentPredictionPreprocessor +from modelscope.utils.constant import Tasks + + +class DialogIntentPredictionTest(unittest.TestCase): + model_id = 'damo/nlp_space_dialog-intent-prediction' + test_case = [ + 'How do I locate my card?', + 'I still have not received my new card, I ordered over a week ago.' + ] + + @unittest.skip('test with snapshot_download') + def test_run(self): + cache_path = snapshot_download(self.model_id) + preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) + model = DialogIntentModel( + model_dir=cache_path, + text_field=preprocessor.text_field, + config=preprocessor.config) + + pipelines = [ + DialogIntentPredictionPipeline( + model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.dialog_intent_prediction, + model=model, + preprocessor=preprocessor) + ] + + for my_pipeline, item in list(zip(pipelines, self.test_case)): + print(my_pipeline(item)) + + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = DialogIntentPredictionPreprocessor( + model_dir=model.model_dir) + + pipelines = [ + DialogIntentPredictionPipeline( + model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.dialog_intent_prediction, + model=model, + preprocessor=preprocessor) + ] + + for my_pipeline, item in list(zip(pipelines, self.test_case)): + print(my_pipeline(item)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/nlp/test_dialog_modeling.py b/tests/pipelines/nlp/test_dialog_modeling.py new file mode 100644 index 00000000..7d4da8fe --- /dev/null +++ b/tests/pipelines/nlp/test_dialog_modeling.py @@ -0,0 +1,149 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import tempfile +import unittest + +from maas_hub.snapshot_download import snapshot_download + +from modelscope.models import Model +from modelscope.models.nlp import DialogModelingModel +from modelscope.pipelines import DialogModelingPipeline, pipeline +from modelscope.preprocessors import DialogModelingPreprocessor +from modelscope.utils.constant import Tasks + + +class DialogModelingTest(unittest.TestCase): + model_id = 'damo/nlp_space_dialog-modeling' + test_case = { + 'sng0073': { + 'goal': { + 'taxi': { + 'info': { + 'leaveat': '17:15', + 'destination': 'pizza hut fen ditton', + 'departure': "saint john's college" + }, + 'reqt': ['car', 'phone'], + 'fail_info': {} + } + }, + 'log': [{ + 'user': + "i would like a taxi from saint john 's college to pizza hut fen ditton .", + 'user_delex': + 'i would like a taxi from [value_departure] to [value_destination] .', + 'resp': + 'what time do you want to leave and what time do you want to arrive by ?', + 'sys': + 'what time do you want to leave and what time do you want to arrive by ?', + 'pointer': '0,0,0,0,0,0', + 'match': '', + 'constraint': + "[taxi] destination pizza hut fen ditton departure saint john 's college", + 'cons_delex': '[taxi] destination departure', + 'sys_act': '[taxi] [request] leave arrive', + 'turn_num': 0, + 'turn_domain': '[taxi]' + }, { + 'user': 'i want to leave after 17:15 .', + 'user_delex': 'i want to leave after [value_leave] .', + 'resp': + 'booking completed ! your taxi will be [value_car] contact number is [value_phone]', + 'sys': + 'booking completed ! your taxi will be blue honda contact number is 07218068540', + 'pointer': '0,0,0,0,0,0', + 'match': '', + 'constraint': + "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", + 'cons_delex': '[taxi] destination departure leave', + 'sys_act': '[taxi] [inform] car phone', + 'turn_num': 1, + 'turn_domain': '[taxi]' + }, { + 'user': 'thank you for all the help ! i appreciate it .', + 'user_delex': 'thank you for all the help ! i appreciate it .', + 'resp': + 'you are welcome . is there anything else i can help you with today ?', + 'sys': + 'you are welcome . is there anything else i can help you with today ?', + 'pointer': '0,0,0,0,0,0', + 'match': '', + 'constraint': + "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", + 'cons_delex': '[taxi] destination departure leave', + 'sys_act': '[general] [reqmore]', + 'turn_num': 2, + 'turn_domain': '[general]' + }, { + 'user': 'no , i am all set . have a nice day . bye .', + 'user_delex': 'no , i am all set . have a nice day . bye .', + 'resp': 'you too ! thank you', + 'sys': 'you too ! thank you', + 'pointer': '0,0,0,0,0,0', + 'match': '', + 'constraint': + "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", + 'cons_delex': '[taxi] destination departure leave', + 'sys_act': '[general] [bye]', + 'turn_num': 3, + 'turn_domain': '[general]' + }] + } + } + + @unittest.skip('test with snapshot_download') + def test_run(self): + + cache_path = snapshot_download(self.model_id) + + preprocessor = DialogModelingPreprocessor(model_dir=cache_path) + model = DialogModelingModel( + model_dir=cache_path, + text_field=preprocessor.text_field, + config=preprocessor.config) + pipelines = [ + DialogModelingPipeline(model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.dialog_modeling, + model=model, + preprocessor=preprocessor) + ] + + result = {} + for step, item in enumerate(self.test_case['sng0073']['log']): + user = item['user'] + print('user: {}'.format(user)) + + result = pipelines[step % 2]({ + 'user_input': user, + 'history': result + }) + print('sys : {}'.format(result['sys'])) + + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir) + + pipelines = [ + DialogModelingPipeline(model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.dialog_modeling, + model=model, + preprocessor=preprocessor) + ] + + result = {} + for step, item in enumerate(self.test_case['sng0073']['log']): + user = item['user'] + print('user: {}'.format(user)) + + result = pipelines[step % 2]({ + 'user_input': user, + 'history': result + }) + print('sys : {}'.format(result['sys'])) + + +if __name__ == '__main__': + unittest.main()