From 6dd0bea98d0507def157413e5a2ed6f97323400c Mon Sep 17 00:00:00 2001 From: "liugao.lg" Date: Thu, 5 Jan 2023 07:25:47 +0800 Subject: [PATCH] =?UTF-8?q?ofa-finetune=20=E6=94=AF=E6=8C=81=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E6=8E=A8=E7=90=86=E9=85=8D=E7=BD=AE=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E4=B8=80=E4=BD=93=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ofa的finetune的配置文档教程冗余,用户使用copy内容太多,修改复杂,提供简洁的finetune代码流程 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11282842 --- .../trainers/multi_modal/ofa/ofa_trainer.py | 48 ++++++++++++++++--- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py index eb3b1c98..885ca118 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -2,15 +2,19 @@ import math import os +import shutil +import tempfile from functools import partial from shutil import ignore_patterns from typing import Callable, Dict, Optional, Tuple, Union +import json import torch from torch import distributed as dist from torch import nn from torch.utils.data import Dataset +from modelscope.hub.file_download import model_file_download from modelscope.metainfo import Trainers from modelscope.models.base import Model, TorchModel from modelscope.msdatasets.ms_dataset import MsDataset @@ -83,18 +87,26 @@ class OFATrainer(EpochBasedTrainer): model, revision=model_revision, invoked_by=Invoke.TRAINER) model_dir = model.model_dir self.cfg_modify_fn = cfg_modify_fn - cfg = self.rebuild_config(Config.from_file(cfg_file)) - if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: - work_dir = cfg.train.work_dir - else: - work_dir = kwargs['work_dir'] + work_dir = kwargs.get('work_dir', 'workspace') os.makedirs(work_dir, exist_ok=True) ignore_file_set = set() - ignore_file_set.add(ModelFile.CONFIGURATION) + if cfg_file is not None: + cfg_file = self.get_config_file(cfg_file) + dst = os.path.abspath( + os.path.join(work_dir, ModelFile.CONFIGURATION)) + src = os.path.abspath(cfg_file) + if src != dst: + shutil.copy(src, work_dir) + ignore_file_set.add(ModelFile.CONFIGURATION) recursive_overwrite( model_dir, work_dir, ignore=ignore_patterns(*ignore_file_set)) - + cfg_file = os.path.join(work_dir, ModelFile.CONFIGURATION) + cfg = self.rebuild_config(Config.from_file(cfg_file)) + if cfg_modify_fn is not None: + cfg = self.cfg_modify_fn(cfg) + with open(cfg_file, 'w') as writer: + json.dump(dict(cfg), fp=writer, indent=4) if preprocessor is None: preprocessor = { ConfigKeys.train: @@ -143,6 +155,7 @@ class OFATrainer(EpochBasedTrainer): model=model, cfg_file=cfg_file, arg_parse_fn=arg_parse_fn, + cfg_modify_fn=cfg_modify_fn, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, @@ -160,6 +173,27 @@ class OFATrainer(EpochBasedTrainer): cfg = self.cfg_modify_fn(cfg) return cfg + def get_config_file(self, config_file: str): + r""" + support local file/ url or model_id with revision + """ + if os.path.exists(config_file): + return config_file + else: + temp_name = tempfile.TemporaryDirectory().name + if len(config_file.split('#')) == 2: + model_id = config_file.split('#')[0] + revision = config_file.split('#')[-1].split('=')[-1] + else: + model_id = config_file + revision = DEFAULT_MODEL_REVISION + file_name = model_file_download( + model_id, + file_path=ModelFile.CONFIGURATION, + revision=revision, + cache_dir=temp_name) + return file_name + def train_step(self, model, inputs): r""" A single training step.