add task preprocess

This commit is contained in:
翎航
2022-11-01 14:28:57 +08:00
parent fd3679b547
commit 9c04fec99c
7 changed files with 237 additions and 25 deletions

View File

@@ -43,7 +43,7 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = data[self.column_map['text']]
target = sample['label']
target = target.translate(self.transtab).strip()
target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length])

View File

@@ -85,11 +85,11 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = ' {}'.format(data[self.column_map['text']])
sample['ref_dict'] = {data[self.column_map['text']]: 1.0}
target = ' {}'.format(sample['label'])
sample['ref_dict'] = {sample['label']: 1.0}
sample['target'] = self.tokenize_text(target, add_bos=False)
sample['prev_output_tokens'] = torch.cat(
[self.bos_item, sample['target']])
[self.bos_item, sample['target'][:-1]])
if self.constraint_trie is not None:
constraint_mask = torch.zeros((len(sample['prev_output_tokens']),

View File

@@ -109,6 +109,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
}
if 'text' in self.column_map and self.column_map['text'] in data:
target = data[self.column_map['text']]
target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans'))
sample['label'] = target
sample['label'] = unicodedata2.normalize(
'NFKC', convert(target, 'zh-hans'))
return sample

View File

@@ -1,6 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict
import torch
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor
@@ -24,9 +26,27 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor):
self).__init__(cfg, model_dir, mode, *args, **kwargs)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target_str = sample['label'].lower()
target = super().pre_caption(target_str, max_words=self.max_tgt_length)
target = target.replace('[unk]', 'unk').replace('<unk>', 'unk')
sample['target'] = self.tokenize_text(target, add_bos=False)
noise_target_item = self.add_noise_to_tgt(
sample['target'][:-1].clone())
sample['prev_output_tokens'] = torch.cat(
[self.bos_item, noise_target_item])
return sample
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
source = super().pre_caption(
data['text'], max_words=self.max_src_length)
source = source.strip()[:self.max_src_length]
data[self.column_map['text']], max_words=self.max_src_length)
# source = source.strip()[:self.max_src_length]
source = source.replace('[unk]', 'unk').replace('<unk>', 'unk')
prompt = self.cfg.model.get(
'prompt', ' " {} " Summarize the article with a title: ')
@@ -42,4 +62,16 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor):
'source': inputs,
'decoder_prompt': decoder_prompt,
}
if 'summary' in self.column_map and self.column_map['summary'] in data:
sample['label'] = data[self.column_map['summary']]
return sample
def add_noise_to_tgt(self, target):
noise_indices = torch.FloatTensor(
target.size(0)).uniform_() < self.cfg.model.get(
'noise_ratio', 0.0)
target[noise_indices] = torch.randint(
4,
len(self.src_dict) - self.code_dict_size - self.num_bins,
size=(noise_indices.sum(), ))
return target

View File

