mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
[to #42322933] Doc2Bot documentation with retrieval rerank, generation
(cherry picked from commit 2fced1c06f)
This commit is contained in:
committed by
Zhicheng Zhang
parent
c0d7f951af
commit
e0edbf135c
@@ -155,6 +155,7 @@ class Models(object):
|
||||
xlm_roberta = 'xlm-roberta'
|
||||
transformers = 'transformers'
|
||||
plug_mental = 'plug-mental'
|
||||
doc2bot = 'doc2bot'
|
||||
|
||||
# audio models
|
||||
sambert_hifigan = 'sambert-hifigan'
|
||||
@@ -426,6 +427,9 @@ class Pipelines(object):
|
||||
token_classification = 'token-classification'
|
||||
translation_evaluation = 'translation-evaluation'
|
||||
user_satisfaction_estimation = 'user-satisfaction-estimation'
|
||||
document_grounded_dialog_retrieval = 'document-grounded-dialog-retrieval'
|
||||
document_grounded_dialog_rerank = 'document-grounded-dialog-rerank'
|
||||
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
|
||||
|
||||
# audio tasks
|
||||
sambert_hifigan_tts = 'sambert-hifigan-tts'
|
||||
@@ -538,6 +542,15 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.table_question_answering:
|
||||
(Pipelines.table_question_answering_pipeline,
|
||||
'damo/nlp-convai-text2sql-pretrain-cn'),
|
||||
Tasks.document_grounded_dialog_generate:
|
||||
(Pipelines.document_grounded_dialog_generate,
|
||||
'DAMO_ConvAI/nlp_convai_generation_pretrain'),
|
||||
Tasks.document_grounded_dialog_rerank:
|
||||
(Pipelines.document_grounded_dialog_rerank,
|
||||
'damo/nlp_convai_rerank_pretrain'),
|
||||
Tasks.document_grounded_dialog_retrieval:
|
||||
(Pipelines.document_grounded_dialog_retrieval,
|
||||
'DAMO_ConvAI/nlp_convai_retrieval_pretrain'),
|
||||
Tasks.text_error_correction:
|
||||
(Pipelines.text_error_correction,
|
||||
'damo/nlp_bart_text-error-correction_chinese'),
|
||||
@@ -691,9 +704,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.text_driven_segmentation:
|
||||
(Pipelines.text_driven_segmentation,
|
||||
'damo/cv_vitl16_segmentation_text-driven-seg'),
|
||||
Tasks.movie_scene_segmentation:
|
||||
(Pipelines.movie_scene_segmentation,
|
||||
'damo/cv_resnet50-bert_video-scene-segmentation_movienet'),
|
||||
Tasks.movie_scene_segmentation: (
|
||||
Pipelines.movie_scene_segmentation,
|
||||
'damo/cv_resnet50-bert_video-scene-segmentation_movienet'),
|
||||
Tasks.shop_segmentation: (Pipelines.shop_segmentation,
|
||||
'damo/cv_vitb16_segmentation_shop-seg'),
|
||||
Tasks.image_inpainting: (Pipelines.image_inpainting,
|
||||
@@ -704,14 +717,14 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_video-inpainting'),
|
||||
Tasks.video_human_matting: (Pipelines.video_human_matting,
|
||||
'damo/cv_effnetv2_video-human-matting'),
|
||||
Tasks.video_frame_interpolation:
|
||||
(Pipelines.video_frame_interpolation,
|
||||
'damo/cv_raft_video-frame-interpolation'),
|
||||
Tasks.video_frame_interpolation: (
|
||||
Pipelines.video_frame_interpolation,
|
||||
'damo/cv_raft_video-frame-interpolation'),
|
||||
Tasks.video_deinterlace: (Pipelines.video_deinterlace,
|
||||
'damo/cv_unet_video-deinterlace'),
|
||||
Tasks.human_wholebody_keypoint:
|
||||
(Pipelines.human_wholebody_keypoint,
|
||||
'damo/cv_hrnetw48_human-wholebody-keypoint_image'),
|
||||
Tasks.human_wholebody_keypoint: (
|
||||
Pipelines.human_wholebody_keypoint,
|
||||
'damo/cv_hrnetw48_human-wholebody-keypoint_image'),
|
||||
Tasks.hand_static: (Pipelines.hand_static,
|
||||
'damo/cv_mobileface_hand-static'),
|
||||
Tasks.face_human_hand_detection: (
|
||||
@@ -797,6 +810,9 @@ class NLPTrainers(object):
|
||||
faq_question_answering_trainer = 'faq-question-answering-trainer'
|
||||
gpt_moe_trainer = 'nlp-gpt-moe-trainer'
|
||||
table_question_answering_trainer = 'table-question-answering-trainer'
|
||||
document_grounded_dialog_generate_trainer = 'document-grounded-dialog-generate-trainer'
|
||||
document_grounded_dialog_rerank_trainer = 'document-grounded-dialog-rerank-trainer'
|
||||
document_grounded_dialog_retrieval_trainer = 'document-grounded-dialog-retrieval-trainer'
|
||||
|
||||
|
||||
class MultiModalTrainers(object):
|
||||
@@ -923,6 +939,9 @@ class Preprocessors(object):
|
||||
sentence_piece = 'sentence-piece'
|
||||
translation_evaluation = 'translation-evaluation-preprocessor'
|
||||
dialog_use_preprocessor = 'dialog-use-preprocessor'
|
||||
document_grounded_dialog_retrieval = 'document-grounded-dialog-retrieval'
|
||||
document_grounded_dialog_rerank = 'document-grounded-dialog-rerank'
|
||||
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
|
||||
|
||||
# audio preprocessor
|
||||
linear_aec_fbank = 'linear-aec-fbank'
|
||||
|
||||
@@ -65,6 +65,9 @@ if TYPE_CHECKING:
|
||||
from .veco import (VecoConfig, VecoForMaskedLM,
|
||||
VecoForSequenceClassification,
|
||||
VecoForTokenClassification, VecoModel)
|
||||
from .dgds import (DocumentGroundedDialogGenerateModel,
|
||||
DocumentGroundedDialogRetrievalModel,
|
||||
DocumentGroundedDialogRerankModel)
|
||||
from .xlm_roberta import XLMRobertaConfig, XLMRobertaModel
|
||||
|
||||
else:
|
||||
@@ -133,6 +136,11 @@ else:
|
||||
'T5': ['T5ForConditionalGeneration'],
|
||||
'unite': ['UniTEForTranslationEvaluation'],
|
||||
'use': ['UserSatisfactionEstimation'],
|
||||
'dgds': [
|
||||
'DocumentGroundedDialogGenerateModel',
|
||||
'DocumentGroundedDialogRetrievalModel',
|
||||
'DocumentGroundedDialogRerankModel'
|
||||
],
|
||||
'veco': [
|
||||
'VecoConfig',
|
||||
'VecoForMaskedLM',
|
||||
|
||||
28
modelscope/models/nlp/dgds/__init__.py
Normal file
28
modelscope/models/nlp/dgds/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .document_grounded_dialog_generate import DocumentGroundedDialogGenerateModel
|
||||
from .document_grounded_dialog_retrieval import DocumentGroundedDialogRerankModel
|
||||
from .document_grounded_dialog_retrieval import DocumentGroundedDialogRetrievalModel
|
||||
else:
|
||||
_import_structure = {
|
||||
'document_grounded_dialog_generate':
|
||||
['DocumentGroundedDialogGenerateModel'],
|
||||
'document_grounded_dialog_rerank':
|
||||
['DocumentGroundedDialogRerankModel'],
|
||||
'document_grounded_dialog_retrieval':
|
||||
['DocumentGroundedDialogRetrievalModel']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
191
modelscope/models/nlp/dgds/backbone.py
Normal file
191
modelscope/models/nlp/dgds/backbone.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved.
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import os.path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers import (AutoConfig, DPRConfig, DPRQuestionEncoder,
|
||||
MT5ForConditionalGeneration, RagTokenForGeneration,
|
||||
XLMRobertaForSequenceClassification, XLMRobertaModel,
|
||||
XLMRobertaTokenizer)
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class Wrapper(nn.Module):
|
||||
|
||||
def __init__(self, encoder):
|
||||
super(Wrapper, self).__init__()
|
||||
self.encoder = encoder
|
||||
|
||||
def forward(self, input_ids, attention_mask, dummy_tensor):
|
||||
return self.encoder(input_ids, attention_mask).pooler_output
|
||||
|
||||
|
||||
class DPRModel(nn.Module):
|
||||
|
||||
def __init__(self, model_dir, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
qry_encoder = XLMRobertaModel(
|
||||
config=AutoConfig.from_pretrained(
|
||||
os.path.join(model_dir, 'qry_encoder')))
|
||||
ctx_encoder = XLMRobertaModel(
|
||||
config=AutoConfig.from_pretrained(
|
||||
os.path.join(model_dir, 'ctx_encoder')))
|
||||
self.qry_encoder = Wrapper(qry_encoder)
|
||||
self.ctx_encoder = Wrapper(ctx_encoder)
|
||||
self.loss_fct = nn.CrossEntropyLoss()
|
||||
|
||||
@staticmethod
|
||||
def encode(model, input_ids, attention_mask, gck_segment=32):
|
||||
dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
|
||||
pooled_output = []
|
||||
for mini_batch in range(0, input_ids.shape[0], gck_segment):
|
||||
mini_batch_input_ids = input_ids[mini_batch:mini_batch
|
||||
+ gck_segment]
|
||||
mini_batch_attention_mask = attention_mask[mini_batch:mini_batch
|
||||
+ gck_segment]
|
||||
mini_batch_pooled_output = checkpoint(model, mini_batch_input_ids,
|
||||
mini_batch_attention_mask,
|
||||
dummy_tensor)
|
||||
pooled_output.append(mini_batch_pooled_output)
|
||||
return torch.cat(pooled_output, dim=0)
|
||||
|
||||
def forward(self,
|
||||
query_input_ids,
|
||||
query_attention_mask,
|
||||
context_input_ids,
|
||||
context_attention_mask,
|
||||
labels,
|
||||
gck_segment=32):
|
||||
query_vector = self.encode(self.qry_encoder, query_input_ids,
|
||||
query_attention_mask, gck_segment)
|
||||
context_vector = self.encode(self.ctx_encoder, context_input_ids,
|
||||
context_attention_mask, gck_segment)
|
||||
logits = torch.matmul(query_vector, context_vector.T)
|
||||
loss = self.loss_fct(logits, labels)
|
||||
return loss, logits
|
||||
|
||||
|
||||
class ClassifyRerank(nn.Module):
|
||||
|
||||
def __init__(self, model_dir):
|
||||
super().__init__()
|
||||
self.base_model = XLMRobertaForSequenceClassification.from_pretrained(
|
||||
model_dir)
|
||||
|
||||
def forward(self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
outputs = self.base_model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict)
|
||||
return outputs
|
||||
|
||||
|
||||
class Rerank(nn.Module):
|
||||
|
||||
def __init__(self, encoder, top_k):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.top_k = top_k
|
||||
|
||||
def forward(self, inputs):
|
||||
model = self.encoder
|
||||
logits = F.log_softmax(model(**inputs)[0], dim=-1)[:, 1]
|
||||
logits = logits.view(-1, self.top_k)
|
||||
logprobs = F.log_softmax(logits, dim=-1)
|
||||
return logprobs
|
||||
|
||||
|
||||
class Re2GModel(nn.Module):
|
||||
|
||||
def __init__(self, model_dir, config):
|
||||
super(Re2GModel, self).__init__()
|
||||
self.config = config
|
||||
self.top_k = self.config['top_k']
|
||||
encoder = XLMRobertaForSequenceClassification(
|
||||
config=AutoConfig.from_pretrained(
|
||||
os.path.join(model_dir, 'rerank')))
|
||||
generator = MT5ForConditionalGeneration(
|
||||
config=AutoConfig.from_pretrained(
|
||||
os.path.join(model_dir, 'generation')))
|
||||
|
||||
self.rerank = Rerank(encoder, self.top_k)
|
||||
|
||||
dpr_config = DPRConfig()
|
||||
dpr_config.vocab_size = encoder.config.vocab_size
|
||||
rag_model = RagTokenForGeneration(
|
||||
question_encoder=DPRQuestionEncoder(dpr_config),
|
||||
generator=generator)
|
||||
rag_model.rag.question_encoder = None
|
||||
self.generator = rag_model
|
||||
|
||||
def forward(self, rerank_input_ids, input_ids, attention_mask, label_ids):
|
||||
doc_scores = self.rerank(rerank_input_ids)
|
||||
|
||||
outputs = self.generator(
|
||||
labels=label_ids,
|
||||
context_input_ids=input_ids,
|
||||
context_attention_mask=attention_mask,
|
||||
doc_scores=doc_scores,
|
||||
n_docs=self.top_k)
|
||||
return outputs
|
||||
|
||||
def generate(self, rerank_input_ids, input_ids, attention_mask):
|
||||
doc_scores = self.rerank(rerank_input_ids)
|
||||
|
||||
beam_search_output = self.generator.generate(
|
||||
n_docs=self.top_k,
|
||||
encoder_input_ids=input_ids,
|
||||
context_input_ids=input_ids,
|
||||
context_attention_mask=attention_mask,
|
||||
doc_scores=doc_scores,
|
||||
num_beams=self.config['num_beams'],
|
||||
max_length=self.config['target_sequence_length'],
|
||||
early_stopping=True,
|
||||
no_repeat_ngram_size=self.config['no_repeat_ngram_size'],
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True)
|
||||
generated_ids = beam_search_output.detach().cpu().numpy()
|
||||
|
||||
return generated_ids
|
||||
@@ -0,0 +1,45 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor, TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .backbone import Re2GModel
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.document_grounded_dialog_generate, module_name=Models.doc2bot)
|
||||
class DocumentGroundedDialogGenerateModel(TorchModel):
|
||||
_backbone_prefix = ''
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.config = Config.from_file(
|
||||
os.path.join(self.model_dir, ModelFile.CONFIGURATION))
|
||||
self.model = Re2GModel(model_dir, self.config)
|
||||
state_dict = torch.load(
|
||||
os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
|
||||
map_location='cpu')
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]):
|
||||
rerank_input_ids = input['rerank_input_ids']
|
||||
input_ids = input['input_ids']
|
||||
attention_mask = input['attention_mask']
|
||||
label_ids = input['label_ids']
|
||||
|
||||
outputs = self.model(rerank_input_ids, input_ids, attention_mask,
|
||||
label_ids)
|
||||
return outputs
|
||||
|
||||
def generate(self, input: Dict[str, Tensor]):
|
||||
rerank_input_ids = input['rerank_input_ids']
|
||||
input_ids = input['input_ids']
|
||||
attention_mask = input['attention_mask']
|
||||
outputs = self.model.generate(rerank_input_ids, input_ids,
|
||||
attention_mask)
|
||||
return outputs
|
||||
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model, Tensor, TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .backbone import ClassifyRerank
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.document_grounded_dialog_rerank, module_name=Models.doc2bot)
|
||||
class DocumentGroundedDialogRerankModel(TorchModel):
|
||||
_backbone_prefix = ''
|
||||
|
||||
def __init__(self, model_dir, **kwargs):
|
||||
super().__init__(model_dir, **kwargs)
|
||||
self.model = ClassifyRerank(model_dir)
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]):
|
||||
outputs = self.model(
|
||||
input_ids=input['input_ids'],
|
||||
attention_mask=input['attention_mask'])
|
||||
return outputs
|
||||
|
||||
def resize_token_embeddings(self, size):
|
||||
self.model.base_model.resize_token_embeddings(size)
|
||||
|
||||
def save_pretrained(self, addr):
|
||||
self.model.base_model.save_pretrained(addr)
|
||||
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor, TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .backbone import DPRModel
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.document_grounded_dialog_retrieval, module_name=Models.doc2bot)
|
||||
class DocumentGroundedDialogRetrievalModel(TorchModel):
|
||||
_backbone_prefix = ''
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.config = Config.from_file(
|
||||
os.path.join(self.model_dir, ModelFile.CONFIGURATION))
|
||||
self.model = DPRModel(model_dir, self.config)
|
||||
state_dict = torch.load(
|
||||
os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
|
||||
map_location='cpu')
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
def forward(self, input: Dict[str, Tensor], gck_segment=32):
|
||||
query_input_ids = input['query_input_ids']
|
||||
query_attention_mask = input['query_attention_mask']
|
||||
context_input_ids = input['context_input_ids']
|
||||
context_attention_mask = input['context_attention_mask']
|
||||
labels = input['labels']
|
||||
outputs = self.model(query_input_ids, query_attention_mask,
|
||||
context_input_ids, context_attention_mask, labels,
|
||||
gck_segment)
|
||||
return outputs
|
||||
|
||||
def encode_query(self, input: Dict[str, Tensor]):
|
||||
query_input_ids = input['query_input_ids']
|
||||
query_attention_mask = input['query_attention_mask']
|
||||
query_vector = self.model.qry_encoder(query_input_ids,
|
||||
query_attention_mask, None)
|
||||
return query_vector
|
||||
|
||||
def encode_context(self, input: Dict[str, Tensor]):
|
||||
context_input_ids = input['context_input_ids']
|
||||
context_attention_mask = input['context_attention_mask']
|
||||
context_vector = self.model.ctx_encoder(context_input_ids,
|
||||
context_attention_mask, None)
|
||||
return context_vector
|
||||
@@ -1068,6 +1068,9 @@ TASK_OUTPUTS = {
|
||||
# "labels": ["dog", "horse", "cow", "cat"],
|
||||
# }
|
||||
Tasks.vision_efficient_tuning: [OutputKeys.SCORES, OutputKeys.LABELS],
|
||||
Tasks.document_grounded_dialog_generate: [OutputKeys.TEXT],
|
||||
Tasks.document_grounded_dialog_rerank: [OutputKeys.OUTPUT],
|
||||
Tasks.document_grounded_dialog_retrieval: [OutputKeys.OUTPUT],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -234,6 +234,19 @@ TASK_INPUTS = {
|
||||
'text': InputType.TEXT,
|
||||
'database': InputType.TEXT
|
||||
},
|
||||
Tasks.document_grounded_dialog_generate: {
|
||||
'query': InputType.LIST,
|
||||
'context': InputType.LIST,
|
||||
'label': InputType.LIST,
|
||||
},
|
||||
Tasks.document_grounded_dialog_rerank: {
|
||||
'dataset': InputType.LIST
|
||||
},
|
||||
Tasks.document_grounded_dialog_retrieval: {
|
||||
'query': InputType.LIST,
|
||||
'positive': InputType.LIST,
|
||||
'negative': InputType.LIST
|
||||
},
|
||||
|
||||
# ============ audio tasks ===================
|
||||
Tasks.auto_speech_recognition:
|
||||
|
||||
@@ -36,6 +36,9 @@ if TYPE_CHECKING:
|
||||
from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline
|
||||
from .translation_evaluation_pipeline import TranslationEvaluationPipeline
|
||||
from .user_satisfaction_estimation_pipeline import UserSatisfactionEstimationPipeline
|
||||
from .document_grounded_dialog_generate_pipeline import DocumentGroundedDialogGeneratePipeline
|
||||
from .document_grounded_dialog_retrieval_pipeline import DocumentGroundedDialogRetrievalPipeline
|
||||
from .document_grounded_dialog_rerank_pipeline import DocumentGroundedDialogRerankPipeline
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -84,7 +87,16 @@ else:
|
||||
['CodeGeeXCodeGenerationPipeline'],
|
||||
'translation_evaluation_pipeline': ['TranslationEvaluationPipeline'],
|
||||
'user_satisfaction_estimation_pipeline':
|
||||
['UserSatisfactionEstimationPipeline']
|
||||
['UserSatisfactionEstimationPipeline'],
|
||||
'document_grounded_dialog_generate_pipeline': [
|
||||
'DocumentGroundedDialogGeneratePipeline'
|
||||
],
|
||||
'document_grounded_dialog_rerank_pipeline': [
|
||||
'DocumentGroundedDialogRerankPipeline'
|
||||
],
|
||||
'document_grounded_dialog_retrieval_pipeline': [
|
||||
'DocumentGroundedDialogRetrievalPipeline'
|
||||
]
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import DocumentGroundedDialogGenerateModel
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import DocumentGroundedDialogGeneratePreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
__all__ = ['DocumentGroundedDialogGeneratePipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.document_grounded_dialog_generate,
|
||||
module_name=Pipelines.document_grounded_dialog_generate)
|
||||
class DocumentGroundedDialogGeneratePipeline(Pipeline):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[DocumentGroundedDialogGenerateModel, str],
|
||||
preprocessor: DocumentGroundedDialogGeneratePreprocessor = None,
|
||||
config_file: str = None,
|
||||
device: str = 'gpu',
|
||||
auto_collate=True,
|
||||
**kwargs):
|
||||
"""The Generate pipeline for document grounded dialog
|
||||
|
||||
Args:
|
||||
model: A model instance or a model local dir or a model id in the model hub.
|
||||
preprocessor: A preprocessor instance.
|
||||
config_file: Path to config file.
|
||||
device: Device to run the model.
|
||||
auto_collate: Apply auto collate.
|
||||
**kwargs: The preprocessor kwargs passed into the preprocessor's constructor.
|
||||
|
||||
Examples:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> pipe_ins = pipeline('document-grounded-dialog-generate', model='damo/nlp_convai_generate')
|
||||
"""
|
||||
super().__init__(
|
||||
model=model,
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
|
||||
if preprocessor is None:
|
||||
self.preprocessor = DocumentGroundedDialogGeneratePreprocessor(
|
||||
self.model.model_dir, **kwargs)
|
||||
|
||||
def forward(self, inputs: Union[list, Dict[str, Any]],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
return {'generated_ids': self.model.generate(inputs)}
|
||||
|
||||
def postprocess(self, inputs: Union[list, Dict[str, Any]],
|
||||
**postprocess_params) -> Dict[str, Any]:
|
||||
predictions = self.preprocessor.generation_tokenizer.batch_decode(
|
||||
inputs['generated_ids'],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)
|
||||
return {OutputKeys.TEXT: predictions}
|
||||
|
||||
def _collate_fn(self, data):
|
||||
return data
|
||||
@@ -0,0 +1,754 @@
|
||||
import os
|
||||
import pprint
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import Any, Dict, Iterable, List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
import ujson as json
|
||||
from torch import nn
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import DocumentGroundedDialogRerankModel
|
||||
from modelscope.models.nlp.ponet.configuration import PoNetConfig
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline, Tensor
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import DocumentGroundedDialogRerankPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['DocumentGroundedDialogRerankPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.document_grounded_dialog_rerank,
|
||||
module_name=Pipelines.document_grounded_dialog_rerank)
|
||||
class DocumentGroundedDialogRerankPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[DocumentGroundedDialogRerankModel, str],
|
||||
preprocessor: DocumentGroundedDialogRerankPreprocessor = None,
|
||||
config_file: str = None,
|
||||
device: str = 'cuda',
|
||||
auto_collate=True,
|
||||
seed: int = 88,
|
||||
**kwarg):
|
||||
"""The Rerank pipeline for document grounded dialog
|
||||
|
||||
Args:
|
||||
model: A model instance or a model local dir or a model id in the model hub.
|
||||
preprocessor: A preprocessor instance.
|
||||
config_file: Path to config file.
|
||||
device: Device to run the model.
|
||||
auto_collate: Apply auto collate.
|
||||
seed: Random seeds of random parameters.
|
||||
**kwargs: The preprocessor kwargs passed into the preprocessor's constructor.
|
||||
|
||||
Examples:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> pipe_ins = pipeline('document_grounded_dialog_rerank', model='damo/nlp_convai_rerank')
|
||||
"""
|
||||
super().__init__(
|
||||
model=model,
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate,
|
||||
seed=seed)
|
||||
self.model = model
|
||||
self.preprocessor = preprocessor
|
||||
self.device = device
|
||||
if kwarg['model_resize']:
|
||||
self.model.resize_token_embeddings(
|
||||
len(self.preprocessor.tokenizer))
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
self.args = kwarg
|
||||
# self.model_cfg = self.model.model_cfg
|
||||
set_seed(seed)
|
||||
|
||||
def one_instance(self, input_ids, attention_mask):
|
||||
all_probs = []
|
||||
for start_ndx in range(0, len(input_ids), self.args['max_batch_size']):
|
||||
probs = F.softmax(
|
||||
self.model({
|
||||
'input_ids':
|
||||
input_ids[start_ndx:start_ndx
|
||||
+ self.args['max_batch_size']],
|
||||
'attention_mask':
|
||||
attention_mask[start_ndx:start_ndx
|
||||
+ self.args['max_batch_size']]
|
||||
}).logits.detach().cpu(),
|
||||
dim=-1)[:, 1].numpy().tolist()
|
||||
all_probs.extend(probs)
|
||||
return all_probs
|
||||
|
||||
def forward(self, dataset: Union[list, Dict[str, Any]],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
report = Reporting()
|
||||
self.guess = []
|
||||
with torch.no_grad():
|
||||
for jobj in dataset:
|
||||
inst_id = jobj['id']
|
||||
probs = self.one_instance(jobj['input_ids'],
|
||||
jobj['attention_mask'])
|
||||
passages = jobj['passages']
|
||||
query = jobj['query']
|
||||
scored_pids = [(p['pid'], prob)
|
||||
for p, prob in zip(passages, probs)]
|
||||
scored_pids.sort(key=lambda x: x[1], reverse=True)
|
||||
wids = to_distinct_doc_ids([
|
||||
pid for pid, prob in scored_pids
|
||||
]) # convert to Wikipedia document ids
|
||||
pred_record = {
|
||||
'id':
|
||||
inst_id,
|
||||
'input':
|
||||
query,
|
||||
'scored_pids':
|
||||
scored_pids,
|
||||
'output': [{
|
||||
'answer':
|
||||
'',
|
||||
'provenance': [{
|
||||
'wikipedia_id': wid
|
||||
} for wid in wids]
|
||||
}]
|
||||
}
|
||||
if self.args['include_passages']:
|
||||
pred_record['passages'] = passages
|
||||
|
||||
if report.is_time():
|
||||
print(
|
||||
f'Finished {report.check_count}; {report.check_count / report.elapsed_seconds()} per second.'
|
||||
)
|
||||
self.guess.append(pred_record)
|
||||
# if args['kilt_data']:
|
||||
# evaluate(dataset, args['output'])
|
||||
|
||||
def postprocess(self, inputs: list):
|
||||
return {OutputKeys.OUTPUT: inputs}
|
||||
|
||||
|
||||
class Reporting:
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
recency_weight=0.001,
|
||||
report_interval_secs=300,
|
||||
check_every=1,
|
||||
gather_samples: Iterable = (),
|
||||
num_samples=10000):
|
||||
"""The Reporting to print parameter status
|
||||
|
||||
Args:
|
||||
recency_weight: when computing the moving average, how much weight to give to the current sample.
|
||||
report_interval_secs: how many seconds between returning true for is_time.
|
||||
check_every: how often to check the time, when calling is_time.
|
||||
gather_samples: keep the last num_samples of the listed names (gathered from moving_averages).
|
||||
num_samples: how many samples to keep.
|
||||
"""
|
||||
self.check_count = 0
|
||||
self.check_every = check_every
|
||||
self.start_time = time.time()
|
||||
self.last_time = self.start_time
|
||||
self.report_interval_secs = report_interval_secs
|
||||
# For tracking moving averages of various values
|
||||
self.names = None
|
||||
self.averages = None
|
||||
self.counts = None
|
||||
self.recency_weight = recency_weight
|
||||
self.per_value_recency_weight = dict()
|
||||
self.report_count = 0
|
||||
self._prev_check_count = 0
|
||||
self.sample_names = list(gather_samples)
|
||||
if len(self.sample_names) > 0:
|
||||
self.sample_values = np.zeros(
|
||||
(len(self.sample_names), num_samples), dtype=np.float32)
|
||||
self.sample_ndxs = np.zeros(len(self.sample_names), dtype=np.int32)
|
||||
else:
|
||||
self.sample_values = None
|
||||
self.sample_ndxs = None
|
||||
|
||||
def reset(self):
|
||||
self.check_count = 0
|
||||
self.start_time = time.time()
|
||||
self.last_time = self.start_time
|
||||
self.report_count = 0
|
||||
self._prev_check_count = 0
|
||||
if len(self.sample_names) > 0:
|
||||
self.sample_values[:, :] = 0
|
||||
self.sample_ndxs[:] = 0
|
||||
if self.counts is not None:
|
||||
self.counts[:] = 0
|
||||
self.averages[:] = 0
|
||||
|
||||
def is_time(self):
|
||||
self.check_count += 1
|
||||
if self.check_count % self.check_every == 0:
|
||||
elapsed = time.time() - self.last_time
|
||||
if elapsed >= self.report_interval_secs:
|
||||
# check the time more or less often
|
||||
if self.check_every > 1 and self.check_count - self._prev_check_count < 5 * self.check_every:
|
||||
self.check_every //= 2
|
||||
elif self.check_count - self._prev_check_count > 50 * self.check_every:
|
||||
self.check_every *= 2
|
||||
self.last_time = time.time()
|
||||
self.report_count += 1
|
||||
self._prev_check_count = self.check_count
|
||||
return True
|
||||
return False
|
||||
|
||||
def moving_averages(self, **values):
|
||||
# create entries in avgs and counts when needed
|
||||
# update the avgs and counts
|
||||
if self.names is None:
|
||||
self.names = list(values.keys())
|
||||
self.averages = np.zeros(len(self.names))
|
||||
self.counts = np.zeros(len(self.names))
|
||||
for name in values.keys():
|
||||
if name not in self.names:
|
||||
self.names.append(name)
|
||||
if self.averages.shape[0] < len(self.names):
|
||||
old_len = self.averages.shape[0]
|
||||
self.averages = np.resize(self.averages, len(self.names))
|
||||
self.averages[old_len:] = 0
|
||||
self.counts = np.resize(self.counts, len(self.names))
|
||||
self.counts[old_len:] = 0
|
||||
for ndx, name in enumerate(self.names):
|
||||
if name in values:
|
||||
self.counts[ndx] += 1
|
||||
# support per-name recency_weight
|
||||
if name in self.per_value_recency_weight:
|
||||
rweight = max(self.per_value_recency_weight[name],
|
||||
1.0 / self.counts[ndx])
|
||||
else:
|
||||
rweight = max(self.recency_weight, 1.0 / self.counts[ndx])
|
||||
self.averages[ndx] = rweight * values[name] + (
|
||||
1.0 - rweight) * self.averages[ndx]
|
||||
for ndx, name in enumerate(self.sample_names):
|
||||
if name in values:
|
||||
self.sample_values[self.sample_ndxs[ndx]] = values[name]
|
||||
self.sample_ndxs[ndx] = (self.sample_ndxs[ndx]
|
||||
+ 1) % self.sample_values.shape[1]
|
||||
|
||||
def get_samples(self, name):
|
||||
for ndx, n in enumerate(self.sample_names):
|
||||
if n == name:
|
||||
count = self.get_count(name)
|
||||
if count is None:
|
||||
count = 0
|
||||
return self.sample_values[ndx, 0:count] # NOTE: not in order
|
||||
return None
|
||||
|
||||
def get_moving_average(self, name):
|
||||
if self.names is None:
|
||||
return None
|
||||
for ndx, n in enumerate(self.names):
|
||||
if n == name:
|
||||
return self.averages[ndx]
|
||||
return None
|
||||
|
||||
def get_count(self, name):
|
||||
if self.names is None:
|
||||
return None
|
||||
for ndx, n in enumerate(self.names):
|
||||
if n == name:
|
||||
return self.counts[ndx]
|
||||
return None
|
||||
|
||||
def elapsed_seconds(self) -> float:
|
||||
return time.time() - self.start_time
|
||||
|
||||
def elapsed_time_str(self) -> str:
|
||||
return time_str(self.elapsed_seconds())
|
||||
|
||||
def progress_str(self, instance_name='instance'):
|
||||
return f'On {instance_name} {self.check_count}, ' \
|
||||
f'{self.check_count / self.elapsed_seconds()} {instance_name}s per second.'
|
||||
|
||||
def display(self, *, prefix=''):
|
||||
# display the moving averages
|
||||
logger.info('==========================================')
|
||||
if self.names is not None:
|
||||
for n, v in zip(self.names, self.averages):
|
||||
logger.info(f'{prefix}{n} = {v}')
|
||||
|
||||
def display_warn(self, *, prefix=''):
|
||||
# display the moving averages
|
||||
logger.info('==========================================')
|
||||
if self.names is not None:
|
||||
for n, v in zip(self.names, self.averages):
|
||||
logger.warning(f'{prefix}{n} = {v}')
|
||||
|
||||
|
||||
def _remove_duplicates(obj):
|
||||
obj_tmp = []
|
||||
for o in obj:
|
||||
if o not in obj_tmp:
|
||||
obj_tmp.append(o)
|
||||
return obj_tmp
|
||||
|
||||
|
||||
def _get_ids_list(datapoint, rank_keys, verbose=False):
|
||||
# collect all gold ids
|
||||
ids_list = []
|
||||
for output in datapoint['output']:
|
||||
current_ids_list = []
|
||||
if 'provenance' in output:
|
||||
for provenance in output['provenance']:
|
||||
if any(rank_key not in provenance for rank_key in rank_keys):
|
||||
missing = set(rank_keys) - set(list(
|
||||
provenance.keys())).intersection(set(rank_keys))
|
||||
if verbose:
|
||||
print(
|
||||
f'WARNING: missing key(s) {missing} in provenance, unable to compute retrieval for those.'
|
||||
)
|
||||
else:
|
||||
current_ids_list.append('+'.join([
|
||||
str(provenance[rank_key]).strip()
|
||||
for rank_key in rank_keys
|
||||
]))
|
||||
ids_list.append(
|
||||
_remove_duplicates(current_ids_list)) # remove duplicates
|
||||
|
||||
# consider only unique ids
|
||||
return ids_list
|
||||
|
||||
|
||||
def _computeRprec(guess_ids, gold_ids):
|
||||
R = len(gold_ids)
|
||||
num = 0
|
||||
|
||||
for prediction in guess_ids[:R]:
|
||||
if str(prediction).strip() in gold_ids:
|
||||
num += 1
|
||||
|
||||
Rprec = num / R if R > 0 else 0
|
||||
return Rprec
|
||||
|
||||
|
||||
# 1. Precision computation
|
||||
def _precision_at_k(rank, k):
|
||||
# precision @ k
|
||||
p = rank[:k].count(True) / k
|
||||
|
||||
return p
|
||||
|
||||
|
||||
# 2. Recall computation
|
||||
def _recall_at_k(rank, num_distinct_evidence_sets, k):
|
||||
r = rank[:k].count(True) / num_distinct_evidence_sets
|
||||
|
||||
return r
|
||||
|
||||
|
||||
# 3. Success rate computation
|
||||
def _success_rate_at_k(rank, k):
|
||||
# success rate @ k
|
||||
p = int(True in rank[:k])
|
||||
|
||||
return p
|
||||
|
||||
|
||||
def get_rank(guess_item, gold_item, k, rank_keys, verbose=False):
|
||||
"""
|
||||
The main idea is to consider each evidence set as a single point in the rank.
|
||||
The score in the rank for an evidence set is given by the lowest scored evidence in the set.
|
||||
"""
|
||||
|
||||
assert k > 0, 'k must be a positive integer grater than 0.'
|
||||
|
||||
rank = []
|
||||
num_distinct_evidence_sets = 0
|
||||
|
||||
guess_ids = _get_ids_list(guess_item, rank_keys)[0]
|
||||
|
||||
if guess_ids and len(guess_ids) > 0:
|
||||
|
||||
# 1. collect evidence sets and their sizes
|
||||
evidence_sets = []
|
||||
e_size = defaultdict(int)
|
||||
for output in gold_item['output']:
|
||||
if 'provenance' in output:
|
||||
e_set = {
|
||||
'+'.join([
|
||||
str(provenance[rank_key]).strip()
|
||||
for rank_key in rank_keys
|
||||
])
|
||||
for provenance in output['provenance']
|
||||
}
|
||||
if e_set not in evidence_sets: # no duplicate evidence set
|
||||
evidence_sets.append(e_set)
|
||||
e_size[len(e_set)] += 1
|
||||
num_distinct_evidence_sets = len(evidence_sets)
|
||||
|
||||
# 2. check what's the minimum number of predicted pages needed to get a robust P/R@k
|
||||
min_prediction_size = 0
|
||||
c = 0
|
||||
for size, freq in sorted(e_size.items(), reverse=True):
|
||||
for _ in range(freq):
|
||||
min_prediction_size += size
|
||||
c += 1
|
||||
if c == k:
|
||||
break
|
||||
if c == k:
|
||||
break
|
||||
# if the number of evidence sets is smaller than k
|
||||
min_prediction_size += k - c
|
||||
|
||||
if verbose and len(guess_ids) < min_prediction_size:
|
||||
print(
|
||||
f'WARNING: you should provide at least {min_prediction_size} provenance items '
|
||||
f'for a robust recall@{k} computation (you provided {len(guess_ids)} item(s)).'
|
||||
)
|
||||
|
||||
# 3. rank by gruping pages in each evidence set (each evidence set count as 1),
|
||||
# the position in the rank of each evidence set is given by the last page in guess_ids
|
||||
# non evidence pages counts as 1
|
||||
rank = []
|
||||
for guess_id in guess_ids:
|
||||
guess_id = str(guess_id).strip()
|
||||
found = False
|
||||
for idx, e_set in enumerate(evidence_sets):
|
||||
|
||||
e_set_id = f'evidence_set:{idx}'
|
||||
|
||||
if guess_id in e_set:
|
||||
found = True
|
||||
|
||||
# remove from the rank previous points referring to this evidence set
|
||||
if e_set_id in rank:
|
||||
rank.remove(e_set_id)
|
||||
|
||||
# remove the guess_id from the evidence set
|
||||
e_set.remove(guess_id)
|
||||
|
||||
if len(e_set) == 0:
|
||||
# it was the last evidence, it counts as true in the rank
|
||||
rank.append(True)
|
||||
else:
|
||||
# add a point for this partial evidence set
|
||||
rank.append(e_set_id)
|
||||
|
||||
if not found:
|
||||
rank.append(False)
|
||||
|
||||
return rank, num_distinct_evidence_sets
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def load_data(filename):
|
||||
data = []
|
||||
file_in = open(filename, 'r')
|
||||
lines = file_in.readlines()
|
||||
for line in lines:
|
||||
data.append(json.loads(line))
|
||||
return data
|
||||
|
||||
|
||||
def rprecision(guess_item, gold_item, rank_keys):
|
||||
gold_ids_list = _get_ids_list(gold_item, rank_keys)
|
||||
guess_ids = _get_ids_list(guess_item, rank_keys)[0]
|
||||
Rprec_vector = []
|
||||
for gold_ids in gold_ids_list:
|
||||
Rprec = _computeRprec(guess_ids, gold_ids)
|
||||
Rprec_vector.append(Rprec)
|
||||
return max(Rprec_vector)
|
||||
|
||||
|
||||
def get_ranking_metrics(guess_item, gold_item, ks, rank_keys):
|
||||
Rprec = 0
|
||||
P_at_k = {'precision@{}'.format(k): 0 for k in sorted(ks) if k > 0}
|
||||
R_at_k = {'recall@{}'.format(k): 0 for k in sorted(ks) if k > 1}
|
||||
S_at_k = {'success_rate@{}'.format(k): 0 for k in sorted(ks) if k > 1}
|
||||
|
||||
assert (
|
||||
'output' in guess_item and len(guess_item['output']) == 1
|
||||
), f"guess should provide exactly one output for {guess_item['id']}"
|
||||
|
||||
Rprec = rprecision(guess_item, gold_item, rank_keys=rank_keys)
|
||||
for k in ks:
|
||||
|
||||
# 0. get rank
|
||||
rank, num_distinct_evidence_sets = get_rank(
|
||||
guess_item, gold_item, k, rank_keys=rank_keys)
|
||||
|
||||
if num_distinct_evidence_sets > 0:
|
||||
# 1. precision
|
||||
P_at_k['precision@{}'.format(k)] = _precision_at_k(rank, k)
|
||||
|
||||
# 2. recall
|
||||
R_at_k['recall@{}'.format(k)] = _recall_at_k(
|
||||
rank, num_distinct_evidence_sets, k)
|
||||
|
||||
# 3. success rate
|
||||
S_at_k['success_rate@{}'.format(k)] = _success_rate_at_k(rank, k)
|
||||
|
||||
# else:
|
||||
# print(
|
||||
# "WARNING: the number of distinct evidence sets is 0 for {}".format(
|
||||
# gold_item
|
||||
# )
|
||||
# )
|
||||
|
||||
return {'Rprec': Rprec, **P_at_k, **R_at_k, **S_at_k}
|
||||
|
||||
|
||||
def compute(gold_dataset, guess_dataset, ks, rank_keys):
|
||||
ks = sorted([int(x) for x in ks])
|
||||
|
||||
result = OrderedDict()
|
||||
result['Rprec'] = 0.0
|
||||
for k in ks:
|
||||
if k > 0:
|
||||
result['precision@{}'.format(k)] = 0.0
|
||||
# if k > 1:
|
||||
result['recall@{}'.format(k)] = 0.0
|
||||
result['success_rate@{}'.format(k)] = 0.0
|
||||
|
||||
assert len(guess_dataset) == len(
|
||||
gold_dataset), 'different size gold: {} guess: {}'.format(
|
||||
len(guess_dataset), len(gold_dataset))
|
||||
|
||||
for gold, guess in zip(guess_dataset, gold_dataset):
|
||||
assert (str(gold['id']).strip() == str(
|
||||
guess['id']).strip()), 'Items must have same order with same IDs'
|
||||
|
||||
for guess_item, gold_item in zip(guess_dataset, gold_dataset):
|
||||
ranking_metrics = get_ranking_metrics(guess_item, gold_item, ks,
|
||||
rank_keys)
|
||||
result['Rprec'] += ranking_metrics['Rprec']
|
||||
for k in ks:
|
||||
if k > 0:
|
||||
result['precision@{}'.format(k)] += ranking_metrics[
|
||||
'precision@{}'.format(k)]
|
||||
result['recall@{}'.format(k)] += ranking_metrics[
|
||||
'recall@{}'.format(k)]
|
||||
result['success_rate@{}'.format(k)] += ranking_metrics[
|
||||
'success_rate@{}'.format(k)]
|
||||
|
||||
if len(guess_dataset) > 0:
|
||||
result['Rprec'] /= len(guess_dataset)
|
||||
for k in ks:
|
||||
if k > 0:
|
||||
result['precision@{}'.format(k)] /= len(guess_dataset)
|
||||
# if k > 1:
|
||||
result['recall@{}'.format(k)] /= len(guess_dataset)
|
||||
result['success_rate@{}'.format(k)] /= len(guess_dataset)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def to_distinct_doc_ids(passage_ids):
|
||||
doc_ids = []
|
||||
for pid in passage_ids:
|
||||
# MARK
|
||||
doc_id = pid
|
||||
if doc_id not in doc_ids:
|
||||
doc_ids.append(doc_id)
|
||||
return doc_ids
|
||||
|
||||
|
||||
def validate_input(gold_records, guess_records):
|
||||
if len(gold_records) != len(guess_records):
|
||||
print('WARNING: DIFFERENT SIZE gold: {} guess: {}'.format(
|
||||
len(gold_records), len(guess_records)))
|
||||
|
||||
# align order
|
||||
gold_ids = []
|
||||
for gold in gold_records:
|
||||
assert str(
|
||||
gold['id']).strip() not in gold_ids, 'Gold IDs should be unique'
|
||||
gold_ids.append(str(gold['id']).strip())
|
||||
|
||||
id2guess_record = {}
|
||||
for guess in guess_records:
|
||||
assert (str(guess['id']).strip()
|
||||
not in id2guess_record), 'Prediction IDs should be unique'
|
||||
id2guess_record[str(guess['id']).strip()] = guess
|
||||
|
||||
guess_records = []
|
||||
for id in gold_ids:
|
||||
if id in id2guess_record:
|
||||
guess_records.append(id2guess_record[id])
|
||||
else:
|
||||
raise ValueError(
|
||||
'ERROR: no prediction provided for id: {}'.format(id))
|
||||
|
||||
return gold_records, guess_records
|
||||
|
||||
|
||||
# utility to get gold answers
|
||||
def get_gold_answers(gold):
|
||||
ground_truths = set()
|
||||
for item in gold['output']:
|
||||
if 'answer' in item and item['answer'] and len(
|
||||
item['answer'].strip()) > 0:
|
||||
ground_truths.add(item['answer'].strip())
|
||||
return ground_truths
|
||||
|
||||
|
||||
# utility to get max
|
||||
def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
scores_for_ground_truths = []
|
||||
for ground_truth in ground_truths:
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
if scores_for_ground_truths:
|
||||
return max(scores_for_ground_truths)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def _calculate_metrics(gold_records, guess_records):
|
||||
assert len(gold_records) == len(
|
||||
guess_records), 'different size gold: {} guess: {}'.format(
|
||||
len(gold_records), len(guess_records))
|
||||
|
||||
total_count = 0
|
||||
|
||||
# downstream metrics
|
||||
accuracy = 0
|
||||
normalized_em = 0
|
||||
normalized_f1 = 0
|
||||
rougel = 0
|
||||
|
||||
# kilt metrics
|
||||
kilt_accuracy = 0
|
||||
kilt_em = 0
|
||||
kilt_f1 = 0
|
||||
kilt_rougel = 0
|
||||
|
||||
for guess_item, gold_item in zip(guess_records, gold_records):
|
||||
|
||||
# check ids
|
||||
assert (str(gold_item['id']).strip() == str(guess_item['id']).strip()
|
||||
), 'Items must have same order with same IDs'
|
||||
|
||||
total_count += 1
|
||||
# check if each output of guess file exist in set of candidate answers
|
||||
gold_candidate_answers = get_gold_answers(gold_item)
|
||||
|
||||
conditions = (len(guess_item['output'])
|
||||
== 1) and ('answer' in guess_item['output'][0])
|
||||
assert (
|
||||
conditions
|
||||
), f"you should provide exactly one valid answer for {guess_item['id']}"
|
||||
guess_answer = str(guess_item['output'][0]['answer']).strip()
|
||||
|
||||
if len(guess_answer) == 0:
|
||||
# empty answer
|
||||
continue
|
||||
|
||||
# 0. accuracy = strict exact match
|
||||
local_accuracy = 0
|
||||
if guess_answer in gold_candidate_answers:
|
||||
local_accuracy = 1
|
||||
accuracy += local_accuracy
|
||||
|
||||
# 1. normalized exact match
|
||||
local_em = _metric_max_over_ground_truths(_exact_match_score,
|
||||
guess_answer,
|
||||
gold_candidate_answers)
|
||||
normalized_em += local_em
|
||||
|
||||
# 2. normalized f1
|
||||
local_f1 = _metric_max_over_ground_truths(_f1_score, guess_answer,
|
||||
gold_candidate_answers)
|
||||
normalized_f1 += local_f1
|
||||
|
||||
# 3. rougel
|
||||
local_rougel = _metric_max_over_ground_truths(_rougel_score,
|
||||
guess_answer,
|
||||
gold_candidate_answers)
|
||||
rougel += local_rougel
|
||||
|
||||
# KILT-metrics
|
||||
Rprec = rprecision(guess_item, gold_item, rank_keys=['wikipedia_id'])
|
||||
if Rprec == 1:
|
||||
# 1. KILT-AC
|
||||
kilt_accuracy += local_accuracy
|
||||
|
||||
# 2. KILT-EM
|
||||
kilt_em += local_em
|
||||
|
||||
# 3. KILT-F1
|
||||
kilt_f1 += local_f1
|
||||
|
||||
# 4. KILT-RL
|
||||
kilt_rougel += local_rougel
|
||||
|
||||
if total_count > 0:
|
||||
accuracy /= total_count
|
||||
normalized_em /= total_count
|
||||
normalized_f1 /= total_count
|
||||
rougel /= total_count
|
||||
kilt_accuracy /= total_count
|
||||
kilt_em /= total_count
|
||||
kilt_f1 /= total_count
|
||||
kilt_rougel /= total_count
|
||||
|
||||
return {
|
||||
'kilt': {
|
||||
'KILT-accuracy': kilt_accuracy,
|
||||
'KILT-em': kilt_em,
|
||||
'KILT-f1': kilt_f1,
|
||||
'KILT-rougel': kilt_rougel,
|
||||
},
|
||||
'downstream': {
|
||||
'accuracy': accuracy,
|
||||
'em': normalized_em,
|
||||
'f1': normalized_f1,
|
||||
'rougel': rougel,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def evaluate(gold, guess):
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
|
||||
gold_records = gold
|
||||
guess_records = load_data(guess)
|
||||
|
||||
# 0. validate input
|
||||
gold_records, guess_records = validate_input(gold_records, guess_records)
|
||||
|
||||
# 1. downstream + kilt
|
||||
result = _calculate_metrics(gold_records, guess_records)
|
||||
|
||||
# 2. retrieval performance
|
||||
retrieval_results = compute(
|
||||
gold_records,
|
||||
guess_records,
|
||||
ks=[1, 5, 10, 100],
|
||||
rank_keys=['wikipedia_id'])
|
||||
result['retrieval'] = {
|
||||
'Rprec': retrieval_results['Rprec'],
|
||||
'recall@1': retrieval_results['recall@1'],
|
||||
'recall@5': retrieval_results['recall@5'],
|
||||
'recall@10': retrieval_results['recall@10'],
|
||||
'recall@100': retrieval_results['recall@100'],
|
||||
}
|
||||
|
||||
pp.pprint(result)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,128 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import faiss
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import DocumentGroundedDialogRetrievalModel
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import \
|
||||
DocumentGroundedDialogRetrievalPreprocessor
|
||||
from modelscope.utils.constant import ModeKeys, Tasks
|
||||
|
||||
__all__ = ['DocumentGroundedDialogRetrievalPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.document_grounded_dialog_retrieval,
|
||||
module_name=Pipelines.document_grounded_dialog_retrieval)
|
||||
class DocumentGroundedDialogRetrievalPipeline(Pipeline):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[DocumentGroundedDialogRetrievalModel, str],
|
||||
preprocessor: DocumentGroundedDialogRetrievalPreprocessor = None,
|
||||
config_file: str = None,
|
||||
device: str = 'gpu',
|
||||
auto_collate=True,
|
||||
index_path: str = None,
|
||||
per_gpu_batch_size: int = 32,
|
||||
**kwargs):
|
||||
"""The Retrieval pipeline for document grounded dialog.
|
||||
Args:
|
||||
model: A model instance or a model local dir or a model id in the model hub.
|
||||
preprocessor: A preprocessor instance.
|
||||
config_file: Path to config file.
|
||||
device: Device to run the model.
|
||||
auto_collate: Apply auto collate.
|
||||
index_path: Index file path.
|
||||
per_gpu_batch_size: Batch size per GPU to run the code.
|
||||
**kwargs: The preprocessor kwargs passed into the preprocessor's constructor.
|
||||
|
||||
Examples:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> pipe_ins = pipeline('document-grounded-dialog-retrieval', model='damo/nlp_convai_retrieval')
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
model=model,
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
|
||||
if preprocessor is None:
|
||||
self.preprocessor = DocumentGroundedDialogRetrievalPreprocessor(
|
||||
self.model.model_dir, **kwargs)
|
||||
self.per_gpu_batch_size = per_gpu_batch_size
|
||||
self.passages_index = []
|
||||
self.passages = []
|
||||
self.index = None
|
||||
self.load_index(index_path)
|
||||
|
||||
def forward(self, inputs: Union[list, Dict[str, Any]],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
query_vector = self.model.encode_query(
|
||||
inputs).detach().cpu().numpy().astype('float32')
|
||||
D, Index = self.index.search(query_vector, 20)
|
||||
return {'retrieved_ids': Index.tolist()}
|
||||
|
||||
def postprocess(self, inputs: Union[list, Dict[str, Any]],
|
||||
**postprocess_params) -> Dict[str, Any]:
|
||||
predictions = [[self.passages[x] for x in retrieved_ids]
|
||||
for retrieved_ids in inputs['retrieved_ids']]
|
||||
return {OutputKeys.OUTPUT: predictions}
|
||||
|
||||
def _collate_fn(self, data):
|
||||
return data
|
||||
|
||||
def load_index(self, index_path: str = None):
|
||||
if not index_path:
|
||||
index_path = os.path.join(self.model.model_dir,
|
||||
'passages_index.json')
|
||||
with open(index_path) as f:
|
||||
passage_index = json.load(f)
|
||||
self.passages_index = passage_index
|
||||
self.passages = [x['passage'] for x in passage_index]
|
||||
all_ctx_vector = np.array([x['vector']
|
||||
for x in passage_index]).astype('float32')
|
||||
index = faiss.IndexFlatIP(all_ctx_vector.shape[-1])
|
||||
index.add(all_ctx_vector)
|
||||
self.index = index
|
||||
|
||||
def save_index(self, index_path: str = None):
|
||||
if not index_path:
|
||||
index_path = os.path.join(self.model.model_dir,
|
||||
'passages_index.json')
|
||||
with open(index_path, 'w') as f:
|
||||
json.dump(self.passage_index, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def add_passage(self, passages: List[str]):
|
||||
all_ctx_vector = []
|
||||
for mini_batch in range(0, len(passages), self.per_gpu_batch_size):
|
||||
context = passages[mini_batch:mini_batch + self.per_gpu_batch_size]
|
||||
processed = self.preprocessor({'context': context},
|
||||
invoke_mode=ModeKeys.INFERENCE,
|
||||
input_type='context')
|
||||
sub_ctx_vector = self.model.encode_context(
|
||||
processed).detach().cpu().numpy()
|
||||
all_ctx_vector.append(sub_ctx_vector)
|
||||
all_ctx_vector = np.concatenate(all_ctx_vector, axis=0)
|
||||
all_ctx_vector = np.array(all_ctx_vector).astype('float32')
|
||||
for passage, vector in zip(passages, all_ctx_vector):
|
||||
self.passages_index.append({
|
||||
'passage': passage,
|
||||
'vector': vector.tolist()
|
||||
})
|
||||
self.passages = [x['passage'] for x in self.passage_index]
|
||||
all_ctx_vector = np.array([x['vector'] for x in self.passage_index
|
||||
]).astype('float32')
|
||||
index = faiss.IndexFlatIP(all_ctx_vector.shape[-1])
|
||||
index.add(all_ctx_vector)
|
||||
self.index = index
|
||||
@@ -40,7 +40,10 @@ if TYPE_CHECKING:
|
||||
TableQuestionAnsweringPreprocessor, NERPreprocessorViet,
|
||||
NERPreprocessorThai, WordSegmentationPreprocessorThai,
|
||||
TranslationEvaluationPreprocessor,
|
||||
DialogueClassificationUsePreprocessor)
|
||||
DialogueClassificationUsePreprocessor,
|
||||
DocumentGroundedDialogGeneratePreprocessor,
|
||||
DocumentGroundedDialogRetrievalPreprocessor,
|
||||
DocumentGroundedDialogRerankPreprocessor)
|
||||
from .video import ReadVideoData, MovieSceneSegmentationPreprocessor
|
||||
|
||||
else:
|
||||
@@ -89,7 +92,10 @@ else:
|
||||
'ConversationalTextToSqlPreprocessor',
|
||||
'TableQuestionAnsweringPreprocessor',
|
||||
'TranslationEvaluationPreprocessor',
|
||||
'DialogueClassificationUsePreprocessor'
|
||||
'DialogueClassificationUsePreprocessor',
|
||||
'DocumentGroundedDialogGeneratePreprocessor',
|
||||
'DocumentGroundedDialogRetrievalPreprocessor',
|
||||
'DocumentGroundedDialogRerankPreprocessor'
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -105,6 +105,14 @@ PREPROCESSOR_MAP = {
|
||||
(Models.structbert, Tasks.word_segmentation):
|
||||
Preprocessors.token_cls_tokenizer,
|
||||
|
||||
# doc2bot
|
||||
(Models.doc2bot, Tasks.document_grounded_dialog_generate):
|
||||
Preprocessors.document_grounded_dialog_generate,
|
||||
(Models.doc2bot, Tasks.document_grounded_dialog_rerank):
|
||||
Preprocessors.document_grounded_dialog_rerank,
|
||||
(Models.doc2bot, Tasks.document_grounded_dialog_retrieval):
|
||||
Preprocessors.document_grounded_dialog_retrieval,
|
||||
|
||||
# veco
|
||||
(Models.veco, Tasks.backbone):
|
||||
Preprocessors.sen_cls_tokenizer,
|
||||
|
||||
@@ -30,6 +30,9 @@ if TYPE_CHECKING:
|
||||
from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor
|
||||
from .translation_evaluation_preprocessor import TranslationEvaluationPreprocessor
|
||||
from .dialog_classification_use_preprocessor import DialogueClassificationUsePreprocessor
|
||||
from .document_grounded_dialog_generate_preprocessor import DocumentGroundedDialogGeneratePreprocessor
|
||||
from .document_grounded_dialog_retrieval_preprocessor import DocumentGroundedDialogRetrievalPreprocessor
|
||||
from .document_grounded_dialog_retrieval_preprocessor import DocumentGroundedDialogRerankPreprocessor
|
||||
else:
|
||||
_import_structure = {
|
||||
'bert_seq_cls_tokenizer': ['Tokenize'],
|
||||
@@ -83,7 +86,13 @@ else:
|
||||
'translation_evaluation_preprocessor':
|
||||
['TranslationEvaluationPreprocessor'],
|
||||
'dialog_classification_use_preprocessor':
|
||||
['DialogueClassificationUsePreprocessor']
|
||||
['DialogueClassificationUsePreprocessor'],
|
||||
'document_grounded_dialog_generate_preprocessor':
|
||||
['DocumentGroundedDialogGeneratePreprocessor'],
|
||||
'document_grounded_dialog_retrieval_preprocessor':
|
||||
['DocumentGroundedDialogRetrievalPreprocessor'],
|
||||
'document_grounded_dialog_rerank_preprocessor':
|
||||
['DocumentGroundedDialogRerankPreprocessor']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from transformers import MT5Tokenizer, XLMRobertaTokenizer
|
||||
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Fields, ModeKeys, ModelFile
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp, module_name=Preprocessors.document_grounded_dialog_generate)
|
||||
class DocumentGroundedDialogGeneratePreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""The preprocessor for DGDS generate task, based on transformers' tokenizer.
|
||||
|
||||
Args:
|
||||
model_dir: The model dir containing the essential files to build the tokenizer.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.model_dir: str = model_dir
|
||||
self.config = Config.from_file(
|
||||
os.path.join(self.model_dir, ModelFile.CONFIGURATION))
|
||||
self.device = 'cuda' \
|
||||
if ('device' not in kwargs or kwargs['device'] == 'gpu') and torch.cuda.is_available() \
|
||||
else 'cpu'
|
||||
|
||||
self.top_k = self.config['top_k']
|
||||
self.query_sequence_length = self.config['query_sequence_length']
|
||||
self.rerank_source_sequence_length = self.config[
|
||||
'rerank_source_sequence_length']
|
||||
self.source_sequence_length = self.config['source_sequence_length']
|
||||
self.target_sequence_length = self.config['target_sequence_length']
|
||||
self.rerank_tokenizer = XLMRobertaTokenizer.from_pretrained(
|
||||
os.path.join(self.model_dir, 'rerank'))
|
||||
self.generation_tokenizer = MT5Tokenizer.from_pretrained(
|
||||
os.path.join(self.model_dir, 'generation'))
|
||||
|
||||
@type_assert(object, Dict)
|
||||
def __call__(self,
|
||||
data: Dict[str, Any],
|
||||
invoke_mode=ModeKeys.INFERENCE,
|
||||
**preprocessor_param) -> Dict[str, Any]:
|
||||
query, context, label = data['query'], data['context'], data.get(
|
||||
'label', None)
|
||||
query = [
|
||||
self.generation_tokenizer.decode(
|
||||
self.generation_tokenizer([x],
|
||||
add_special_tokens=False,
|
||||
return_tensors='pt')['input_ids'][0]
|
||||
[:self.query_sequence_length]) for x in query
|
||||
]
|
||||
|
||||
querys = [x for x in query for i in range(self.top_k)]
|
||||
contexts = [x for ctxs in context for x in ctxs[:self.top_k]]
|
||||
assert len(querys) == len(contexts)
|
||||
rerank_input_ids = self.rerank_tokenizer(
|
||||
querys,
|
||||
contexts,
|
||||
add_special_tokens=True,
|
||||
return_tensors='pt',
|
||||
max_length=self.rerank_source_sequence_length,
|
||||
padding='longest',
|
||||
truncation=True)
|
||||
|
||||
generator_inputs = [
|
||||
' '.join([query[i], '<passage>', doc]) for i in range(len(query))
|
||||
for doc in context[i][:self.top_k]
|
||||
]
|
||||
inputs_tokenizer_outputs = self.generation_tokenizer.batch_encode_plus(
|
||||
list(generator_inputs),
|
||||
padding=True,
|
||||
return_tensors='pt',
|
||||
max_length=self.source_sequence_length,
|
||||
truncation=True)
|
||||
|
||||
result = {
|
||||
'rerank_input_ids': rerank_input_ids,
|
||||
'input_ids': inputs_tokenizer_outputs.input_ids,
|
||||
'attention_mask': inputs_tokenizer_outputs.attention_mask
|
||||
}
|
||||
if invoke_mode in (ModeKeys.TRAIN, ModeKeys.EVAL
|
||||
) and invoke_mode != ModeKeys.INFERENCE:
|
||||
result['label_ids'] = self.generation_tokenizer.batch_encode_plus(
|
||||
list(label),
|
||||
padding=True,
|
||||
return_tensors='pt',
|
||||
max_length=self.target_sequence_length,
|
||||
truncation=True).input_ids
|
||||
|
||||
for k, v in result.items():
|
||||
result[k] = v.to(self.device)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,111 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import copy
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from transformers import XLMRobertaTokenizer
|
||||
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||||
from modelscope.utils.constant import Fields, ModeKeys, ModelFile
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp, module_name=Preprocessors.document_grounded_dialog_rerank)
|
||||
class DocumentGroundedDialogRerankPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, **kwargs):
|
||||
"""The preprocessor for DGDS rerank task, based on transformers' tokenizer.
|
||||
|
||||
Args:
|
||||
model_dir: The model dir containing the essential files to build the tokenizer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.model_dir = model_dir
|
||||
self.device = 'cuda' \
|
||||
if ('device' not in kwargs or kwargs['device'] == 'gpu') and torch.cuda.is_available() \
|
||||
else 'cpu'
|
||||
self.query_length = kwargs['query_length']
|
||||
self.max_seq_length = kwargs['max_seq_length']
|
||||
self.tokenizer = XLMRobertaTokenizer.from_pretrained(self.model_dir)
|
||||
if kwargs['tokenizer_resize']:
|
||||
special_tokens = [
|
||||
'<last_turn>', '<user>', '<agent>', '<response>', '<passage>'
|
||||
]
|
||||
self.tokenizer.add_tokens(special_tokens)
|
||||
|
||||
@type_assert(object, Dict)
|
||||
def __call__(self, data: Dict[str, Any],
|
||||
**preprocessor_param) -> Dict[str, Any]:
|
||||
if 'query' not in data:
|
||||
query = data['input']
|
||||
passages = data['passages']
|
||||
ids = data['id']
|
||||
output = data['output']
|
||||
positive_pids = data['positive_pids']
|
||||
preprocess_output_list = []
|
||||
for index in range(len(query)):
|
||||
now_query = query[index]
|
||||
now_passages = eval(passages[index])
|
||||
now_id = ids[index]
|
||||
now_output = eval(output[index])
|
||||
now_positive_pids = eval(positive_pids[index])
|
||||
# query
|
||||
query_ids = self.tokenizer(
|
||||
[now_query], add_special_tokens=False,
|
||||
return_tensors='pt')['input_ids'][0][:self.query_length]
|
||||
now_query = self.tokenizer.decode(query_ids)
|
||||
# passage
|
||||
texts_b = []
|
||||
for p in now_passages:
|
||||
texts_b.append(' '.join(
|
||||
[now_query, '<passage>', p['text']]))
|
||||
passages_input = self.tokenizer(
|
||||
texts_b,
|
||||
add_special_tokens=True,
|
||||
return_tensors='pt',
|
||||
max_length=self.max_seq_length,
|
||||
padding='longest',
|
||||
truncation=True)
|
||||
preprocess_output_list.append({
|
||||
'input_ids':
|
||||
passages_input['input_ids'].to(self.device),
|
||||
'attention_mask':
|
||||
passages_input['attention_mask'].to(self.device),
|
||||
'id':
|
||||
now_id,
|
||||
'output':
|
||||
now_output,
|
||||
'positive_pids':
|
||||
now_positive_pids,
|
||||
'passages':
|
||||
now_passages,
|
||||
'query':
|
||||
now_query
|
||||
})
|
||||
return preprocess_output_list
|
||||
else:
|
||||
query = data['query']
|
||||
passages = data['passages']
|
||||
# query
|
||||
query_ids = self.tokenizer(
|
||||
[query], add_special_tokens=False,
|
||||
return_tensors='pt')['input_ids'][0][:self.query_length]
|
||||
query = self.tokenizer.decode(query_ids)
|
||||
# passage
|
||||
texts_b = []
|
||||
for p in passages:
|
||||
texts_b.append(' '.join([query, '<passage>', p['text']]))
|
||||
passages_input = self.tokenizer(
|
||||
texts_b,
|
||||
add_special_tokens=True,
|
||||
return_tensors='pt',
|
||||
max_length=self.max_seq_length,
|
||||
padding='longest',
|
||||
truncation=True)
|
||||
result = {n: t.to(self.device) for n, t in passages_input.items()}
|
||||
return result
|
||||
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from transformers import XLMRobertaTokenizer
|
||||
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Fields, ModeKeys, ModelFile
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp, module_name=Preprocessors.document_grounded_dialog_retrieval)
|
||||
class DocumentGroundedDialogRetrievalPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""The preprocessor for DGDS retrieval task, based on transformers' tokenizer.
|
||||
|
||||
Args:
|
||||
model_dir: The model dir containing the essential files to build the tokenizer.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.model_dir: str = model_dir
|
||||
self.config = Config.from_file(
|
||||
os.path.join(self.model_dir, ModelFile.CONFIGURATION))
|
||||
self.device = 'cuda' \
|
||||
if ('device' not in kwargs or kwargs['device'] == 'gpu') and torch.cuda.is_available() \
|
||||
else 'cpu'
|
||||
self.query_sequence_length = self.config['query_sequence_length']
|
||||
self.context_sequence_length = self.config['context_sequence_length']
|
||||
self.tokenizer = XLMRobertaTokenizer.from_pretrained(
|
||||
os.path.join(self.model_dir))
|
||||
|
||||
@type_assert(object, Dict)
|
||||
def __call__(self,
|
||||
data: Dict[str, Any],
|
||||
invoke_mode=ModeKeys.INFERENCE,
|
||||
input_type='query',
|
||||
**preprocessor_param) -> Dict[str, Any]:
|
||||
if invoke_mode in (ModeKeys.TRAIN, ModeKeys.EVAL
|
||||
) and invoke_mode != ModeKeys.INFERENCE:
|
||||
query, positive, negative = data['query'], data['positive'], data[
|
||||
'negative']
|
||||
|
||||
query_tokenizer_outputs = self.tokenizer.batch_encode_plus(
|
||||
query,
|
||||
padding=True,
|
||||
return_tensors='pt',
|
||||
max_length=self.query_sequence_length,
|
||||
truncation=True)
|
||||
|
||||
context_tokenizer_outputs = self.tokenizer.batch_encode_plus(
|
||||
positive + negative,
|
||||
padding=True,
|
||||
return_tensors='pt',
|
||||
max_length=self.context_sequence_length,
|
||||
truncation=True)
|
||||
|
||||
result = {
|
||||
'query_input_ids': query_tokenizer_outputs.input_ids,
|
||||
'query_attention_mask': query_tokenizer_outputs.attention_mask,
|
||||
'context_input_ids': context_tokenizer_outputs.input_ids,
|
||||
'context_attention_mask':
|
||||
context_tokenizer_outputs.attention_mask,
|
||||
'labels':
|
||||
torch.tensor(list(range(len(query))), dtype=torch.long)
|
||||
}
|
||||
elif input_type == 'query':
|
||||
query = data['query']
|
||||
query_tokenizer_outputs = self.tokenizer.batch_encode_plus(
|
||||
query,
|
||||
padding=True,
|
||||
return_tensors='pt',
|
||||
max_length=self.query_sequence_length,
|
||||
truncation=True)
|
||||
result = {
|
||||
'query_input_ids': query_tokenizer_outputs.input_ids,
|
||||
'query_attention_mask': query_tokenizer_outputs.attention_mask,
|
||||
}
|
||||
else:
|
||||
context = data['context']
|
||||
context_tokenizer_outputs = self.tokenizer.batch_encode_plus(
|
||||
context,
|
||||
padding=True,
|
||||
return_tensors='pt',
|
||||
max_length=self.context_sequence_length,
|
||||
truncation=True)
|
||||
result = {
|
||||
'context_input_ids': context_tokenizer_outputs.input_ids,
|
||||
'context_attention_mask':
|
||||
context_tokenizer_outputs.attention_mask,
|
||||
}
|
||||
|
||||
for k, v in result.items():
|
||||
result[k] = v.to(self.device)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,287 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
from collections import Counter
|
||||
|
||||
import json
|
||||
import sacrebleu
|
||||
import torch
|
||||
import tqdm
|
||||
from rouge import Rouge
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AdamW, get_scheduler
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models import Model
|
||||
from modelscope.preprocessors import DocumentGroundedDialogGeneratePreprocessor
|
||||
from modelscope.trainers import EpochBasedTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def collate(batch):
|
||||
query = [item['query'] for item in batch]
|
||||
context = [json.loads(item['rerank']) for item in batch]
|
||||
label = [item['response'] for item in batch]
|
||||
return query, context, label
|
||||
|
||||
|
||||
def prepare_optimizer(model, lr, weight_decay, eps):
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [{
|
||||
'params': [
|
||||
p for n, p in model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
'weight_decay':
|
||||
weight_decay,
|
||||
}, {
|
||||
'params': [
|
||||
p for n, p in model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
'weight_decay':
|
||||
0.0,
|
||||
}]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=eps)
|
||||
return optimizer
|
||||
|
||||
|
||||
def prepare_scheduler(optimizer, epochs, steps_per_epoch, warmup_rate):
|
||||
total_steps = epochs * steps_per_epoch
|
||||
warmup_steps = int(total_steps * warmup_rate)
|
||||
scheduler = get_scheduler(
|
||||
name='linear',
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=warmup_steps,
|
||||
num_training_steps=total_steps)
|
||||
return scheduler
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
scores_for_ground_truths = []
|
||||
for ground_truth in ground_truths:
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
|
||||
def matching_evaluate(references, predictions):
|
||||
f1 = em = total = 0
|
||||
for ref_text, prediction in zip(references, predictions):
|
||||
total += 1
|
||||
ground_truths = [ref_text]
|
||||
f1 += metric_max_over_ground_truths(f1_score, prediction,
|
||||
ground_truths)
|
||||
em += metric_max_over_ground_truths(exact_match_score, prediction,
|
||||
ground_truths)
|
||||
f1 = 100.0 * f1 / total
|
||||
em = 100.0 * em / total
|
||||
|
||||
return f1, em
|
||||
|
||||
|
||||
def measure_result(result_dict):
|
||||
meters = dict()
|
||||
|
||||
hypothesis_list = [
|
||||
x.split('<response>')[-1].strip() for x in result_dict['outputs']
|
||||
]
|
||||
hypothesis_list = [x if x else '@' for x in hypothesis_list]
|
||||
reference_list = [
|
||||
x.split('<response>')[-1].strip() for x in result_dict['targets']
|
||||
]
|
||||
instance_num = len(reference_list)
|
||||
|
||||
# F1
|
||||
f1, em = matching_evaluate(reference_list, hypothesis_list)
|
||||
meters['f1'] = f1
|
||||
|
||||
# SacreBleu
|
||||
bleu_score = [
|
||||
sacrebleu.sentence_bleu(hypothesis, [reference]).score
|
||||
for hypothesis, reference in zip(hypothesis_list, reference_list)
|
||||
]
|
||||
bleu_score = sum(bleu_score) / instance_num
|
||||
meters['bleu'] = bleu_score
|
||||
|
||||
# Rouge-L
|
||||
rouge_func = Rouge()
|
||||
rouge_score = [
|
||||
x['rouge-l']['f']
|
||||
for x in rouge_func.get_scores(hypothesis_list, reference_list)
|
||||
]
|
||||
rouge_score = (sum(rouge_score) / instance_num) * 100
|
||||
meters['rouge'] = rouge_score
|
||||
|
||||
return meters
|
||||
|
||||
|
||||
@TRAINERS.register_module(
|
||||
module_name=Trainers.document_grounded_dialog_generate_trainer)
|
||||
class DocumentGroundedDialogGenerateTrainer(EpochBasedTrainer):
|
||||
|
||||
def __init__(self, model: str, revision='v1.0.0', *args, **kwargs):
|
||||
self.model = Model.from_pretrained(model, revision=revision)
|
||||
self.preprocessor = DocumentGroundedDialogGeneratePreprocessor(
|
||||
model_dir=self.model.model_dir)
|
||||
self.device = self.preprocessor.device
|
||||
self.model.model.to(self.device)
|
||||
self.train_dataset = kwargs['train_dataset']
|
||||
self.eval_dataset = kwargs['eval_dataset']
|
||||
|
||||
def train(self,
|
||||
total_epoches=10,
|
||||
batch_size=16,
|
||||
accumulation_steps=1,
|
||||
learning_rate=1e-4,
|
||||
warmup_ratio=0.1,
|
||||
weight_decay=0.1,
|
||||
eps=1e-06,
|
||||
loss_log_freq=40):
|
||||
"""
|
||||
Fine-tuning trainsets
|
||||
"""
|
||||
# obtain train loader
|
||||
train_loader = DataLoader(
|
||||
dataset=self.train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate)
|
||||
|
||||
optimizer = prepare_optimizer(self.model.model, learning_rate,
|
||||
weight_decay, eps)
|
||||
steps_per_epoch = len(train_loader) // accumulation_steps
|
||||
scheduler = prepare_scheduler(optimizer, total_epoches,
|
||||
steps_per_epoch, warmup_ratio)
|
||||
scaler = GradScaler()
|
||||
best_score = 0.0
|
||||
for epoch in range(total_epoches):
|
||||
self.model.model.train()
|
||||
losses = []
|
||||
for index, payload in enumerate(tqdm.tqdm(train_loader)):
|
||||
query, context, label = payload
|
||||
processed = self.preprocessor(
|
||||
{
|
||||
'query': query,
|
||||
'context': context,
|
||||
'label': label
|
||||
},
|
||||
invoke_mode=ModeKeys.TRAIN)
|
||||
with autocast():
|
||||
outputs = self.model.forward(processed)
|
||||
loss = outputs.loss.mean()
|
||||
|
||||
if accumulation_steps > 1:
|
||||
loss = loss / accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (index + 1) % accumulation_steps == 0:
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
losses.append(loss.item())
|
||||
if (index + 1) % loss_log_freq == 0:
|
||||
logger.info(
|
||||
f'epoch: {epoch} \t batch: {batch_size * index} \t loss: {sum(losses) / len(losses)}'
|
||||
)
|
||||
losses = []
|
||||
if losses:
|
||||
logger.info(
|
||||
f'epoch: {epoch} \t batch: last \t loss: {sum(losses) / len(losses)}'
|
||||
)
|
||||
|
||||
meters = self.evaluate(batch_size=batch_size)
|
||||
total_score = sum([x for x in meters.values()])
|
||||
if total_score >= best_score:
|
||||
best_score = total_score
|
||||
model_path = os.path.join(self.model.model_dir,
|
||||
'finetuned_model.bin')
|
||||
state_dict = self.model.model.state_dict()
|
||||
torch.save(state_dict, model_path)
|
||||
logger.info(
|
||||
'epoch %d obtain max score: %.4f, saving model to %s' %
|
||||
(epoch, total_score, model_path))
|
||||
|
||||
def evaluate(self, batch_size=16, checkpoint_path=None):
|
||||
"""
|
||||
Evaluate testsets
|
||||
"""
|
||||
if checkpoint_path is not None:
|
||||
state_dict = torch.load(checkpoint_path)
|
||||
self.model.model.load_state_dict(state_dict)
|
||||
|
||||
valid_loader = DataLoader(
|
||||
dataset=self.eval_dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate)
|
||||
self.model.model.eval()
|
||||
with torch.no_grad():
|
||||
results = {'outputs': [], 'targets': []}
|
||||
for index, payload in enumerate(tqdm.tqdm(valid_loader)):
|
||||
query, context, label = payload
|
||||
processed = self.preprocessor(
|
||||
{
|
||||
'query': query,
|
||||
'context': context,
|
||||
},
|
||||
invoke_mode=ModeKeys.INFERENCE)
|
||||
outputs = self.model.generate(processed)
|
||||
predictions = self.preprocessor.generation_tokenizer.batch_decode(
|
||||
outputs,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)
|
||||
label = self.preprocessor.generation_tokenizer.batch_decode(
|
||||
self.preprocessor.generation_tokenizer.batch_encode_plus(
|
||||
label, add_special_tokens=False).input_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False)
|
||||
|
||||
results['outputs'] += predictions
|
||||
results['targets'] += label
|
||||
meters = measure_result(results)
|
||||
logger.info(meters)
|
||||
return meters
|
||||
@@ -0,0 +1,603 @@
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.nn.functional as F
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models import Model
|
||||
from modelscope.preprocessors import DocumentGroundedDialogRerankPreprocessor
|
||||
from modelscope.trainers import EpochBasedTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@TRAINERS.register_module(
|
||||
module_name=Trainers.document_grounded_dialog_rerank_trainer)
|
||||
class DocumentGroundedDialogRerankTrainer(EpochBasedTrainer):
|
||||
|
||||
def __init__(self, model, dataset, **args):
|
||||
args = args['args']
|
||||
set_seed(args['seed'])
|
||||
self.positive_pids = ''
|
||||
self.instances_size = 1
|
||||
# load id to positive pid map
|
||||
self.inst_id2pos_pids = dict()
|
||||
self.inst_id2pos_passages = dict()
|
||||
self.dataset = dataset
|
||||
self.model = Model.from_pretrained(model, revision='v1.0.0')
|
||||
self.preprocessor = DocumentGroundedDialogRerankPreprocessor(
|
||||
self.model.model_dir, **args)
|
||||
self.tokenizer = self.preprocessor.tokenizer
|
||||
if args['model_resize']:
|
||||
self.model.resize_token_embeddings(len(self.tokenizer))
|
||||
self.device = self.preprocessor.device
|
||||
self.model.to(self.device)
|
||||
for jobj in self.dataset:
|
||||
self.inst_id2pos_pids[jobj['id']] = eval(jobj['positive_pids'])
|
||||
assert isinstance(eval(jobj['positive_pids']), list)
|
||||
logger.info(
|
||||
f'gathered positive pids for {len(self.inst_id2pos_pids)} instances'
|
||||
)
|
||||
|
||||
# remove out-of-recall
|
||||
instance_count = 0
|
||||
for jobj in self.dataset:
|
||||
inst_id = jobj['id']
|
||||
if inst_id not in self.inst_id2pos_pids:
|
||||
continue
|
||||
passages = eval(jobj['passages'])
|
||||
positive_pids = self.inst_id2pos_pids[inst_id]
|
||||
target_mask = [p['pid'] in positive_pids for p in passages]
|
||||
if not any(target_mask) or all(target_mask):
|
||||
del self.inst_id2pos_pids[inst_id]
|
||||
else:
|
||||
instance_count += 1
|
||||
if instance_count != len(self.inst_id2pos_pids):
|
||||
logger.error(
|
||||
f'!!! Mismatch between --positive_pids and --initial_retrieval! '
|
||||
f'{len(self.inst_id2pos_pids)} vs {instance_count}')
|
||||
|
||||
# transformer_optimize
|
||||
if args['train_instances'] <= 0:
|
||||
args['train_instances'] = instance_count
|
||||
# MARK
|
||||
instances_to_train_over = args['train_instances'] * args[
|
||||
'num_train_epochs'] // args['instances_size']
|
||||
self.optimizer = TransformerOptimize(args, instances_to_train_over,
|
||||
self.model)
|
||||
logger.info(' Num Epochs = %d', args['num_train_epochs'])
|
||||
self.optimizer.model.zero_grad()
|
||||
# MARK
|
||||
train_batch_size = \
|
||||
args['full_train_batch_size'] // args['gradient_accumulation_steps']
|
||||
self.loss_history = \
|
||||
LossHistory(
|
||||
args['train_instances'] // train_batch_size // args['instances_size']
|
||||
)
|
||||
self.args = args
|
||||
self.max_length_count = 0
|
||||
|
||||
def one_instance(self, query, passages):
|
||||
model = self.optimizer.model
|
||||
input_dict = {'query': query, 'passages': passages}
|
||||
inputs = self.preprocessor(input_dict)
|
||||
logits = F.log_softmax(
|
||||
model(inputs).logits,
|
||||
dim=-1)[:, 1] # log_softmax over the binary classification
|
||||
logprobs = F.log_softmax(
|
||||
logits, dim=0) # log_softmax over the passages
|
||||
# we want the logits rather than the logprobs as the teacher labels
|
||||
return logprobs
|
||||
|
||||
def limit_gpu_sequences_binary(self, passages, target_mask, rand):
|
||||
if len(passages) > self.args['max_num_seq_pairs_per_device']:
|
||||
num_pos = min(
|
||||
sum(target_mask),
|
||||
self.args['max_num_seq_pairs_per_device'] // 2)
|
||||
num_neg = self.args['max_num_seq_pairs_per_device'] - num_pos
|
||||
passage_and_pos = list(zip(passages, target_mask))
|
||||
rand.shuffle(passage_and_pos)
|
||||
pos_count = 0
|
||||
neg_count = 0
|
||||
passages = []
|
||||
target_mask = []
|
||||
for passage, mask in passage_and_pos:
|
||||
if mask and pos_count < num_pos:
|
||||
passages.append(passage)
|
||||
target_mask.append(mask)
|
||||
pos_count += 1
|
||||
elif not mask and neg_count < num_neg:
|
||||
passages.append(passage)
|
||||
target_mask.append(mask)
|
||||
neg_count += 1
|
||||
return passages, target_mask
|
||||
|
||||
def limit_gpu_sequences(self, passages, correctness, rand):
|
||||
if len(passages) > self.args['max_num_seq_pairs_per_device']:
|
||||
num_pos = min(
|
||||
sum([c > 0 for c in correctness]),
|
||||
self.args['max_num_seq_pairs_per_device'] // 2)
|
||||
num_neg = self.args['max_num_seq_pairs_per_device'] - num_pos
|
||||
passage_and_pos = list(zip(passages, correctness))
|
||||
rand.shuffle(passage_and_pos)
|
||||
pos_count = 0
|
||||
neg_count = 0
|
||||
passages = []
|
||||
correctness = []
|
||||
for passage, pos in passage_and_pos:
|
||||
if pos > 0 and pos_count < num_pos:
|
||||
passages.append(passage)
|
||||
correctness.append(pos)
|
||||
pos_count += 1
|
||||
elif pos == 0 and neg_count < num_neg:
|
||||
passages.append(passage)
|
||||
correctness.append(pos)
|
||||
neg_count += 1
|
||||
return passages, correctness
|
||||
|
||||
def passage_correctness(self, pid, positive_pids, positive_dids):
|
||||
if pid in positive_pids:
|
||||
return 1.0
|
||||
elif positive_dids and pid[:pid.index('::')] in positive_dids:
|
||||
return self.args['doc_match_weight']
|
||||
else:
|
||||
return 0
|
||||
|
||||
def train(self):
|
||||
rand = random.Random()
|
||||
while self.optimizer.should_continue():
|
||||
self.optimizer.model.train()
|
||||
dataset = block_shuffle(self.dataset, block_size=100000, rand=rand)
|
||||
for line_ndx, jobj in enumerate(dataset):
|
||||
inst_id = jobj['id']
|
||||
if inst_id not in self.inst_id2pos_pids:
|
||||
continue
|
||||
if line_ndx % self.args['world_size'] != \
|
||||
self.args['global_rank']:
|
||||
continue
|
||||
query = jobj['input'] if 'input' in jobj else jobj['query']
|
||||
passages = eval(jobj['passages'])
|
||||
positive_pids = self.inst_id2pos_pids[inst_id]
|
||||
if self.args['doc_match_weight'] > 0:
|
||||
positive_dids = [
|
||||
pid[:pid.index('::')] for pid in positive_pids
|
||||
]
|
||||
else:
|
||||
positive_dids = None
|
||||
correctness = [
|
||||
self.passage_correctness(p['pid'], positive_pids,
|
||||
positive_dids) for p in passages
|
||||
]
|
||||
passages, correctness = self.limit_gpu_sequences(
|
||||
passages, correctness, rand)
|
||||
logits = self.one_instance(query, passages)
|
||||
# nll = -(logits[target_mask].sum()) # TODO: instead take the weighted sum
|
||||
nll = -(
|
||||
logits.dot(torch.tensor(correctness).to(logits.device)))
|
||||
loss_val = self.optimizer.step_loss(nll)
|
||||
self.loss_history.note_loss(loss_val)
|
||||
if not self.optimizer.should_continue():
|
||||
break
|
||||
get_length = self.args['max_seq_length']
|
||||
logger.info(f'loss_history = {self.loss_history.loss_history}')
|
||||
logger.info(
|
||||
f'truncated to max length ({get_length}) {self.max_length_count} times'
|
||||
)
|
||||
save_transformer(self.args, self.optimizer.model, self.tokenizer)
|
||||
|
||||
|
||||
class Reporting:
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
recency_weight=0.001,
|
||||
report_interval_secs=300,
|
||||
check_every=1,
|
||||
gather_samples: Iterable = (),
|
||||
num_samples=10000):
|
||||
"""The Reporting to print parameter status
|
||||
|
||||
Args:
|
||||
recency_weight: when computing the moving average, how much weight to give to the current sample.
|
||||
report_interval_secs: how many seconds between returning true for is_time.
|
||||
check_every: how often to check the time, when calling is_time.
|
||||
gather_samples: keep the last num_samples of the listed names (gathered from moving_averages).
|
||||
num_samples: how many samples to keep.
|
||||
"""
|
||||
self.check_count = 0
|
||||
self.check_every = check_every
|
||||
self.start_time = time.time()
|
||||
self.last_time = self.start_time
|
||||
self.report_interval_secs = report_interval_secs
|
||||
# For tracking moving averages of various values
|
||||
self.names = None
|
||||
self.averages = None
|
||||
self.counts = None
|
||||
self.recency_weight = recency_weight
|
||||
self.per_value_recency_weight = dict()
|
||||
self.report_count = 0
|
||||
self._prev_check_count = 0
|
||||
self.sample_names = list(gather_samples)
|
||||
if len(self.sample_names) > 0:
|
||||
self.sample_values = np.zeros(
|
||||
(len(self.sample_names), num_samples), dtype=np.float32)
|
||||
self.sample_ndxs = np.zeros(len(self.sample_names), dtype=np.int32)
|
||||
else:
|
||||
self.sample_values = None
|
||||
self.sample_ndxs = None
|
||||
|
||||
def reset(self):
|
||||
self.check_count = 0
|
||||
self.start_time = time.time()
|
||||
self.last_time = self.start_time
|
||||
self.report_count = 0
|
||||
self._prev_check_count = 0
|
||||
if len(self.sample_names) > 0:
|
||||
self.sample_values[:, :] = 0
|
||||
self.sample_ndxs[:] = 0
|
||||
if self.counts is not None:
|
||||
self.counts[:] = 0
|
||||
self.averages[:] = 0
|
||||
|
||||
def is_time(self):
|
||||
self.check_count += 1
|
||||
if self.check_count % self.check_every == 0:
|
||||
elapsed = time.time() - self.last_time
|
||||
if elapsed >= self.report_interval_secs:
|
||||
# check the time more or less often
|
||||
if self.check_every > 1 and self.check_count - self._prev_check_count < 5 * self.check_every:
|
||||
self.check_every //= 2
|
||||
elif self.check_count - self._prev_check_count > 50 * self.check_every:
|
||||
self.check_every *= 2
|
||||
self.last_time = time.time()
|
||||
self.report_count += 1
|
||||
self._prev_check_count = self.check_count
|
||||
return True
|
||||
return False
|
||||
|
||||
def moving_averages(self, **values):
|
||||
# create entries in avgs and counts when needed
|
||||
# update the avgs and counts
|
||||
if self.names is None:
|
||||
self.names = list(values.keys())
|
||||
self.averages = np.zeros(len(self.names))
|
||||
self.counts = np.zeros(len(self.names))
|
||||
for name in values.keys():
|
||||
if name not in self.names:
|
||||
self.names.append(name)
|
||||
if self.averages.shape[0] < len(self.names):
|
||||
old_len = self.averages.shape[0]
|
||||
self.averages = np.resize(self.averages, len(self.names))
|
||||
self.averages[old_len:] = 0
|
||||
self.counts = np.resize(self.counts, len(self.names))
|
||||
self.counts[old_len:] = 0
|
||||
for ndx, name in enumerate(self.names):
|
||||
if name in values:
|
||||
self.counts[ndx] += 1
|
||||
# support per-name recency_weight
|
||||
if name in self.per_value_recency_weight:
|
||||
rweight = max(self.per_value_recency_weight[name],
|
||||
1.0 / self.counts[ndx])
|
||||
else:
|
||||
rweight = max(self.recency_weight, 1.0 / self.counts[ndx])
|
||||
self.averages[ndx] = \
|
||||
rweight * values[name] + (1.0 - rweight) * self.averages[ndx]
|
||||
for ndx, name in enumerate(self.sample_names):
|
||||
if name in values:
|
||||
self.sample_values[self.sample_ndxs[ndx]] = values[name]
|
||||
self.sample_ndxs[ndx] = (self.sample_ndxs[ndx]
|
||||
+ 1) % self.sample_values.shape[1]
|
||||
|
||||
def get_samples(self, name):
|
||||
for ndx, n in enumerate(self.sample_names):
|
||||
if n == name:
|
||||
count = self.get_count(name)
|
||||
if count is None:
|
||||
count = 0
|
||||
return self.sample_values[ndx, 0:count] # NOTE: not in order
|
||||
return None
|
||||
|
||||
def get_moving_average(self, name):
|
||||
if self.names is None:
|
||||
return None
|
||||
for ndx, n in enumerate(self.names):
|
||||
if n == name:
|
||||
return self.averages[ndx]
|
||||
return None
|
||||
|
||||
def get_count(self, name):
|
||||
if self.names is None:
|
||||
return None
|
||||
for ndx, n in enumerate(self.names):
|
||||
if n == name:
|
||||
return self.counts[ndx]
|
||||
return None
|
||||
|
||||
def elapsed_seconds(self) -> float:
|
||||
return time.time() - self.start_time
|
||||
|
||||
def elapsed_time_str(self) -> str:
|
||||
return time_str(self.elapsed_seconds())
|
||||
|
||||
def progress_str(self, instance_name='instance'):
|
||||
return f'On {instance_name} {self.check_count}, ' \
|
||||
f'{self.check_count / self.elapsed_seconds()} {instance_name}s per second.'
|
||||
|
||||
def display(self, *, prefix=''):
|
||||
# display the moving averages
|
||||
logger.info('==========================================')
|
||||
if self.names is not None:
|
||||
for n, v in zip(self.names, self.averages):
|
||||
logger.info(f'{prefix}{n} = {v}')
|
||||
|
||||
def display_warn(self, *, prefix=''):
|
||||
# display the moving averages
|
||||
logger.info('==========================================')
|
||||
if self.names is not None:
|
||||
for n, v in zip(self.names, self.averages):
|
||||
logger.warning(f'{prefix}{n} = {v}')
|
||||
|
||||
|
||||
class LossHistory:
|
||||
|
||||
def __init__(self,
|
||||
one_epoch_batch_count,
|
||||
*,
|
||||
loss_points_per_epoch=10,
|
||||
recency_weight=0.001):
|
||||
self.avg_loss = 0
|
||||
self.batch_count = 0
|
||||
self.recency_weight = recency_weight
|
||||
self.loss_history = []
|
||||
self.record_loss_every = max(
|
||||
1, one_epoch_batch_count // loss_points_per_epoch)
|
||||
|
||||
def note_loss(self, loss_val):
|
||||
self.batch_count += 1
|
||||
rweight = max(self.recency_weight, 1.0 / self.batch_count)
|
||||
self.avg_loss = (1.0 - rweight) * self.avg_loss + rweight * loss_val
|
||||
if self.batch_count % self.record_loss_every == 0:
|
||||
self.loss_history.append(self.avg_loss)
|
||||
logger.info(
|
||||
f'loss point {self.batch_count // self.record_loss_every} = {self.avg_loss}'
|
||||
)
|
||||
if self.avg_loss == min(
|
||||
self.loss_history) and len(self.loss_history) > 10:
|
||||
return 2
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class TransformerOptimize:
|
||||
"""
|
||||
Collects standard steps to train transformer
|
||||
call step_loss after computing each loss
|
||||
"""
|
||||
|
||||
def __init__(self, hypers, num_instances_to_train_over: int, model):
|
||||
self.step = 0
|
||||
self.global_step = 0
|
||||
self.hypers = hypers
|
||||
self.model = model
|
||||
instances_per_step = hypers['full_train_batch_size'] // hypers[
|
||||
'gradient_accumulation_steps']
|
||||
self.reporting = Reporting(recency_weight=0.0001 * instances_per_step)
|
||||
args = self.hypers
|
||||
|
||||
self.t_total = num_instances_to_train_over // args[
|
||||
'full_train_batch_size']
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
'params': [
|
||||
p for n, p in self.model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
'weight_decay':
|
||||
args['weight_decay'],
|
||||
},
|
||||
{
|
||||
'params': [
|
||||
p for n, p in self.model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
'weight_decay':
|
||||
0.0
|
||||
},
|
||||
]
|
||||
|
||||
warmup_instances = args['warmup_instances']
|
||||
if hasattr(
|
||||
args, 'warmup_fraction'
|
||||
) and args['warmup_fraction'] > 0 >= args['warmup_instances']:
|
||||
warmup_instances = \
|
||||
args['warmup_fraction'] * num_instances_to_train_over
|
||||
if warmup_instances < 0:
|
||||
warmup_instances = 0
|
||||
|
||||
self.optimizer = AdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=args['learning_rate'],
|
||||
eps=args['adam_epsilon'])
|
||||
self.scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps=warmup_instances // args['full_train_batch_size'],
|
||||
num_training_steps=self.t_total)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if args['resume_from'] and os.path.isfile(os.path.join(args['resume_from'], 'optimizer.pt')) and \
|
||||
os.path.isfile(os.path.join(args['resume_from'], 'scheduler.pt')):
|
||||
resume_from = args['resume_from']
|
||||
# elif os.path.isfile(os.path.join(args['model_name_or_path'], "optimizer.pt")) and \
|
||||
# os.path.isfile(os.path.join(args['model_name_or_path'], "scheduler.pt")):
|
||||
# resume_from = args['model_name_or_path']
|
||||
else:
|
||||
resume_from = None
|
||||
if resume_from is not None:
|
||||
# Load in optimizer and scheduler states
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(resume_from, 'optimizer.pt'),
|
||||
map_location='cpu'))
|
||||
self.scheduler.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(resume_from, 'scheduler.pt'),
|
||||
map_location='cpu'))
|
||||
logger.info(f'loaded optimizer and scheduler from {resume_from}')
|
||||
|
||||
if args['fp16']:
|
||||
self.model, optimizer = amp.initialize(
|
||||
self.model, self.optimizer, opt_level=args['fp16_opt_level'])
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args['n_gpu'] > 1:
|
||||
# NOTE: won't work at O2, only O1
|
||||
self.model = torch.nn.DataParallel(
|
||||
self.model, device_ids=list(range(args['n_gpu'])))
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
# if args.local_rank != -1:
|
||||
# self.model = torch.nn.parallel.DistributedDataParallel(
|
||||
# self.model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
|
||||
# )
|
||||
# set_seed(args)
|
||||
# assert args.per_gpu_train_batch_size * (args.n_gpu if args.n_gpu > 0 else 1) * \
|
||||
# args.world_size * args.gradient_accumulation_steps == args.full_train_batch_size
|
||||
logger.info('***** Running training *****')
|
||||
logger.info(' Instantaneous batch size per GPU = %d',
|
||||
args['per_gpu_train_batch_size'])
|
||||
logger.info(
|
||||
' Total train batch size (w. parallel, distributed & accumulation) = %d',
|
||||
args['full_train_batch_size'])
|
||||
logger.info(' Gradient Accumulation steps = %d',
|
||||
args['gradient_accumulation_steps'])
|
||||
logger.info(' Total optimization steps = %d', self.t_total)
|
||||
|
||||
def should_continue(self):
|
||||
return self.global_step < self.t_total
|
||||
|
||||
def backward_on_loss(self, loss, **moving_averages):
|
||||
if self.hypers['n_gpu'] > 1:
|
||||
loss = loss.mean(
|
||||
) # mean() to average on multi-gpu parallel training
|
||||
loss_val = loss.item()
|
||||
if self.hypers['gradient_accumulation_steps'] > 1:
|
||||
loss = loss / self.hypers['gradient_accumulation_steps']
|
||||
if self.hypers['fp16']:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
self.reporting.moving_averages(loss=loss_val, **moving_averages)
|
||||
return loss_val
|
||||
|
||||
def optimizer_step(self):
|
||||
if self.global_step >= self.t_total:
|
||||
logger.warning(
|
||||
f'Warning, exceeded total steps! {self.global_step} step of {self.t_total}'
|
||||
)
|
||||
return False
|
||||
if (self.step + 1) % self.hypers['gradient_accumulation_steps'] == 0:
|
||||
if self.hypers['max_grad_norm'] > 0:
|
||||
if self.hypers['fp16']:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
amp.master_params(self.optimizer),
|
||||
self.hypers['max_grad_norm'])
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.hypers['max_grad_norm'])
|
||||
|
||||
self.optimizer.step()
|
||||
self.scheduler.step() # Update learning rate schedule
|
||||
self.model.zero_grad()
|
||||
self.global_step += 1
|
||||
self.step += 1
|
||||
|
||||
if self.reporting.is_time():
|
||||
self.reporting.display()
|
||||
inst_count = \
|
||||
self.hypers['world_size'] * self.hypers['n_gpu'] * self.hypers[
|
||||
'per_gpu_train_batch_size'] * self.reporting.check_count
|
||||
learning_rate_scalar = self.scheduler.get_lr()[0]
|
||||
logger.info(
|
||||
f'{inst_count / self.reporting.elapsed_seconds()} instances per second; '
|
||||
f'{inst_count} total ({learning_rate_scalar} learn rate)')
|
||||
return True
|
||||
|
||||
def step_loss(self, loss, **moving_averages):
|
||||
loss_val = self.backward_on_loss(loss, **moving_averages)
|
||||
if self.optimizer_step():
|
||||
return loss_val
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def block_shuffle(iter, *, block_size=20000, rand=random):
|
||||
"""
|
||||
shuffle the possibly endless iterator by blocks
|
||||
Good shuffling over multiple files:
|
||||
block_shuffle(read_lines(files, shuffled_files=rand), rand=rand, block_size=100000)
|
||||
:param iter: the iterator we will yield shuffled items from
|
||||
:param block_size: size of memory to use for block shuffling
|
||||
:param rand: rand.shuffle will be used on the list block
|
||||
:return:
|
||||
"""
|
||||
assert block_size >= 4
|
||||
block = []
|
||||
for item in iter:
|
||||
block.append(item)
|
||||
if len(block) >= block_size:
|
||||
rand.shuffle(block)
|
||||
for _ in range(block_size // 2):
|
||||
yield block.pop(-1)
|
||||
rand.shuffle(block)
|
||||
for bi in block:
|
||||
yield bi
|
||||
|
||||
|
||||
def save_transformer(hypers, model, tokenizer, *, save_dir=None):
|
||||
if hypers['global_rank'] == 0:
|
||||
if save_dir is None:
|
||||
save_dir = hypers['output_dir']
|
||||
# Create output directory if needed
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
logger.info('Saving model checkpoint to %s', save_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (model.module if hasattr(model, 'module') else model
|
||||
) # Take care of distributed/parallel training
|
||||
torch.save(hypers, os.path.join(save_dir, 'training_args.bin'))
|
||||
model_to_save.save_pretrained(save_dir)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
|
||||
def kofn(kofn: str):
|
||||
"""
|
||||
'' -> 0, 1
|
||||
'1of2' -> 0, 2
|
||||
'2of2' -> 1, 2
|
||||
:param kofn:
|
||||
:return:
|
||||
"""
|
||||
if not kofn:
|
||||
return 0, 1
|
||||
k, n = [int(i) for i in kofn.lower().split('of')]
|
||||
assert 1 <= k <= n
|
||||
return k - 1, n
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
@@ -0,0 +1,216 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
|
||||
import faiss
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AdamW, get_scheduler
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models import Model
|
||||
from modelscope.preprocessors import \
|
||||
DocumentGroundedDialogRetrievalPreprocessor
|
||||
from modelscope.trainers import EpochBasedTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def collate(batch):
|
||||
query = [item['query'] for item in batch]
|
||||
positive = [item['positive'] for item in batch]
|
||||
negative = [item['negative'] for item in batch]
|
||||
return query, positive, negative
|
||||
|
||||
|
||||
def prepare_optimizer(model, lr, weight_decay, eps):
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [{
|
||||
'params': [
|
||||
p for n, p in model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
'weight_decay':
|
||||
weight_decay,
|
||||
}, {
|
||||
'params': [
|
||||
p for n, p in model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
'weight_decay':
|
||||
0.0,
|
||||
}]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=eps)
|
||||
return optimizer
|
||||
|
||||
|
||||
def prepare_scheduler(optimizer, epochs, steps_per_epoch, warmup_rate):
|
||||
total_steps = epochs * steps_per_epoch
|
||||
warmup_steps = int(total_steps * warmup_rate)
|
||||
scheduler = get_scheduler(
|
||||
name='linear',
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=warmup_steps,
|
||||
num_training_steps=total_steps)
|
||||
return scheduler
|
||||
|
||||
|
||||
def measure_result(result_dict):
|
||||
recall_k = [1, 5, 10, 20]
|
||||
meters = {f'R@{k}': [] for k in recall_k}
|
||||
|
||||
for output, target in zip(result_dict['outputs'], result_dict['targets']):
|
||||
for k in recall_k:
|
||||
if target in output[:k]:
|
||||
meters[f'R@{k}'].append(1)
|
||||
else:
|
||||
meters[f'R@{k}'].append(0)
|
||||
for k, v in meters.items():
|
||||
meters[k] = sum(v) / len(v)
|
||||
return meters
|
||||
|
||||
|
||||
@TRAINERS.register_module(
|
||||
module_name=Trainers.document_grounded_dialog_retrieval_trainer)
|
||||
class DocumentGroundedDialogRetrievalTrainer(EpochBasedTrainer):
|
||||
|
||||
def __init__(self, model: str, revision='v1.0.0', *args, **kwargs):
|
||||
self.model = Model.from_pretrained(model, revision=revision)
|
||||
self.preprocessor = DocumentGroundedDialogRetrievalPreprocessor(
|
||||
model_dir=self.model.model_dir)
|
||||
self.device = self.preprocessor.device
|
||||
self.model.model.to(self.device)
|
||||
self.train_dataset = kwargs['train_dataset']
|
||||
self.eval_dataset = kwargs['eval_dataset']
|
||||
self.all_passages = kwargs['all_passages']
|
||||
|
||||
def train(self,
|
||||
total_epoches=20,
|
||||
batch_size=128,
|
||||
per_gpu_batch_size=32,
|
||||
accumulation_steps=1,
|
||||
learning_rate=2e-5,
|
||||
warmup_ratio=0.1,
|
||||
weight_decay=0.1,
|
||||
eps=1e-06,
|
||||
loss_log_freq=40):
|
||||
"""
|
||||
Fine-tuning trainsets
|
||||
"""
|
||||
# obtain train loader
|
||||
train_loader = DataLoader(
|
||||
dataset=self.train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate)
|
||||
|
||||
optimizer = prepare_optimizer(self.model.model, learning_rate,
|
||||
weight_decay, eps)
|
||||
steps_per_epoch = len(train_loader) // accumulation_steps
|
||||
scheduler = prepare_scheduler(optimizer, total_epoches,
|
||||
steps_per_epoch, warmup_ratio)
|
||||
|
||||
best_score = 0.0
|
||||
for epoch in range(total_epoches):
|
||||
self.model.model.train()
|
||||
losses = []
|
||||
for index, payload in enumerate(tqdm.tqdm(train_loader)):
|
||||
query, positive, negative = payload
|
||||
processed = self.preprocessor(
|
||||
{
|
||||
'query': query,
|
||||
'positive': positive,
|
||||
'negative': negative
|
||||
},
|
||||
invoke_mode=ModeKeys.TRAIN)
|
||||
loss, logits = self.model.forward(processed)
|
||||
|
||||
if accumulation_steps > 1:
|
||||
loss = loss / accumulation_steps
|
||||
|
||||
loss.backward()
|
||||
|
||||
if (index + 1) % accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
losses.append(loss.item())
|
||||
if (index + 1) % loss_log_freq == 0:
|
||||
logger.info(
|
||||
f'epoch: {epoch} \t batch: {batch_size * index} \t loss: {sum(losses) / len(losses)}'
|
||||
)
|
||||
losses = []
|
||||
if losses:
|
||||
logger.info(
|
||||
f'epoch: {epoch} \t batch: last \t loss: {sum(losses) / len(losses)}'
|
||||
)
|
||||
|
||||
meters = self.evaluate(per_gpu_batch_size=per_gpu_batch_size)
|
||||
total_score = sum([x for x in meters.values()])
|
||||
if total_score >= best_score:
|
||||
best_score = total_score
|
||||
model_path = os.path.join(self.model.model_dir,
|
||||
'finetuned_model.bin')
|
||||
state_dict = self.model.model.state_dict()
|
||||
torch.save(state_dict, model_path)
|
||||
logger.info(
|
||||
'epoch %d obtain max score: %.4f, saving model to %s' %
|
||||
(epoch, total_score, model_path))
|
||||
|
||||
def evaluate(self, per_gpu_batch_size=32, checkpoint_path=None):
|
||||
"""
|
||||
Evaluate testsets
|
||||
"""
|
||||
if checkpoint_path is not None:
|
||||
state_dict = torch.load(checkpoint_path)
|
||||
self.model.model.load_state_dict(state_dict)
|
||||
|
||||
valid_loader = DataLoader(
|
||||
dataset=self.eval_dataset,
|
||||
batch_size=per_gpu_batch_size,
|
||||
collate_fn=collate)
|
||||
self.model.model.eval()
|
||||
with torch.no_grad():
|
||||
all_ctx_vector = []
|
||||
for mini_batch in tqdm.tqdm(
|
||||
range(0, len(self.all_passages), per_gpu_batch_size)):
|
||||
context = self.all_passages[mini_batch:mini_batch
|
||||
+ per_gpu_batch_size]
|
||||
processed = \
|
||||
self.preprocessor({'context': context},
|
||||
invoke_mode=ModeKeys.INFERENCE,
|
||||
input_type='context')
|
||||
sub_ctx_vector = self.model.encode_context(
|
||||
processed).detach().cpu().numpy()
|
||||
all_ctx_vector.append(sub_ctx_vector)
|
||||
|
||||
all_ctx_vector = np.concatenate(all_ctx_vector, axis=0)
|
||||
all_ctx_vector = np.array(all_ctx_vector).astype('float32')
|
||||
faiss_index = faiss.IndexFlatIP(all_ctx_vector.shape[-1])
|
||||
faiss_index.add(all_ctx_vector)
|
||||
|
||||
results = {'outputs': [], 'targets': []}
|
||||
for index, payload in enumerate(tqdm.tqdm(valid_loader)):
|
||||
query, positive, negative = payload
|
||||
processed = self.preprocessor({'query': query},
|
||||
invoke_mode=ModeKeys.INFERENCE)
|
||||
query_vector = self.model.encode_query(
|
||||
processed).detach().cpu().numpy().astype('float32')
|
||||
D, Index = faiss_index.search(query_vector, 20)
|
||||
results['outputs'] += [[
|
||||
self.all_passages[x] for x in retrieved_ids
|
||||
] for retrieved_ids in Index.tolist()]
|
||||
results['targets'] += positive
|
||||
meters = measure_result(results)
|
||||
result_path = os.path.join(self.model.model_dir,
|
||||
'evaluate_result.json')
|
||||
with open(result_path, 'w') as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=4)
|
||||
|
||||
logger.info(meters)
|
||||
return meters
|
||||
@@ -194,6 +194,9 @@ class NLPTasks(object):
|
||||
translation_evaluation = 'translation-evaluation'
|
||||
sudoku = 'sudoku'
|
||||
text2sql = 'text2sql'
|
||||
document_grounded_dialog_retrieval = 'document-grounded-dialog-retrieval'
|
||||
document_grounded_dialog_rerank = 'document-grounded-dialog-rerank'
|
||||
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
|
||||
|
||||
|
||||
class AudioTasks(object):
|
||||
|
||||
88
tests/pipelines/test_document_grounded_dialog_generate.py
Normal file
88
tests/pipelines/test_document_grounded_dialog_generate.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path
|
||||
import unittest
|
||||
from threading import Thread
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.preprocessors.nlp import \
|
||||
DocumentGroundedDialogGeneratePreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class DocumentGroundedDialogGenerateTest(unittest.TestCase,
|
||||
DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.document_grounded_dialog_generate
|
||||
self.model_id = 'DAMO_ConvAI/nlp_convai_generation_pretrain'
|
||||
|
||||
param = {
|
||||
'query': [
|
||||
'<last_turn>:我想知道孩子如果出现阑尾炎的话会怎么样?',
|
||||
'<last_turn>:好像是从肚脐开始,然后到右下方<system>您可以描述一下孩子的情况吗?<user>我想知道孩子如果出现阑尾炎的话会怎么样?',
|
||||
],
|
||||
'context': [
|
||||
['c1', 'c2', 'c3', 'c4', 'c5'],
|
||||
['c1', 'c2', 'c3', 'c4', 'c5'],
|
||||
],
|
||||
'label': [
|
||||
'<response>您可以描述一下孩子的情况吗?',
|
||||
'<response>那还有没有烦躁或无精打采的表现呢?',
|
||||
]
|
||||
}
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download(self):
|
||||
cache_path = snapshot_download(self.model_id, revision='v1.0.0')
|
||||
preprocessor = DocumentGroundedDialogGeneratePreprocessor(
|
||||
model_dir=cache_path)
|
||||
pipeline_ins = pipeline(
|
||||
Tasks.document_grounded_dialog_generate,
|
||||
model=cache_path,
|
||||
preprocessor=preprocessor)
|
||||
print(pipeline_ins(self.param))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download_with_multithreads(self):
|
||||
cache_path = snapshot_download(self.model_id, revision='v1.0.0')
|
||||
pl = pipeline(
|
||||
Tasks.document_grounded_dialog_generate, model=cache_path)
|
||||
|
||||
def print_func(pl, i):
|
||||
result = pl(self.param)
|
||||
print(i, result)
|
||||
|
||||
procs = []
|
||||
for i in range(5):
|
||||
proc = Thread(target=print_func, args=(pl, i))
|
||||
procs.append(proc)
|
||||
proc.start()
|
||||
for proc in procs:
|
||||
proc.join()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id, revision='v1.0.0')
|
||||
|
||||
preprocessor = DocumentGroundedDialogGeneratePreprocessor(
|
||||
model_dir=model.model_dir)
|
||||
pipeline_ins = pipeline(
|
||||
Tasks.document_grounded_dialog_generate,
|
||||
model=model,
|
||||
preprocessor=preprocessor)
|
||||
print(pipeline_ins(self.param))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_default_model(self):
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.document_grounded_dialog_generate,
|
||||
model_revision='v1.0.0')
|
||||
print(pipeline_ins(self.param))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
51
tests/pipelines/test_document_grounded_dialog_rerank.py
Normal file
51
tests/pipelines/test_document_grounded_dialog_rerank.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
import json
|
||||
import torch
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import DocumentGroundedDialogRerankModel
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines.nlp import DocumentGroundedDialogRerankPipeline
|
||||
from modelscope.preprocessors.nlp import \
|
||||
DocumentGroundedDialogRerankPreprocessor
|
||||
from modelscope.utils.constant import DownloadMode, Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class DocumentGroundedDialogRerankTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.document_grounded_dialog_rerank
|
||||
self.model_id = 'DAMO_ConvAI/nlp_convai_ranking_pretrain'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run(self):
|
||||
args = {
|
||||
'output': '../../../result.json',
|
||||
'max_batch_size': 64,
|
||||
'exclude_instances': '',
|
||||
'include_passages': False,
|
||||
'do_lower_case': True,
|
||||
'max_seq_length': 512,
|
||||
'query_length': 195,
|
||||
'tokenizer_resize': True,
|
||||
'model_resize': True,
|
||||
'kilt_data': True
|
||||
}
|
||||
model = Model.from_pretrained(self.model_id, revision='v1.0.0', **args)
|
||||
mypreprocessor = DocumentGroundedDialogRerankPreprocessor(
|
||||
model.model_dir, **args)
|
||||
pipeline_ins = DocumentGroundedDialogRerankPipeline(
|
||||
model=model, preprocessor=mypreprocessor, **args)
|
||||
dataset = MsDataset.load(
|
||||
'DAMO_ConvAI/FrDoc2BotRerank',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD,
|
||||
split='test')[:2]
|
||||
# print(dataset)
|
||||
pipeline_ins(dataset)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
84
tests/pipelines/test_document_grounded_dialog_retrieval.py
Normal file
84
tests/pipelines/test_document_grounded_dialog_retrieval.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path
|
||||
import unittest
|
||||
from threading import Thread
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.preprocessors.nlp import \
|
||||
DocumentGroundedDialogRetrievalPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class DocumentGroundedDialogRetrievalTest(unittest.TestCase,
|
||||
DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.document_grounded_dialog_retrieval
|
||||
self.model_id = 'DAMO_ConvAI/nlp_convai_retrieval_pretrain'
|
||||
|
||||
param = {
|
||||
'query': [
|
||||
'<last_turn>我想知道孩子如果出现阑尾炎的话会怎么样',
|
||||
'<last_turn>好像是从肚脐开始,然后到右下方<system>您可以描述一下孩子的情况吗?<user>我想知道孩子如果出现阑尾炎的话会怎么样?',
|
||||
],
|
||||
'positive': ['阑尾炎', '肚脐开始'],
|
||||
'negative': [
|
||||
'肠胃炎',
|
||||
'肚脐为止',
|
||||
]
|
||||
}
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download(self):
|
||||
cache_path = snapshot_download(self.model_id, revision='v1.0.0')
|
||||
preprocessor = DocumentGroundedDialogRetrievalPreprocessor(
|
||||
model_dir=cache_path)
|
||||
pipeline_ins = pipeline(
|
||||
Tasks.document_grounded_dialog_retrieval,
|
||||
model=cache_path,
|
||||
preprocessor=preprocessor)
|
||||
print(pipeline_ins(self.param))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download_with_multithreads(self):
|
||||
cache_path = snapshot_download(self.model_id, revision='v1.0.0')
|
||||
pl = pipeline(
|
||||
Tasks.document_grounded_dialog_retrieval, model=cache_path)
|
||||
|
||||
def print_func(pl, i):
|
||||
result = pl(self.param)
|
||||
print(i, result)
|
||||
|
||||
procs = []
|
||||
for i in range(5):
|
||||
proc = Thread(target=print_func, args=(pl, i))
|
||||
procs.append(proc)
|
||||
proc.start()
|
||||
for proc in procs:
|
||||
proc.join()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id, revision='v1.0.0')
|
||||
preprocessor = DocumentGroundedDialogRetrievalPreprocessor(
|
||||
model_dir=model.model_dir)
|
||||
pipeline_ins = pipeline(
|
||||
Tasks.document_grounded_dialog_retrieval,
|
||||
model=model,
|
||||
preprocessor=preprocessor)
|
||||
print(pipeline_ins(self.param))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_default_model(self):
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.document_grounded_dialog_retrieval,
|
||||
model_revision='v1.0.0')
|
||||
print(pipeline_ins(self.param))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import json
|
||||
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers.nlp.document_grounded_dialog_generate_trainer import \
|
||||
DocumentGroundedDialogGenerateTrainer
|
||||
from modelscope.utils.constant import DownloadMode, ModelFile
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class DocumentGroundedDialogGenerateTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'DAMO_ConvAI/nlp_convai_generation_pretrain'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer_with_model_name(self):
|
||||
# load data
|
||||
train_dataset = MsDataset.load(
|
||||
'DAMO_ConvAI/FrDoc2BotGeneration',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
test_len = 1
|
||||
sub_train_dataset = [x for x in train_dataset][:1]
|
||||
sub_train_dataset = [{
|
||||
'query':
|
||||
x['query'][:test_len],
|
||||
'rerank':
|
||||
json.dumps([p[:test_len] for p in json.loads(x['rerank'])]),
|
||||
'response':
|
||||
x['response'][:test_len]
|
||||
} for x in sub_train_dataset]
|
||||
|
||||
trainer = DocumentGroundedDialogGenerateTrainer(
|
||||
model=self.model_id,
|
||||
train_dataset=sub_train_dataset,
|
||||
eval_dataset=sub_train_dataset,
|
||||
)
|
||||
trainer.model.model.config['num_beams'] = 1
|
||||
trainer.model.model.config['target_sequence_length'] = test_len
|
||||
trainer.train(batch_size=1, total_epoches=1, learning_rate=2e-4)
|
||||
trainer.evaluate(
|
||||
checkpoint_path=os.path.join(trainer.model.model_dir,
|
||||
'finetuned_model.bin'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import json
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers.nlp.document_grounded_dialog_rerank_trainer import \
|
||||
DocumentGroundedDialogRerankTrainer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DownloadMode, ModelFile, Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestDialogIntentTrainer(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.model_id = 'DAMO_ConvAI/nlp_convai_ranking_pretrain'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer_with_model_and_args(self):
|
||||
args = {
|
||||
'device': 'gpu',
|
||||
'tokenizer_name': '',
|
||||
'cache_dir': '',
|
||||
'instances_size': 1,
|
||||
'output_dir': './model',
|
||||
'max_num_seq_pairs_per_device': 32,
|
||||
'full_train_batch_size': 32,
|
||||
'gradient_accumulation_steps': 32,
|
||||
'per_gpu_train_batch_size': 1,
|
||||
'num_train_epochs': 1,
|
||||
'train_instances': -1,
|
||||
'learning_rate': 3e-5,
|
||||
'max_seq_length': 128,
|
||||
'num_labels': 2,
|
||||
'fold': '', # IofN
|
||||
'doc_match_weight': 0.0,
|
||||
'query_length': 64,
|
||||
'resume_from': '', # to resume training from a checkpoint
|
||||
'config_name': '',
|
||||
'do_lower_case': True,
|
||||
'weight_decay': 0.0, # previous default was 0.01
|
||||
'adam_epsilon': 1e-8,
|
||||
'max_grad_norm': 1.0,
|
||||
'warmup_instances': 0, # previous default was 0.1 of total
|
||||
'warmup_fraction': 0.0, # only applies if warmup_instances <= 0
|
||||
'no_cuda': False,
|
||||
'n_gpu': 1,
|
||||
'seed': 42,
|
||||
'fp16': False,
|
||||
'fp16_opt_level': 'O1', # previous default was O2
|
||||
'per_gpu_eval_batch_size': 8,
|
||||
'log_on_all_nodes': False,
|
||||
'world_size': 1,
|
||||
'global_rank': 0,
|
||||
'local_rank': -1,
|
||||
'tokenizer_resize': True,
|
||||
'model_resize': True
|
||||
}
|
||||
args[
|
||||
'gradient_accumulation_steps'] = args['full_train_batch_size'] // (
|
||||
args['per_gpu_train_batch_size'] * args['world_size'])
|
||||
data = MsDataset.load(
|
||||
'DAMO_ConvAI/FrDoc2BotRerank',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD,
|
||||
split='train')
|
||||
sub_train_dataset = [x for x in data][:10]
|
||||
trainer = DocumentGroundedDialogRerankTrainer(
|
||||
model=self.model_id, dataset=sub_train_dataset, args=args)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import json
|
||||
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers.nlp.document_grounded_dialog_retrieval_trainer import \
|
||||
DocumentGroundedDialogRetrievalTrainer
|
||||
from modelscope.utils.constant import DownloadMode, ModelFile
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class DocumentGroundedDialogRetrievalTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'DAMO_ConvAI/nlp_convai_retrieval_pretrain'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer_with_model_name(self):
|
||||
# load data
|
||||
train_dataset = MsDataset.load(
|
||||
'DAMO_ConvAI/FrDoc2BotRetrieval',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
sub_train_dataset = [x for x in train_dataset][:10]
|
||||
all_passages = ['阑尾炎', '肠胃炎', '肚脐开始', '肚脐为止']
|
||||
|
||||
trainer = DocumentGroundedDialogRetrievalTrainer(
|
||||
model=self.model_id,
|
||||
train_dataset=sub_train_dataset,
|
||||
eval_dataset=sub_train_dataset,
|
||||
all_passages=all_passages)
|
||||
trainer.train(
|
||||
batch_size=64,
|
||||
total_epoches=2,
|
||||
)
|
||||
trainer.evaluate(
|
||||
checkpoint_path=os.path.join(trainer.model.model_dir,
|
||||
'finetuned_model.bin'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user