mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
cv/language_guided_video_summarization增加finetune
cv/language_guided_video_summarization增加finetune
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10790262
This commit is contained in:
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
|
||||
from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset
|
||||
from .movie_scene_segmentation import MovieSceneSegmentationDataset
|
||||
from .video_summarization_dataset import VideoSummarizationDataset
|
||||
from .language_guided_video_summarization_dataset import LanguageGuidedVideoSummarizationDataset
|
||||
from .image_inpainting import ImageInpaintingDataset
|
||||
from .text_ranking_dataset import TextRankingDataset
|
||||
from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset
|
||||
@@ -25,6 +26,8 @@ else:
|
||||
'image_instance_segmentation_coco_dataset':
|
||||
['ImageInstanceSegmentationCocoDataset'],
|
||||
'video_summarization_dataset': ['VideoSummarizationDataset'],
|
||||
'language_guided_video_summarization_dataset':
|
||||
['LanguageGuidedVideoSummarizationDataset'],
|
||||
'movie_scene_segmentation': ['MovieSceneSegmentationDataset'],
|
||||
'image_inpainting': ['ImageInpaintingDataset'],
|
||||
'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'],
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
# Part of the implementation is borrowed and modified from PGL-SUM,
|
||||
# publicly available at https://github.com/e-apostolidis/PGL-SUM, follow the
|
||||
# license https://github.com/e-apostolidis/PGL-SUM/blob/master/LICENSE.md.
|
||||
|
||||
import os
|
||||
|
||||
import h5py
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
Tasks.language_guided_video_summarization,
|
||||
module_name=Models.language_guided_video_summarization)
|
||||
class LanguageGuidedVideoSummarizationDataset(TorchTaskDataset):
|
||||
|
||||
def __init__(self, mode, opt, root_dir):
|
||||
self.mode = mode
|
||||
self.data_filename = os.path.join(root_dir, opt.dataset_file)
|
||||
self.split_filename = os.path.join(root_dir, opt.split_file)
|
||||
self.split_index = opt.split_index
|
||||
hdf = h5py.File(self.data_filename, 'r')
|
||||
self.list_image_features = []
|
||||
self.list_text_features = []
|
||||
self.list_gtscores = []
|
||||
self.list_user_summary = []
|
||||
self.list_change_points = []
|
||||
self.list_n_frames = []
|
||||
self.list_positions = []
|
||||
|
||||
with open(self.split_filename) as f:
|
||||
data = json.loads(f.read())
|
||||
for i, split in enumerate(data):
|
||||
if i == self.split_index:
|
||||
self.split = split
|
||||
break
|
||||
|
||||
for video_name in self.split[self.mode + '_keys']:
|
||||
clip_image_features = torch.Tensor(
|
||||
np.array(hdf[video_name + '/features_clip_image']))
|
||||
clip_txt_features = torch.Tensor(
|
||||
np.array(hdf[video_name + '/features_clip_txt'])).reshape(
|
||||
1, -1)
|
||||
clip_txt_features = clip_txt_features.repeat(
|
||||
clip_image_features.size(0), 1)
|
||||
|
||||
gtscore = torch.Tensor(np.array(hdf[video_name + '/gtscore']))
|
||||
user_summary = np.array(hdf[f'{video_name}/user_summary'])
|
||||
change_points = np.array(hdf[f'{video_name}/change_points'])
|
||||
n_frames = np.array(hdf[f'{video_name}/n_frames'])
|
||||
positions = np.array(hdf[f'{video_name}/picks'])
|
||||
|
||||
self.list_image_features.append(clip_image_features)
|
||||
self.list_text_features.append(clip_txt_features)
|
||||
self.list_gtscores.append(gtscore)
|
||||
self.list_user_summary.append(user_summary)
|
||||
self.list_change_points.append(change_points)
|
||||
self.list_n_frames.append(n_frames)
|
||||
self.list_positions.append(positions)
|
||||
|
||||
hdf.close()
|
||||
|
||||
def __len__(self):
|
||||
self.len = len(self.split[self.mode + '_keys'])
|
||||
return self.len
|
||||
|
||||
def __getitem__(self, index):
|
||||
clip_image_features = self.list_image_features[index]
|
||||
clip_txt_features = self.list_text_features[index]
|
||||
gtscore = self.list_gtscores[index]
|
||||
user_summary = self.list_user_summary[index]
|
||||
change_points = self.list_change_points[index]
|
||||
n_frames = self.list_n_frames[index]
|
||||
positions = self.list_positions[index]
|
||||
|
||||
return dict(
|
||||
frame_features=clip_image_features,
|
||||
txt_features=clip_txt_features,
|
||||
gtscore=gtscore,
|
||||
user_summary=user_summary,
|
||||
change_points=change_points,
|
||||
n_frames=n_frames,
|
||||
positions=positions)
|
||||
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.cv.language_guided_video_summarization import \
|
||||
ClipItVideoSummarization
|
||||
from modelscope.msdatasets.task_datasets import \
|
||||
LanguageGuidedVideoSummarizationDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class LanguageGuidedVideoSummarizationTrainerTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
self.model_id = 'damo/cv_clip-it_video-summarization_language-guided_en'
|
||||
self.cache_path = snapshot_download(self.model_id)
|
||||
self.config = Config.from_file(
|
||||
os.path.join(self.cache_path, ModelFile.CONFIGURATION))
|
||||
self.dataset_train = LanguageGuidedVideoSummarizationDataset(
|
||||
'train', self.config.dataset, self.cache_path)
|
||||
self.dataset_val = LanguageGuidedVideoSummarizationDataset(
|
||||
'test', self.config.dataset, self.cache_path)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer(self):
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
train_dataset=self.dataset_train,
|
||||
eval_dataset=self.dataset_val,
|
||||
max_epochs=2,
|
||||
work_dir=self.tmp_dir)
|
||||
trainer = build_trainer(default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
for i in range(2):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_trainer_with_model_and_args(self):
|
||||
model = ClipItVideoSummarization.from_pretrained(self.cache_path)
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION),
|
||||
model=model,
|
||||
train_dataset=self.dataset_train,
|
||||
eval_dataset=self.dataset_val,
|
||||
max_epochs=2,
|
||||
work_dir=self.tmp_dir)
|
||||
trainer = build_trainer(default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
for i in range(2):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user