mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
fix finetune-task
This commit is contained in:
@@ -72,6 +72,7 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor):
|
||||
'noise_ratio', 0.0)
|
||||
target[noise_indices] = torch.randint(
|
||||
4,
|
||||
len(self.src_dict) - self.code_dict_size - self.num_bins,
|
||||
len(self.src_dict) - self.cfg.model.get('num_codes', 8192)
|
||||
- self.cfg.model.get('num_bins', 1000),
|
||||
size=(noise_indices.sum(), ))
|
||||
return target
|
||||
|
||||
@@ -61,7 +61,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
target_item[:-len(tgt_item) - 1] = self.tgt_dict.pad()
|
||||
target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id
|
||||
sample['target'] = target_item
|
||||
sample['prev_output_tokens'] = prev_output_item
|
||||
|
||||
@@ -85,14 +85,17 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
|
||||
image = self.get_img_pil(data[self.column_map['image']])
|
||||
patch_image = self.patch_resize_transform(image)
|
||||
if 'text2' not in data:
|
||||
hypothesis = self.pre_caption(data['text'], self.max_src_length)
|
||||
hypothesis = self.pre_caption(data[self.column_map['text']],
|
||||
self.max_src_length)
|
||||
prompt = self.cfg.model.get('prompt',
|
||||
' does the image describe " {} "?')
|
||||
text = prompt.format(hypothesis)
|
||||
else:
|
||||
assert 'text' in data, f'text must be in the input {data.keys()}'
|
||||
caption = self.pre_caption(data['text2'], self.max_src_length)
|
||||
hypothesis = self.pre_caption(data['text'], self.max_src_length)
|
||||
caption = self.pre_caption(data[self.column_map['text2']],
|
||||
self.max_src_length)
|
||||
hypothesis = self.pre_caption(data[self.column_map['text']],
|
||||
self.max_src_length)
|
||||
prompt = self.cfg.model.get(
|
||||
'prompt', ' can image and text1 " {} " imply text2 " {} "?')
|
||||
text = prompt.format(caption, hypothesis)
|
||||
|
||||
@@ -45,42 +45,24 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample = self._build_infer_sample(data)
|
||||
src_item = sample['source']
|
||||
ref = data[self.column_map['ref']]
|
||||
predict_objects = data[self.column_map['predict_objects']]
|
||||
|
||||
ref_dict = {
|
||||
item.split('|!+')[1]: float(item.split('|!+')[0])
|
||||
for item in ref.split('&&')
|
||||
}
|
||||
answer = max(ref_dict, key=ref_dict.get)
|
||||
sample['conf'] = torch.tensor([ref_dict[answer]])
|
||||
tgt_item = self.tokenize_text(
|
||||
' {}'.format(answer), add_bos=False, add_eos=False)
|
||||
|
||||
if self.add_object and predict_objects is not None:
|
||||
predict_object_seq = ' '.join(
|
||||
predict_objects.strip().split('&&')[:self.max_object_length])
|
||||
predict_object_item = self.tokenize_text(
|
||||
' object: {}'.format(predict_object_seq), add_bos=False)
|
||||
src_item = torch.cat([src_item, predict_object_item[:-1]])
|
||||
' {}'.format(sample['label']), add_bos=False, add_eos=False)
|
||||
|
||||
if self.prompt_type == 'none':
|
||||
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
||||
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
||||
elif self.prompt_type == 'src':
|
||||
prev_output_item = torch.cat([src_item, tgt_item])
|
||||
prev_output_item = torch.cat([sample['source'], tgt_item])
|
||||
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
||||
elif self.prompt_type == 'prev_output':
|
||||
prev_output_item = torch.cat([src_item[:-1], tgt_item])
|
||||
prev_output_item = torch.cat([sample['source'][:-1], tgt_item])
|
||||
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
target_item[:-len(tgt_item) - 1] = self.tgt_dict.pad()
|
||||
target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id
|
||||
|
||||
sample['prev_output_tokens'] = prev_output_item
|
||||
sample['target'] = target_item
|
||||
sample['ref_dict'] = ref_dict
|
||||
|
||||
if self.constraint_trie is not None:
|
||||
constraint_mask = torch.zeros(
|
||||
@@ -101,7 +83,7 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):
|
||||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
image = self.get_img_pil(data[self.column_map['image']])
|
||||
patch_image = self.patch_resize_transform(image)
|
||||
text = ' {}'.format(data[self.column_map['text']])
|
||||
text = ' {}'.format(data[self.column_map['query']])
|
||||
inputs = self.tokenize_text(text)
|
||||
if self.prompt_type == 'none':
|
||||
decoder_prompt = self.bos_item
|
||||
|
||||
Reference in New Issue
Block a user