mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
fix all comments
This commit is contained in:
@@ -282,7 +282,7 @@ class Trainers(object):
|
||||
|
||||
# multi-modal trainers
|
||||
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
|
||||
ofa_tasks = 'ofa'
|
||||
ofa = 'ofa'
|
||||
|
||||
# cv trainers
|
||||
image_instance_segmentation = 'image-instance-segmentation'
|
||||
|
||||
@@ -74,9 +74,7 @@ class OfaPreprocessor(Preprocessor):
|
||||
data[key] = item
|
||||
return data
|
||||
|
||||
def _compatible_with_pretrain(self, data):
|
||||
# 预训练的时候使用的image都是经过pil转换的,PIL save的时候一般会进行有损压缩,为了保证和预训练一致
|
||||
# 所以增加了这个逻辑
|
||||
def _ofa_input_compatibility_conversion(self, data):
|
||||
if 'image' in data and self.cfg.model.get('type', None) == 'ofa':
|
||||
if isinstance(data['image'], str):
|
||||
image = load_image(data['image'])
|
||||
@@ -95,7 +93,7 @@ class OfaPreprocessor(Preprocessor):
|
||||
data = input
|
||||
else:
|
||||
data = self._build_dict(input)
|
||||
data = self._compatible_with_pretrain(data)
|
||||
data = self._ofa_input_compatibility_conversion(data)
|
||||
sample = self.preprocess(data)
|
||||
str_data = dict()
|
||||
for k, v in data.items():
|
||||
|
||||
@@ -27,7 +27,7 @@ from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion,
|
||||
get_schedule)
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.ofa_tasks)
|
||||
@TRAINERS.register_module(module_name=Trainers.ofa)
|
||||
class OFATrainer(EpochBasedTrainer):
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -93,7 +93,7 @@ class TestOfaTrainer(unittest.TestCase):
|
||||
split='validation[:10]'),
|
||||
metrics=[Metrics.BLEU],
|
||||
cfg_file=config_file)
|
||||
trainer = build_trainer(name=Trainers.ofa_tasks, default_args=args)
|
||||
trainer = build_trainer(name=Trainers.ofa, default_args=args)
|
||||
trainer.train()
|
||||
|
||||
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE,
|
||||
|
||||
Reference in New Issue
Block a user