@@ -38,8 +38,51 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
])
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = ' {}'.format(sample['label'])
sample['ref_dict'] = {sample['label']: 1.0}
tgt_item = self.tokenize_text(target, add_bos=False, add_eos=False)
if self.prompt_type == 'none':
prev_output_item = torch.cat([self.bos_item, tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'src':
prev_output_item = torch.cat([sample['source'], tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'prev_output':
prev_output_item = torch.cat([sample['source'][:-1], tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
else:
raise NotImplementedError
target_item[:-len(tgt_item) - 1] = self.tgt_dict.pad()
sample['target'] = target_item
sample['prev_output_tokens'] = prev_output_item
if self.constraint_trie is not None:
constraint_mask = torch.zeros(
(len(target_item), len(self.tgt_dict))).bool()
start_idx = len(target_item) - len(tgt_item) - 1
for i in range(
len(target_item) - len(tgt_item) - 1, len(target_item)):
constraint_prefix_token = [
self.tgt_dict.bos()
] + target_item[start_idx:i].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(
constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
sample['constraint_mask'] = constraint_mask
return sample
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
if 'text2' not in data:
hypothesis = self.pre_caption(data['text'], self.max_src_length)
@@ -68,4 +111,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
'patch_mask': torch.tensor([True]),
'decoder_prompt': decoder_prompt,
}
if 'relation' in self.column_map and self.column_map[
'relation'] in data:
sample['label'] = data[self.column_map['relation']]
return sample

View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
@@ -27,24 +28,95 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
"""
super(OfaVisualGroundingPreprocessor,
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])
if self.mode == ModeKeys.TRAIN:
# for positioning
self.positioning_transform = transforms.Compose([
transforms.RandomResize([self.patch_image_size],
max_size=self.patch_image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=self.mean,
std=self.std,
max_image_size=self.max_image_size)
])
else:
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(self.patch_image_size, self.patch_image_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
w, h = image.size
b_tgt = {
'boxes': [],
'labels': [],
'area': [],
'size': torch.tensor([h, w])
}
x0, y0, x1, y1 = data[self.column_map['region_coord']].strip().split(
',')
region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
b_tgt['boxes'] = torch.tensor(
[[float(x0), float(y0), float(x1),
float(y1)]])
b_tgt['labels'] = np.array([0])
b_tgt['area'] = [(float(x1) - float(x0)) * (float(y1) - float(y0))]
patch_image, patch_boxes = self.positioning_transform(image, b_tgt)
resize_h, resize_w = patch_boxes['size'][0], patch_boxes['size'][1]
quant_x0 = '<bin_{}>'.format(
int((patch_boxes['boxes'][0][0] * (self.num_bins - 1)).round()))
quant_y0 = '<bin_{}>'.format(
int((patch_boxes['boxes'][0][1] * (self.num_bins - 1)).round()))
quant_x1 = '<bin_{}>'.format(
int((patch_boxes['boxes'][0][2] * (self.num_bins - 1)).round()))
quant_y1 = '<bin_{}>'.format(
int((patch_boxes['boxes'][0][3] * (self.num_bins - 1)).round()))
region_coord = '{} {} {} {}'.format(quant_x0, quant_y0, quant_x1,
quant_y1)
src_caption = self.pre_caption(data[self.column_map['text']],
self.max_src_length)
prompt = self.cfg.model.get(
'prompt', ' which region does the text " {} " describe?')
text = prompt.format(src_caption)
src_item = self.tokenize_text(text)
target_item = self.tokenize_text(
region_coord, add_bos=False) # !!! use_bpe=False
prev_output_item = torch.cat([self.bos_item, target_item[:-1]])
sample = {
'source': src_item,
'patch_image': patch_image,
'patch_mask': torch.tensor([True]),
'target': target_item,
'prev_output_tokens': prev_output_item,
'w_resize_ratio': resize_w / w,
'h_resize_ratio': resize_h / h,
'region_coord': region
}
return sample
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
w, h = image.size
patch_image = self.patch_resize_transform(image)
w_resize_ratio = torch.tensor(self.patch_image_size / w)
h_resize_ratio = torch.tensor(self.patch_image_size / h)
src_caption = self.pre_caption(data['text'], self.max_src_length)
src_caption = self.pre_caption(data[self.column_map['text']],
self.max_src_length)
prompt = self.cfg.model.get(
'prompt', ' which region does the text " {} " describe?')
text = prompt.format(src_caption)

View File

@@ -38,10 +38,70 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):
])
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
src_item = sample['source']
ref = data[self.column_map['ref']]
predict_objects = data[self.column_map['predict_objects']]
ref_dict = {
item.split('|!+')[1]: float(item.split('|!+')[0])
for item in ref.split('&&')
}
answer = max(ref_dict, key=ref_dict.get)
sample['conf'] = torch.tensor([ref_dict[answer]])
tgt_item = self.tokenize_text(
' {}'.format(answer), add_bos=False, add_eos=False)
if self.add_object and predict_objects is not None:
predict_object_seq = ' '.join(
predict_objects.strip().split('&&')[:self.max_object_length])
predict_object_item = self.tokenize_text(
' object: {}'.format(predict_object_seq), add_bos=False)
src_item = torch.cat([src_item, predict_object_item[:-1]])
if self.prompt_type == 'none':
prev_output_item = torch.cat([self.bos_item, tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'src':
prev_output_item = torch.cat([src_item, tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'prev_output':
prev_output_item = torch.cat([src_item[:-1], tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
else:
raise NotImplementedError
target_item[:-len(tgt_item) - 1] = self.tgt_dict.pad()
sample['prev_output_tokens'] = prev_output_item
sample['target'] = target_item
sample['ref_dict'] = ref_dict
if self.constraint_trie is not None:
constraint_mask = torch.zeros(
(len(target_item), len(self.tgt_dict))).bool()
start_idx = len(target_item) - len(tgt_item) - 1
for i in range(
len(target_item) - len(tgt_item) - 1, len(target_item)):
constraint_prefix_token = [
self.tgt_dict.bos()
] + target_item[start_idx:i].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(
constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
sample['constraint_mask'] = constraint_mask
return sample
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
text = ' {}'.format(data['text'])
text = ' {}'.format(data[self.column_map['text']])
inputs = self.tokenize_text(text)
if self.prompt_type == 'none':
decoder_prompt = self.bos_item
@@ -57,4 +117,6 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):
'patch_mask': torch.tensor([True]),
'decoder_prompt': decoder_prompt,
}
if 'answer' in self.column_map and self.column_map['answer'] in data:
sample['label'] = data[self.column_map['answer']]
return sample