mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
feat(thirdparty): add ADADET && add thirdparty arg for damoyolo trainer.
feat(thirdparty): add ADADET && add thirdparty arg for damoyolo trainer.
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11957959
* feat(thirdparty): add ADADET && add thirdparty arg for damoyolo trainer.
This commit is contained in:
@@ -28,7 +28,8 @@ from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import (
|
||||
from modelscope.trainers.base import BaseTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.checkpoint import save_checkpoint
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModelFile,
|
||||
ThirdParty)
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.metric import MeterBuffer
|
||||
from modelscope.utils.torch_utils import get_rank, synchronize
|
||||
@@ -62,14 +63,19 @@ class ImageDetectionDamoyoloTrainer(BaseTrainer):
|
||||
train_ann: the path of train set annotation file.
|
||||
val_ann: the path of val set annotation file.
|
||||
num_classes: class number.
|
||||
base_lr_per_img: learning rate per image. The final learning rate is base_lr_per_img*batch_size.
|
||||
base_lr_per_img: learning rate per image.
|
||||
The final learning rate is base_lr_per_img*batch_size.
|
||||
pretrain_model: the path of pretrained model.
|
||||
work_dir: the directory of work folder.
|
||||
exp_name: the name of experiment.
|
||||
third_party: in which third party library this function is called.
|
||||
"""
|
||||
if model is not None:
|
||||
third_party = kwargs.get(ThirdParty.KEY)
|
||||
if third_party is not None:
|
||||
kwargs.pop(ThirdParty.KEY)
|
||||
self.cache_path = self.get_or_download_model_dir(
|
||||
model, model_revision)
|
||||
model, model_revision, third_party)
|
||||
if cfg_file is None:
|
||||
self.cfg_file = os.path.join(self.cache_path,
|
||||
ModelFile.CONFIGURATION)
|
||||
|
||||
@@ -399,6 +399,7 @@ class ThirdParty(object):
|
||||
KEY = 'third_party'
|
||||
EASYCV = 'easycv'
|
||||
ADASEQ = 'adaseq'
|
||||
ADADET = 'adadet'
|
||||
|
||||
|
||||
class ConfigFields(object):
|
||||
|
||||
Reference in New Issue
Block a user