[to #42322933]fine tune team on caltech-101

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10525413
This commit is contained in:
eniac.xcw
2022-10-27 12:00:14 +08:00
committed by yingda.chen
parent de708dd518
commit 8886c3c1ae
6 changed files with 334 additions and 1 deletions

View File

@@ -305,6 +305,7 @@ class Trainers(object):
face_detection_scrfd = 'face-detection-scrfd'
card_detection_scrfd = 'card-detection-scrfd'
image_inpainting = 'image-inpainting'
image_classification_team = 'image-classification-team'
# nlp trainers
bert_sentiment_analysis = 'bert-sentiment-analysis'

View File

@@ -5,9 +5,13 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .clip import CLIPTrainer
from .team import TEAMImgClsTrainer
else:
_import_structure = {'clip': ['CLIPTrainer']}
_import_structure = {
'clip': ['CLIPTrainer'],
'team': ['TEAMImgClsTrainer']
}
import sys

View File

@@ -0,0 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .team_trainer import TEAMImgClsTrainer

View File

@@ -0,0 +1,144 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from collections import OrderedDict
from typing import Callable, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from sklearn.metrics import confusion_matrix
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from modelscope.metainfo import Trainers
from modelscope.models.base import Model
from modelscope.msdatasets import MsDataset
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.multi_modal.team.team_trainer_utils import (
get_optimizer, train_mapping, val_mapping)
from modelscope.utils.config import Config
from modelscope.utils.constant import DownloadMode, ModeKeys
from modelscope.utils.logger import get_logger
logger = get_logger()
@TRAINERS.register_module(module_name=Trainers.image_classification_team)
class TEAMImgClsTrainer(BaseTrainer):
def __init__(self, cfg_file: str, model: str, device_id: int,
data_collator: Callable, train_dataset: Dataset,
val_dataset: Dataset, *args, **kwargs):
super().__init__(cfg_file)
self.cfg = Config.from_file(cfg_file)
team_model = Model.from_pretrained(model)
image_model = team_model.model.image_model.vision_transformer
classification_model = nn.Sequential(
OrderedDict([('encoder', image_model),
('classifier',
nn.Linear(768, self.cfg.dataset.class_num))]))
self.model = classification_model
for pname, param in self.model.named_parameters():
if 'encoder' in pname:
param.requires_grad = False
self.device_id = device_id
self.total_epoch = self.cfg.train.epoch
self.train_batch_size = self.cfg.train.batch_size
self.val_batch_size = self.cfg.evaluation.batch_size
self.ckpt_dir = self.cfg.train.ckpt_dir
self.collate_fn = data_collator
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.criterion = nn.CrossEntropyLoss().to(self.device_id)
def train(self, *args, **kwargs):
self.model.train()
self.model.to(self.device_id)
optimizer = get_optimizer(self.model)
for epoch in range(self.total_epoch):
train_params = {
'pin_memory': True,
'collate_fn': self.collate_fn,
'batch_size': self.train_batch_size,
'shuffle': True,
'drop_last': True,
'num_workers': 8
}
train_loader = DataLoader(self.train_dataset, **train_params)
for batch_idx, data in enumerate(train_loader):
img_tensor, label_tensor = data['pixel_values'], data['labels']
img_tensor = img_tensor.to(self.device_id, non_blocking=True)
label_tensor = label_tensor.to(
self.device_id, non_blocking=True)
pred_logits = self.model(img_tensor)
loss = self.criterion(pred_logits, label_tensor)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 10 == 0:
logger.info(
'epoch: {}, train batch {}/{}, loss={:.5f}'.format(
epoch, batch_idx, len(train_loader), loss.item()))
os.makedirs(self.ckpt_dir, exist_ok=True)
torch.save(self.model.state_dict(),
'{}/epoch{}.pth'.format(self.ckpt_dir, epoch))
self.evaluate()
def evaluate(self,
checkpoint_path: Optional[str] = None,
*args,
**kwargs) -> Dict[str, float]:
if checkpoint_path is not None:
checkpoint_params = torch.load(checkpoint_path, 'cpu')
self.model.load_state_dict(checkpoint_params)
self.model.eval()
self.model.to(self.device_id)
val_params = {
'collate_fn': self.collate_fn,
'batch_size': self.val_batch_size,
'shuffle': False,
'drop_last': False,
'num_workers': 8
}
val_loader = DataLoader(self.val_dataset, **val_params)
tp_cnt, processed_cnt = 0, 0
all_pred_labels, all_gt_labels = [], []
with torch.no_grad():
for batch_idx, data in enumerate(val_loader):
img_tensor, label_tensor = data['pixel_values'], data['labels']
img_tensor = img_tensor.to(self.device_id, non_blocking=True)
label_tensor = label_tensor.to(
self.device_id, non_blocking=True)
pred_logits = self.model(img_tensor)
pred_labels = torch.max(pred_logits, dim=1)[1]
tp_cnt += torch.sum(pred_labels == label_tensor).item()
processed_cnt += img_tensor.shape[0]
logger.info('Accuracy: {:.3f}'.format(tp_cnt / processed_cnt))
all_pred_labels.extend(pred_labels.tolist())
all_gt_labels.extend(label_tensor.tolist())
conf_mat = confusion_matrix(all_gt_labels, all_pred_labels)
acc_mean_per_class = np.mean(conf_mat.diagonal()
/ conf_mat.sum(axis=1))
logger.info(
'Accuracy mean per class: {:.3f}'.format(acc_mean_per_class))

View File

@@ -0,0 +1,87 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.optim import AdamW
from modelscope.utils.logger import get_logger
logger = get_logger()
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
def train_mapping(examples):
examples['pixel_values'] = [
train_transforms(Image.open(image).convert('RGB'))
for image in examples['image:FILE']
]
examples['labels'] = [label for label in examples['label:LABEL']]
return examples
def val_mapping(examples):
examples['pixel_values'] = [
val_transforms(Image.open(image).convert('RGB'))
for image in examples['image:FILE']
]
examples['labels'] = [label for label in examples['label:LABEL']]
return examples
def collate_fn(examples):
images = []
labels = []
for example in examples:
images.append((example['pixel_values']))
labels.append(example['labels'])
pixel_values = torch.stack(images)
labels = torch.tensor(labels)
return {'pixel_values': pixel_values, 'labels': labels}
def get_params_groups(ddp_model, lr):
large_lr_params = []
small_lr_params = []
for name, param in ddp_model.named_parameters():
if not param.requires_grad:
continue
if 'encoder' in name:
small_lr_params.append(param)
elif 'classifier' in name:
large_lr_params.append(param)
else:
logger.info('skip param: {}'.format(name))
params_groups = [{
'params': small_lr_params,
'lr': lr / 10.0
}, {
'params': large_lr_params,
'lr': lr
}]
return params_groups
def get_optimizer(ddp_model):
lr_init = 1e-3
betas = [0.9, 0.999]
weight_decay = 0.02
params_groups = get_params_groups(ddp_model, lr=lr_init)
return AdamW(
params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay)

View File

@@ -0,0 +1,94 @@
import os
import unittest
import json
import requests
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.trainers.multi_modal.team.team_trainer_utils import (
collate_fn, train_mapping, val_mapping)
from modelscope.utils.config import Config
from modelscope.utils.constant import DownloadMode, ModeKeys, ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level
logger = get_logger()
def train_worker(device_id):
model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'
ckpt_dir = './ckpt'
os.makedirs(ckpt_dir, exist_ok=True)
# Use epoch=1 for faster training here
cfg = Config({
'framework': 'pytorch',
'task': 'multi-modal-similarity',
'pipeline': {
'type': 'multi-modal-similarity'
},
'model': {
'type': 'team-multi-modal-similarity'
},
'dataset': {
'name': 'Caltech101',
'class_num': 101
},
'preprocessor': {},
'train': {
'epoch': 1,
'batch_size': 32,
'ckpt_dir': ckpt_dir
},
'evaluation': {
'batch_size': 64
}
})
cfg_file = '{}/{}'.format(ckpt_dir, ModelFile.CONFIGURATION)
cfg.dump(cfg_file)
train_dataset = MsDataset.load(
cfg.dataset.name,
namespace='modelscope',
split='train',
download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset()
train_dataset = train_dataset.with_transform(train_mapping)
val_dataset = MsDataset.load(
cfg.dataset.name,
namespace='modelscope',
split='validation',
download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset()
val_dataset = val_dataset.with_transform(val_mapping)
default_args = dict(
cfg_file=cfg_file,
model=model_id,
device_id=device_id,
data_collator=collate_fn,
train_dataset=train_dataset,
val_dataset=val_dataset)
trainer = build_trainer(
name=Trainers.image_classification_team, default_args=default_args)
trainer.train()
trainer.evaluate()
class TEAMTransferTrainerTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
if torch.cuda.device_count() > 0:
train_worker(device_id=0)
else:
train_worker(device_id=-1)
logger.info('Training done')
if __name__ == '__main__':
unittest.main()