mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
add model_revision para to ImageDetectionDamoyoloTrainer
ImageDetectionDamoyoloTrainer添加model_revision参数 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11546342
This commit is contained in:
@@ -30,7 +30,7 @@ from modelscope.msdatasets.task_datasets.damoyolo import (build_dataloader,
|
||||
from modelscope.trainers.base import BaseTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.checkpoint import save_checkpoint
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.metric import MeterBuffer
|
||||
from modelscope.utils.torch_utils import get_rank, synchronize
|
||||
@@ -44,6 +44,7 @@ class ImageDetectionDamoyoloTrainer(BaseTrainer):
|
||||
cfg_file: str = None,
|
||||
load_pretrain: bool = True,
|
||||
cache_path: str = None,
|
||||
model_revision: str = DEFAULT_MODEL_REVISION,
|
||||
*args,
|
||||
**kwargs):
|
||||
""" High-level finetune api for Damoyolo.
|
||||
@@ -56,7 +57,8 @@ class ImageDetectionDamoyoloTrainer(BaseTrainer):
|
||||
cache_path: cache path of model files.
|
||||
"""
|
||||
if model is not None:
|
||||
self.cache_path = self.get_or_download_model_dir(model)
|
||||
self.cache_path = self.get_or_download_model_dir(
|
||||
model, model_revision)
|
||||
if cfg_file is None:
|
||||
self.cfg_file = os.path.join(self.cache_path,
|
||||
ModelFile.CONFIGURATION)
|
||||
|
||||
Reference in New Issue
Block a user