1. add check model for training 2. add model_dir for automodel (#422)

This commit is contained in:
tastelikefeet
2023-07-28 16:34:02 +08:00
committed by GitHub
parent 2566d028cd
commit 972298813b
3 changed files with 28 additions and 19 deletions

View File

@@ -21,20 +21,20 @@ def check_local_model_is_latest(
"""Check local model repo is latest.
Check local model repo is same as hub latest version.
"""
model_cache = None
# download with git
if os.path.exists(os.path.join(model_root_path, '.git')):
git_cmd_wrapper = GitCommandWrapper()
git_url = git_cmd_wrapper.get_repo_remote_url(model_root_path)
if git_url.endswith('.git'):
git_url = git_url[:-4]
u_parse = urlparse(git_url)
model_id = u_parse.path[1:]
else: # snapshot_download
model_cache = ModelFileSystemCache(model_root_path)
model_id = model_cache.get_model_id()
try:
model_cache = None
# download with git
if os.path.exists(os.path.join(model_root_path, '.git')):
git_cmd_wrapper = GitCommandWrapper()
git_url = git_cmd_wrapper.get_repo_remote_url(model_root_path)
if git_url.endswith('.git'):
git_url = git_url[:-4]
u_parse = urlparse(git_url)
model_id = u_parse.path[1:]
else: # snapshot_download
model_cache = ModelFileSystemCache(model_root_path)
model_id = model_cache.get_model_id()
# make headers
headers = {
'user-agent':
@@ -75,7 +75,8 @@ def check_local_model_is_latest(
continue
else:
logger.info(
'Model is updated from modelscope hub, you can verify from https://www.modelscope.cn.'
f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
f'This is because you are using an older version or the file is updated manually.'
)
break
else:
@@ -86,7 +87,8 @@ def check_local_model_is_latest(
continue
else:
logger.info(
'Model is updated from modelscope hub, you can verify from https://www.modelscope.cn.'
f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
f'This is because you are using an older version or the file is updated manually.'
)
break
except: # noqa: E722

View File

@@ -15,6 +15,7 @@ from torch.utils.data import DataLoader, Dataset, Sampler
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from modelscope.hub.check_model import check_local_model_is_latest
from modelscope.metainfo import Trainers
from modelscope.metrics import build_metric, task_default_metrics
from modelscope.metrics.prediction_saving_wrapper import \
@@ -27,6 +28,7 @@ from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
from modelscope.msdatasets.ms_dataset import MsDataset
from modelscope.outputs import ModelOutputBase
from modelscope.preprocessors.base import Preprocessor
from modelscope.swift import Swift
from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.priority import Priority, get_priority
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
@@ -34,7 +36,7 @@ from modelscope.trainers.optimizer.builder import build_optimizer
from modelscope.utils.config import Config, ConfigDict, JSONIteratorEncoder
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
ConfigKeys, DistributedParallelType,
ModeKeys, ModelFile, ThirdParty,
Invoke, ModeKeys, ModelFile, ThirdParty,
TrainerStages)
from modelscope.utils.data_utils import to_device
from modelscope.utils.device import create_device
@@ -45,7 +47,6 @@ from modelscope.utils.torch_utils import (compile_model, get_dist_info,
get_local_rank, init_dist, is_dist,
is_master, is_on_same_device,
set_random_seed)
from ..swift import Swift
from .base import BaseTrainer
from .builder import TRAINERS
from .default_config import merge_cfg, merge_hooks, update_cfg
@@ -152,6 +153,10 @@ class EpochBasedTrainer(BaseTrainer):
assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
self.model_dir = os.path.dirname(cfg_file)
self.input_model_id = None
if hasattr(model, 'model_dir'):
check_local_model_is_latest(
model.model_dir,
user_agent={Invoke.KEY: Invoke.LOCAL_TRAINER})
super().__init__(cfg_file, arg_parse_fn)
self.cfg_modify_fn = cfg_modify_fn

View File

@@ -95,8 +95,10 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
else:
model_dir = pretrained_model_name_or_path
return module_class.from_pretrained(model_dir, *model_args,
**kwargs)
model = module_class.from_pretrained(model_dir, *model_args,
**kwargs)
model.model_dir = model_dir
return model
return ClassWrapper