feat(thirdparty): add ADADET && add thirdparty arg for damoyolo trainer.

feat(thirdparty): add ADADET && add thirdparty arg for damoyolo trainer.
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11957959

    * feat(thirdparty): add ADADET && add thirdparty arg for damoyolo trainer.
This commit is contained in:
lee.lcy
2023-03-13 11:20:26 +08:00
committed by yuze.zyz
parent 38bcd54ee4
commit 104213e4bf
2 changed files with 10 additions and 3 deletions

View File

@@ -28,7 +28,8 @@ from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import (
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.checkpoint import save_checkpoint
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModelFile,
ThirdParty)
from modelscope.utils.logger import get_logger
from modelscope.utils.metric import MeterBuffer
from modelscope.utils.torch_utils import get_rank, synchronize
@@ -62,14 +63,19 @@ class ImageDetectionDamoyoloTrainer(BaseTrainer):
train_ann: the path of train set annotation file.
val_ann: the path of val set annotation file.
num_classes: class number.
base_lr_per_img: learning rate per image. The final learning rate is base_lr_per_img*batch_size.
base_lr_per_img: learning rate per image.
The final learning rate is base_lr_per_img*batch_size.
pretrain_model: the path of pretrained model.
work_dir: the directory of work folder.
exp_name: the name of experiment.
third_party: in which third party library this function is called.
"""
if model is not None:
third_party = kwargs.get(ThirdParty.KEY)
if third_party is not None:
kwargs.pop(ThirdParty.KEY)
self.cache_path = self.get_or_download_model_dir(
model, model_revision)
model, model_revision, third_party)
if cfg_file is None:
self.cfg_file = os.path.join(self.cache_path,
ModelFile.CONFIGURATION)

View File

@@ -399,6 +399,7 @@ class ThirdParty(object):
KEY = 'third_party'
EASYCV = 'easycv'
ADASEQ = 'adaseq'
ADADET = 'adadet'
class ConfigFields(object):