diff --git a/modelscope/trainers/cv/image_detection_damoyolo_trainer.py b/modelscope/trainers/cv/image_detection_damoyolo_trainer.py index c8081ee0..8d8b32ae 100644 --- a/modelscope/trainers/cv/image_detection_damoyolo_trainer.py +++ b/modelscope/trainers/cv/image_detection_damoyolo_trainer.py @@ -28,7 +28,8 @@ from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import ( from modelscope.trainers.base import BaseTrainer from modelscope.trainers.builder import TRAINERS from modelscope.utils.checkpoint import save_checkpoint -from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModelFile, + ThirdParty) from modelscope.utils.logger import get_logger from modelscope.utils.metric import MeterBuffer from modelscope.utils.torch_utils import get_rank, synchronize @@ -62,14 +63,19 @@ class ImageDetectionDamoyoloTrainer(BaseTrainer): train_ann: the path of train set annotation file. val_ann: the path of val set annotation file. num_classes: class number. - base_lr_per_img: learning rate per image. The final learning rate is base_lr_per_img*batch_size. + base_lr_per_img: learning rate per image. + The final learning rate is base_lr_per_img*batch_size. pretrain_model: the path of pretrained model. work_dir: the directory of work folder. exp_name: the name of experiment. + third_party: in which third party library this function is called. """ if model is not None: + third_party = kwargs.get(ThirdParty.KEY) + if third_party is not None: + kwargs.pop(ThirdParty.KEY) self.cache_path = self.get_or_download_model_dir( - model, model_revision) + model, model_revision, third_party) if cfg_file is None: self.cfg_file = os.path.join(self.cache_path, ModelFile.CONFIGURATION) diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 6400c468..68630c81 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -399,6 +399,7 @@ class ThirdParty(object): KEY = 'third_party' EASYCV = 'easycv' ADASEQ = 'adaseq' + ADADET = 'adadet' class ConfigFields(object):