From ca946067e643dac1fd9920eb7bfd53c5cafbe320 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BF=8E=E8=88=AA?= Date: Wed, 2 Nov 2022 16:15:32 +0800 Subject: [PATCH] fix finetune-task --- modelscope/preprocessors/ofa/summarization.py | 3 +- .../preprocessors/ofa/visual_entailment.py | 11 +++++--- .../ofa/visual_question_answering.py | 28 ++++--------------- 3 files changed, 14 insertions(+), 28 deletions(-) diff --git a/modelscope/preprocessors/ofa/summarization.py b/modelscope/preprocessors/ofa/summarization.py index 176600a9..8568a543 100644 --- a/modelscope/preprocessors/ofa/summarization.py +++ b/modelscope/preprocessors/ofa/summarization.py @@ -72,6 +72,7 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor): 'noise_ratio', 0.0) target[noise_indices] = torch.randint( 4, - len(self.src_dict) - self.code_dict_size - self.num_bins, + len(self.src_dict) - self.cfg.model.get('num_codes', 8192) + - self.cfg.model.get('num_bins', 1000), size=(noise_indices.sum(), )) return target diff --git a/modelscope/preprocessors/ofa/visual_entailment.py b/modelscope/preprocessors/ofa/visual_entailment.py index aeba199c..fff5bbd3 100644 --- a/modelscope/preprocessors/ofa/visual_entailment.py +++ b/modelscope/preprocessors/ofa/visual_entailment.py @@ -61,7 +61,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): else: raise NotImplementedError - target_item[:-len(tgt_item) - 1] = self.tgt_dict.pad() + target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id sample['target'] = target_item sample['prev_output_tokens'] = prev_output_item @@ -85,14 +85,17 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): image = self.get_img_pil(data[self.column_map['image']]) patch_image = self.patch_resize_transform(image) if 'text2' not in data: - hypothesis = self.pre_caption(data['text'], self.max_src_length) + hypothesis = self.pre_caption(data[self.column_map['text']], + self.max_src_length) prompt = self.cfg.model.get('prompt', ' does the image describe " {} "?') text = prompt.format(hypothesis) else: assert 'text' in data, f'text must be in the input {data.keys()}' - caption = self.pre_caption(data['text2'], self.max_src_length) - hypothesis = self.pre_caption(data['text'], self.max_src_length) + caption = self.pre_caption(data[self.column_map['text2']], + self.max_src_length) + hypothesis = self.pre_caption(data[self.column_map['text']], + self.max_src_length) prompt = self.cfg.model.get( 'prompt', ' can image and text1 " {} " imply text2 " {} "?') text = prompt.format(caption, hypothesis) diff --git a/modelscope/preprocessors/ofa/visual_question_answering.py b/modelscope/preprocessors/ofa/visual_question_answering.py index 9f9ea4f7..c623a869 100644 --- a/modelscope/preprocessors/ofa/visual_question_answering.py +++ b/modelscope/preprocessors/ofa/visual_question_answering.py @@ -45,42 +45,24 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: sample = self._build_infer_sample(data) - src_item = sample['source'] - ref = data[self.column_map['ref']] - predict_objects = data[self.column_map['predict_objects']] - - ref_dict = { - item.split('|!+')[1]: float(item.split('|!+')[0]) - for item in ref.split('&&') - } - answer = max(ref_dict, key=ref_dict.get) - sample['conf'] = torch.tensor([ref_dict[answer]]) tgt_item = self.tokenize_text( - ' {}'.format(answer), add_bos=False, add_eos=False) - - if self.add_object and predict_objects is not None: - predict_object_seq = ' '.join( - predict_objects.strip().split('&&')[:self.max_object_length]) - predict_object_item = self.tokenize_text( - ' object: {}'.format(predict_object_seq), add_bos=False) - src_item = torch.cat([src_item, predict_object_item[:-1]]) + ' {}'.format(sample['label']), add_bos=False, add_eos=False) if self.prompt_type == 'none': prev_output_item = torch.cat([self.bos_item, tgt_item]) target_item = torch.cat([prev_output_item[1:], self.eos_item]) elif self.prompt_type == 'src': - prev_output_item = torch.cat([src_item, tgt_item]) + prev_output_item = torch.cat([sample['source'], tgt_item]) target_item = torch.cat([prev_output_item[1:], self.eos_item]) elif self.prompt_type == 'prev_output': - prev_output_item = torch.cat([src_item[:-1], tgt_item]) + prev_output_item = torch.cat([sample['source'][:-1], tgt_item]) target_item = torch.cat([prev_output_item[1:], self.eos_item]) else: raise NotImplementedError - target_item[:-len(tgt_item) - 1] = self.tgt_dict.pad() + target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id sample['prev_output_tokens'] = prev_output_item sample['target'] = target_item - sample['ref_dict'] = ref_dict if self.constraint_trie is not None: constraint_mask = torch.zeros( @@ -101,7 +83,7 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: image = self.get_img_pil(data[self.column_map['image']]) patch_image = self.patch_resize_transform(image) - text = ' {}'.format(data[self.column_map['text']]) + text = ' {}'.format(data[self.column_map['query']]) inputs = self.tokenize_text(text) if self.prompt_type == 'none': decoder_prompt = self.bos_item