From 9b3a92e65df35cbab7848fb7c057563cb0a56fa9 Mon Sep 17 00:00:00 2001 From: "james.wjg" Date: Thu, 1 Dec 2022 19:16:56 +0800 Subject: [PATCH] =?UTF-8?q?cv/language=5Fguided=5Fvideo=5Fsummarization?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0finetune?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit cv/language_guided_video_summarization增加finetune Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10790262 --- .../msdatasets/task_datasets/__init__.py | 3 + ...uage_guided_video_summarization_dataset.py | 90 +++++++++++++++++++ ...uage_guided_video_summarization_trainer.py | 76 ++++++++++++++++ 3 files changed, 169 insertions(+) create mode 100644 modelscope/msdatasets/task_datasets/language_guided_video_summarization_dataset.py create mode 100644 tests/trainers/test_language_guided_video_summarization_trainer.py diff --git a/modelscope/msdatasets/task_datasets/__init__.py b/modelscope/msdatasets/task_datasets/__init__.py index 043010bf..3494c8da 100644 --- a/modelscope/msdatasets/task_datasets/__init__.py +++ b/modelscope/msdatasets/task_datasets/__init__.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset from .movie_scene_segmentation import MovieSceneSegmentationDataset from .video_summarization_dataset import VideoSummarizationDataset + from .language_guided_video_summarization_dataset import LanguageGuidedVideoSummarizationDataset from .image_inpainting import ImageInpaintingDataset from .text_ranking_dataset import TextRankingDataset from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset @@ -25,6 +26,8 @@ else: 'image_instance_segmentation_coco_dataset': ['ImageInstanceSegmentationCocoDataset'], 'video_summarization_dataset': ['VideoSummarizationDataset'], + 'language_guided_video_summarization_dataset': + ['LanguageGuidedVideoSummarizationDataset'], 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], 'image_inpainting': ['ImageInpaintingDataset'], 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], diff --git a/modelscope/msdatasets/task_datasets/language_guided_video_summarization_dataset.py b/modelscope/msdatasets/task_datasets/language_guided_video_summarization_dataset.py new file mode 100644 index 00000000..ef7ec9d8 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/language_guided_video_summarization_dataset.py @@ -0,0 +1,90 @@ +# Part of the implementation is borrowed and modified from PGL-SUM, +# publicly available at https://github.com/e-apostolidis/PGL-SUM, follow the +# license https://github.com/e-apostolidis/PGL-SUM/blob/master/LICENSE.md. + +import os + +import h5py +import json +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks + + +@TASK_DATASETS.register_module( + Tasks.language_guided_video_summarization, + module_name=Models.language_guided_video_summarization) +class LanguageGuidedVideoSummarizationDataset(TorchTaskDataset): + + def __init__(self, mode, opt, root_dir): + self.mode = mode + self.data_filename = os.path.join(root_dir, opt.dataset_file) + self.split_filename = os.path.join(root_dir, opt.split_file) + self.split_index = opt.split_index + hdf = h5py.File(self.data_filename, 'r') + self.list_image_features = [] + self.list_text_features = [] + self.list_gtscores = [] + self.list_user_summary = [] + self.list_change_points = [] + self.list_n_frames = [] + self.list_positions = [] + + with open(self.split_filename) as f: + data = json.loads(f.read()) + for i, split in enumerate(data): + if i == self.split_index: + self.split = split + break + + for video_name in self.split[self.mode + '_keys']: + clip_image_features = torch.Tensor( + np.array(hdf[video_name + '/features_clip_image'])) + clip_txt_features = torch.Tensor( + np.array(hdf[video_name + '/features_clip_txt'])).reshape( + 1, -1) + clip_txt_features = clip_txt_features.repeat( + clip_image_features.size(0), 1) + + gtscore = torch.Tensor(np.array(hdf[video_name + '/gtscore'])) + user_summary = np.array(hdf[f'{video_name}/user_summary']) + change_points = np.array(hdf[f'{video_name}/change_points']) + n_frames = np.array(hdf[f'{video_name}/n_frames']) + positions = np.array(hdf[f'{video_name}/picks']) + + self.list_image_features.append(clip_image_features) + self.list_text_features.append(clip_txt_features) + self.list_gtscores.append(gtscore) + self.list_user_summary.append(user_summary) + self.list_change_points.append(change_points) + self.list_n_frames.append(n_frames) + self.list_positions.append(positions) + + hdf.close() + + def __len__(self): + self.len = len(self.split[self.mode + '_keys']) + return self.len + + def __getitem__(self, index): + clip_image_features = self.list_image_features[index] + clip_txt_features = self.list_text_features[index] + gtscore = self.list_gtscores[index] + user_summary = self.list_user_summary[index] + change_points = self.list_change_points[index] + n_frames = self.list_n_frames[index] + positions = self.list_positions[index] + + return dict( + frame_features=clip_image_features, + txt_features=clip_txt_features, + gtscore=gtscore, + user_summary=user_summary, + change_points=change_points, + n_frames=n_frames, + positions=positions) diff --git a/tests/trainers/test_language_guided_video_summarization_trainer.py b/tests/trainers/test_language_guided_video_summarization_trainer.py new file mode 100644 index 00000000..3ff0e102 --- /dev/null +++ b/tests/trainers/test_language_guided_video_summarization_trainer.py @@ -0,0 +1,76 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.cv.language_guided_video_summarization import \ + ClipItVideoSummarization +from modelscope.msdatasets.task_datasets import \ + LanguageGuidedVideoSummarizationDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class LanguageGuidedVideoSummarizationTrainerTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/cv_clip-it_video-summarization_language-guided_en' + self.cache_path = snapshot_download(self.model_id) + self.config = Config.from_file( + os.path.join(self.cache_path, ModelFile.CONFIGURATION)) + self.dataset_train = LanguageGuidedVideoSummarizationDataset( + 'train', self.config.dataset, self.cache_path) + self.dataset_val = LanguageGuidedVideoSummarizationDataset( + 'test', self.config.dataset, self.cache_path) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, + max_epochs=2, + work_dir=self.tmp_dir) + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + model = ClipItVideoSummarization.from_pretrained(self.cache_path) + kwargs = dict( + cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, + max_epochs=2, + work_dir=self.tmp_dir) + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main()