diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 8c9964b8..2df6f2a0 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -402,6 +402,7 @@ class Metrics(object): # accuracy accuracy = 'accuracy' + multi_average_precision = 'mAP' audio_noise_metric = 'audio-noise-metric' # text gen diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index b9e402c5..e2fe67f8 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -24,6 +24,7 @@ class MetricKeys(object): ROUGE_1 = 'rouge-1' ROUGE_L = 'rouge-l' NED = 'ned' # ocr metric + mAP = 'mAP' BatchAcc = 'inbatch_t2i_recall_at_1' diff --git a/modelscope/metrics/map_metric.py b/modelscope/metrics/map_metric.py new file mode 100644 index 00000000..aac76f22 --- /dev/null +++ b/modelscope/metrics/map_metric.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict + +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.outputs import OutputKeys +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.multi_average_precision) +class AveragePrecisionMetric(Metric): + """The metric computation class for multi avarage precision classes. + + This metric class calculates multi avarage precision for the whole input batches. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.preds = [] + self.labels = [] + self.thresh = kwargs.get('threshold', 0.5) + + def add(self, outputs: Dict, inputs: Dict): + label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS + ground_truths = inputs[label_name] + eval_results = outputs[label_name] + for key in [ + OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, + OutputKeys.LABELS, OutputKeys.SCORES + ]: + if key in outputs and outputs[key] is not None: + eval_results = outputs[key] + break + assert type(ground_truths) == type(eval_results) + for truth in ground_truths: + self.labels.append(truth) + for result in eval_results: + if isinstance(truth, str): + self.preds.append(result.strip().replace(' ', '')) + else: + self.preds.append(result) + + def evaluate(self): + assert len(self.preds) == len(self.labels) + scores = self._calculate_ap_score(self.preds, self.labels, self.thresh) + return {MetricKeys.mAP: scores.mean().item()} + + def _calculate_ap_score(self, preds, labels, thresh=0.5): + hyps = np.array(preds) + refs = np.array(labels) + a = np.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]) + b = np.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:]) + interacts = np.concatenate([a, b], axis=1) + area_predictions = (hyps[:, 2] - hyps[:, 0]) * ( + hyps[:, 3] - hyps[:, 1]) + area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1]) + interacts_w = interacts[:, 2] - interacts[:, 0] + interacts_h = interacts[:, 3] - interacts[:, 1] + area_interacts = interacts_w * interacts_h + ious = area_interacts / ( + area_predictions + area_targets - area_interacts + 1e-6) + return (ious >= thresh) & (interacts_w > 0) & (interacts_h > 0) diff --git a/modelscope/preprocessors/ofa/image_captioning.py b/modelscope/preprocessors/ofa/image_captioning.py index af623297..5fb83908 100644 --- a/modelscope/preprocessors/ofa/image_captioning.py +++ b/modelscope/preprocessors/ofa/image_captioning.py @@ -43,7 +43,7 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: sample = self._build_infer_sample(data) - target = data[self.column_map['text']] + target = sample['label'] target = target.translate(self.transtab).strip() target_token_list = target.strip().split() target = ' '.join(target_token_list[:self.max_tgt_length]) diff --git a/modelscope/preprocessors/ofa/image_classification.py b/modelscope/preprocessors/ofa/image_classification.py index 49968823..038a9e15 100644 --- a/modelscope/preprocessors/ofa/image_classification.py +++ b/modelscope/preprocessors/ofa/image_classification.py @@ -1,13 +1,20 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import functools from typing import Any, Dict import torch -from PIL import Image +from PIL import Image, ImageFile +from timm.data import create_transform from torchvision import transforms from modelscope.preprocessors.image import load_image from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor +from .utils.vision_helper import RandomAugment + +ImageFile.LOAD_TRUNCATED_IMAGES = True +ImageFile.MAX_IMAGE_PIXELS = None +Image.MAX_IMAGE_PIXELS = None class OfaImageClassificationPreprocessor(OfaBasePreprocessor): @@ -28,18 +35,77 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor): super(OfaImageClassificationPreprocessor, self).__init__(cfg, model_dir, mode, *args, **kwargs) # Initialize transform - self.patch_resize_transform = transforms.Compose([ - lambda image: image.convert('RGB'), - transforms.Resize( - (self.patch_image_size, self.patch_image_size), - interpolation=transforms.InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize(mean=self.mean, std=self.std), - ]) + if self.mode != ModeKeys.TRAIN: + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ]) + else: + self.patch_resize_transform = create_transform( + input_size=self.patch_image_size, + is_training=True, + color_jitter=0.4, + auto_augment='rand-m9-mstd0.5-inc1', + interpolation='bicubic', + re_prob=0.25, + re_mode='pixel', + re_count=1, + mean=self.mean, + std=self.std) + self.patch_resize_transform = transforms.Compose( + functools.reduce(lambda x, y: x + y, [ + [ + lambda image: image.convert('RGB'), + ], + self.patch_resize_transform.transforms[:2], + [self.patch_resize_transform.transforms[2]], + [ + RandomAugment( + 2, + 7, + isPIL=True, + augs=[ + 'Identity', 'AutoContrast', 'Equalize', + 'Brightness', 'Sharpness', 'ShearX', 'ShearY', + 'TranslateX', 'TranslateY', 'Rotate' + ]), + ], + self.patch_resize_transform.transforms[3:], + ])) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: - image = data['image'] if isinstance( - data['image'], Image.Image) else load_image(data['image']) + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self._build_infer_sample(data) + target = ' {}'.format(sample['label']) + sample['ref_dict'] = {sample['label']: 1.0} + sample['target'] = self.tokenize_text(target, add_bos=False) + sample['prev_output_tokens'] = torch.cat( + [self.bos_item, sample['target'][:-1]]) + + if self.constraint_trie is not None: + constraint_mask = torch.zeros((len(sample['prev_output_tokens']), + len(self.tgt_dict))).bool() + for i in range(len(sample['prev_output_tokens'])): + constraint_prefix_token = sample[ + 'prev_output_tokens'][:i + 1].tolist() + constraint_nodes = self.constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_mask[i][constraint_nodes] = True + sample['constraint_mask'] = constraint_mask + + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) patch_image = self.patch_resize_transform(image) prompt = self.cfg.model.get('prompt', ' what does the image describe?') inputs = self.tokenize_text(prompt) @@ -48,4 +114,6 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor): 'patch_image': patch_image, 'patch_mask': torch.tensor([True]) } + if 'text' in self.column_map and self.column_map['text'] in data: + sample['label'] = data[self.column_map['text']] return sample diff --git a/modelscope/preprocessors/ofa/ocr_recognition.py b/modelscope/preprocessors/ofa/ocr_recognition.py index 58e3ea6e..e15be93f 100644 --- a/modelscope/preprocessors/ofa/ocr_recognition.py +++ b/modelscope/preprocessors/ofa/ocr_recognition.py @@ -11,9 +11,6 @@ from zhconv import convert from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor -IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) -IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) - def ocr_resize(img, patch_image_size, is_document=False): img = img.convert('RGB') @@ -112,6 +109,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): } if 'text' in self.column_map and self.column_map['text'] in data: target = data[self.column_map['text']] - target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans')) - sample['label'] = target + sample['label'] = unicodedata2.normalize( + 'NFKC', convert(target, 'zh-hans')) return sample diff --git a/modelscope/preprocessors/ofa/summarization.py b/modelscope/preprocessors/ofa/summarization.py index cfd3c23d..d33e9d25 100644 --- a/modelscope/preprocessors/ofa/summarization.py +++ b/modelscope/preprocessors/ofa/summarization.py @@ -1,6 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict +import torch + from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor @@ -24,9 +26,26 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor): self).__init__(cfg, model_dir, mode, *args, **kwargs) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self._build_infer_sample(data) + target_str = sample['label'].lower() + target = super().pre_caption(target_str, max_words=self.max_tgt_length) + target = target.replace('[unk]', 'unk').replace('', 'unk') + sample['target'] = self.tokenize_text(target, add_bos=False) + noise_target_item = self.add_noise_to_tgt( + sample['target'][:-1].clone()) + sample['prev_output_tokens'] = torch.cat( + [self.bos_item, noise_target_item]) + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: source = super().pre_caption( - data['text'], max_words=self.max_src_length) - source = source.strip()[:self.max_src_length] + data[self.column_map['text']], max_words=self.max_src_length) source = source.replace('[unk]', 'unk').replace('', 'unk') prompt = self.cfg.model.get( 'prompt', ' " {} " Summarize the article with a title: ') @@ -42,4 +61,17 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor): 'source': inputs, 'decoder_prompt': decoder_prompt, } + if 'summary' in self.column_map and self.column_map['summary'] in data: + sample['label'] = data[self.column_map['summary']] return sample + + def add_noise_to_tgt(self, target): + noise_indices = torch.FloatTensor( + target.size(0)).uniform_() < self.cfg.model.get( + 'noise_ratio', 0.0) + target[noise_indices] = torch.randint( + 4, + len(self.src_dict) - self.cfg.model.get('num_codes', 8192) + - self.cfg.model.get('num_bins', 1000), + size=(noise_indices.sum(), )) + return target diff --git a/modelscope/preprocessors/ofa/visual_entailment.py b/modelscope/preprocessors/ofa/visual_entailment.py index 61c3cc6a..fff5bbd3 100644 --- a/modelscope/preprocessors/ofa/visual_entailment.py +++ b/modelscope/preprocessors/ofa/visual_entailment.py @@ -38,18 +38,64 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): ]) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: - image = data['image'] if isinstance( - data['image'], Image.Image) else load_image(data['image']) + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self._build_infer_sample(data) + target = ' {}'.format(sample['label']) + sample['ref_dict'] = {sample['label']: 1.0} + tgt_item = self.tokenize_text(target, add_bos=False, add_eos=False) + + if self.prompt_type == 'none': + prev_output_item = torch.cat([self.bos_item, tgt_item]) + target_item = torch.cat([prev_output_item[1:], self.eos_item]) + elif self.prompt_type == 'src': + prev_output_item = torch.cat([sample['source'], tgt_item]) + target_item = torch.cat([prev_output_item[1:], self.eos_item]) + elif self.prompt_type == 'prev_output': + prev_output_item = torch.cat([sample['source'][:-1], tgt_item]) + target_item = torch.cat([prev_output_item[1:], self.eos_item]) + else: + raise NotImplementedError + + target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id + sample['target'] = target_item + sample['prev_output_tokens'] = prev_output_item + + if self.constraint_trie is not None: + constraint_mask = torch.zeros( + (len(target_item), len(self.tgt_dict))).bool() + start_idx = len(target_item) - len(tgt_item) - 1 + for i in range( + len(target_item) - len(tgt_item) - 1, len(target_item)): + constraint_prefix_token = [ + self.tgt_dict.bos() + ] + target_item[start_idx:i].tolist() + constraint_nodes = self.constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_mask[i][constraint_nodes] = True + sample['constraint_mask'] = constraint_mask + + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) patch_image = self.patch_resize_transform(image) if 'text2' not in data: - hypothesis = self.pre_caption(data['text'], self.max_src_length) + hypothesis = self.pre_caption(data[self.column_map['text']], + self.max_src_length) prompt = self.cfg.model.get('prompt', ' does the image describe " {} "?') text = prompt.format(hypothesis) else: assert 'text' in data, f'text must be in the input {data.keys()}' - caption = self.pre_caption(data['text2'], self.max_src_length) - hypothesis = self.pre_caption(data['text'], self.max_src_length) + caption = self.pre_caption(data[self.column_map['text2']], + self.max_src_length) + hypothesis = self.pre_caption(data[self.column_map['text']], + self.max_src_length) prompt = self.cfg.model.get( 'prompt', ' can image and text1 " {} " imply text2 " {} "?') text = prompt.format(caption, hypothesis) @@ -68,4 +114,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): 'patch_mask': torch.tensor([True]), 'decoder_prompt': decoder_prompt, } + if 'relation' in self.column_map and self.column_map[ + 'relation'] in data: + sample['label'] = data[self.column_map['relation']] return sample diff --git a/modelscope/preprocessors/ofa/visual_grounding.py b/modelscope/preprocessors/ofa/visual_grounding.py index 8b116463..2da79670 100644 --- a/modelscope/preprocessors/ofa/visual_grounding.py +++ b/modelscope/preprocessors/ofa/visual_grounding.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict +import numpy as np import torch from PIL import Image from torchvision import transforms @@ -8,6 +9,7 @@ from torchvision import transforms from modelscope.preprocessors.image import load_image from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor +from .utils import transforms as T class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): @@ -27,24 +29,98 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): """ super(OfaVisualGroundingPreprocessor, self).__init__(cfg, model_dir, mode, *args, **kwargs) - # Initialize transform - self.patch_resize_transform = transforms.Compose([ - lambda image: image.convert('RGB'), - transforms.Resize( - (self.patch_image_size, self.patch_image_size), - interpolation=transforms.InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize(mean=self.mean, std=self.std), - ]) + + self.num_bins = self.cfg.model.get('num_bins', 1000) + if self.mode == ModeKeys.TRAIN: + # for positioning + self.positioning_transform = T.Compose([ + T.RandomResize([self.patch_image_size], + max_size=self.patch_image_size), + T.ToTensor(), + T.Normalize( + mean=self.mean, + std=self.std, + max_image_size=self.max_image_size) + ]) + else: + # Initialize transform + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ]) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: - image = data['image'] if isinstance( - data['image'], Image.Image) else load_image(data['image']) + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) + w, h = image.size + boxes_target = { + 'boxes': [], + 'labels': [], + 'area': [], + 'size': torch.tensor([h, w]) + } + x0, y0, x1, y1 = data[self.column_map['region_coord']].strip().split( + ',') + region = torch.tensor([float(x0), float(y0), float(x1), float(y1)]) + boxes_target['boxes'] = torch.tensor( + [[float(x0), float(y0), float(x1), + float(y1)]]) + boxes_target['labels'] = np.array([0]) + area = [(float(x1) - float(x0)) * (float(y1) - float(y0))] + boxes_target['area'] = torch.tensor(area) + + patch_image, patch_boxes = self.positioning_transform( + image, boxes_target) + resize_h, resize_w = patch_boxes['size'][0], patch_boxes['size'][1] + quant_x0 = ''.format( + int((patch_boxes['boxes'][0][0] * (self.num_bins - 1)).round())) + quant_y0 = ''.format( + int((patch_boxes['boxes'][0][1] * (self.num_bins - 1)).round())) + quant_x1 = ''.format( + int((patch_boxes['boxes'][0][2] * (self.num_bins - 1)).round())) + quant_y1 = ''.format( + int((patch_boxes['boxes'][0][3] * (self.num_bins - 1)).round())) + region_coord = '{} {} {} {}'.format(quant_x0, quant_y0, quant_x1, + quant_y1) + src_caption = self.pre_caption(data[self.column_map['text']], + self.max_src_length) + prompt = self.cfg.model.get( + 'prompt', ' which region does the text " {} " describe?') + text = prompt.format(src_caption) + src_item = self.tokenize_text(text) + target_item = self.tokenize_text( + region_coord, add_bos=False) # !!! use_bpe=False + prev_output_item = torch.cat([self.bos_item, target_item[:-1]]) + + sample = { + 'source': src_item, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]), + 'target': target_item, + 'prev_output_tokens': prev_output_item, + 'w_resize_ratio': resize_w / w, + 'h_resize_ratio': resize_h / h, + 'region_coord': region + } + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) w, h = image.size patch_image = self.patch_resize_transform(image) w_resize_ratio = torch.tensor(self.patch_image_size / w) h_resize_ratio = torch.tensor(self.patch_image_size / h) - src_caption = self.pre_caption(data['text'], self.max_src_length) + src_caption = self.pre_caption(data[self.column_map['text']], + self.max_src_length) prompt = self.cfg.model.get( 'prompt', ' which region does the text " {} " describe?') text = prompt.format(src_caption) @@ -56,4 +132,10 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): 'w_resize_ratio': w_resize_ratio, 'h_resize_ratio': h_resize_ratio, } + + if 'region_coord' in self.column_map and self.column_map[ + 'region_coord'] in data: + x0, y0, x1, y1 = data[ + self.column_map['region_coord']].strip().split(',') + sample['label'] = [float(x0), float(y0), float(x1), float(y1)] return sample diff --git a/modelscope/preprocessors/ofa/visual_question_answering.py b/modelscope/preprocessors/ofa/visual_question_answering.py index 11104e7e..b83cf935 100644 --- a/modelscope/preprocessors/ofa/visual_question_answering.py +++ b/modelscope/preprocessors/ofa/visual_question_answering.py @@ -38,10 +38,52 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): ]) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: - image = data['image'] if isinstance( - data['image'], Image.Image) else load_image(data['image']) + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self._build_infer_sample(data) + tgt_item = self.tokenize_text( + ' {}'.format(sample['label']), add_bos=False, add_eos=False) + + if self.prompt_type == 'none': + prev_output_item = torch.cat([self.bos_item, tgt_item]) + target_item = torch.cat([prev_output_item[1:], self.eos_item]) + elif self.prompt_type == 'src': + prev_output_item = torch.cat([sample['source'], tgt_item]) + target_item = torch.cat([prev_output_item[1:], self.eos_item]) + elif self.prompt_type == 'prev_output': + prev_output_item = torch.cat([sample['source'][:-1], tgt_item]) + target_item = torch.cat([prev_output_item[1:], self.eos_item]) + else: + raise NotImplementedError + target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id + + sample['prev_output_tokens'] = prev_output_item + sample['target'] = target_item + + if self.constraint_trie is not None: + constraint_mask = torch.zeros( + (len(target_item), len(self.tgt_dict))).bool() + start_idx = len(target_item) - len(tgt_item) - 1 + for i in range( + len(target_item) - len(tgt_item) - 1, len(target_item)): + constraint_prefix_token = [ + self.tgt_dict.bos() + ] + target_item[start_idx:i].tolist() + constraint_nodes = self.constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_mask[i][constraint_nodes] = True + sample['constraint_mask'] = constraint_mask + + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) patch_image = self.patch_resize_transform(image) - text = ' {}'.format(data['text']) + text = ' {}'.format(data[self.column_map['text']]) inputs = self.tokenize_text(text) if self.prompt_type == 'none': decoder_prompt = self.bos_item @@ -57,4 +99,6 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): 'patch_mask': torch.tensor([True]), 'decoder_prompt': decoder_prompt, } + if 'answer' in self.column_map and self.column_map['answer'] in data: + sample['label'] = data[self.column_map['answer']] return sample diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py index f8028c6c..71494768 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -34,6 +34,7 @@ class OFATrainer(EpochBasedTrainer): self, model: Optional[Union[TorchModel, nn.Module, str]] = None, cfg_file: Optional[str] = None, + cfg_modify_fn: Optional[Callable] = None, arg_parse_fn: Optional[Callable] = None, data_collator: Optional[Union[Callable, Dict[str, Callable]]] = None, @@ -49,7 +50,8 @@ class OFATrainer(EpochBasedTrainer): **kwargs): model = Model.from_pretrained(model, revision=model_revision) model_dir = model.model_dir - cfg = Config.from_file(cfg_file) + self.cfg_modify_fn = cfg_modify_fn + cfg = self.rebuild_config(Config.from_file(cfg_file)) if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: work_dir = cfg.train.work_dir else: @@ -57,10 +59,12 @@ class OFATrainer(EpochBasedTrainer): tokenizer_files = { 'zh': [ 'tokenizer.json', 'tokenizer_config.json', 'vocab.txt', - 'config.json' + 'config.json', 'ans2label.json' + ], + 'en': [ + 'tokenizer.json', 'vocab.json', 'merges.txt', 'config.json', + 'ans2label.json' ], - 'en': - ['tokenizer.json', 'vocab.json', 'merges.txt', 'config.json'], } for filename in tokenizer_files[cfg.model.get('language', 'en')]: finetune_file = os.path.join(work_dir, filename) @@ -127,6 +131,11 @@ class OFATrainer(EpochBasedTrainer): **kwargs, ) + def rebuild_config(self, cfg: Config): + if self.cfg_modify_fn is not None: + cfg = self.cfg_modify_fn(cfg) + return cfg + def train_step(self, model, inputs): model.train() loss, sample_size, logging_output = self.criterion(model, inputs) diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index 85c21881..0516e569 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -9,6 +9,7 @@ from modelscope.metainfo import Trainers from modelscope.msdatasets import MsDataset from modelscope.trainers import build_trainer from modelscope.utils.constant import DownloadMode, ModelFile +from modelscope.utils.hub import read_config from modelscope.utils.test_utils import test_level @@ -78,6 +79,7 @@ class TestOfaTrainer(unittest.TestCase): json.dump(self.finetune_cfg, writer) pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' + args = dict( model=pretrained_model, work_dir=WORKSPACE,