From 9a84f4bb4bcb10a674ba402555b83d34cd1ae78c Mon Sep 17 00:00:00 2001 From: "lee.lcy" Date: Mon, 13 Mar 2023 11:20:26 +0800 Subject: [PATCH] feat(thirdparty): add ADADET && add thirdparty arg for damoyolo trainer. feat(thirdparty): add ADADET && add thirdparty arg for damoyolo trainer. Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11957959 * feat(thirdparty): add ADADET && add thirdparty arg for damoyolo trainer. (cherry picked from commit 104213e4bf6817b54aa3b7609e237eadd3c03769) --- .../trainers/cv/image_detection_damoyolo_trainer.py | 12 +++++++++--- modelscope/utils/constant.py | 1 + 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/modelscope/trainers/cv/image_detection_damoyolo_trainer.py b/modelscope/trainers/cv/image_detection_damoyolo_trainer.py index c8081ee0..8d8b32ae 100644 --- a/modelscope/trainers/cv/image_detection_damoyolo_trainer.py +++ b/modelscope/trainers/cv/image_detection_damoyolo_trainer.py @@ -28,7 +28,8 @@ from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import ( from modelscope.trainers.base import BaseTrainer from modelscope.trainers.builder import TRAINERS from modelscope.utils.checkpoint import save_checkpoint -from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModelFile, + ThirdParty) from modelscope.utils.logger import get_logger from modelscope.utils.metric import MeterBuffer from modelscope.utils.torch_utils import get_rank, synchronize @@ -62,14 +63,19 @@ class ImageDetectionDamoyoloTrainer(BaseTrainer): train_ann: the path of train set annotation file. val_ann: the path of val set annotation file. num_classes: class number. - base_lr_per_img: learning rate per image. The final learning rate is base_lr_per_img*batch_size. + base_lr_per_img: learning rate per image. + The final learning rate is base_lr_per_img*batch_size. pretrain_model: the path of pretrained model. work_dir: the directory of work folder. exp_name: the name of experiment. + third_party: in which third party library this function is called. """ if model is not None: + third_party = kwargs.get(ThirdParty.KEY) + if third_party is not None: + kwargs.pop(ThirdParty.KEY) self.cache_path = self.get_or_download_model_dir( - model, model_revision) + model, model_revision, third_party) if cfg_file is None: self.cfg_file = os.path.join(self.cache_path, ModelFile.CONFIGURATION) diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 6400c468..68630c81 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -399,6 +399,7 @@ class ThirdParty(object): KEY = 'third_party' EASYCV = 'easycv' ADASEQ = 'adaseq' + ADADET = 'adadet' class ConfigFields(object):