From 06990f90dc5ece84afbeb1fce9b0bb20152765ee Mon Sep 17 00:00:00 2001 From: "lee.lcy" Date: Tue, 7 Feb 2023 06:10:40 +0000 Subject: [PATCH] add model_revision para to ImageDetectionDamoyoloTrainer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ImageDetectionDamoyoloTrainer添加model_revision参数 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11546342 --- modelscope/trainers/cv/image_detection_damoyolo_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modelscope/trainers/cv/image_detection_damoyolo_trainer.py b/modelscope/trainers/cv/image_detection_damoyolo_trainer.py index e9c4cc20..fe827b74 100644 --- a/modelscope/trainers/cv/image_detection_damoyolo_trainer.py +++ b/modelscope/trainers/cv/image_detection_damoyolo_trainer.py @@ -30,7 +30,7 @@ from modelscope.msdatasets.task_datasets.damoyolo import (build_dataloader, from modelscope.trainers.base import BaseTrainer from modelscope.trainers.builder import TRAINERS from modelscope.utils.checkpoint import save_checkpoint -from modelscope.utils.constant import ModelFile +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile from modelscope.utils.logger import get_logger from modelscope.utils.metric import MeterBuffer from modelscope.utils.torch_utils import get_rank, synchronize @@ -44,6 +44,7 @@ class ImageDetectionDamoyoloTrainer(BaseTrainer): cfg_file: str = None, load_pretrain: bool = True, cache_path: str = None, + model_revision: str = DEFAULT_MODEL_REVISION, *args, **kwargs): """ High-level finetune api for Damoyolo. @@ -56,7 +57,8 @@ class ImageDetectionDamoyoloTrainer(BaseTrainer): cache_path: cache path of model files. """ if model is not None: - self.cache_path = self.get_or_download_model_dir(model) + self.cache_path = self.get_or_download_model_dir( + model, model_revision) if cfg_file is None: self.cfg_file = os.path.join(self.cache_path, ModelFile.CONFIGURATION)