mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
[to #47521140]feat: tools statastics collection
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11565214
This commit is contained in:
@@ -7,7 +7,7 @@ from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE, Pipelines
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.utils.config import ConfigDict, check_config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, Tasks
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
|
||||
from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.registry import Registry, build_from_cfg
|
||||
from .base import Pipeline
|
||||
|
||||
@@ -11,7 +11,8 @@ from PIL import ImageFile
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.pipelines.util import is_official_hub_path
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke,
|
||||
ModelFile, ThirdParty)
|
||||
from modelscope.utils.device import create_device
|
||||
|
||||
|
||||
@@ -39,7 +40,10 @@ class EasyCVPipeline(object):
|
||||
model_dir = snapshot_download(
|
||||
model_id=model,
|
||||
revision=DEFAULT_MODEL_REVISION,
|
||||
user_agent={Invoke.KEY: Invoke.PIPELINE})
|
||||
user_agent={
|
||||
Invoke.KEY: Invoke.PIPELINE,
|
||||
ThirdParty.KEY: ThirdParty.EASYCV
|
||||
})
|
||||
|
||||
assert osp.isdir(model_dir)
|
||||
model_files = glob.glob(
|
||||
|
||||
@@ -9,7 +9,7 @@ from modelscope.hub.check_model import check_local_model_is_latest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Invoke
|
||||
from modelscope.utils.constant import Invoke, ThirdParty
|
||||
from .utils.log_buffer import LogBuffer
|
||||
|
||||
|
||||
@@ -37,17 +37,35 @@ class BaseTrainer(ABC):
|
||||
self.visualization_buffer = LogBuffer()
|
||||
self.timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
|
||||
def get_or_download_model_dir(self, model, model_revision=None):
|
||||
def get_or_download_model_dir(self,
|
||||
model,
|
||||
model_revision=None,
|
||||
third_party=None):
|
||||
""" Get local model directory or download model if necessary.
|
||||
|
||||
Args:
|
||||
model (str): model id or path to local model directory.
|
||||
model_revision (str, optional): model version number.
|
||||
third_party (str, optional): in which third party library
|
||||
this function is called.
|
||||
"""
|
||||
if os.path.exists(model):
|
||||
model_cache_dir = model if os.path.isdir(
|
||||
model) else os.path.dirname(model)
|
||||
check_local_model_is_latest(
|
||||
model_cache_dir, user_agent={Invoke.KEY: Invoke.LOCAL_TRAINER})
|
||||
model_cache_dir,
|
||||
user_agent={
|
||||
Invoke.KEY: Invoke.LOCAL_TRAINER,
|
||||
ThirdParty.KEY: third_party
|
||||
})
|
||||
else:
|
||||
model_cache_dir = snapshot_download(
|
||||
model,
|
||||
revision=model_revision,
|
||||
user_agent={Invoke.KEY: Invoke.TRAINER})
|
||||
user_agent={
|
||||
Invoke.KEY: Invoke.TRAINER,
|
||||
ThirdParty.KEY: third_party
|
||||
})
|
||||
return model_cache_dir
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -33,7 +33,7 @@ from modelscope.trainers.optimizer.builder import build_optimizer
|
||||
from modelscope.utils.config import Config, ConfigDict, JSONIteratorEncoder
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
|
||||
ConfigKeys, ModeKeys, ModelFile,
|
||||
TrainerStages)
|
||||
ThirdParty, TrainerStages)
|
||||
from modelscope.utils.data_utils import to_device
|
||||
from modelscope.utils.device import create_device
|
||||
from modelscope.utils.file_utils import func_receive_dict_inputs
|
||||
@@ -121,8 +121,12 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self._stop_training = False
|
||||
|
||||
if isinstance(model, str):
|
||||
third_party = kwargs.get(ThirdParty.KEY)
|
||||
if third_party is not None:
|
||||
kwargs.pop(ThirdParty.KEY)
|
||||
|
||||
self.model_dir = self.get_or_download_model_dir(
|
||||
model, model_revision)
|
||||
model, model_revision, third_party)
|
||||
if cfg_file is None:
|
||||
cfg_file = os.path.join(self.model_dir,
|
||||
ModelFile.CONFIGURATION)
|
||||
|
||||
@@ -385,6 +385,12 @@ class Invoke(object):
|
||||
PREPROCESSOR = 'preprocessor'
|
||||
|
||||
|
||||
class ThirdParty(object):
|
||||
KEY = 'third_party'
|
||||
EASYCV = 'easycv'
|
||||
ADASEQ = 'adaseq'
|
||||
|
||||
|
||||
class ConfigFields(object):
|
||||
""" First level keyword in configuration file
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user