mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
1. add check model for training 2. add model_dir for automodel (#422)
This commit is contained in:
@@ -21,20 +21,20 @@ def check_local_model_is_latest(
|
||||
"""Check local model repo is latest.
|
||||
Check local model repo is same as hub latest version.
|
||||
"""
|
||||
model_cache = None
|
||||
# download with git
|
||||
if os.path.exists(os.path.join(model_root_path, '.git')):
|
||||
git_cmd_wrapper = GitCommandWrapper()
|
||||
git_url = git_cmd_wrapper.get_repo_remote_url(model_root_path)
|
||||
if git_url.endswith('.git'):
|
||||
git_url = git_url[:-4]
|
||||
u_parse = urlparse(git_url)
|
||||
model_id = u_parse.path[1:]
|
||||
else: # snapshot_download
|
||||
model_cache = ModelFileSystemCache(model_root_path)
|
||||
model_id = model_cache.get_model_id()
|
||||
|
||||
try:
|
||||
model_cache = None
|
||||
# download with git
|
||||
if os.path.exists(os.path.join(model_root_path, '.git')):
|
||||
git_cmd_wrapper = GitCommandWrapper()
|
||||
git_url = git_cmd_wrapper.get_repo_remote_url(model_root_path)
|
||||
if git_url.endswith('.git'):
|
||||
git_url = git_url[:-4]
|
||||
u_parse = urlparse(git_url)
|
||||
model_id = u_parse.path[1:]
|
||||
else: # snapshot_download
|
||||
model_cache = ModelFileSystemCache(model_root_path)
|
||||
model_id = model_cache.get_model_id()
|
||||
|
||||
# make headers
|
||||
headers = {
|
||||
'user-agent':
|
||||
@@ -75,7 +75,8 @@ def check_local_model_is_latest(
|
||||
continue
|
||||
else:
|
||||
logger.info(
|
||||
'Model is updated from modelscope hub, you can verify from https://www.modelscope.cn.'
|
||||
f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
|
||||
f'This is because you are using an older version or the file is updated manually.'
|
||||
)
|
||||
break
|
||||
else:
|
||||
@@ -86,7 +87,8 @@ def check_local_model_is_latest(
|
||||
continue
|
||||
else:
|
||||
logger.info(
|
||||
'Model is updated from modelscope hub, you can verify from https://www.modelscope.cn.'
|
||||
f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
|
||||
f'This is because you are using an older version or the file is updated manually.'
|
||||
)
|
||||
break
|
||||
except: # noqa: E722
|
||||
|
||||
@@ -15,6 +15,7 @@ from torch.utils.data import DataLoader, Dataset, Sampler
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from modelscope.hub.check_model import check_local_model_is_latest
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.metrics import build_metric, task_default_metrics
|
||||
from modelscope.metrics.prediction_saving_wrapper import \
|
||||
@@ -27,6 +28,7 @@ from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
|
||||
from modelscope.msdatasets.ms_dataset import MsDataset
|
||||
from modelscope.outputs import ModelOutputBase
|
||||
from modelscope.preprocessors.base import Preprocessor
|
||||
from modelscope.swift import Swift
|
||||
from modelscope.trainers.hooks.builder import HOOKS
|
||||
from modelscope.trainers.hooks.priority import Priority, get_priority
|
||||
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
|
||||
@@ -34,7 +36,7 @@ from modelscope.trainers.optimizer.builder import build_optimizer
|
||||
from modelscope.utils.config import Config, ConfigDict, JSONIteratorEncoder
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
|
||||
ConfigKeys, DistributedParallelType,
|
||||
ModeKeys, ModelFile, ThirdParty,
|
||||
Invoke, ModeKeys, ModelFile, ThirdParty,
|
||||
TrainerStages)
|
||||
from modelscope.utils.data_utils import to_device
|
||||
from modelscope.utils.device import create_device
|
||||
@@ -45,7 +47,6 @@ from modelscope.utils.torch_utils import (compile_model, get_dist_info,
|
||||
get_local_rank, init_dist, is_dist,
|
||||
is_master, is_on_same_device,
|
||||
set_random_seed)
|
||||
from ..swift import Swift
|
||||
from .base import BaseTrainer
|
||||
from .builder import TRAINERS
|
||||
from .default_config import merge_cfg, merge_hooks, update_cfg
|
||||
@@ -152,6 +153,10 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
|
||||
self.model_dir = os.path.dirname(cfg_file)
|
||||
self.input_model_id = None
|
||||
if hasattr(model, 'model_dir'):
|
||||
check_local_model_is_latest(
|
||||
model.model_dir,
|
||||
user_agent={Invoke.KEY: Invoke.LOCAL_TRAINER})
|
||||
|
||||
super().__init__(cfg_file, arg_parse_fn)
|
||||
self.cfg_modify_fn = cfg_modify_fn
|
||||
|
||||
@@ -95,8 +95,10 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
|
||||
return module_class.from_pretrained(model_dir, *model_args,
|
||||
**kwargs)
|
||||
model = module_class.from_pretrained(model_dir, *model_args,
|
||||
**kwargs)
|
||||
model.model_dir = model_dir
|
||||
return model
|
||||
|
||||
return ClassWrapper
|
||||
|
||||
|
||||
Reference in New Issue
Block a user