mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 20:07:58 +01:00
378 lines
15 KiB
Python
378 lines
15 KiB
Python
import os
|
|
import sys
|
|
sys.path.insert(1, os.path.join(sys.path[0], '../utils'))
|
|
import numpy as np
|
|
import argparse
|
|
import time
|
|
import logging
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
import torch.utils.data
|
|
|
|
from utilities import (create_folder, get_filename, create_logging, Mixup,
|
|
StatisticsContainer)
|
|
from models import (PVT, PVT2, PVT_lr, PVT_nopretrain, PVT_2layer, Cnn14, Cnn14_no_specaug, Cnn14_no_dropout,
|
|
Cnn6, Cnn10, ResNet22, ResNet38, ResNet54, Cnn14_emb512, Cnn14_emb128,
|
|
Cnn14_emb32, MobileNetV1, MobileNetV2, LeeNet11, LeeNet24, DaiNet19,
|
|
Res1dNet31, Res1dNet51, Wavegram_Cnn14, Wavegram_Logmel_Cnn14,
|
|
Wavegram_Logmel128_Cnn14, Cnn14_16k, Cnn14_8k, Cnn14_mel32, Cnn14_mel128,
|
|
Cnn14_mixup_time_domain, Cnn14_DecisionLevelMax, Cnn14_DecisionLevelAtt, Cnn6_Transformer, GLAM, GLAM2, GLAM3, Cnn4, EAT)
|
|
#from models_test import (PVT_test)
|
|
#from models1 import (PVT1)
|
|
#from models_vig import (VIG, VIG2)
|
|
#from models_vvt import (VVT)
|
|
#from models2 import (MPVIT, MPVIT2)
|
|
#from models_reshape import (PVT_reshape, PVT_tscam)
|
|
#from models_swin import (Swin, Swin_nopretrain)
|
|
#from models_swin2 import (Swin2)
|
|
#from models_van import (Van, Van_tiny)
|
|
#from models_focal import (Focal)
|
|
#from models_cross import (Cross)
|
|
#from models_cov import (Cov)
|
|
#from models_cnn import (Cnn_light)
|
|
#from models_twins import (Twins)
|
|
#from models_cmt import (Cmt, Cmt1)
|
|
#from models_shunted import (Shunted)
|
|
#from models_quadtree import (Quadtree, Quadtree2, Quadtree_nopretrain)
|
|
#from models_davit import (Davit_tscam, Davit, Davit_nopretrain)
|
|
from pytorch_utils import (move_data_to_device, count_parameters, count_flops,
|
|
do_mixup)
|
|
from data_generator import (AudioSetDataset, TrainSampler, BalancedTrainSampler,
|
|
AlternateTrainSampler, EvaluateSampler, collate_fn)
|
|
from evaluate import Evaluator
|
|
import config
|
|
from losses import get_loss_func
|
|
|
|
|
|
def train(args):
|
|
"""Train AudioSet tagging model.
|
|
|
|
Args:
|
|
dataset_dir: str
|
|
workspace: str
|
|
data_type: 'balanced_train' | 'full_train'
|
|
window_size: int
|
|
hop_size: int
|
|
mel_bins: int
|
|
model_type: str
|
|
loss_type: 'clip_bce'
|
|
balanced: 'none' | 'balanced' | 'alternate'
|
|
augmentation: 'none' | 'mixup'
|
|
batch_size: int
|
|
learning_rate: float
|
|
resume_iteration: int
|
|
early_stop: int
|
|
accumulation_steps: int
|
|
cuda: bool
|
|
"""
|
|
|
|
# Arugments & parameters
|
|
workspace = args.workspace
|
|
data_type = args.data_type
|
|
sample_rate = args.sample_rate
|
|
window_size = args.window_size
|
|
hop_size = args.hop_size
|
|
mel_bins = args.mel_bins
|
|
fmin = args.fmin
|
|
fmax = args.fmax
|
|
model_type = args.model_type
|
|
loss_type = args.loss_type
|
|
balanced = args.balanced
|
|
augmentation = args.augmentation
|
|
batch_size = args.batch_size
|
|
learning_rate = args.learning_rate
|
|
resume_iteration = args.resume_iteration
|
|
early_stop = args.early_stop
|
|
device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu')
|
|
filename = args.filename
|
|
|
|
num_workers = 8
|
|
clip_samples = config.clip_samples
|
|
classes_num = config.classes_num
|
|
loss_func = get_loss_func(loss_type)
|
|
|
|
# Paths
|
|
black_list_csv = None
|
|
|
|
train_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
|
|
'{}.h5'.format(data_type))
|
|
|
|
eval_bal_indexes_hdf5_path = os.path.join(workspace,
|
|
'hdf5s', 'indexes', 'balanced_train.h5')
|
|
|
|
eval_test_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
|
|
'eval.h5')
|
|
|
|
checkpoints_dir = os.path.join(workspace, 'checkpoints', filename,
|
|
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format(
|
|
sample_rate, window_size, hop_size, mel_bins, fmin, fmax),
|
|
'data_type={}'.format(data_type), model_type,
|
|
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
|
|
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size))
|
|
create_folder(checkpoints_dir)
|
|
|
|
statistics_path = os.path.join(workspace, 'statistics', filename,
|
|
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format(
|
|
sample_rate, window_size, hop_size, mel_bins, fmin, fmax),
|
|
'data_type={}'.format(data_type), model_type,
|
|
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
|
|
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size),
|
|
'statistics.pkl')
|
|
create_folder(os.path.dirname(statistics_path))
|
|
|
|
logs_dir = os.path.join(workspace, 'logs', filename,
|
|
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format(
|
|
sample_rate, window_size, hop_size, mel_bins, fmin, fmax),
|
|
'data_type={}'.format(data_type), model_type,
|
|
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
|
|
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size))
|
|
|
|
create_logging(logs_dir, filemode='w')
|
|
logging.info(args)
|
|
|
|
if 'cuda' in str(device):
|
|
logging.info('Using GPU.')
|
|
device = 'cuda'
|
|
else:
|
|
logging.info('Using CPU. Set --cuda flag to use GPU.')
|
|
device = 'cpu'
|
|
|
|
# Model
|
|
Model = eval(model_type)
|
|
model = Model(sample_rate=sample_rate, window_size=window_size,
|
|
hop_size=hop_size, mel_bins=mel_bins, fmin=fmin, fmax=fmax,
|
|
classes_num=classes_num)
|
|
total = sum(p.numel() for p in model.parameters())
|
|
print("Total params: %.2fM" % (total/1e6))
|
|
logging.info("Total params: %.2fM" % (total/1e6))
|
|
#params_num = count_parameters(model)
|
|
# flops_num = count_flops(model, clip_samples)
|
|
#logging.info('Parameters num: {}'.format(params_num))
|
|
# logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9))
|
|
|
|
# Dataset will be used by DataLoader later. Dataset takes a meta as input
|
|
# and return a waveform and a target.
|
|
dataset = AudioSetDataset(sample_rate=sample_rate)
|
|
|
|
# Train sampler
|
|
if balanced == 'none':
|
|
Sampler = TrainSampler
|
|
elif balanced == 'balanced':
|
|
Sampler = BalancedTrainSampler
|
|
elif balanced == 'alternate':
|
|
Sampler = AlternateTrainSampler
|
|
|
|
train_sampler = Sampler(
|
|
indexes_hdf5_path=train_indexes_hdf5_path,
|
|
batch_size=batch_size * 2 if 'mixup' in augmentation else batch_size,
|
|
black_list_csv=black_list_csv)
|
|
|
|
# Evaluate sampler
|
|
eval_bal_sampler = EvaluateSampler(
|
|
indexes_hdf5_path=eval_bal_indexes_hdf5_path, batch_size=batch_size)
|
|
|
|
eval_test_sampler = EvaluateSampler(
|
|
indexes_hdf5_path=eval_test_indexes_hdf5_path, batch_size=batch_size)
|
|
|
|
# Data loader
|
|
train_loader = torch.utils.data.DataLoader(dataset=dataset,
|
|
batch_sampler=train_sampler, collate_fn=collate_fn,
|
|
num_workers=num_workers, pin_memory=True)
|
|
|
|
eval_bal_loader = torch.utils.data.DataLoader(dataset=dataset,
|
|
batch_sampler=eval_bal_sampler, collate_fn=collate_fn,
|
|
num_workers=num_workers, pin_memory=True)
|
|
|
|
eval_test_loader = torch.utils.data.DataLoader(dataset=dataset,
|
|
batch_sampler=eval_test_sampler, collate_fn=collate_fn,
|
|
num_workers=num_workers, pin_memory=True)
|
|
mix=0.5
|
|
if 'mixup' in augmentation:
|
|
mixup_augmenter = Mixup(mixup_alpha=mix)
|
|
print(mix)
|
|
logging.info(mix)
|
|
|
|
# Evaluator
|
|
evaluator = Evaluator(model=model)
|
|
|
|
# Statistics
|
|
statistics_container = StatisticsContainer(statistics_path)
|
|
|
|
# Optimizer
|
|
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.05, amsgrad=True)
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=4, min_lr=1e-06, verbose=True)
|
|
train_bgn_time = time.time()
|
|
|
|
# Resume training
|
|
if resume_iteration > 0:
|
|
resume_checkpoint_path = os.path.join(workspace, 'checkpoints', filename,
|
|
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format(
|
|
sample_rate, window_size, hop_size, mel_bins, fmin, fmax),
|
|
'data_type={}'.format(data_type), model_type,
|
|
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
|
|
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size),
|
|
'{}_iterations.pth'.format(resume_iteration))
|
|
|
|
logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
|
|
checkpoint = torch.load(resume_checkpoint_path)
|
|
model.load_state_dict(checkpoint['model'])
|
|
train_sampler.load_state_dict(checkpoint['sampler'])
|
|
statistics_container.load_state_dict(resume_iteration)
|
|
iteration = checkpoint['iteration']
|
|
|
|
else:
|
|
iteration = 0
|
|
|
|
# Parallel
|
|
print('GPU number: {}'.format(torch.cuda.device_count()))
|
|
model = torch.nn.DataParallel(model)
|
|
|
|
if 'cuda' in str(device):
|
|
model.to(device)
|
|
|
|
if resume_iteration:
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
scheduler.load_state_dict(checkpoint['scheduler'])
|
|
print(optimizer.state_dict()['param_groups'][0]['lr'])
|
|
|
|
time1 = time.time()
|
|
|
|
for batch_data_dict in train_loader:
|
|
"""batch_data_dict: {
|
|
'audio_name': (batch_size [*2 if mixup],),
|
|
'waveform': (batch_size [*2 if mixup], clip_samples),
|
|
'target': (batch_size [*2 if mixup], classes_num),
|
|
(ifexist) 'mixup_lambda': (batch_size * 2,)}
|
|
"""
|
|
|
|
# Evaluate
|
|
if (iteration % 2000 == 0 and iteration >= resume_iteration) or (iteration == 0):
|
|
train_fin_time = time.time()
|
|
|
|
bal_statistics = evaluator.evaluate(eval_bal_loader)
|
|
test_statistics = evaluator.evaluate(eval_test_loader)
|
|
|
|
logging.info('Validate bal mAP: {:.3f}'.format(
|
|
np.mean(bal_statistics['average_precision'])))
|
|
|
|
logging.info('Validate test mAP: {:.3f}'.format(
|
|
np.mean(test_statistics['average_precision'])))
|
|
|
|
statistics_container.append(iteration, bal_statistics, data_type='bal')
|
|
statistics_container.append(iteration, test_statistics, data_type='test')
|
|
statistics_container.dump()
|
|
|
|
train_time = train_fin_time - train_bgn_time
|
|
validate_time = time.time() - train_fin_time
|
|
|
|
logging.info(
|
|
'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
|
|
''.format(iteration, train_time, validate_time))
|
|
|
|
logging.info('------------------------------------')
|
|
|
|
train_bgn_time = time.time()
|
|
|
|
# Save model
|
|
if iteration % 2000 == 0:
|
|
checkpoint = {
|
|
'iteration': iteration,
|
|
'model': model.module.state_dict(),
|
|
'sampler': train_sampler.state_dict(),
|
|
'optimizer': optimizer.state_dict(),
|
|
'scheduler': scheduler.state_dict()}
|
|
|
|
checkpoint_path = os.path.join(
|
|
checkpoints_dir, '{}_iterations.pth'.format(iteration))
|
|
|
|
torch.save(checkpoint, checkpoint_path)
|
|
logging.info('Model saved to {}'.format(checkpoint_path))
|
|
|
|
# Mixup lambda
|
|
if 'mixup' in augmentation:
|
|
batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(
|
|
batch_size=len(batch_data_dict['waveform']))
|
|
|
|
# Move data to device
|
|
for key in batch_data_dict.keys():
|
|
batch_data_dict[key] = move_data_to_device(batch_data_dict[key], device)
|
|
|
|
# Forward
|
|
model.train()
|
|
|
|
if 'mixup' in augmentation:
|
|
batch_output_dict = model(batch_data_dict['waveform'],
|
|
batch_data_dict['mixup_lambda'])
|
|
"""{'clipwise_output': (batch_size, classes_num), ...}"""
|
|
|
|
batch_target_dict = {'target': do_mixup(batch_data_dict['target'],
|
|
batch_data_dict['mixup_lambda'])}
|
|
"""{'target': (batch_size, classes_num)}"""
|
|
else:
|
|
batch_output_dict = model(batch_data_dict['waveform'], None)
|
|
"""{'clipwise_output': (batch_size, classes_num), ...}"""
|
|
|
|
batch_target_dict = {'target': batch_data_dict['target']}
|
|
"""{'target': (batch_size, classes_num)}"""
|
|
|
|
# Loss
|
|
loss = loss_func(batch_output_dict, batch_target_dict)
|
|
# Backward
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
if iteration % 10 == 0:
|
|
print(iteration, loss)
|
|
#print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\
|
|
# .format(iteration, time.time() - time1))
|
|
#time1 = time.time()
|
|
|
|
if iteration % 2000 == 0:
|
|
scheduler.step(np.mean(test_statistics['average_precision']))
|
|
print(optimizer.state_dict()['param_groups'][0]['lr'])
|
|
logging.info(optimizer.state_dict()['param_groups'][0]['lr'])
|
|
|
|
# Stop learning
|
|
if iteration == early_stop:
|
|
break
|
|
|
|
iteration += 1
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser(description='Example of parser. ')
|
|
subparsers = parser.add_subparsers(dest='mode')
|
|
|
|
parser_train = subparsers.add_parser('train')
|
|
parser_train.add_argument('--workspace', type=str, required=True)
|
|
parser_train.add_argument('--data_type', type=str, default='full_train', choices=['balanced_train', 'full_train'])
|
|
parser_train.add_argument('--sample_rate', type=int, default=32000)
|
|
parser_train.add_argument('--window_size', type=int, default=1024)
|
|
parser_train.add_argument('--hop_size', type=int, default=320)
|
|
parser_train.add_argument('--mel_bins', type=int, default=64)
|
|
parser_train.add_argument('--fmin', type=int, default=50)
|
|
parser_train.add_argument('--fmax', type=int, default=14000)
|
|
parser_train.add_argument('--model_type', type=str, required=True)
|
|
parser_train.add_argument('--loss_type', type=str, default='clip_bce', choices=['clip_bce'])
|
|
parser_train.add_argument('--balanced', type=str, default='balanced', choices=['none', 'balanced', 'alternate'])
|
|
parser_train.add_argument('--augmentation', type=str, default='mixup', choices=['none', 'mixup'])
|
|
parser_train.add_argument('--batch_size', type=int, default=32)
|
|
parser_train.add_argument('--learning_rate', type=float, default=1e-3)
|
|
parser_train.add_argument('--resume_iteration', type=int, default=0)
|
|
parser_train.add_argument('--early_stop', type=int, default=1000000)
|
|
parser_train.add_argument('--cuda', action='store_true', default=False)
|
|
|
|
args = parser.parse_args()
|
|
args.filename = get_filename(__file__)
|
|
|
|
if args.mode == 'train':
|
|
train(args)
|
|
|
|
else:
|
|
raise Exception('Error argument!') |