mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
update multi_modal_embedding example
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12626062
This commit is contained in:
@@ -1,15 +1,13 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
|
||||
from modelscope import MsDataset, TrainingArgs
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.trainers.args import (TrainingArgs, get_flatten_value,
|
||||
set_flatten_value)
|
||||
from modelscope.trainers.training_args import set_flatten_value
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(init=False)
|
||||
class MultiModalEmbeddingArguments(TrainingArgs):
|
||||
|
||||
trainer: str = field(
|
||||
@@ -17,6 +15,12 @@ class MultiModalEmbeddingArguments(TrainingArgs):
|
||||
'help': 'The trainer used',
|
||||
})
|
||||
|
||||
work_dir: str = field(
|
||||
default='./tmp',
|
||||
metadata={
|
||||
'help': 'The working path for saving checkpoint',
|
||||
})
|
||||
|
||||
use_fp16: bool = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@@ -35,7 +39,6 @@ class MultiModalEmbeddingArguments(TrainingArgs):
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'train.optimizer_hparams',
|
||||
'cfg_getter': partial(get_flatten_value, exclusions=['lr']),
|
||||
'cfg_setter': set_flatten_value,
|
||||
'help': 'The optimizer init params except `lr`',
|
||||
})
|
||||
@@ -51,7 +54,6 @@ class MultiModalEmbeddingArguments(TrainingArgs):
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'dataset.column_map',
|
||||
'cfg_getter': get_flatten_value,
|
||||
'cfg_setter': set_flatten_value,
|
||||
'help': 'The column map for dataset',
|
||||
})
|
||||
@@ -67,7 +69,6 @@ class MultiModalEmbeddingArguments(TrainingArgs):
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'train.lr_scheduler_hook',
|
||||
'cfg_getter': get_flatten_value,
|
||||
'cfg_setter': set_flatten_value,
|
||||
'help': 'The parameters for lr scheduler hook',
|
||||
})
|
||||
@@ -76,7 +77,6 @@ class MultiModalEmbeddingArguments(TrainingArgs):
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'train.optimizer_hook',
|
||||
'cfg_getter': get_flatten_value,
|
||||
'cfg_setter': set_flatten_value,
|
||||
'help': 'The parameters for optimizer hook',
|
||||
})
|
||||
@@ -92,23 +92,28 @@ class MultiModalEmbeddingArguments(TrainingArgs):
|
||||
'help': 'The data parallel world size',
|
||||
})
|
||||
|
||||
def __call__(self, config):
|
||||
config = super().__call__(config)
|
||||
config.merge_from_dict({'pretrained_model.model_name': self.model})
|
||||
if self.clip_clamp:
|
||||
config.train.hooks.append({'type': 'ClipClampLogitScaleHook'})
|
||||
if self.world_size > 1:
|
||||
config.train.launcher = 'pytorch'
|
||||
return config
|
||||
|
||||
config, args = MultiModalEmbeddingArguments().parse_cli().to_config()
|
||||
print(config, args)
|
||||
|
||||
|
||||
args = MultiModalEmbeddingArguments.from_cli(task='multi-modal-embedding')
|
||||
print(args)
|
||||
def cfg_modify_fn(cfg):
|
||||
if args.use_model_config:
|
||||
cfg.merge_from_dict(config)
|
||||
else:
|
||||
cfg = config
|
||||
cfg.merge_from_dict({'pretrained_model.model_name': args.model})
|
||||
if args.clip_clamp:
|
||||
cfg.train.hooks.append({'type': 'ClipClampLogitScaleHook'})
|
||||
if args.world_size > 1:
|
||||
cfg.train.launcher = 'pytorch'
|
||||
return cfg
|
||||
|
||||
|
||||
train_dataset = MsDataset.load(
|
||||
args.dataset_name, namespace='modelscope', split='train')
|
||||
args.train_dataset_name, namespace='modelscope', split='train')
|
||||
eval_dataset = MsDataset.load(
|
||||
args.dataset_name, namespace='modelscope', split='validation')
|
||||
args.train_dataset_name, namespace='modelscope', split='validation')
|
||||
|
||||
os.makedirs(args.work_dir, exist_ok=True)
|
||||
kwargs = dict(
|
||||
@@ -116,6 +121,6 @@ kwargs = dict(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
work_dir=args.work_dir,
|
||||
cfg_modify_fn=args)
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
trainer = build_trainer(name=args.trainer, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
@@ -6,14 +6,16 @@ PYTHONPATH=. torchrun --nproc_per_node $DATA_PARALLEL_SIZE \
|
||||
--trainer 'clip-multi-modal-embedding' \
|
||||
--work_dir './workspace/ckpts/clip' \
|
||||
--model 'damo/multi-modal_clip-vit-base-patch16_zh' \
|
||||
--dataset_name 'muge' \
|
||||
--train_dataset_name 'muge' \
|
||||
--dataset_column_map 'img=image,text=query' \
|
||||
--max_epochs 1 \
|
||||
--use_fp16 true \
|
||||
--per_device_train_batch_size 180 \
|
||||
--train_data_worker 0 \
|
||||
--train_shuffle true \
|
||||
--train_drop_last true \
|
||||
--per_device_eval_batch_size 128 \
|
||||
--eval_data_worker 0 \
|
||||
--eval_shuffle true \
|
||||
--eval_drop_last true \
|
||||
--save_ckpt_best true \
|
||||
@@ -33,3 +35,4 @@ PYTHONPATH=. torchrun --nproc_per_node $DATA_PARALLEL_SIZE \
|
||||
--optimizer_hook 'type=TorchAMPOptimizerHook,cumulative_iters=1,loss_keys=loss' \
|
||||
--clip_clamp true \
|
||||
--world_size $DATA_PARALLEL_SIZE \
|
||||
--use_model_config true \
|
||||
|
||||
@@ -2,7 +2,8 @@ import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from modelscope import (EpochBasedTrainer, MsDataset, TrainingArgs,
|
||||
build_dataset_from_file, build_trainer)
|
||||
build_dataset_from_file)
|
||||
from modelscope.trainers import build_trainer
|
||||
|
||||
|
||||
def set_labels(labels):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from modelscope import (EpochBasedTrainer, MsDataset, TrainingArgs,
|
||||
build_trainer)
|
||||
from modelscope import EpochBasedTrainer, MsDataset, TrainingArgs
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
|
||||
@@ -176,11 +176,10 @@ class CLIPTrainer(EpochBasedTrainer):
|
||||
self.dataset_cfg = cfg.dataset
|
||||
if hasattr(self.dataset_cfg, 'column_map'):
|
||||
# cases where dataset key names are not "img" and "text"
|
||||
img_key_name = getattr(self.dataset_cfg.column_map, 'img', 'img')
|
||||
img_key_name = self.dataset_cfg['column_map'].get('img', 'img')
|
||||
preprocessor[ConfigKeys.train].set_input_img_key(img_key_name)
|
||||
preprocessor[ConfigKeys.val].set_input_img_key(img_key_name)
|
||||
text_key_name = getattr(self.dataset_cfg.column_map, 'text',
|
||||
'text')
|
||||
text_key_name = self.dataset_cfg['column_map'].get('text', 'text')
|
||||
preprocessor[ConfigKeys.train].set_input_text_key(text_key_name)
|
||||
preprocessor[ConfigKeys.val].set_input_text_key(text_key_name)
|
||||
self.global_batch_size = cfg.train.dataloader.batch_size_per_gpu * world_size
|
||||
|
||||
Reference in New Issue
Block a user