mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
[to #42322933]fine tune team on caltech-101
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10525413
This commit is contained in:
@@ -305,6 +305,7 @@ class Trainers(object):
|
||||
face_detection_scrfd = 'face-detection-scrfd'
|
||||
card_detection_scrfd = 'card-detection-scrfd'
|
||||
image_inpainting = 'image-inpainting'
|
||||
image_classification_team = 'image-classification-team'
|
||||
|
||||
# nlp trainers
|
||||
bert_sentiment_analysis = 'bert-sentiment-analysis'
|
||||
|
||||
@@ -5,9 +5,13 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .clip import CLIPTrainer
|
||||
from .team import TEAMImgClsTrainer
|
||||
|
||||
else:
|
||||
_import_structure = {'clip': ['CLIPTrainer']}
|
||||
_import_structure = {
|
||||
'clip': ['CLIPTrainer'],
|
||||
'team': ['TEAMImgClsTrainer']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
3
modelscope/trainers/multi_modal/team/__init__.py
Normal file
3
modelscope/trainers/multi_modal/team/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .team_trainer import TEAMImgClsTrainer
|
||||
144
modelscope/trainers/multi_modal/team/team_trainer.py
Normal file
144
modelscope/trainers/multi_modal/team/team_trainer.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers.base import BaseTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.trainers.multi_modal.team.team_trainer_utils import (
|
||||
get_optimizer, train_mapping, val_mapping)
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DownloadMode, ModeKeys
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.image_classification_team)
|
||||
class TEAMImgClsTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self, cfg_file: str, model: str, device_id: int,
|
||||
data_collator: Callable, train_dataset: Dataset,
|
||||
val_dataset: Dataset, *args, **kwargs):
|
||||
super().__init__(cfg_file)
|
||||
|
||||
self.cfg = Config.from_file(cfg_file)
|
||||
team_model = Model.from_pretrained(model)
|
||||
image_model = team_model.model.image_model.vision_transformer
|
||||
classification_model = nn.Sequential(
|
||||
OrderedDict([('encoder', image_model),
|
||||
('classifier',
|
||||
nn.Linear(768, self.cfg.dataset.class_num))]))
|
||||
self.model = classification_model
|
||||
|
||||
for pname, param in self.model.named_parameters():
|
||||
if 'encoder' in pname:
|
||||
param.requires_grad = False
|
||||
|
||||
self.device_id = device_id
|
||||
self.total_epoch = self.cfg.train.epoch
|
||||
self.train_batch_size = self.cfg.train.batch_size
|
||||
self.val_batch_size = self.cfg.evaluation.batch_size
|
||||
self.ckpt_dir = self.cfg.train.ckpt_dir
|
||||
|
||||
self.collate_fn = data_collator
|
||||
self.train_dataset = train_dataset
|
||||
self.val_dataset = val_dataset
|
||||
|
||||
self.criterion = nn.CrossEntropyLoss().to(self.device_id)
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
self.model.train()
|
||||
self.model.to(self.device_id)
|
||||
|
||||
optimizer = get_optimizer(self.model)
|
||||
|
||||
for epoch in range(self.total_epoch):
|
||||
train_params = {
|
||||
'pin_memory': True,
|
||||
'collate_fn': self.collate_fn,
|
||||
'batch_size': self.train_batch_size,
|
||||
'shuffle': True,
|
||||
'drop_last': True,
|
||||
'num_workers': 8
|
||||
}
|
||||
|
||||
train_loader = DataLoader(self.train_dataset, **train_params)
|
||||
|
||||
for batch_idx, data in enumerate(train_loader):
|
||||
img_tensor, label_tensor = data['pixel_values'], data['labels']
|
||||
img_tensor = img_tensor.to(self.device_id, non_blocking=True)
|
||||
label_tensor = label_tensor.to(
|
||||
self.device_id, non_blocking=True)
|
||||
|
||||
pred_logits = self.model(img_tensor)
|
||||
loss = self.criterion(pred_logits, label_tensor)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if batch_idx % 10 == 0:
|
||||
logger.info(
|
||||
'epoch: {}, train batch {}/{}, loss={:.5f}'.format(
|
||||
epoch, batch_idx, len(train_loader), loss.item()))
|
||||
|
||||
os.makedirs(self.ckpt_dir, exist_ok=True)
|
||||
torch.save(self.model.state_dict(),
|
||||
'{}/epoch{}.pth'.format(self.ckpt_dir, epoch))
|
||||
self.evaluate()
|
||||
|
||||
def evaluate(self,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
*args,
|
||||
**kwargs) -> Dict[str, float]:
|
||||
if checkpoint_path is not None:
|
||||
checkpoint_params = torch.load(checkpoint_path, 'cpu')
|
||||
self.model.load_state_dict(checkpoint_params)
|
||||
self.model.eval()
|
||||
self.model.to(self.device_id)
|
||||
|
||||
val_params = {
|
||||
'collate_fn': self.collate_fn,
|
||||
'batch_size': self.val_batch_size,
|
||||
'shuffle': False,
|
||||
'drop_last': False,
|
||||
'num_workers': 8
|
||||
}
|
||||
val_loader = DataLoader(self.val_dataset, **val_params)
|
||||
|
||||
tp_cnt, processed_cnt = 0, 0
|
||||
all_pred_labels, all_gt_labels = [], []
|
||||
with torch.no_grad():
|
||||
for batch_idx, data in enumerate(val_loader):
|
||||
img_tensor, label_tensor = data['pixel_values'], data['labels']
|
||||
img_tensor = img_tensor.to(self.device_id, non_blocking=True)
|
||||
label_tensor = label_tensor.to(
|
||||
self.device_id, non_blocking=True)
|
||||
|
||||
pred_logits = self.model(img_tensor)
|
||||
pred_labels = torch.max(pred_logits, dim=1)[1]
|
||||
tp_cnt += torch.sum(pred_labels == label_tensor).item()
|
||||
processed_cnt += img_tensor.shape[0]
|
||||
logger.info('Accuracy: {:.3f}'.format(tp_cnt / processed_cnt))
|
||||
|
||||
all_pred_labels.extend(pred_labels.tolist())
|
||||
all_gt_labels.extend(label_tensor.tolist())
|
||||
conf_mat = confusion_matrix(all_gt_labels, all_pred_labels)
|
||||
acc_mean_per_class = np.mean(conf_mat.diagonal()
|
||||
/ conf_mat.sum(axis=1))
|
||||
logger.info(
|
||||
'Accuracy mean per class: {:.3f}'.format(acc_mean_per_class))
|
||||
87
modelscope/trainers/multi_modal/team/team_trainer_utils.py
Normal file
87
modelscope/trainers/multi_modal/team/team_trainer_utils.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from torch.optim import AdamW
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
train_transforms = transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
|
||||
(0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
val_transforms = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
|
||||
(0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
|
||||
def train_mapping(examples):
|
||||
examples['pixel_values'] = [
|
||||
train_transforms(Image.open(image).convert('RGB'))
|
||||
for image in examples['image:FILE']
|
||||
]
|
||||
examples['labels'] = [label for label in examples['label:LABEL']]
|
||||
return examples
|
||||
|
||||
|
||||
def val_mapping(examples):
|
||||
examples['pixel_values'] = [
|
||||
val_transforms(Image.open(image).convert('RGB'))
|
||||
for image in examples['image:FILE']
|
||||
]
|
||||
examples['labels'] = [label for label in examples['label:LABEL']]
|
||||
return examples
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
images = []
|
||||
labels = []
|
||||
for example in examples:
|
||||
images.append((example['pixel_values']))
|
||||
labels.append(example['labels'])
|
||||
|
||||
pixel_values = torch.stack(images)
|
||||
labels = torch.tensor(labels)
|
||||
return {'pixel_values': pixel_values, 'labels': labels}
|
||||
|
||||
|
||||
def get_params_groups(ddp_model, lr):
|
||||
large_lr_params = []
|
||||
small_lr_params = []
|
||||
for name, param in ddp_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if 'encoder' in name:
|
||||
small_lr_params.append(param)
|
||||
elif 'classifier' in name:
|
||||
large_lr_params.append(param)
|
||||
else:
|
||||
logger.info('skip param: {}'.format(name))
|
||||
|
||||
params_groups = [{
|
||||
'params': small_lr_params,
|
||||
'lr': lr / 10.0
|
||||
}, {
|
||||
'params': large_lr_params,
|
||||
'lr': lr
|
||||
}]
|
||||
return params_groups
|
||||
|
||||
|
||||
def get_optimizer(ddp_model):
|
||||
lr_init = 1e-3
|
||||
betas = [0.9, 0.999]
|
||||
weight_decay = 0.02
|
||||
params_groups = get_params_groups(ddp_model, lr=lr_init)
|
||||
return AdamW(
|
||||
params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay)
|
||||
94
tests/trainers/test_team_transfer_trainer.py
Normal file
94
tests/trainers/test_team_transfer_trainer.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import json
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.trainers.multi_modal.team.team_trainer_utils import (
|
||||
collate_fn, train_mapping, val_mapping)
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DownloadMode, ModeKeys, ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def train_worker(device_id):
|
||||
model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'
|
||||
ckpt_dir = './ckpt'
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
# Use epoch=1 for faster training here
|
||||
cfg = Config({
|
||||
'framework': 'pytorch',
|
||||
'task': 'multi-modal-similarity',
|
||||
'pipeline': {
|
||||
'type': 'multi-modal-similarity'
|
||||
},
|
||||
'model': {
|
||||
'type': 'team-multi-modal-similarity'
|
||||
},
|
||||
'dataset': {
|
||||
'name': 'Caltech101',
|
||||
'class_num': 101
|
||||
},
|
||||
'preprocessor': {},
|
||||
'train': {
|
||||
'epoch': 1,
|
||||
'batch_size': 32,
|
||||
'ckpt_dir': ckpt_dir
|
||||
},
|
||||
'evaluation': {
|
||||
'batch_size': 64
|
||||
}
|
||||
})
|
||||
cfg_file = '{}/{}'.format(ckpt_dir, ModelFile.CONFIGURATION)
|
||||
cfg.dump(cfg_file)
|
||||
|
||||
train_dataset = MsDataset.load(
|
||||
cfg.dataset.name,
|
||||
namespace='modelscope',
|
||||
split='train',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset()
|
||||
train_dataset = train_dataset.with_transform(train_mapping)
|
||||
val_dataset = MsDataset.load(
|
||||
cfg.dataset.name,
|
||||
namespace='modelscope',
|
||||
split='validation',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset()
|
||||
val_dataset = val_dataset.with_transform(val_mapping)
|
||||
|
||||
default_args = dict(
|
||||
cfg_file=cfg_file,
|
||||
model=model_id,
|
||||
device_id=device_id,
|
||||
data_collator=collate_fn,
|
||||
train_dataset=train_dataset,
|
||||
val_dataset=val_dataset)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.image_classification_team, default_args=default_args)
|
||||
trainer.train()
|
||||
trainer.evaluate()
|
||||
|
||||
|
||||
class TEAMTransferTrainerTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer(self):
|
||||
if torch.cuda.device_count() > 0:
|
||||
train_worker(device_id=0)
|
||||
else:
|
||||
train_worker(device_id=-1)
|
||||
logger.info('Training done')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user