mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
Merge branch ofa/task_finetune into master
Title: [to #42322933]add finetune & merge master 新增ofa其它任务的finetune能力 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10656541
This commit is contained in:
@@ -402,6 +402,7 @@ class Metrics(object):
|
||||
|
||||
# accuracy
|
||||
accuracy = 'accuracy'
|
||||
multi_average_precision = 'mAP'
|
||||
audio_noise_metric = 'audio-noise-metric'
|
||||
|
||||
# text gen
|
||||
|
||||
@@ -24,6 +24,7 @@ class MetricKeys(object):
|
||||
ROUGE_1 = 'rouge-1'
|
||||
ROUGE_L = 'rouge-l'
|
||||
NED = 'ned' # ocr metric
|
||||
mAP = 'mAP'
|
||||
BatchAcc = 'inbatch_t2i_recall_at_1'
|
||||
|
||||
|
||||
|
||||
67
modelscope/metrics/map_metric.py
Normal file
67
modelscope/metrics/map_metric.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.registry import default_group
|
||||
from .base import Metric
|
||||
from .builder import METRICS, MetricKeys
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
group_key=default_group, module_name=Metrics.multi_average_precision)
|
||||
class AveragePrecisionMetric(Metric):
|
||||
"""The metric computation class for multi avarage precision classes.
|
||||
|
||||
This metric class calculates multi avarage precision for the whole input batches.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.preds = []
|
||||
self.labels = []
|
||||
self.thresh = kwargs.get('threshold', 0.5)
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
|
||||
ground_truths = inputs[label_name]
|
||||
eval_results = outputs[label_name]
|
||||
for key in [
|
||||
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
|
||||
OutputKeys.LABELS, OutputKeys.SCORES
|
||||
]:
|
||||
if key in outputs and outputs[key] is not None:
|
||||
eval_results = outputs[key]
|
||||
break
|
||||
assert type(ground_truths) == type(eval_results)
|
||||
for truth in ground_truths:
|
||||
self.labels.append(truth)
|
||||
for result in eval_results:
|
||||
if isinstance(truth, str):
|
||||
self.preds.append(result.strip().replace(' ', ''))
|
||||
else:
|
||||
self.preds.append(result)
|
||||
|
||||
def evaluate(self):
|
||||
assert len(self.preds) == len(self.labels)
|
||||
scores = self._calculate_ap_score(self.preds, self.labels, self.thresh)
|
||||
return {MetricKeys.mAP: scores.mean().item()}
|
||||
|
||||
def _calculate_ap_score(self, preds, labels, thresh=0.5):
|
||||
hyps = np.array(preds)
|
||||
refs = np.array(labels)
|
||||
a = np.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2])
|
||||
b = np.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])
|
||||
interacts = np.concatenate([a, b], axis=1)
|
||||
area_predictions = (hyps[:, 2] - hyps[:, 0]) * (
|
||||
hyps[:, 3] - hyps[:, 1])
|
||||
area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
|
||||
interacts_w = interacts[:, 2] - interacts[:, 0]
|
||||
interacts_h = interacts[:, 3] - interacts[:, 1]
|
||||
area_interacts = interacts_w * interacts_h
|
||||
ious = area_interacts / (
|
||||
area_predictions + area_targets - area_interacts + 1e-6)
|
||||
return (ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)
|
||||
@@ -43,7 +43,7 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample = self._build_infer_sample(data)
|
||||
target = data[self.column_map['text']]
|
||||
target = sample['label']
|
||||
target = target.translate(self.transtab).strip()
|
||||
target_token_list = target.strip().split()
|
||||
target = ' '.join(target_token_list[:self.max_tgt_length])
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import functools
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageFile
|
||||
from timm.data import create_transform
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.preprocessors.image import load_image
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from .base import OfaBasePreprocessor
|
||||
from .utils.vision_helper import RandomAugment
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
ImageFile.MAX_IMAGE_PIXELS = None
|
||||
Image.MAX_IMAGE_PIXELS = None
|
||||
|
||||
|
||||
class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
|
||||
@@ -28,18 +35,77 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
|
||||
super(OfaImageClassificationPreprocessor,
|
||||
self).__init__(cfg, model_dir, mode, *args, **kwargs)
|
||||
# Initialize transform
|
||||
self.patch_resize_transform = transforms.Compose([
|
||||
lambda image: image.convert('RGB'),
|
||||
transforms.Resize(
|
||||
(self.patch_image_size, self.patch_image_size),
|
||||
interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=self.mean, std=self.std),
|
||||
])
|
||||
if self.mode != ModeKeys.TRAIN:
|
||||
self.patch_resize_transform = transforms.Compose([
|
||||
lambda image: image.convert('RGB'),
|
||||
transforms.Resize(
|
||||
(self.patch_image_size, self.patch_image_size),
|
||||
interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=self.mean, std=self.std),
|
||||
])
|
||||
else:
|
||||
self.patch_resize_transform = create_transform(
|
||||
input_size=self.patch_image_size,
|
||||
is_training=True,
|
||||
color_jitter=0.4,
|
||||
auto_augment='rand-m9-mstd0.5-inc1',
|
||||
interpolation='bicubic',
|
||||
re_prob=0.25,
|
||||
re_mode='pixel',
|
||||
re_count=1,
|
||||
mean=self.mean,
|
||||
std=self.std)
|
||||
self.patch_resize_transform = transforms.Compose(
|
||||
functools.reduce(lambda x, y: x + y, [
|
||||
[
|
||||
lambda image: image.convert('RGB'),
|
||||
],
|
||||
self.patch_resize_transform.transforms[:2],
|
||||
[self.patch_resize_transform.transforms[2]],
|
||||
[
|
||||
RandomAugment(
|
||||
2,
|
||||
7,
|
||||
isPIL=True,
|
||||
augs=[
|
||||
'Identity', 'AutoContrast', 'Equalize',
|
||||
'Brightness', 'Sharpness', 'ShearX', 'ShearY',
|
||||
'TranslateX', 'TranslateY', 'Rotate'
|
||||
]),
|
||||
],
|
||||
self.patch_resize_transform.transforms[3:],
|
||||
]))
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
image = data['image'] if isinstance(
|
||||
data['image'], Image.Image) else load_image(data['image'])
|
||||
if self.mode == ModeKeys.TRAIN:
|
||||
return self._build_train_sample(data)
|
||||
else:
|
||||
return self._build_infer_sample(data)
|
||||
|
||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample = self._build_infer_sample(data)
|
||||
target = ' {}'.format(sample['label'])
|
||||
sample['ref_dict'] = {sample['label']: 1.0}
|
||||
sample['target'] = self.tokenize_text(target, add_bos=False)
|
||||
sample['prev_output_tokens'] = torch.cat(
|
||||
[self.bos_item, sample['target'][:-1]])
|
||||
|
||||
if self.constraint_trie is not None:
|
||||
constraint_mask = torch.zeros((len(sample['prev_output_tokens']),
|
||||
len(self.tgt_dict))).bool()
|
||||
for i in range(len(sample['prev_output_tokens'])):
|
||||
constraint_prefix_token = sample[
|
||||
'prev_output_tokens'][:i + 1].tolist()
|
||||
constraint_nodes = self.constraint_trie.get_next_layer(
|
||||
constraint_prefix_token)
|
||||
constraint_mask[i][constraint_nodes] = True
|
||||
sample['constraint_mask'] = constraint_mask
|
||||
|
||||
return sample
|
||||
|
||||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
image = self.get_img_pil(data[self.column_map['image']])
|
||||
patch_image = self.patch_resize_transform(image)
|
||||
prompt = self.cfg.model.get('prompt', ' what does the image describe?')
|
||||
inputs = self.tokenize_text(prompt)
|
||||
@@ -48,4 +114,6 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
|
||||
'patch_image': patch_image,
|
||||
'patch_mask': torch.tensor([True])
|
||||
}
|
||||
if 'text' in self.column_map and self.column_map['text'] in data:
|
||||
sample['label'] = data[self.column_map['text']]
|
||||
return sample
|
||||
|
||||
@@ -11,9 +11,6 @@ from zhconv import convert
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from .base import OfaBasePreprocessor
|
||||
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
|
||||
def ocr_resize(img, patch_image_size, is_document=False):
|
||||
img = img.convert('RGB')
|
||||
@@ -112,6 +109,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
|
||||
}
|
||||
if 'text' in self.column_map and self.column_map['text'] in data:
|
||||
target = data[self.column_map['text']]
|
||||
target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans'))
|
||||
sample['label'] = target
|
||||
sample['label'] = unicodedata2.normalize(
|
||||
'NFKC', convert(target, 'zh-hans'))
|
||||
return sample
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from .base import OfaBasePreprocessor
|
||||
|
||||
@@ -24,9 +26,26 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor):
|
||||
self).__init__(cfg, model_dir, mode, *args, **kwargs)
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if self.mode == ModeKeys.TRAIN:
|
||||
return self._build_train_sample(data)
|
||||
else:
|
||||
return self._build_infer_sample(data)
|
||||
|
||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample = self._build_infer_sample(data)
|
||||
target_str = sample['label'].lower()
|
||||
target = super().pre_caption(target_str, max_words=self.max_tgt_length)
|
||||
target = target.replace('[unk]', 'unk').replace('<unk>', 'unk')
|
||||
sample['target'] = self.tokenize_text(target, add_bos=False)
|
||||
noise_target_item = self.add_noise_to_tgt(
|
||||
sample['target'][:-1].clone())
|
||||
sample['prev_output_tokens'] = torch.cat(
|
||||
[self.bos_item, noise_target_item])
|
||||
return sample
|
||||
|
||||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
source = super().pre_caption(
|
||||
data['text'], max_words=self.max_src_length)
|
||||
source = source.strip()[:self.max_src_length]
|
||||
data[self.column_map['text']], max_words=self.max_src_length)
|
||||
source = source.replace('[unk]', 'unk').replace('<unk>', 'unk')
|
||||
prompt = self.cfg.model.get(
|
||||
'prompt', ' " {} " Summarize the article with a title: ')
|
||||
@@ -42,4 +61,17 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor):
|
||||
'source': inputs,
|
||||
'decoder_prompt': decoder_prompt,
|
||||
}
|
||||
if 'summary' in self.column_map and self.column_map['summary'] in data:
|
||||
sample['label'] = data[self.column_map['summary']]
|
||||
return sample
|
||||
|
||||
def add_noise_to_tgt(self, target):
|
||||
noise_indices = torch.FloatTensor(
|
||||
target.size(0)).uniform_() < self.cfg.model.get(
|
||||
'noise_ratio', 0.0)
|
||||
target[noise_indices] = torch.randint(
|
||||
4,
|
||||
len(self.src_dict) - self.cfg.model.get('num_codes', 8192)
|
||||
- self.cfg.model.get('num_bins', 1000),
|
||||
size=(noise_indices.sum(), ))
|
||||
return target
|
||||
|
||||
@@ -38,18 +38,64 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
|
||||
])
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
image = data['image'] if isinstance(
|
||||
data['image'], Image.Image) else load_image(data['image'])
|
||||
if self.mode == ModeKeys.TRAIN:
|
||||
return self._build_train_sample(data)
|
||||
else:
|
||||
return self._build_infer_sample(data)
|
||||
|
||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample = self._build_infer_sample(data)
|
||||
target = ' {}'.format(sample['label'])
|
||||
sample['ref_dict'] = {sample['label']: 1.0}
|
||||
tgt_item = self.tokenize_text(target, add_bos=False, add_eos=False)
|
||||
|
||||
if self.prompt_type == 'none':
|
||||
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
||||
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
||||
elif self.prompt_type == 'src':
|
||||
prev_output_item = torch.cat([sample['source'], tgt_item])
|
||||
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
||||
elif self.prompt_type == 'prev_output':
|
||||
prev_output_item = torch.cat([sample['source'][:-1], tgt_item])
|
||||
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id
|
||||
sample['target'] = target_item
|
||||
sample['prev_output_tokens'] = prev_output_item
|
||||
|
||||
if self.constraint_trie is not None:
|
||||
constraint_mask = torch.zeros(
|
||||
(len(target_item), len(self.tgt_dict))).bool()
|
||||
start_idx = len(target_item) - len(tgt_item) - 1
|
||||
for i in range(
|
||||
len(target_item) - len(tgt_item) - 1, len(target_item)):
|
||||
constraint_prefix_token = [
|
||||
self.tgt_dict.bos()
|
||||
] + target_item[start_idx:i].tolist()
|
||||
constraint_nodes = self.constraint_trie.get_next_layer(
|
||||
constraint_prefix_token)
|
||||
constraint_mask[i][constraint_nodes] = True
|
||||
sample['constraint_mask'] = constraint_mask
|
||||
|
||||
return sample
|
||||
|
||||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
image = self.get_img_pil(data[self.column_map['image']])
|
||||
patch_image = self.patch_resize_transform(image)
|
||||
if 'text2' not in data:
|
||||
hypothesis = self.pre_caption(data['text'], self.max_src_length)
|
||||
hypothesis = self.pre_caption(data[self.column_map['text']],
|
||||
self.max_src_length)
|
||||
prompt = self.cfg.model.get('prompt',
|
||||
' does the image describe " {} "?')
|
||||
text = prompt.format(hypothesis)
|
||||
else:
|
||||
assert 'text' in data, f'text must be in the input {data.keys()}'
|
||||
caption = self.pre_caption(data['text2'], self.max_src_length)
|
||||
hypothesis = self.pre_caption(data['text'], self.max_src_length)
|
||||
caption = self.pre_caption(data[self.column_map['text2']],
|
||||
self.max_src_length)
|
||||
hypothesis = self.pre_caption(data[self.column_map['text']],
|
||||
self.max_src_length)
|
||||
prompt = self.cfg.model.get(
|
||||
'prompt', ' can image and text1 " {} " imply text2 " {} "?')
|
||||
text = prompt.format(caption, hypothesis)
|
||||
@@ -68,4 +114,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
|
||||
'patch_mask': torch.tensor([True]),
|
||||
'decoder_prompt': decoder_prompt,
|
||||
}
|
||||
if 'relation' in self.column_map and self.column_map[
|
||||
'relation'] in data:
|
||||
sample['label'] = data[self.column_map['relation']]
|
||||
return sample
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
@@ -8,6 +9,7 @@ from torchvision import transforms
|
||||
from modelscope.preprocessors.image import load_image
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from .base import OfaBasePreprocessor
|
||||
from .utils import transforms as T
|
||||
|
||||
|
||||
class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
|
||||
@@ -27,24 +29,98 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
|
||||
"""
|
||||
super(OfaVisualGroundingPreprocessor,
|
||||
self).__init__(cfg, model_dir, mode, *args, **kwargs)
|
||||
# Initialize transform
|
||||
self.patch_resize_transform = transforms.Compose([
|
||||
lambda image: image.convert('RGB'),
|
||||
transforms.Resize(
|
||||
(self.patch_image_size, self.patch_image_size),
|
||||
interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=self.mean, std=self.std),
|
||||
])
|
||||
|
||||
self.num_bins = self.cfg.model.get('num_bins', 1000)
|
||||
if self.mode == ModeKeys.TRAIN:
|
||||
# for positioning
|
||||
self.positioning_transform = T.Compose([
|
||||
T.RandomResize([self.patch_image_size],
|
||||
max_size=self.patch_image_size),
|
||||
T.ToTensor(),
|
||||
T.Normalize(
|
||||
mean=self.mean,
|
||||
std=self.std,
|
||||
max_image_size=self.max_image_size)
|
||||
])
|
||||
else:
|
||||
# Initialize transform
|
||||
self.patch_resize_transform = transforms.Compose([
|
||||
lambda image: image.convert('RGB'),
|
||||
transforms.Resize(
|
||||
(self.patch_image_size, self.patch_image_size),
|
||||
interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=self.mean, std=self.std),
|
||||
])
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
image = data['image'] if isinstance(
|
||||
data['image'], Image.Image) else load_image(data['image'])
|
||||
if self.mode == ModeKeys.TRAIN:
|
||||
return self._build_train_sample(data)
|
||||
else:
|
||||
return self._build_infer_sample(data)
|
||||
|
||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
image = self.get_img_pil(data[self.column_map['image']])
|
||||
w, h = image.size
|
||||
boxes_target = {
|
||||
'boxes': [],
|
||||
'labels': [],
|
||||
'area': [],
|
||||
'size': torch.tensor([h, w])
|
||||
}
|
||||
x0, y0, x1, y1 = data[self.column_map['region_coord']].strip().split(
|
||||
',')
|
||||
region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
|
||||
boxes_target['boxes'] = torch.tensor(
|
||||
[[float(x0), float(y0), float(x1),
|
||||
float(y1)]])
|
||||
boxes_target['labels'] = np.array([0])
|
||||
area = [(float(x1) - float(x0)) * (float(y1) - float(y0))]
|
||||
boxes_target['area'] = torch.tensor(area)
|
||||
|
||||
patch_image, patch_boxes = self.positioning_transform(
|
||||
image, boxes_target)
|
||||
resize_h, resize_w = patch_boxes['size'][0], patch_boxes['size'][1]
|
||||
quant_x0 = '<bin_{}>'.format(
|
||||
int((patch_boxes['boxes'][0][0] * (self.num_bins - 1)).round()))
|
||||
quant_y0 = '<bin_{}>'.format(
|
||||
int((patch_boxes['boxes'][0][1] * (self.num_bins - 1)).round()))
|
||||
quant_x1 = '<bin_{}>'.format(
|
||||
int((patch_boxes['boxes'][0][2] * (self.num_bins - 1)).round()))
|
||||
quant_y1 = '<bin_{}>'.format(
|
||||
int((patch_boxes['boxes'][0][3] * (self.num_bins - 1)).round()))
|
||||
region_coord = '{} {} {} {}'.format(quant_x0, quant_y0, quant_x1,
|
||||
quant_y1)
|
||||
src_caption = self.pre_caption(data[self.column_map['text']],
|
||||
self.max_src_length)
|
||||
prompt = self.cfg.model.get(
|
||||
'prompt', ' which region does the text " {} " describe?')
|
||||
text = prompt.format(src_caption)
|
||||
src_item = self.tokenize_text(text)
|
||||
target_item = self.tokenize_text(
|
||||
region_coord, add_bos=False) # !!! use_bpe=False
|
||||
prev_output_item = torch.cat([self.bos_item, target_item[:-1]])
|
||||
|
||||
sample = {
|
||||
'source': src_item,
|
||||
'patch_image': patch_image,
|
||||
'patch_mask': torch.tensor([True]),
|
||||
'target': target_item,
|
||||
'prev_output_tokens': prev_output_item,
|
||||
'w_resize_ratio': resize_w / w,
|
||||
'h_resize_ratio': resize_h / h,
|
||||
'region_coord': region
|
||||
}
|
||||
return sample
|
||||
|
||||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
image = self.get_img_pil(data[self.column_map['image']])
|
||||
w, h = image.size
|
||||
patch_image = self.patch_resize_transform(image)
|
||||
w_resize_ratio = torch.tensor(self.patch_image_size / w)
|
||||
h_resize_ratio = torch.tensor(self.patch_image_size / h)
|
||||
src_caption = self.pre_caption(data['text'], self.max_src_length)
|
||||
src_caption = self.pre_caption(data[self.column_map['text']],
|
||||
self.max_src_length)
|
||||
prompt = self.cfg.model.get(
|
||||
'prompt', ' which region does the text " {} " describe?')
|
||||
text = prompt.format(src_caption)
|
||||
@@ -56,4 +132,10 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
|
||||
'w_resize_ratio': w_resize_ratio,
|
||||
'h_resize_ratio': h_resize_ratio,
|
||||
}
|
||||
|
||||
if 'region_coord' in self.column_map and self.column_map[
|
||||
'region_coord'] in data:
|
||||
x0, y0, x1, y1 = data[
|
||||
self.column_map['region_coord']].strip().split(',')
|
||||
sample['label'] = [float(x0), float(y0), float(x1), float(y1)]
|
||||
return sample
|
||||
|
||||
@@ -38,10 +38,52 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):
|
||||
])
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
image = data['image'] if isinstance(
|
||||
data['image'], Image.Image) else load_image(data['image'])
|
||||
if self.mode == ModeKeys.TRAIN:
|
||||
return self._build_train_sample(data)
|
||||
else:
|
||||
return self._build_infer_sample(data)
|
||||
|
||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample = self._build_infer_sample(data)
|
||||
tgt_item = self.tokenize_text(
|
||||
' {}'.format(sample['label']), add_bos=False, add_eos=False)
|
||||
|
||||
if self.prompt_type == 'none':
|
||||
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
||||
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
||||
elif self.prompt_type == 'src':
|
||||
prev_output_item = torch.cat([sample['source'], tgt_item])
|
||||
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
||||
elif self.prompt_type == 'prev_output':
|
||||
prev_output_item = torch.cat([sample['source'][:-1], tgt_item])
|
||||
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id
|
||||
|
||||
sample['prev_output_tokens'] = prev_output_item
|
||||
sample['target'] = target_item
|
||||
|
||||
if self.constraint_trie is not None:
|
||||
constraint_mask = torch.zeros(
|
||||
(len(target_item), len(self.tgt_dict))).bool()
|
||||
start_idx = len(target_item) - len(tgt_item) - 1
|
||||
for i in range(
|
||||
len(target_item) - len(tgt_item) - 1, len(target_item)):
|
||||
constraint_prefix_token = [
|
||||
self.tgt_dict.bos()
|
||||
] + target_item[start_idx:i].tolist()
|
||||
constraint_nodes = self.constraint_trie.get_next_layer(
|
||||
constraint_prefix_token)
|
||||
constraint_mask[i][constraint_nodes] = True
|
||||
sample['constraint_mask'] = constraint_mask
|
||||
|
||||
return sample
|
||||
|
||||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
image = self.get_img_pil(data[self.column_map['image']])
|
||||
patch_image = self.patch_resize_transform(image)
|
||||
text = ' {}'.format(data['text'])
|
||||
text = ' {}'.format(data[self.column_map['text']])
|
||||
inputs = self.tokenize_text(text)
|
||||
if self.prompt_type == 'none':
|
||||
decoder_prompt = self.bos_item
|
||||
@@ -57,4 +99,6 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):
|
||||
'patch_mask': torch.tensor([True]),
|
||||
'decoder_prompt': decoder_prompt,
|
||||
}
|
||||
if 'answer' in self.column_map and self.column_map['answer'] in data:
|
||||
sample['label'] = data[self.column_map['answer']]
|
||||
return sample
|
||||
|
||||
@@ -34,6 +34,7 @@ class OFATrainer(EpochBasedTrainer):
|
||||
self,
|
||||
model: Optional[Union[TorchModel, nn.Module, str]] = None,
|
||||
cfg_file: Optional[str] = None,
|
||||
cfg_modify_fn: Optional[Callable] = None,
|
||||
arg_parse_fn: Optional[Callable] = None,
|
||||
data_collator: Optional[Union[Callable, Dict[str,
|
||||
Callable]]] = None,
|
||||
@@ -49,7 +50,8 @@ class OFATrainer(EpochBasedTrainer):
|
||||
**kwargs):
|
||||
model = Model.from_pretrained(model, revision=model_revision)
|
||||
model_dir = model.model_dir
|
||||
cfg = Config.from_file(cfg_file)
|
||||
self.cfg_modify_fn = cfg_modify_fn
|
||||
cfg = self.rebuild_config(Config.from_file(cfg_file))
|
||||
if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0:
|
||||
work_dir = cfg.train.work_dir
|
||||
else:
|
||||
@@ -57,10 +59,12 @@ class OFATrainer(EpochBasedTrainer):
|
||||
tokenizer_files = {
|
||||
'zh': [
|
||||
'tokenizer.json', 'tokenizer_config.json', 'vocab.txt',
|
||||
'config.json'
|
||||
'config.json', 'ans2label.json'
|
||||
],
|
||||
'en': [
|
||||
'tokenizer.json', 'vocab.json', 'merges.txt', 'config.json',
|
||||
'ans2label.json'
|
||||
],
|
||||
'en':
|
||||
['tokenizer.json', 'vocab.json', 'merges.txt', 'config.json'],
|
||||
}
|
||||
for filename in tokenizer_files[cfg.model.get('language', 'en')]:
|
||||
finetune_file = os.path.join(work_dir, filename)
|
||||
@@ -127,6 +131,11 @@ class OFATrainer(EpochBasedTrainer):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def rebuild_config(self, cfg: Config):
|
||||
if self.cfg_modify_fn is not None:
|
||||
cfg = self.cfg_modify_fn(cfg)
|
||||
return cfg
|
||||
|
||||
def train_step(self, model, inputs):
|
||||
model.train()
|
||||
loss, sample_size, logging_output = self.criterion(model, inputs)
|
||||
|
||||
@@ -9,6 +9,7 @@ from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import DownloadMode, ModelFile
|
||||
from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@@ -78,6 +79,7 @@ class TestOfaTrainer(unittest.TestCase):
|
||||
json.dump(self.finetune_cfg, writer)
|
||||
|
||||
pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh'
|
||||
|
||||
args = dict(
|
||||
model=pretrained_model,
|
||||
work_dir=WORKSPACE,
|
||||
|
||||
Reference in New Issue
Block a user