fix all comments

This commit is contained in:
行嗔
2022-10-25 10:18:33 +08:00
parent df5bd86048
commit 2288a0fdf3
4 changed files with 5 additions and 7 deletions

View File

@@ -282,7 +282,7 @@ class Trainers(object):
# multi-modal trainers
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
ofa_tasks = 'ofa'
ofa = 'ofa'
# cv trainers
image_instance_segmentation = 'image-instance-segmentation'

View File

@@ -74,9 +74,7 @@ class OfaPreprocessor(Preprocessor):
data[key] = item
return data
def _compatible_with_pretrain(self, data):
# 预训练的时候使用的image都是经过pil转换的PIL save的时候一般会进行有损压缩为了保证和预训练一致
# 所以增加了这个逻辑
def _ofa_input_compatibility_conversion(self, data):
if 'image' in data and self.cfg.model.get('type', None) == 'ofa':
if isinstance(data['image'], str):
image = load_image(data['image'])
@@ -95,7 +93,7 @@ class OfaPreprocessor(Preprocessor):
data = input
else:
data = self._build_dict(input)
data = self._compatible_with_pretrain(data)
data = self._ofa_input_compatibility_conversion(data)
sample = self.preprocess(data)
str_data = dict()
for k, v in data.items():

View File

@@ -27,7 +27,7 @@ from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion,
get_schedule)
@TRAINERS.register_module(module_name=Trainers.ofa_tasks)
@TRAINERS.register_module(module_name=Trainers.ofa)
class OFATrainer(EpochBasedTrainer):
def __init__(

View File

@@ -93,7 +93,7 @@ class TestOfaTrainer(unittest.TestCase):
split='validation[:10]'),
metrics=[Metrics.BLEU],
cfg_file=config_file)
trainer = build_trainer(name=Trainers.ofa_tasks, default_args=args)
trainer = build_trainer(name=Trainers.ofa, default_args=args)
trainer.train()
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE,