mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
[to #42322933]
enable finetune of ofa-mmspeech
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10981972
This commit is contained in:
@@ -41,6 +41,8 @@ __all__ = ['OfaForAllTasks']
|
||||
class OfaForAllTasks(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
if os.path.exists(model_dir):
|
||||
model_dir = os.path.abspath(model_dir)
|
||||
super().__init__(model_dir=model_dir, *args, **kwargs)
|
||||
self.cfg = Config.from_file(
|
||||
osp.join(model_dir, ModelFile.CONFIGURATION))
|
||||
|
||||
@@ -80,10 +80,11 @@ class OfaASRPreprocessor(OfaBasePreprocessor):
|
||||
target = ' '.join(target_token_list[:self.max_tgt_length])
|
||||
sample['target'] = self.tokenize_text(target, add_bos=False)
|
||||
|
||||
phone_item = self.to_phone(target) - 3
|
||||
phone_item = self.to_phone(target) + 1
|
||||
phone_mask = torch.tensor([False])
|
||||
|
||||
sample['phone_item'] = phone_item
|
||||
sample['phone_item'] = phone_item + 3
|
||||
sample['phone_target'] = phone_item
|
||||
sample['phone_mask'] = phone_mask
|
||||
|
||||
sample['prev_output_tokens'] = torch.cat(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
from os import path as osp
|
||||
@@ -32,6 +33,8 @@ class OfaBasePreprocessor:
|
||||
self.cfg = cfg
|
||||
self.mode = mode
|
||||
self.language = self.cfg.model.get('language', 'en')
|
||||
if os.path.exists(model_dir):
|
||||
model_dir = os.path.abspath(model_dir)
|
||||
if self.language == 'en':
|
||||
tokenizer = OFATokenizer.from_pretrained(model_dir)
|
||||
elif self.language in ['zh', 'cn']:
|
||||
|
||||
@@ -83,6 +83,10 @@ def collate_fn(samples, pad_idx, eos_idx):
|
||||
batch['net_input']['phone_items'] = merge('phone_item')
|
||||
batch['net_input']['phone_masks'] = torch.cat(
|
||||
[s['phone_mask'] for s in samples])
|
||||
if samples[0].get('phone_target', None) is not None:
|
||||
batch['phone_target'] = merge('phone_target')
|
||||
batch['phone_length'] = torch.tensor(
|
||||
[s['phone_target'].size(0) for s in samples], dtype=torch.long)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from functools import partial
|
||||
from shutil import ignore_patterns
|
||||
from typing import Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -23,9 +23,9 @@ from modelscope.trainers.optimizer.builder import build_optimizer
|
||||
from modelscope.trainers.parallel.utils import is_parallel
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys,
|
||||
Invoke, ModeKeys)
|
||||
Invoke, ModeKeys, ModelFile)
|
||||
from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion,
|
||||
get_schedule)
|
||||
get_schedule, recursive_overwrite)
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.ofa)
|
||||
@@ -58,23 +58,12 @@ class OFATrainer(EpochBasedTrainer):
|
||||
work_dir = cfg.train.work_dir
|
||||
else:
|
||||
work_dir = kwargs['work_dir']
|
||||
tokenizer_files = {
|
||||
'zh': [
|
||||
'tokenizer.json', 'tokenizer_config.json', 'vocab.txt',
|
||||
'config.json', 'ans2label.json'
|
||||
],
|
||||
'en': [
|
||||
'tokenizer.json', 'vocab.json', 'merges.txt', 'config.json',
|
||||
'ans2label.json'
|
||||
],
|
||||
}
|
||||
for filename in tokenizer_files[cfg.model.get('language', 'en')]:
|
||||
finetune_file = os.path.join(work_dir, filename)
|
||||
pretrain_file = os.path.join(model_dir, filename)
|
||||
if os.path.exists(finetune_file):
|
||||
continue
|
||||
if os.path.exists(pretrain_file):
|
||||
shutil.copy(pretrain_file, finetune_file)
|
||||
|
||||
os.makedirs(work_dir, exist_ok=True)
|
||||
ignore_file_set = set()
|
||||
ignore_file_set.add(ModelFile.CONFIGURATION)
|
||||
recursive_overwrite(
|
||||
model_dir, work_dir, ignore=ignore_patterns(*ignore_file_set))
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = {
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
# This source code is licensed under the Apache 2.0 license
|
||||
# found in the LICENSE file in the root directory.
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -11,6 +13,23 @@ import transformers
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
def recursive_overwrite(src, dst, ignore=None):
|
||||
if os.path.isdir(src):
|
||||
if not os.path.isdir(dst):
|
||||
os.makedirs(dst)
|
||||
files = os.listdir(src)
|
||||
if ignore is not None:
|
||||
ignored = ignore(src, files)
|
||||
else:
|
||||
ignored = set()
|
||||
for f in files:
|
||||
if f not in ignored:
|
||||
recursive_overwrite(
|
||||
os.path.join(src, f), os.path.join(dst, f), ignore)
|
||||
else:
|
||||
shutil.copyfile(src, dst)
|
||||
|
||||
|
||||
def construct_rdrop_sample(x):
|
||||
if isinstance(x, dict):
|
||||
for key in x:
|
||||
@@ -211,17 +230,17 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
|
||||
return loss, nll_loss, ntokens
|
||||
|
||||
def compute_ctc_loss(self, model, output, sample):
|
||||
lprobs = model.get_encoder_normalized_probs(
|
||||
lprobs = model.model.get_encoder_normalized_probs(
|
||||
output, log_probs=True).contiguous() # (T, B, C) from the encoder
|
||||
|
||||
non_padding_mask = ~output.encoder_padding_mask
|
||||
input_lengths = non_padding_mask.long().sum(-1)
|
||||
|
||||
target_lengths = sample['ctc_output_lengths']
|
||||
target_lengths = sample['phone_length']
|
||||
pad_mask = torch.arange(target_lengths.max()).expand([
|
||||
target_lengths.shape[0], -1
|
||||
]).to(target_lengths) < target_lengths.unsqueeze(1)
|
||||
targets_flat = sample['ctc_outputs'].masked_select(pad_mask)
|
||||
targets_flat = sample['phone_target'].masked_select(pad_mask)
|
||||
|
||||
with torch.backends.cudnn.flags(enabled=False):
|
||||
loss = F.ctc_loss(
|
||||
@@ -229,12 +248,12 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
|
||||
targets_flat,
|
||||
input_lengths,
|
||||
target_lengths,
|
||||
blank=self.blank_idx,
|
||||
blank=0,
|
||||
reduction='sum',
|
||||
zero_infinity=True,
|
||||
)
|
||||
|
||||
return loss
|
||||
return loss / lprobs.shape[1]
|
||||
|
||||
|
||||
def get_schedule(scheduler):
|
||||
|
||||
108
tests/trainers/test_ofa_mmspeech_trainer.py
Normal file
108
tests/trainers/test_ofa_mmspeech_trainer.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import json
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import DownloadMode, ModelFile
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestMMSpeechTrainer(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.finetune_cfg = \
|
||||
{'framework': 'pytorch',
|
||||
'task': 'auto-speech-recognition',
|
||||
'model': {'type': 'ofa',
|
||||
'beam_search': {'beam_size': 5,
|
||||
'max_len_b': 128,
|
||||
'min_len': 1,
|
||||
'no_repeat_ngram_size': 5,
|
||||
'constraint_range': '4,21134'},
|
||||
'seed': 7,
|
||||
'max_src_length': 256,
|
||||
'language': 'zh',
|
||||
'gen_type': 'generation',
|
||||
'multimodal_type': 'mmspeech'},
|
||||
'pipeline': {'type': 'ofa-asr'},
|
||||
'n_frames_per_step': 1,
|
||||
'dataset': {'column_map': {'wav': 'Audio:FILE', 'text': 'Text:LABEL'}},
|
||||
'train': {'work_dir': 'work/ckpts/asr_recognition',
|
||||
# 'launcher': 'pytorch',
|
||||
'max_epochs': 1,
|
||||
'use_fp16': True,
|
||||
'dataloader': {'batch_size_per_gpu': 16, 'workers_per_gpu': 0},
|
||||
'lr_scheduler': {'name': 'polynomial_decay',
|
||||
'warmup_proportion': 0.01,
|
||||
'lr_end': 1e-07},
|
||||
'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False},
|
||||
'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01},
|
||||
'optimizer_hook': {'type': 'TorchAMPOptimizerHook',
|
||||
'cumulative_iters': 1,
|
||||
'grad_clip': {'max_norm': 1.0, 'norm_type': 2},
|
||||
'loss_keys': 'loss'},
|
||||
'criterion': {'name': 'AdjustLabelSmoothedCrossEntropyCriterion',
|
||||
'constraint_range': '4,21134',
|
||||
'drop_worst_after': 0,
|
||||
'drop_worst_ratio': 0.0,
|
||||
'ignore_eos': False,
|
||||
'ignore_prefix_size': 0,
|
||||
'label_smoothing': 0.1,
|
||||
'reg_alpha': 1.0,
|
||||
'report_accuracy': False,
|
||||
'sample_patch_num': 196,
|
||||
'sentence_avg': True,
|
||||
'use_rdrop': False,
|
||||
'ctc_weight': 1.0},
|
||||
'hooks': [{'type': 'BestCkptSaverHook',
|
||||
'metric_key': 'accuracy',
|
||||
'interval': 100},
|
||||
{'type': 'TextLoggerHook', 'interval': 1},
|
||||
{'type': 'IterTimerHook'},
|
||||
{'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]},
|
||||
'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
|
||||
'metrics': [{'type': 'accuracy'}]},
|
||||
'preprocessor': []}
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer_std(self):
|
||||
WORKSPACE = './workspace/ckpts/asr_recognition'
|
||||
os.makedirs(WORKSPACE, exist_ok=True)
|
||||
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION)
|
||||
with open(config_file, 'w') as writer:
|
||||
json.dump(self.finetune_cfg, writer)
|
||||
|
||||
pretrained_model = 'damo/ofa_mmspeech_pretrain_base_zh'
|
||||
|
||||
args = dict(
|
||||
model=pretrained_model,
|
||||
work_dir=WORKSPACE,
|
||||
train_dataset=MsDataset.load(
|
||||
'aishell1_subset',
|
||||
subset_name='default',
|
||||
namespace='modelscope',
|
||||
split='train',
|
||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
|
||||
eval_dataset=MsDataset.load(
|
||||
'aishell1_subset',
|
||||
subset_name='default',
|
||||
namespace='modelscope',
|
||||
split='test',
|
||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
|
||||
cfg_file=config_file)
|
||||
trainer = build_trainer(name=Trainers.ofa, default_args=args)
|
||||
trainer.train()
|
||||
|
||||
self.assertIn(
|
||||
ModelFile.TORCH_MODEL_BIN_FILE,
|
||||
os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR)))
|
||||
shutil.rmtree(WORKSPACE)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -76,8 +76,7 @@ class TestOfaTrainer(unittest.TestCase):
|
||||
os.makedirs(WORKSPACE, exist_ok=True)
|
||||
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION)
|
||||
with open(config_file, 'w') as writer:
|
||||
json.dump(self.finetune_cfg, writer)
|
||||
|
||||
json.dump(self.finetune_cfg, writer, indent=4)
|
||||
pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh'
|
||||
|
||||
args = dict(
|
||||
|
||||
Reference in New Issue
Block a user