update multi_modal_embedding example

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12626062
This commit is contained in:
hemu.zp
2023-05-16 14:31:26 +08:00
committed by yuze.zyz
parent 6c8c3a53f8
commit 5804ad2dc1
5 changed files with 37 additions and 29 deletions

View File

@@ -1,15 +1,13 @@
import os
from dataclasses import dataclass, field
from functools import partial
from modelscope import MsDataset, TrainingArgs
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.trainers.args import (TrainingArgs, get_flatten_value,
set_flatten_value)
from modelscope.trainers.training_args import set_flatten_value
@dataclass
@dataclass(init=False)
class MultiModalEmbeddingArguments(TrainingArgs):
trainer: str = field(
@@ -17,6 +15,12 @@ class MultiModalEmbeddingArguments(TrainingArgs):
'help': 'The trainer used',
})
work_dir: str = field(
default='./tmp',
metadata={
'help': 'The working path for saving checkpoint',
})
use_fp16: bool = field(
default=None,
metadata={
@@ -35,7 +39,6 @@ class MultiModalEmbeddingArguments(TrainingArgs):
default=None,
metadata={
'cfg_node': 'train.optimizer_hparams',
'cfg_getter': partial(get_flatten_value, exclusions=['lr']),
'cfg_setter': set_flatten_value,
'help': 'The optimizer init params except `lr`',
})
@@ -51,7 +54,6 @@ class MultiModalEmbeddingArguments(TrainingArgs):
default=None,
metadata={
'cfg_node': 'dataset.column_map',
'cfg_getter': get_flatten_value,
'cfg_setter': set_flatten_value,
'help': 'The column map for dataset',
})
@@ -67,7 +69,6 @@ class MultiModalEmbeddingArguments(TrainingArgs):
default=None,
metadata={
'cfg_node': 'train.lr_scheduler_hook',
'cfg_getter': get_flatten_value,
'cfg_setter': set_flatten_value,
'help': 'The parameters for lr scheduler hook',
})
@@ -76,7 +77,6 @@ class MultiModalEmbeddingArguments(TrainingArgs):
default=None,
metadata={
'cfg_node': 'train.optimizer_hook',
'cfg_getter': get_flatten_value,
'cfg_setter': set_flatten_value,
'help': 'The parameters for optimizer hook',
})
@@ -92,23 +92,28 @@ class MultiModalEmbeddingArguments(TrainingArgs):
'help': 'The data parallel world size',
})
def __call__(self, config):
config = super().__call__(config)
config.merge_from_dict({'pretrained_model.model_name': self.model})
if self.clip_clamp:
config.train.hooks.append({'type': 'ClipClampLogitScaleHook'})
if self.world_size > 1:
config.train.launcher = 'pytorch'
return config
config, args = MultiModalEmbeddingArguments().parse_cli().to_config()
print(config, args)
args = MultiModalEmbeddingArguments.from_cli(task='multi-modal-embedding')
print(args)
def cfg_modify_fn(cfg):
if args.use_model_config:
cfg.merge_from_dict(config)
else:
cfg = config
cfg.merge_from_dict({'pretrained_model.model_name': args.model})
if args.clip_clamp:
cfg.train.hooks.append({'type': 'ClipClampLogitScaleHook'})
if args.world_size > 1:
cfg.train.launcher = 'pytorch'
return cfg
train_dataset = MsDataset.load(
args.dataset_name, namespace='modelscope', split='train')
args.train_dataset_name, namespace='modelscope', split='train')
eval_dataset = MsDataset.load(
args.dataset_name, namespace='modelscope', split='validation')
args.train_dataset_name, namespace='modelscope', split='validation')
os.makedirs(args.work_dir, exist_ok=True)
kwargs = dict(
@@ -116,6 +121,6 @@ kwargs = dict(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
work_dir=args.work_dir,
cfg_modify_fn=args)
cfg_modify_fn=cfg_modify_fn)
trainer = build_trainer(name=args.trainer, default_args=kwargs)
trainer.train()

View File

@@ -6,14 +6,16 @@ PYTHONPATH=. torchrun --nproc_per_node $DATA_PARALLEL_SIZE \
--trainer 'clip-multi-modal-embedding' \
--work_dir './workspace/ckpts/clip' \
--model 'damo/multi-modal_clip-vit-base-patch16_zh' \
--dataset_name 'muge' \
--train_dataset_name 'muge' \
--dataset_column_map 'img=image,text=query' \
--max_epochs 1 \
--use_fp16 true \
--per_device_train_batch_size 180 \
--train_data_worker 0 \
--train_shuffle true \
--train_drop_last true \
--per_device_eval_batch_size 128 \
--eval_data_worker 0 \
--eval_shuffle true \
--eval_drop_last true \
--save_ckpt_best true \
@@ -33,3 +35,4 @@ PYTHONPATH=. torchrun --nproc_per_node $DATA_PARALLEL_SIZE \
--optimizer_hook 'type=TorchAMPOptimizerHook,cumulative_iters=1,loss_keys=loss' \
--clip_clamp true \
--world_size $DATA_PARALLEL_SIZE \
--use_model_config true \

View File

@@ -2,7 +2,8 @@ import os
from dataclasses import dataclass, field
from modelscope import (EpochBasedTrainer, MsDataset, TrainingArgs,
build_dataset_from_file, build_trainer)
build_dataset_from_file)
from modelscope.trainers import build_trainer
def set_labels(labels):

View File

@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from modelscope import (EpochBasedTrainer, MsDataset, TrainingArgs,
build_trainer)
from modelscope import EpochBasedTrainer, MsDataset, TrainingArgs
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
@dataclass(init=False)

View File

@@ -176,11 +176,10 @@ class CLIPTrainer(EpochBasedTrainer):
self.dataset_cfg = cfg.dataset
if hasattr(self.dataset_cfg, 'column_map'):
# cases where dataset key names are not "img" and "text"
img_key_name = getattr(self.dataset_cfg.column_map, 'img', 'img')
img_key_name = self.dataset_cfg['column_map'].get('img', 'img')
preprocessor[ConfigKeys.train].set_input_img_key(img_key_name)
preprocessor[ConfigKeys.val].set_input_img_key(img_key_name)
text_key_name = getattr(self.dataset_cfg.column_map, 'text',
'text')
text_key_name = self.dataset_cfg['column_map'].get('text', 'text')
preprocessor[ConfigKeys.train].set_input_text_key(text_key_name)
preprocessor[ConfigKeys.val].set_input_text_key(text_key_name)
self.global_batch_size = cfg.train.dataloader.batch_size_per_gpu * world_size