mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
ofa-finetune 支持训练推理配置文件一体化
ofa的finetune的配置文档教程冗余,用户使用copy内容太多,修改复杂,提供简洁的finetune代码流程
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11282842
This commit is contained in:
@@ -2,15 +2,19 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from shutil import ignore_patterns
|
||||
from typing import Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import json
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.base import Model, TorchModel
|
||||
from modelscope.msdatasets.ms_dataset import MsDataset
|
||||
@@ -83,18 +87,26 @@ class OFATrainer(EpochBasedTrainer):
|
||||
model, revision=model_revision, invoked_by=Invoke.TRAINER)
|
||||
model_dir = model.model_dir
|
||||
self.cfg_modify_fn = cfg_modify_fn
|
||||
cfg = self.rebuild_config(Config.from_file(cfg_file))
|
||||
if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0:
|
||||
work_dir = cfg.train.work_dir
|
||||
else:
|
||||
work_dir = kwargs['work_dir']
|
||||
|
||||
work_dir = kwargs.get('work_dir', 'workspace')
|
||||
os.makedirs(work_dir, exist_ok=True)
|
||||
ignore_file_set = set()
|
||||
ignore_file_set.add(ModelFile.CONFIGURATION)
|
||||
if cfg_file is not None:
|
||||
cfg_file = self.get_config_file(cfg_file)
|
||||
dst = os.path.abspath(
|
||||
os.path.join(work_dir, ModelFile.CONFIGURATION))
|
||||
src = os.path.abspath(cfg_file)
|
||||
if src != dst:
|
||||
shutil.copy(src, work_dir)
|
||||
ignore_file_set.add(ModelFile.CONFIGURATION)
|
||||
recursive_overwrite(
|
||||
model_dir, work_dir, ignore=ignore_patterns(*ignore_file_set))
|
||||
|
||||
cfg_file = os.path.join(work_dir, ModelFile.CONFIGURATION)
|
||||
cfg = self.rebuild_config(Config.from_file(cfg_file))
|
||||
if cfg_modify_fn is not None:
|
||||
cfg = self.cfg_modify_fn(cfg)
|
||||
with open(cfg_file, 'w') as writer:
|
||||
json.dump(dict(cfg), fp=writer, indent=4)
|
||||
if preprocessor is None:
|
||||
preprocessor = {
|
||||
ConfigKeys.train:
|
||||
@@ -143,6 +155,7 @@ class OFATrainer(EpochBasedTrainer):
|
||||
model=model,
|
||||
cfg_file=cfg_file,
|
||||
arg_parse_fn=arg_parse_fn,
|
||||
cfg_modify_fn=cfg_modify_fn,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
@@ -160,6 +173,27 @@ class OFATrainer(EpochBasedTrainer):
|
||||
cfg = self.cfg_modify_fn(cfg)
|
||||
return cfg
|
||||
|
||||
def get_config_file(self, config_file: str):
|
||||
r"""
|
||||
support local file/ url or model_id with revision
|
||||
"""
|
||||
if os.path.exists(config_file):
|
||||
return config_file
|
||||
else:
|
||||
temp_name = tempfile.TemporaryDirectory().name
|
||||
if len(config_file.split('#')) == 2:
|
||||
model_id = config_file.split('#')[0]
|
||||
revision = config_file.split('#')[-1].split('=')[-1]
|
||||
else:
|
||||
model_id = config_file
|
||||
revision = DEFAULT_MODEL_REVISION
|
||||
file_name = model_file_download(
|
||||
model_id,
|
||||
file_path=ModelFile.CONFIGURATION,
|
||||
revision=revision,
|
||||
cache_dir=temp_name)
|
||||
return file_name
|
||||
|
||||
def train_step(self, model, inputs):
|
||||
r"""
|
||||
A single training step.
|
||||
|
||||
Reference in New Issue
Block a user