add model_revision para to ImageDetectionDamoyoloTrainer

ImageDetectionDamoyoloTrainer添加model_revision参数

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11546342
This commit is contained in:
lee.lcy
2023-02-07 06:10:40 +00:00
committed by wenmeng.zwm
parent 7298bd2bb4
commit 06990f90dc

View File

@@ -30,7 +30,7 @@ from modelscope.msdatasets.task_datasets.damoyolo import (build_dataloader,
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.checkpoint import save_checkpoint
from modelscope.utils.constant import ModelFile
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.metric import MeterBuffer
from modelscope.utils.torch_utils import get_rank, synchronize
@@ -44,6 +44,7 @@ class ImageDetectionDamoyoloTrainer(BaseTrainer):
cfg_file: str = None,
load_pretrain: bool = True,
cache_path: str = None,
model_revision: str = DEFAULT_MODEL_REVISION,
*args,
**kwargs):
""" High-level finetune api for Damoyolo.
@@ -56,7 +57,8 @@ class ImageDetectionDamoyoloTrainer(BaseTrainer):
cache_path: cache path of model files.
"""
if model is not None:
self.cache_path = self.get_or_download_model_dir(model)
self.cache_path = self.get_or_download_model_dir(
model, model_revision)
if cfg_file is None:
self.cfg_file = os.path.join(self.cache_path,
ModelFile.CONFIGURATION)