diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 419ec919..af60f072 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -305,6 +305,7 @@ class Trainers(object): face_detection_scrfd = 'face-detection-scrfd' card_detection_scrfd = 'card-detection-scrfd' image_inpainting = 'image-inpainting' + image_classification_team = 'image-classification-team' # nlp trainers bert_sentiment_analysis = 'bert-sentiment-analysis' diff --git a/modelscope/trainers/multi_modal/__init__.py b/modelscope/trainers/multi_modal/__init__.py index 89b7e1bc..448f23a3 100644 --- a/modelscope/trainers/multi_modal/__init__.py +++ b/modelscope/trainers/multi_modal/__init__.py @@ -5,9 +5,13 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .clip import CLIPTrainer + from .team import TEAMImgClsTrainer else: - _import_structure = {'clip': ['CLIPTrainer']} + _import_structure = { + 'clip': ['CLIPTrainer'], + 'team': ['TEAMImgClsTrainer'] + } import sys diff --git a/modelscope/trainers/multi_modal/team/__init__.py b/modelscope/trainers/multi_modal/team/__init__.py new file mode 100644 index 00000000..b48fcc7e --- /dev/null +++ b/modelscope/trainers/multi_modal/team/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .team_trainer import TEAMImgClsTrainer diff --git a/modelscope/trainers/multi_modal/team/team_trainer.py b/modelscope/trainers/multi_modal/team/team_trainer.py new file mode 100644 index 00000000..7c557416 --- /dev/null +++ b/modelscope/trainers/multi_modal/team/team_trainer.py @@ -0,0 +1,144 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from collections import OrderedDict +from typing import Callable, Dict, Optional + +import numpy as np +import torch +import torch.nn as nn +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from sklearn.metrics import confusion_matrix +from torch.optim import AdamW +from torch.utils.data import DataLoader, Dataset + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model +from modelscope.msdatasets import MsDataset +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.multi_modal.team.team_trainer_utils import ( + get_optimizer, train_mapping, val_mapping) +from modelscope.utils.config import Config +from modelscope.utils.constant import DownloadMode, ModeKeys +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@TRAINERS.register_module(module_name=Trainers.image_classification_team) +class TEAMImgClsTrainer(BaseTrainer): + + def __init__(self, cfg_file: str, model: str, device_id: int, + data_collator: Callable, train_dataset: Dataset, + val_dataset: Dataset, *args, **kwargs): + super().__init__(cfg_file) + + self.cfg = Config.from_file(cfg_file) + team_model = Model.from_pretrained(model) + image_model = team_model.model.image_model.vision_transformer + classification_model = nn.Sequential( + OrderedDict([('encoder', image_model), + ('classifier', + nn.Linear(768, self.cfg.dataset.class_num))])) + self.model = classification_model + + for pname, param in self.model.named_parameters(): + if 'encoder' in pname: + param.requires_grad = False + + self.device_id = device_id + self.total_epoch = self.cfg.train.epoch + self.train_batch_size = self.cfg.train.batch_size + self.val_batch_size = self.cfg.evaluation.batch_size + self.ckpt_dir = self.cfg.train.ckpt_dir + + self.collate_fn = data_collator + self.train_dataset = train_dataset + self.val_dataset = val_dataset + + self.criterion = nn.CrossEntropyLoss().to(self.device_id) + + def train(self, *args, **kwargs): + self.model.train() + self.model.to(self.device_id) + + optimizer = get_optimizer(self.model) + + for epoch in range(self.total_epoch): + train_params = { + 'pin_memory': True, + 'collate_fn': self.collate_fn, + 'batch_size': self.train_batch_size, + 'shuffle': True, + 'drop_last': True, + 'num_workers': 8 + } + + train_loader = DataLoader(self.train_dataset, **train_params) + + for batch_idx, data in enumerate(train_loader): + img_tensor, label_tensor = data['pixel_values'], data['labels'] + img_tensor = img_tensor.to(self.device_id, non_blocking=True) + label_tensor = label_tensor.to( + self.device_id, non_blocking=True) + + pred_logits = self.model(img_tensor) + loss = self.criterion(pred_logits, label_tensor) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if batch_idx % 10 == 0: + logger.info( + 'epoch: {}, train batch {}/{}, loss={:.5f}'.format( + epoch, batch_idx, len(train_loader), loss.item())) + + os.makedirs(self.ckpt_dir, exist_ok=True) + torch.save(self.model.state_dict(), + '{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) + self.evaluate() + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + if checkpoint_path is not None: + checkpoint_params = torch.load(checkpoint_path, 'cpu') + self.model.load_state_dict(checkpoint_params) + self.model.eval() + self.model.to(self.device_id) + + val_params = { + 'collate_fn': self.collate_fn, + 'batch_size': self.val_batch_size, + 'shuffle': False, + 'drop_last': False, + 'num_workers': 8 + } + val_loader = DataLoader(self.val_dataset, **val_params) + + tp_cnt, processed_cnt = 0, 0 + all_pred_labels, all_gt_labels = [], [] + with torch.no_grad(): + for batch_idx, data in enumerate(val_loader): + img_tensor, label_tensor = data['pixel_values'], data['labels'] + img_tensor = img_tensor.to(self.device_id, non_blocking=True) + label_tensor = label_tensor.to( + self.device_id, non_blocking=True) + + pred_logits = self.model(img_tensor) + pred_labels = torch.max(pred_logits, dim=1)[1] + tp_cnt += torch.sum(pred_labels == label_tensor).item() + processed_cnt += img_tensor.shape[0] + logger.info('Accuracy: {:.3f}'.format(tp_cnt / processed_cnt)) + + all_pred_labels.extend(pred_labels.tolist()) + all_gt_labels.extend(label_tensor.tolist()) + conf_mat = confusion_matrix(all_gt_labels, all_pred_labels) + acc_mean_per_class = np.mean(conf_mat.diagonal() + / conf_mat.sum(axis=1)) + logger.info( + 'Accuracy mean per class: {:.3f}'.format(acc_mean_per_class)) diff --git a/modelscope/trainers/multi_modal/team/team_trainer_utils.py b/modelscope/trainers/multi_modal/team/team_trainer_utils.py new file mode 100644 index 00000000..ff1a4fd6 --- /dev/null +++ b/modelscope/trainers/multi_modal/team/team_trainer_utils.py @@ -0,0 +1,87 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torchvision.transforms as transforms +from PIL import Image +from torch.optim import AdamW + +from modelscope.utils.logger import get_logger + +logger = get_logger() + +train_transforms = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), +]) +val_transforms = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), +]) + + +def train_mapping(examples): + examples['pixel_values'] = [ + train_transforms(Image.open(image).convert('RGB')) + for image in examples['image:FILE'] + ] + examples['labels'] = [label for label in examples['label:LABEL']] + return examples + + +def val_mapping(examples): + examples['pixel_values'] = [ + val_transforms(Image.open(image).convert('RGB')) + for image in examples['image:FILE'] + ] + examples['labels'] = [label for label in examples['label:LABEL']] + return examples + + +def collate_fn(examples): + images = [] + labels = [] + for example in examples: + images.append((example['pixel_values'])) + labels.append(example['labels']) + + pixel_values = torch.stack(images) + labels = torch.tensor(labels) + return {'pixel_values': pixel_values, 'labels': labels} + + +def get_params_groups(ddp_model, lr): + large_lr_params = [] + small_lr_params = [] + for name, param in ddp_model.named_parameters(): + if not param.requires_grad: + continue + + if 'encoder' in name: + small_lr_params.append(param) + elif 'classifier' in name: + large_lr_params.append(param) + else: + logger.info('skip param: {}'.format(name)) + + params_groups = [{ + 'params': small_lr_params, + 'lr': lr / 10.0 + }, { + 'params': large_lr_params, + 'lr': lr + }] + return params_groups + + +def get_optimizer(ddp_model): + lr_init = 1e-3 + betas = [0.9, 0.999] + weight_decay = 0.02 + params_groups = get_params_groups(ddp_model, lr=lr_init) + return AdamW( + params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) diff --git a/tests/trainers/test_team_transfer_trainer.py b/tests/trainers/test_team_transfer_trainer.py new file mode 100644 index 00000000..0f6b88bb --- /dev/null +++ b/tests/trainers/test_team_transfer_trainer.py @@ -0,0 +1,94 @@ +import os +import unittest + +import json +import requests +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.trainers.multi_modal.team.team_trainer_utils import ( + collate_fn, train_mapping, val_mapping) +from modelscope.utils.config import Config +from modelscope.utils.constant import DownloadMode, ModeKeys, ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +def train_worker(device_id): + model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity' + ckpt_dir = './ckpt' + os.makedirs(ckpt_dir, exist_ok=True) + # Use epoch=1 for faster training here + cfg = Config({ + 'framework': 'pytorch', + 'task': 'multi-modal-similarity', + 'pipeline': { + 'type': 'multi-modal-similarity' + }, + 'model': { + 'type': 'team-multi-modal-similarity' + }, + 'dataset': { + 'name': 'Caltech101', + 'class_num': 101 + }, + 'preprocessor': {}, + 'train': { + 'epoch': 1, + 'batch_size': 32, + 'ckpt_dir': ckpt_dir + }, + 'evaluation': { + 'batch_size': 64 + } + }) + cfg_file = '{}/{}'.format(ckpt_dir, ModelFile.CONFIGURATION) + cfg.dump(cfg_file) + + train_dataset = MsDataset.load( + cfg.dataset.name, + namespace='modelscope', + split='train', + download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() + train_dataset = train_dataset.with_transform(train_mapping) + val_dataset = MsDataset.load( + cfg.dataset.name, + namespace='modelscope', + split='validation', + download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() + val_dataset = val_dataset.with_transform(val_mapping) + + default_args = dict( + cfg_file=cfg_file, + model=model_id, + device_id=device_id, + data_collator=collate_fn, + train_dataset=train_dataset, + val_dataset=val_dataset) + + trainer = build_trainer( + name=Trainers.image_classification_team, default_args=default_args) + trainer.train() + trainer.evaluate() + + +class TEAMTransferTrainerTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + if torch.cuda.device_count() > 0: + train_worker(device_id=0) + else: + train_worker(device_id=-1) + logger.info('Training done') + + +if __name__ == '__main__': + unittest.main()