ofa-finetune 支持训练推理配置文件一体化

ofa的finetune的配置文档教程冗余,用户使用copy内容太多,修改复杂,提供简洁的finetune代码流程
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11282842
This commit is contained in:
liugao.lg
2023-01-05 07:25:47 +08:00
committed by yingda.chen
parent 60bd40742a
commit 6dd0bea98d

View File

@@ -2,15 +2,19 @@
import math
import os
import shutil
import tempfile
from functools import partial
from shutil import ignore_patterns
from typing import Callable, Dict, Optional, Tuple, Union
import json
import torch
from torch import distributed as dist
from torch import nn
from torch.utils.data import Dataset
from modelscope.hub.file_download import model_file_download
from modelscope.metainfo import Trainers
from modelscope.models.base import Model, TorchModel
from modelscope.msdatasets.ms_dataset import MsDataset
@@ -83,18 +87,26 @@ class OFATrainer(EpochBasedTrainer):
model, revision=model_revision, invoked_by=Invoke.TRAINER)
model_dir = model.model_dir
self.cfg_modify_fn = cfg_modify_fn
cfg = self.rebuild_config(Config.from_file(cfg_file))
if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0:
work_dir = cfg.train.work_dir
else:
work_dir = kwargs['work_dir']
work_dir = kwargs.get('work_dir', 'workspace')
os.makedirs(work_dir, exist_ok=True)
ignore_file_set = set()
ignore_file_set.add(ModelFile.CONFIGURATION)
if cfg_file is not None:
cfg_file = self.get_config_file(cfg_file)
dst = os.path.abspath(
os.path.join(work_dir, ModelFile.CONFIGURATION))
src = os.path.abspath(cfg_file)
if src != dst:
shutil.copy(src, work_dir)
ignore_file_set.add(ModelFile.CONFIGURATION)
recursive_overwrite(
model_dir, work_dir, ignore=ignore_patterns(*ignore_file_set))
cfg_file = os.path.join(work_dir, ModelFile.CONFIGURATION)
cfg = self.rebuild_config(Config.from_file(cfg_file))
if cfg_modify_fn is not None:
cfg = self.cfg_modify_fn(cfg)
with open(cfg_file, 'w') as writer:
json.dump(dict(cfg), fp=writer, indent=4)
if preprocessor is None:
preprocessor = {
ConfigKeys.train:
@@ -143,6 +155,7 @@ class OFATrainer(EpochBasedTrainer):
model=model,
cfg_file=cfg_file,
arg_parse_fn=arg_parse_fn,
cfg_modify_fn=cfg_modify_fn,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
@@ -160,6 +173,27 @@ class OFATrainer(EpochBasedTrainer):
cfg = self.cfg_modify_fn(cfg)
return cfg
def get_config_file(self, config_file: str):
r"""
support local file/ url or model_id with revision
"""
if os.path.exists(config_file):
return config_file
else:
temp_name = tempfile.TemporaryDirectory().name
if len(config_file.split('#')) == 2:
model_id = config_file.split('#')[0]
revision = config_file.split('#')[-1].split('=')[-1]
else:
model_id = config_file
revision = DEFAULT_MODEL_REVISION
file_name = model_file_download(
model_id,
file_path=ModelFile.CONFIGURATION,
revision=revision,
cache_dir=temp_name)
return file_name
def train_step(self, model, inputs):
r"""
A single training step.