mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
[to #42322933]add fine-tune code for referring video object segmentation
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10539423
This commit is contained in:
@@ -305,6 +305,7 @@ class Trainers(object):
|
||||
face_detection_scrfd = 'face-detection-scrfd'
|
||||
card_detection_scrfd = 'card-detection-scrfd'
|
||||
image_inpainting = 'image-inpainting'
|
||||
referring_video_object_segmentation = 'referring-video-object-segmentation'
|
||||
image_classification_team = 'image-classification-team'
|
||||
|
||||
# nlp trainers
|
||||
@@ -423,6 +424,8 @@ class Metrics(object):
|
||||
image_inpainting_metric = 'image-inpainting-metric'
|
||||
# metric for ocr
|
||||
NED = 'ned'
|
||||
# metric for referring-video-object-segmentation task
|
||||
referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric'
|
||||
|
||||
|
||||
class Optimizers(object):
|
||||
|
||||
@@ -20,6 +20,7 @@ if TYPE_CHECKING:
|
||||
from .accuracy_metric import AccuracyMetric
|
||||
from .bleu_metric import BleuMetric
|
||||
from .image_inpainting_metric import ImageInpaintingMetric
|
||||
from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -40,6 +41,8 @@ else:
|
||||
'image_inpainting_metric': ['ImageInpaintingMetric'],
|
||||
'accuracy_metric': ['AccuracyMetric'],
|
||||
'bleu_metric': ['BleuMetric'],
|
||||
'referring_video_object_segmentation_metric':
|
||||
['ReferringVideoObjectSegmentationMetric'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -43,6 +43,8 @@ task_default_metrics = {
|
||||
Tasks.visual_question_answering: [Metrics.text_gen_metric],
|
||||
Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric],
|
||||
Tasks.image_inpainting: [Metrics.image_inpainting_metric],
|
||||
Tasks.referring_video_object_segmentation:
|
||||
[Metrics.referring_video_object_segmentation_metric],
|
||||
}
|
||||
|
||||
|
||||
|
||||
108
modelscope/metrics/referring_video_object_segmentation_metric.py
Normal file
108
modelscope/metrics/referring_video_object_segmentation_metric.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Part of the implementation is borrowed and modified from MTTR,
|
||||
# publicly available at https://github.com/mttr2021/MTTR
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from pycocotools.mask import decode
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.utils.registry import default_group
|
||||
from .base import Metric
|
||||
from .builder import METRICS, MetricKeys
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
group_key=default_group,
|
||||
module_name=Metrics.referring_video_object_segmentation_metric)
|
||||
class ReferringVideoObjectSegmentationMetric(Metric):
|
||||
"""The metric computation class for movie scene segmentation classes.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ann_file=None,
|
||||
calculate_precision_and_iou_metrics=True):
|
||||
self.ann_file = ann_file
|
||||
self.calculate_precision_and_iou_metrics = calculate_precision_and_iou_metrics
|
||||
self.preds = []
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
preds_batch = outputs['pred']
|
||||
self.preds.extend(preds_batch)
|
||||
|
||||
def evaluate(self):
|
||||
coco_gt = COCO(self.ann_file)
|
||||
coco_pred = coco_gt.loadRes(self.preds)
|
||||
coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm')
|
||||
coco_eval.params.useCats = 0
|
||||
|
||||
coco_eval.evaluate()
|
||||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
|
||||
ap_labels = [
|
||||
'mAP 0.5:0.95', 'AP 0.5', 'AP 0.75', 'AP 0.5:0.95 S',
|
||||
'AP 0.5:0.95 M', 'AP 0.5:0.95 L'
|
||||
]
|
||||
ap_metrics = coco_eval.stats[:6]
|
||||
eval_metrics = {la: m for la, m in zip(ap_labels, ap_metrics)}
|
||||
if self.calculate_precision_and_iou_metrics:
|
||||
precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics(
|
||||
coco_gt, coco_pred)
|
||||
eval_metrics.update({
|
||||
f'P@{k}': m
|
||||
for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k)
|
||||
})
|
||||
eval_metrics.update({
|
||||
'overall_iou': overall_iou,
|
||||
'mean_iou': mean_iou
|
||||
})
|
||||
|
||||
return eval_metrics
|
||||
|
||||
|
||||
def compute_iou(outputs: torch.Tensor, labels: torch.Tensor, EPS=1e-6):
|
||||
outputs = outputs.int()
|
||||
intersection = (outputs & labels).float().sum(
|
||||
(1, 2)) # Will be zero if Truth=0 or Prediction=0
|
||||
union = (outputs | labels).float().sum(
|
||||
(1, 2)) # Will be zero if both are 0
|
||||
iou = (intersection + EPS) / (union + EPS
|
||||
) # EPS is used to avoid division by zero
|
||||
return iou, intersection, union
|
||||
|
||||
|
||||
def calculate_precision_at_k_and_iou_metrics(coco_gt: COCO, coco_pred: COCO):
|
||||
print('evaluating precision@k & iou metrics...')
|
||||
counters_by_iou = {iou: 0 for iou in [0.5, 0.6, 0.7, 0.8, 0.9]}
|
||||
total_intersection_area = 0
|
||||
total_union_area = 0
|
||||
ious_list = []
|
||||
for instance in tqdm(coco_gt.imgs.keys()
|
||||
): # each image_id contains exactly one instance
|
||||
gt_annot = coco_gt.imgToAnns[instance][0]
|
||||
gt_mask = decode(gt_annot['segmentation'])
|
||||
pred_annots = coco_pred.imgToAnns[instance]
|
||||
pred_annot = sorted(
|
||||
pred_annots,
|
||||
key=lambda a: a['score'])[-1] # choose pred with highest score
|
||||
pred_mask = decode(pred_annot['segmentation'])
|
||||
iou, intersection, union = compute_iou(
|
||||
torch.tensor(pred_mask).unsqueeze(0),
|
||||
torch.tensor(gt_mask).unsqueeze(0))
|
||||
iou, intersection, union = iou.item(), intersection.item(), union.item(
|
||||
)
|
||||
for iou_threshold in counters_by_iou.keys():
|
||||
if iou > iou_threshold:
|
||||
counters_by_iou[iou_threshold] += 1
|
||||
total_intersection_area += intersection
|
||||
total_union_area += union
|
||||
ious_list.append(iou)
|
||||
num_samples = len(ious_list)
|
||||
precision_at_k = np.array(list(counters_by_iou.values())) / num_samples
|
||||
overall_iou = total_intersection_area / total_union_area
|
||||
mean_iou = np.mean(ious_list)
|
||||
return precision_at_k, overall_iou, mean_iou
|
||||
@@ -5,11 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .model import MovieSceneSegmentation
|
||||
from .model import ReferringVideoObjectSegmentation
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'model': ['MovieSceneSegmentation'],
|
||||
'model': ['ReferringVideoObjectSegmentation'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed and modified from MTTR,
|
||||
# publicly available at https://github.com/mttr2021/MTTR
|
||||
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
@@ -10,7 +12,9 @@ from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .utils import (MTTR, A2DSentencesPostProcess, ReferYoutubeVOSPostProcess,
|
||||
from .utils import (MTTR, A2DSentencesPostProcess, HungarianMatcher,
|
||||
ReferYoutubeVOSPostProcess, SetCriterion,
|
||||
flatten_temporal_batch_dims,
|
||||
nested_tensor_from_videos_list)
|
||||
|
||||
logger = get_logger()
|
||||
@@ -35,16 +39,66 @@ class ReferringVideoObjectSegmentation(TorchModel):
|
||||
params_dict = params_dict['model_state_dict']
|
||||
self.model.load_state_dict(params_dict, strict=True)
|
||||
|
||||
dataset_name = self.cfg.pipeline.dataset_name
|
||||
if dataset_name == 'a2d_sentences' or dataset_name == 'jhmdb_sentences':
|
||||
self.postprocessor = A2DSentencesPostProcess()
|
||||
elif dataset_name == 'ref_youtube_vos':
|
||||
self.postprocessor = ReferYoutubeVOSPostProcess()
|
||||
self.set_postprocessor(self.cfg.pipeline.dataset_name)
|
||||
self.set_criterion()
|
||||
|
||||
def set_device(self, device, name):
|
||||
self.device = device
|
||||
self._device_name = name
|
||||
|
||||
def set_postprocessor(self, dataset_name):
|
||||
if 'a2d_sentences' in dataset_name or 'jhmdb_sentences' in dataset_name:
|
||||
self.postprocessor = A2DSentencesPostProcess() # fine-tune
|
||||
elif 'ref_youtube_vos' in dataset_name:
|
||||
self.postprocessor = ReferYoutubeVOSPostProcess() # inference
|
||||
else:
|
||||
assert False, f'postprocessing for dataset: {dataset_name} is not supported'
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
||||
return inputs
|
||||
def forward(self, inputs: Dict[str, Any]):
|
||||
samples = inputs['samples']
|
||||
targets = inputs['targets']
|
||||
text_queries = inputs['text_queries']
|
||||
|
||||
valid_indices = torch.tensor(
|
||||
[i for i, t in enumerate(targets) if None not in t])
|
||||
targets = [targets[i] for i in valid_indices.tolist()]
|
||||
if self._device_name == 'gpu':
|
||||
samples = samples.to(self.device)
|
||||
valid_indices = valid_indices.to(self.device)
|
||||
if isinstance(text_queries, tuple):
|
||||
text_queries = list(text_queries)
|
||||
|
||||
outputs = self.model(samples, valid_indices, text_queries)
|
||||
losses = -1
|
||||
if self.training:
|
||||
loss_dict = self.criterion(outputs, targets)
|
||||
weight_dict = self.criterion.weight_dict
|
||||
losses = sum(loss_dict[k] * weight_dict[k]
|
||||
for k in loss_dict.keys() if k in weight_dict)
|
||||
|
||||
predictions = []
|
||||
if not self.training:
|
||||
outputs.pop('aux_outputs', None)
|
||||
outputs, targets = flatten_temporal_batch_dims(outputs, targets)
|
||||
processed_outputs = self.postprocessor(
|
||||
outputs,
|
||||
resized_padded_sample_size=samples.tensors.shape[-2:],
|
||||
resized_sample_sizes=[t['size'] for t in targets],
|
||||
orig_sample_sizes=[t['orig_size'] for t in targets])
|
||||
image_ids = [t['image_id'] for t in targets]
|
||||
predictions = []
|
||||
for p, image_id in zip(processed_outputs, image_ids):
|
||||
for s, m in zip(p['scores'], p['rle_masks']):
|
||||
predictions.append({
|
||||
'image_id': image_id,
|
||||
'category_id':
|
||||
1, # dummy label, as categories are not predicted in ref-vos
|
||||
'segmentation': m,
|
||||
'score': s.item()
|
||||
})
|
||||
|
||||
re = dict(pred=predictions, loss=losses)
|
||||
return re
|
||||
|
||||
def inference(self, **kwargs):
|
||||
window = kwargs['window']
|
||||
@@ -63,3 +117,26 @@ class ReferringVideoObjectSegmentation(TorchModel):
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any], **kwargs):
|
||||
return inputs
|
||||
|
||||
def set_criterion(self):
|
||||
matcher = HungarianMatcher(
|
||||
cost_is_referred=self.cfg.matcher.set_cost_is_referred,
|
||||
cost_dice=self.cfg.matcher.set_cost_dice)
|
||||
weight_dict = {
|
||||
'loss_is_referred': self.cfg.loss.is_referred_loss_coef,
|
||||
'loss_dice': self.cfg.loss.dice_loss_coef,
|
||||
'loss_sigmoid_focal': self.cfg.loss.sigmoid_focal_loss_coef
|
||||
}
|
||||
|
||||
if self.cfg.loss.aux_loss:
|
||||
aux_weight_dict = {}
|
||||
for i in range(self.cfg.model.num_decoder_layers - 1):
|
||||
aux_weight_dict.update(
|
||||
{k + f'_{i}': v
|
||||
for k, v in weight_dict.items()})
|
||||
weight_dict.update(aux_weight_dict)
|
||||
|
||||
self.criterion = SetCriterion(
|
||||
matcher=matcher,
|
||||
weight_dict=weight_dict,
|
||||
eos_coef=self.cfg.loss.eos_coef)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .misc import nested_tensor_from_videos_list
|
||||
from .criterion import SetCriterion, flatten_temporal_batch_dims
|
||||
from .matcher import HungarianMatcher
|
||||
from .misc import interpolate, nested_tensor_from_videos_list
|
||||
from .mttr import MTTR
|
||||
from .postprocessing import A2DSentencesPostProcess, ReferYoutubeVOSPostProcess
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
# The implementation is adopted from MTTR,
|
||||
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR
|
||||
# Modified from DETR https://github.com/facebookresearch/detr
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .misc import (get_world_size, interpolate, is_dist_avail_and_initialized,
|
||||
nested_tensor_from_tensor_list)
|
||||
from .segmentation import dice_loss, sigmoid_focal_loss
|
||||
|
||||
|
||||
class SetCriterion(nn.Module):
|
||||
""" This class computes the loss for MTTR.
|
||||
The process happens in two steps:
|
||||
1) we compute the hungarian assignment between the ground-truth and predicted sequences.
|
||||
2) we supervise each pair of matched ground-truth / prediction sequences (mask + reference prediction)
|
||||
"""
|
||||
|
||||
def __init__(self, matcher, weight_dict, eos_coef):
|
||||
""" Create the criterion.
|
||||
Parameters:
|
||||
matcher: module able to compute a matching between targets and proposals
|
||||
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
||||
eos_coef: relative classification weight applied to the un-referred category
|
||||
"""
|
||||
super().__init__()
|
||||
self.matcher = matcher
|
||||
self.weight_dict = weight_dict
|
||||
self.eos_coef = eos_coef
|
||||
# make sure that only loss functions with non-zero weights are computed:
|
||||
losses_to_compute = []
|
||||
if weight_dict['loss_dice'] > 0 or weight_dict[
|
||||
'loss_sigmoid_focal'] > 0:
|
||||
losses_to_compute.append('masks')
|
||||
if weight_dict['loss_is_referred'] > 0:
|
||||
losses_to_compute.append('is_referred')
|
||||
self.losses = losses_to_compute
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
aux_outputs_list = outputs.pop('aux_outputs', None)
|
||||
# compute the losses for the output of the last decoder layer:
|
||||
losses = self.compute_criterion(
|
||||
outputs, targets, losses_to_compute=self.losses)
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate decoder layer.
|
||||
if aux_outputs_list is not None:
|
||||
aux_losses_to_compute = self.losses.copy()
|
||||
for i, aux_outputs in enumerate(aux_outputs_list):
|
||||
losses_dict = self.compute_criterion(aux_outputs, targets,
|
||||
aux_losses_to_compute)
|
||||
losses_dict = {k + f'_{i}': v for k, v in losses_dict.items()}
|
||||
losses.update(losses_dict)
|
||||
|
||||
return losses
|
||||
|
||||
def compute_criterion(self, outputs, targets, losses_to_compute):
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs, targets)
|
||||
|
||||
# T & B dims are flattened so loss functions can be computed per frame (but with same indices per video).
|
||||
# also, indices are repeated so so the same indices can be used for frames of the same video.
|
||||
T = len(targets)
|
||||
outputs, targets = flatten_temporal_batch_dims(outputs, targets)
|
||||
# repeat the indices list T times so the same indices can be used for each video frame
|
||||
indices = T * indices
|
||||
|
||||
# Compute the average number of target masks across all nodes, for normalization purposes
|
||||
num_masks = sum(len(t['masks']) for t in targets)
|
||||
num_masks = torch.as_tensor([num_masks],
|
||||
dtype=torch.float,
|
||||
device=indices[0][0].device)
|
||||
if is_dist_avail_and_initialized():
|
||||
torch.distributed.all_reduce(num_masks)
|
||||
num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in losses_to_compute:
|
||||
losses.update(
|
||||
self.get_loss(
|
||||
loss, outputs, targets, indices, num_masks=num_masks))
|
||||
return losses
|
||||
|
||||
def loss_is_referred(self, outputs, targets, indices, **kwargs):
|
||||
device = outputs['pred_is_referred'].device
|
||||
bs = outputs['pred_is_referred'].shape[0]
|
||||
pred_is_referred = outputs['pred_is_referred'].log_softmax(
|
||||
dim=-1) # note that log-softmax is used here
|
||||
target_is_referred = torch.zeros_like(pred_is_referred)
|
||||
# extract indices of object queries that where matched with text-referred target objects
|
||||
query_referred_indices = self._get_query_referred_indices(
|
||||
indices, targets)
|
||||
# by default penalize compared to the no-object class (last token)
|
||||
target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device)
|
||||
if 'is_ref_inst_visible' in targets[
|
||||
0]: # visibility labels are available per-frame for the referred object:
|
||||
is_ref_inst_visible_per_frame = torch.stack(
|
||||
[t['is_ref_inst_visible'] for t in targets])
|
||||
ref_inst_visible_frame_indices = is_ref_inst_visible_per_frame.nonzero(
|
||||
).squeeze()
|
||||
# keep only the matched query indices of the frames in which the referred object is visible:
|
||||
visible_query_referred_indices = query_referred_indices[
|
||||
ref_inst_visible_frame_indices]
|
||||
target_is_referred[ref_inst_visible_frame_indices,
|
||||
visible_query_referred_indices] = torch.tensor(
|
||||
[1.0, 0.0], device=device)
|
||||
else: # assume that the referred object is visible in every frame:
|
||||
target_is_referred[torch.arange(bs),
|
||||
query_referred_indices] = torch.tensor(
|
||||
[1.0, 0.0], device=device)
|
||||
loss = -(pred_is_referred * target_is_referred).sum(-1)
|
||||
# apply no-object class weights:
|
||||
eos_coef = torch.full(loss.shape, self.eos_coef, device=loss.device)
|
||||
eos_coef[torch.arange(bs), query_referred_indices] = 1.0
|
||||
loss = loss * eos_coef
|
||||
bs = len(indices)
|
||||
loss = loss.sum() / bs # sum and normalize the loss by the batch size
|
||||
losses = {'loss_is_referred': loss}
|
||||
return losses
|
||||
|
||||
def loss_masks(self, outputs, targets, indices, num_masks, **kwargs):
|
||||
assert 'pred_masks' in outputs
|
||||
|
||||
src_idx = self._get_src_permutation_idx(indices)
|
||||
tgt_idx = self._get_tgt_permutation_idx(indices)
|
||||
src_masks = outputs['pred_masks']
|
||||
src_masks = src_masks[src_idx]
|
||||
masks = [t['masks'] for t in targets]
|
||||
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
||||
target_masks = target_masks.to(src_masks)
|
||||
target_masks = target_masks[tgt_idx]
|
||||
|
||||
# upsample predictions to the target size
|
||||
src_masks = interpolate(
|
||||
src_masks[:, None],
|
||||
size=target_masks.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
src_masks = src_masks[:, 0].flatten(1)
|
||||
|
||||
target_masks = target_masks.flatten(1)
|
||||
target_masks = target_masks.view(src_masks.shape)
|
||||
losses = {
|
||||
'loss_sigmoid_focal':
|
||||
sigmoid_focal_loss(src_masks, target_masks, num_masks),
|
||||
'loss_dice':
|
||||
dice_loss(src_masks, target_masks, num_masks),
|
||||
}
|
||||
return losses
|
||||
|
||||
@staticmethod
|
||||
def _get_src_permutation_idx(indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat(
|
||||
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
||||
src_idx = torch.cat([src for (src, _) in indices])
|
||||
return batch_idx, src_idx
|
||||
|
||||
@staticmethod
|
||||
def _get_tgt_permutation_idx(indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat(
|
||||
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
||||
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
||||
return batch_idx, tgt_idx
|
||||
|
||||
@staticmethod
|
||||
def _get_query_referred_indices(indices, targets):
|
||||
"""
|
||||
extract indices of object queries that where matched with text-referred target objects
|
||||
"""
|
||||
query_referred_indices = []
|
||||
for (query_idxs, target_idxs), target in zip(indices, targets):
|
||||
ref_query_idx = query_idxs[torch.where(
|
||||
target_idxs == target['referred_instance_idx'])[0]]
|
||||
query_referred_indices.append(ref_query_idx)
|
||||
query_referred_indices = torch.cat(query_referred_indices)
|
||||
return query_referred_indices
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, **kwargs):
|
||||
loss_map = {
|
||||
'masks': self.loss_masks,
|
||||
'is_referred': self.loss_is_referred,
|
||||
}
|
||||
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
||||
return loss_map[loss](outputs, targets, indices, **kwargs)
|
||||
|
||||
|
||||
def flatten_temporal_batch_dims(outputs, targets):
|
||||
for k in outputs.keys():
|
||||
if isinstance(outputs[k], torch.Tensor):
|
||||
outputs[k] = outputs[k].flatten(0, 1)
|
||||
else: # list
|
||||
outputs[k] = [i for step_t in outputs[k] for i in step_t]
|
||||
targets = [
|
||||
frame_t_target for step_t in targets for frame_t_target in step_t
|
||||
]
|
||||
return outputs, targets
|
||||
@@ -0,0 +1,163 @@
|
||||
# The implementation is adopted from MTTR,
|
||||
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR
|
||||
# Modified from DETR https://github.com/facebookresearch/detr
|
||||
# Module to compute the matching cost and solve the corresponding LSAP.
|
||||
|
||||
import torch
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from torch import nn
|
||||
|
||||
from .misc import interpolate, nested_tensor_from_tensor_list
|
||||
|
||||
|
||||
class HungarianMatcher(nn.Module):
|
||||
"""This class computes an assignment between the targets and the predictions of the network
|
||||
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
||||
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
||||
while the others are un-matched (and thus treated as non-objects).
|
||||
"""
|
||||
|
||||
def __init__(self, cost_is_referred: float = 1, cost_dice: float = 1):
|
||||
"""Creates the matcher
|
||||
|
||||
Params:
|
||||
cost_is_referred: This is the relative weight of the reference cost in the total matching cost
|
||||
cost_dice: This is the relative weight of the dice cost in the total matching cost
|
||||
"""
|
||||
super().__init__()
|
||||
self.cost_is_referred = cost_is_referred
|
||||
self.cost_dice = cost_dice
|
||||
assert cost_is_referred != 0 or cost_dice != 0, 'all costs cant be 0'
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, outputs, targets):
|
||||
""" Performs the matching
|
||||
|
||||
Params:
|
||||
outputs: A dict that contains at least these entries:
|
||||
"pred_is_referred": Tensor of dim [time, batch_size, num_queries, 2] with the reference logits
|
||||
"pred_masks": Tensor of dim [time, batch_size, num_queries, H, W] with the predicted masks logits
|
||||
|
||||
targets: A list of lists of targets (outer - time steps, inner - batch samples). each target is a dict
|
||||
which contain mask and reference ground truth information for a single frame.
|
||||
|
||||
Returns:
|
||||
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds:
|
||||
len(index_i) = len(index_j) = min(num_queries, num_target_masks)
|
||||
"""
|
||||
t, bs, num_queries = outputs['pred_masks'].shape[:3]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_masks = outputs['pred_masks'].flatten(
|
||||
1, 2) # [t, batch_size * num_queries, mask_h, mask_w]
|
||||
|
||||
# preprocess and concat the target masks
|
||||
tgt_masks = [[
|
||||
m for v in t_step_batch for m in v['masks'].unsqueeze(1)
|
||||
] for t_step_batch in targets]
|
||||
# pad the target masks to a uniform shape
|
||||
tgt_masks, valid = list(
|
||||
zip(*[
|
||||
nested_tensor_from_tensor_list(t).decompose()
|
||||
for t in tgt_masks
|
||||
]))
|
||||
tgt_masks = torch.stack(tgt_masks).squeeze(2)
|
||||
|
||||
# upsample predicted masks to target mask size
|
||||
out_masks = interpolate(
|
||||
out_masks,
|
||||
size=tgt_masks.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
# Compute the soft-tokens cost:
|
||||
if self.cost_is_referred > 0:
|
||||
cost_is_referred = compute_is_referred_cost(outputs, targets)
|
||||
else:
|
||||
cost_is_referred = 0
|
||||
|
||||
# Compute the DICE coefficient between the masks:
|
||||
if self.cost_dice > 0:
|
||||
cost_dice = -dice_coef(out_masks, tgt_masks)
|
||||
else:
|
||||
cost_dice = 0
|
||||
|
||||
# Final cost matrix
|
||||
C = self.cost_is_referred * cost_is_referred + self.cost_dice * cost_dice
|
||||
C = C.view(bs, num_queries, -1).cpu()
|
||||
|
||||
num_traj_per_batch = [
|
||||
len(v['masks']) for v in targets[0]
|
||||
] # number of instance trajectories in each batch
|
||||
indices = [
|
||||
linear_sum_assignment(c[i])
|
||||
for i, c in enumerate(C.split(num_traj_per_batch, -1))
|
||||
]
|
||||
device = out_masks.device
|
||||
return [(torch.as_tensor(i, dtype=torch.int64, device=device),
|
||||
torch.as_tensor(j, dtype=torch.int64, device=device))
|
||||
for i, j in indices]
|
||||
|
||||
|
||||
def dice_coef(inputs, targets, smooth=1.0):
|
||||
"""
|
||||
Compute the DICE coefficient, similar to generalized IOU for masks
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs
|
||||
(0 for the negative class and 1 for the positive class).
|
||||
"""
|
||||
inputs = inputs.sigmoid().flatten(2).unsqueeze(2)
|
||||
targets = targets.flatten(2).unsqueeze(1)
|
||||
numerator = 2 * (inputs * targets).sum(-1)
|
||||
denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
coef = (numerator + smooth) / (denominator + smooth)
|
||||
coef = coef.mean(
|
||||
0) # average on the temporal dim to get instance trajectory scores
|
||||
return coef
|
||||
|
||||
|
||||
def compute_is_referred_cost(outputs, targets):
|
||||
pred_is_referred = outputs['pred_is_referred'].flatten(1, 2).softmax(
|
||||
dim=-1) # [t, b*nq, 2]
|
||||
device = pred_is_referred.device
|
||||
t = pred_is_referred.shape[0]
|
||||
# number of instance trajectories in each batch
|
||||
num_traj_per_batch = torch.tensor([len(v['masks']) for v in targets[0]],
|
||||
device=device)
|
||||
total_trajectories = num_traj_per_batch.sum()
|
||||
# note that ref_indices are shared across time steps:
|
||||
ref_indices = torch.tensor(
|
||||
[v['referred_instance_idx'] for v in targets[0]], device=device)
|
||||
# convert ref_indices to fit flattened batch targets:
|
||||
ref_indices += torch.cat(
|
||||
(torch.zeros(1, dtype=torch.long,
|
||||
device=device), num_traj_per_batch.cumsum(0)[:-1]))
|
||||
# number of instance trajectories in each batch
|
||||
target_is_referred = torch.zeros((t, total_trajectories, 2), device=device)
|
||||
# 'no object' class by default (for un-referred objects)
|
||||
target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device)
|
||||
if 'is_ref_inst_visible' in targets[0][
|
||||
0]: # visibility labels are available per-frame for the referred object:
|
||||
is_ref_inst_visible = torch.stack([
|
||||
torch.stack([t['is_ref_inst_visible'] for t in t_step])
|
||||
for t_step in targets
|
||||
]).permute(1, 0)
|
||||
for ref_idx, is_visible in zip(ref_indices, is_ref_inst_visible):
|
||||
is_visible = is_visible.nonzero().squeeze()
|
||||
target_is_referred[is_visible,
|
||||
ref_idx, :] = torch.tensor([1.0, 0.0],
|
||||
device=device)
|
||||
else: # assume that the referred object is visible in every frame:
|
||||
target_is_referred[:, ref_indices, :] = torch.tensor([1.0, 0.0],
|
||||
device=device)
|
||||
cost_is_referred = -(pred_is_referred.unsqueeze(2)
|
||||
* target_is_referred.unsqueeze(1)).sum(dim=-1).mean(
|
||||
dim=0)
|
||||
return cost_is_referred
|
||||
@@ -122,8 +122,8 @@ class MultimodalTransformer(nn.Module):
|
||||
with torch.inference_mode(mode=self.freeze_text_encoder):
|
||||
encoded_text = self.text_encoder(**tokenized_queries)
|
||||
# Transpose memory because pytorch's attention expects sequence first
|
||||
txt_memory = rearrange(encoded_text.last_hidden_state,
|
||||
'b s c -> s b c')
|
||||
tmp_last_hidden_state = encoded_text.last_hidden_state.clone()
|
||||
txt_memory = rearrange(tmp_last_hidden_state, 'b s c -> s b c')
|
||||
txt_memory = self.txt_proj(
|
||||
txt_memory) # change text embeddings dim to model dim
|
||||
# Invert attention mask that we get from huggingface because its the opposite in pytorch transformer
|
||||
|
||||
@@ -123,7 +123,8 @@ class WindowAttention3D(nn.Module):
|
||||
# define a parameter table of relative position bias
|
||||
wd, wh, ww = window_size
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), num_heads))
|
||||
torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1),
|
||||
num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_d = torch.arange(self.window_size[0])
|
||||
|
||||
@@ -13,6 +13,7 @@ if TYPE_CHECKING:
|
||||
from .video_summarization_dataset import VideoSummarizationDataset
|
||||
from .image_inpainting import ImageInpaintingDataset
|
||||
from .text_ranking_dataset import TextRankingDataset
|
||||
from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -29,6 +30,8 @@ else:
|
||||
'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'],
|
||||
'image_portrait_enhancement_dataset':
|
||||
['ImagePortraitEnhancementDataset'],
|
||||
'referring_video_object_segmentation':
|
||||
['ReferringVideoObjectSegmentationDataset'],
|
||||
}
|
||||
import sys
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .referring_video_object_segmentation_dataset import \
|
||||
ReferringVideoObjectSegmentationDataset
|
||||
@@ -0,0 +1,361 @@
|
||||
# Part of the implementation is borrowed and modified from MTTR,
|
||||
# publicly available at https://github.com/mttr2021/MTTR
|
||||
|
||||
from glob import glob
|
||||
from os import path as osp
|
||||
|
||||
import h5py
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torchvision.transforms.functional as F
|
||||
from pycocotools.mask import area, encode
|
||||
from torchvision.io import read_video
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.cv.referring_video_object_segmentation.utils import \
|
||||
nested_tensor_from_videos_list
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from . import transformers as T
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def get_image_id(video_id, frame_idx, ref_instance_a2d_id):
|
||||
image_id = f'v_{video_id}_f_{frame_idx}_i_{ref_instance_a2d_id}'
|
||||
return image_id
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
Tasks.referring_video_object_segmentation,
|
||||
module_name=Models.referring_video_object_segmentation)
|
||||
class ReferringVideoObjectSegmentationDataset(TorchTaskDataset):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
split_config = kwargs['split_config']
|
||||
LOGGER.info(kwargs)
|
||||
data_cfg = kwargs.get('cfg').data_kwargs
|
||||
trans_cfg = kwargs.get('cfg').transformers_kwargs
|
||||
distributed = data_cfg.get('distributed', False)
|
||||
|
||||
self.data_root = next(iter(split_config.values()))
|
||||
if not osp.exists(self.data_root):
|
||||
self.data_root = osp.dirname(self.data_root)
|
||||
assert osp.exists(self.data_root)
|
||||
|
||||
self.window_size = data_cfg.get('window_size', 8)
|
||||
self.mask_annotations_dir = osp.join(
|
||||
self.data_root, 'text_annotations/annotation_with_instances')
|
||||
self.videos_dir = osp.join(self.data_root, 'Release/CLIPS320')
|
||||
self.subset_type = next(iter(split_config.keys()))
|
||||
self.text_annotations = self.get_text_annotations(
|
||||
self.data_root, self.subset_type, distributed)
|
||||
self.transforms = A2dSentencesTransforms(self.subset_type, **trans_cfg)
|
||||
self.collator = Collator()
|
||||
self.ann_file = osp.join(
|
||||
self.data_root,
|
||||
data_cfg.get('ann_file',
|
||||
'a2d_sentences_test_annotations_in_coco_format.json'))
|
||||
|
||||
# create ground-truth test annotations for the evaluation process if necessary:
|
||||
if self.subset_type == 'test' and not osp.exists(self.ann_file):
|
||||
if (distributed and dist.get_rank() == 0) or not distributed:
|
||||
create_a2d_sentences_ground_truth_test_annotations(
|
||||
self.data_root, self.subset_type,
|
||||
self.mask_annotations_dir, self.ann_file)
|
||||
if distributed:
|
||||
dist.barrier()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.text_annotations)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
text_query, video_id, frame_idx, instance_id = self.text_annotations[
|
||||
idx]
|
||||
|
||||
text_query = ' '.join(
|
||||
text_query.lower().split()) # clean up the text query
|
||||
|
||||
# read the source window frames:
|
||||
video_frames, _, _ = read_video(
|
||||
osp.join(self.videos_dir, f'{video_id}.mp4'),
|
||||
pts_unit='sec') # (T, H, W, C)
|
||||
# get a window of window_size frames with frame frame_idx in the middle.
|
||||
# note that the original a2d dataset is 1 indexed, so we have to subtract 1 from frame_idx
|
||||
start_idx, end_idx = frame_idx - 1 - self.window_size // 2, frame_idx - 1 + (
|
||||
self.window_size + 1) // 2
|
||||
|
||||
# extract the window source frames:
|
||||
source_frames = []
|
||||
for i in range(start_idx, end_idx):
|
||||
i = min(max(i, 0),
|
||||
len(video_frames)
|
||||
- 1) # pad out of range indices with edge frames
|
||||
source_frames.append(
|
||||
F.to_pil_image(video_frames[i].permute(2, 0, 1)))
|
||||
|
||||
# read the instance mask:
|
||||
frame_annot_path = osp.join(self.mask_annotations_dir, video_id,
|
||||
f'{frame_idx:05d}.h5')
|
||||
f = h5py.File(frame_annot_path, 'r')
|
||||
instances = list(f['instance'])
|
||||
instance_idx = instances.index(
|
||||
instance_id) # existence was already validated during init
|
||||
|
||||
instance_masks = np.array(f['reMask'])
|
||||
if len(instances) == 1:
|
||||
instance_masks = instance_masks[np.newaxis, ...]
|
||||
instance_masks = torch.tensor(instance_masks).transpose(1, 2)
|
||||
mask_rles = [encode(mask) for mask in instance_masks.numpy()]
|
||||
mask_areas = area(mask_rles).astype(np.float)
|
||||
f.close()
|
||||
|
||||
# create the target dict for the center frame:
|
||||
target = {
|
||||
'masks': instance_masks,
|
||||
'orig_size': instance_masks.
|
||||
shape[-2:], # original frame shape without any augmentations
|
||||
# size with augmentations, will be changed inside transforms if necessary
|
||||
'size': instance_masks.shape[-2:],
|
||||
'referred_instance_idx': torch.tensor(
|
||||
instance_idx), # idx in 'masks' of the text referred instance
|
||||
'area': torch.tensor(mask_areas),
|
||||
'iscrowd':
|
||||
torch.zeros(len(instance_masks)
|
||||
), # for compatibility with DETR COCO transforms
|
||||
'image_id': get_image_id(video_id, frame_idx, instance_id)
|
||||
}
|
||||
|
||||
# create dummy targets for adjacent frames:
|
||||
targets = self.window_size * [None]
|
||||
center_frame_idx = self.window_size // 2
|
||||
targets[center_frame_idx] = target
|
||||
source_frames, targets, text_query = self.transforms(
|
||||
source_frames, targets, text_query)
|
||||
return source_frames, targets, text_query
|
||||
|
||||
@staticmethod
|
||||
def get_text_annotations(root_path, subset, distributed):
|
||||
saved_annotations_file_path = osp.join(
|
||||
root_path, f'sentences_single_frame_{subset}_annotations.json')
|
||||
if osp.exists(saved_annotations_file_path):
|
||||
with open(saved_annotations_file_path, 'r') as f:
|
||||
text_annotations_by_frame = [tuple(a) for a in json.load(f)]
|
||||
return text_annotations_by_frame
|
||||
elif (distributed and dist.get_rank() == 0) or not distributed:
|
||||
print(f'building a2d sentences {subset} text annotations...')
|
||||
# without 'header == None' pandas will ignore the first sample...
|
||||
a2d_data_info = pandas.read_csv(
|
||||
osp.join(root_path, 'Release/videoset.csv'), header=None)
|
||||
# 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset'
|
||||
a2d_data_info.columns = [
|
||||
'vid', '', '', '', '', '', '', '', 'subset'
|
||||
]
|
||||
with open(
|
||||
osp.join(root_path, 'text_annotations/missed_videos.txt'),
|
||||
'r') as f:
|
||||
unused_videos = f.read().splitlines()
|
||||
subsets = {'train': 0, 'test': 1}
|
||||
# filter unused videos and videos which do not belong to our train/test subset:
|
||||
used_videos = a2d_data_info[
|
||||
~a2d_data_info.vid.isin(unused_videos)
|
||||
& (a2d_data_info.subset == subsets[subset])]
|
||||
used_videos_ids = list(used_videos['vid'])
|
||||
text_annotations = pandas.read_csv(
|
||||
osp.join(root_path, 'text_annotations/annotation.txt'))
|
||||
# filter the text annotations based on the used videos:
|
||||
used_text_annotations = text_annotations[
|
||||
text_annotations.video_id.isin(used_videos_ids)]
|
||||
# remove a single dataset annotation mistake in video: T6bNPuKV-wY
|
||||
used_text_annotations = used_text_annotations[
|
||||
used_text_annotations['instance_id'] != '1 (copy)']
|
||||
# convert data-frame to list of tuples:
|
||||
used_text_annotations = list(
|
||||
used_text_annotations.to_records(index=False))
|
||||
text_annotations_by_frame = []
|
||||
mask_annotations_dir = osp.join(
|
||||
root_path, 'text_annotations/annotation_with_instances')
|
||||
for video_id, instance_id, text_query in tqdm(
|
||||
used_text_annotations):
|
||||
frame_annot_paths = sorted(
|
||||
glob(osp.join(mask_annotations_dir, video_id, '*.h5')))
|
||||
instance_id = int(instance_id)
|
||||
for p in frame_annot_paths:
|
||||
f = h5py.File(p)
|
||||
instances = list(f['instance'])
|
||||
if instance_id in instances:
|
||||
# in case this instance does not appear in this frame it has no ground-truth mask, and thus this
|
||||
# frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out:
|
||||
# https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98
|
||||
frame_idx = int(p.split('/')[-1].split('.')[0])
|
||||
text_query = text_query.lower(
|
||||
) # lower the text query prior to augmentation & tokenization
|
||||
text_annotations_by_frame.append(
|
||||
(text_query, video_id, frame_idx, instance_id))
|
||||
with open(saved_annotations_file_path, 'w') as f:
|
||||
json.dump(text_annotations_by_frame, f)
|
||||
if distributed:
|
||||
dist.barrier()
|
||||
with open(saved_annotations_file_path, 'r') as f:
|
||||
text_annotations_by_frame = [tuple(a) for a in json.load(f)]
|
||||
return text_annotations_by_frame
|
||||
|
||||
|
||||
class A2dSentencesTransforms:
|
||||
|
||||
def __init__(self, subset_type, horizontal_flip_augmentations,
|
||||
resize_and_crop_augmentations, train_short_size,
|
||||
train_max_size, eval_short_size, eval_max_size, **kwargs):
|
||||
self.h_flip_augmentation = subset_type == 'train' and horizontal_flip_augmentations
|
||||
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
scales = [
|
||||
train_short_size
|
||||
] # no more scales for now due to GPU memory constraints. might be changed later
|
||||
transforms = []
|
||||
if resize_and_crop_augmentations:
|
||||
if subset_type == 'train':
|
||||
transforms.append(
|
||||
T.RandomResize(scales, max_size=train_max_size))
|
||||
elif subset_type == 'test':
|
||||
transforms.append(
|
||||
T.RandomResize([eval_short_size], max_size=eval_max_size)),
|
||||
transforms.extend([T.ToTensor(), normalize])
|
||||
self.size_transforms = T.Compose(transforms)
|
||||
|
||||
def __call__(self, source_frames, targets, text_query):
|
||||
if self.h_flip_augmentation and torch.rand(1) > 0.5:
|
||||
source_frames = [F.hflip(f) for f in source_frames]
|
||||
targets[len(targets) // 2]['masks'] = F.hflip(
|
||||
targets[len(targets) // 2]['masks'])
|
||||
# Note - is it possible for both 'right' and 'left' to appear together in the same query. hence this fix:
|
||||
text_query = text_query.replace('left', '@').replace(
|
||||
'right', 'left').replace('@', 'right')
|
||||
source_frames, targets = list(
|
||||
zip(*[
|
||||
self.size_transforms(f, t)
|
||||
for f, t in zip(source_frames, targets)
|
||||
]))
|
||||
source_frames = torch.stack(source_frames) # [T, 3, H, W]
|
||||
return source_frames, targets, text_query
|
||||
|
||||
|
||||
class Collator:
|
||||
|
||||
def __call__(self, batch):
|
||||
samples, targets, text_queries = list(zip(*batch))
|
||||
samples = nested_tensor_from_videos_list(samples) # [T, B, C, H, W]
|
||||
# convert targets to a list of tuples. outer list - time steps, inner tuples - time step batch
|
||||
targets = list(zip(*targets))
|
||||
batch_dict = {
|
||||
'samples': samples,
|
||||
'targets': targets,
|
||||
'text_queries': text_queries
|
||||
}
|
||||
return batch_dict
|
||||
|
||||
|
||||
def get_text_annotations_gt(root_path, subset):
|
||||
# without 'header == None' pandas will ignore the first sample...
|
||||
a2d_data_info = pandas.read_csv(
|
||||
osp.join(root_path, 'Release/videoset.csv'), header=None)
|
||||
# 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset'
|
||||
a2d_data_info.columns = ['vid', '', '', '', '', '', '', '', 'subset']
|
||||
with open(osp.join(root_path, 'text_annotations/missed_videos.txt'),
|
||||
'r') as f:
|
||||
unused_videos = f.read().splitlines()
|
||||
subsets = {'train': 0, 'test': 1}
|
||||
# filter unused videos and videos which do not belong to our train/test subset:
|
||||
used_videos = a2d_data_info[~a2d_data_info.vid.isin(unused_videos)
|
||||
& (a2d_data_info.subset == subsets[subset])]
|
||||
used_videos_ids = list(used_videos['vid'])
|
||||
text_annotations = pandas.read_csv(
|
||||
osp.join(root_path, 'text_annotations/annotation.txt'))
|
||||
# filter the text annotations based on the used videos:
|
||||
used_text_annotations = text_annotations[text_annotations.video_id.isin(
|
||||
used_videos_ids)]
|
||||
# convert data-frame to list of tuples:
|
||||
used_text_annotations = list(used_text_annotations.to_records(index=False))
|
||||
return used_text_annotations
|
||||
|
||||
|
||||
def create_a2d_sentences_ground_truth_test_annotations(dataset_path,
|
||||
subset_type,
|
||||
mask_annotations_dir,
|
||||
output_path):
|
||||
text_annotations = get_text_annotations_gt(dataset_path, subset_type)
|
||||
|
||||
# Note - it is very important to start counting the instance and category ids from 1 (not 0). This is implicitly
|
||||
# expected by pycocotools as it is the convention of the original coco dataset annotations.
|
||||
|
||||
categories_dict = [{
|
||||
'id': 1,
|
||||
'name': 'dummy_class'
|
||||
}] # dummy class, as categories are not used/predicted in RVOS
|
||||
|
||||
images_dict = []
|
||||
annotations_dict = []
|
||||
images_set = set()
|
||||
instance_id_counter = 1
|
||||
for annot in tqdm(text_annotations):
|
||||
video_id, instance_id, text_query = annot
|
||||
annot_paths = sorted(
|
||||
glob(osp.join(mask_annotations_dir, video_id, '*.h5')))
|
||||
for p in annot_paths:
|
||||
f = h5py.File(p)
|
||||
instances = list(f['instance'])
|
||||
try:
|
||||
instance_idx = instances.index(int(instance_id))
|
||||
# in case this instance does not appear in this frame it has no ground-truth mask, and thus this
|
||||
# frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out:
|
||||
# https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98
|
||||
except ValueError:
|
||||
continue # instance_id does not appear in current frame
|
||||
mask = f['reMask'][instance_idx] if len(
|
||||
instances) > 1 else np.array(f['reMask'])
|
||||
mask = mask.transpose()
|
||||
|
||||
frame_idx = int(p.split('/')[-1].split('.')[0])
|
||||
image_id = get_image_id(video_id, frame_idx, instance_id)
|
||||
assert image_id not in images_set, f'error: image id: {image_id} appeared twice'
|
||||
images_set.add(image_id)
|
||||
images_dict.append({
|
||||
'id': image_id,
|
||||
'height': mask.shape[0],
|
||||
'width': mask.shape[1]
|
||||
})
|
||||
|
||||
mask_rle = encode(mask)
|
||||
mask_rle['counts'] = mask_rle['counts'].decode('ascii')
|
||||
mask_area = float(area(mask_rle))
|
||||
bbox = f['reBBox'][:, instance_idx] if len(
|
||||
instances) > 1 else np.array(
|
||||
f['reBBox']).squeeze() # x1y1x2y2 form
|
||||
bbox_xywh = [
|
||||
bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
]
|
||||
instance_annot = {
|
||||
'id': instance_id_counter,
|
||||
'image_id': image_id,
|
||||
'category_id':
|
||||
1, # dummy class, as categories are not used/predicted in ref-vos
|
||||
'segmentation': mask_rle,
|
||||
'area': mask_area,
|
||||
'bbox': bbox_xywh,
|
||||
'iscrowd': 0,
|
||||
}
|
||||
annotations_dict.append(instance_annot)
|
||||
instance_id_counter += 1
|
||||
dataset_dict = {
|
||||
'categories': categories_dict,
|
||||
'images': images_dict,
|
||||
'annotations': annotations_dict
|
||||
}
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(dataset_dict, f)
|
||||
@@ -0,0 +1,294 @@
|
||||
# The implementation is adopted from MTTR,
|
||||
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR
|
||||
# Modified from DETR https://github.com/facebookresearch/detr
|
||||
|
||||
import random
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
from modelscope.models.cv.referring_video_object_segmentation.utils import \
|
||||
interpolate
|
||||
|
||||
|
||||
def crop(image, target, region):
|
||||
cropped_image = F.crop(image, *region)
|
||||
|
||||
target = target.copy()
|
||||
i, j, h, w = region
|
||||
|
||||
# should we do something wrt the original size?
|
||||
target['size'] = torch.tensor([h, w])
|
||||
|
||||
fields = ['labels', 'area', 'iscrowd']
|
||||
|
||||
if 'boxes' in target:
|
||||
boxes = target['boxes']
|
||||
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
||||
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
|
||||
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
||||
cropped_boxes = cropped_boxes.clamp(min=0)
|
||||
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
||||
target['boxes'] = cropped_boxes.reshape(-1, 4)
|
||||
target['area'] = area
|
||||
fields.append('boxes')
|
||||
|
||||
if 'masks' in target:
|
||||
# FIXME should we update the area here if there are no boxes?
|
||||
target['masks'] = target['masks'][:, i:i + h, j:j + w]
|
||||
fields.append('masks')
|
||||
|
||||
# remove elements for which the boxes or masks that have zero area
|
||||
if 'boxes' in target or 'masks' in target:
|
||||
# favor boxes selection when defining which elements to keep
|
||||
# this is compatible with previous implementation
|
||||
if 'boxes' in target:
|
||||
cropped_boxes = target['boxes'].reshape(-1, 2, 2)
|
||||
keep = torch.all(
|
||||
cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
||||
else:
|
||||
keep = target['masks'].flatten(1).any(1)
|
||||
|
||||
for field in fields:
|
||||
target[field] = target[field][keep]
|
||||
|
||||
return cropped_image, target
|
||||
|
||||
|
||||
def hflip(image, target):
|
||||
flipped_image = F.hflip(image)
|
||||
|
||||
w, h = image.size
|
||||
|
||||
target = target.copy()
|
||||
if 'boxes' in target:
|
||||
boxes = target['boxes']
|
||||
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
|
||||
[-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
|
||||
target['boxes'] = boxes
|
||||
|
||||
if 'masks' in target:
|
||||
target['masks'] = target['masks'].flip(-1)
|
||||
|
||||
return flipped_image, target
|
||||
|
||||
|
||||
def resize(image, target, size, max_size=None):
|
||||
# size can be min_size (scalar) or (w, h) tuple
|
||||
|
||||
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
||||
w, h = image_size
|
||||
if max_size is not None:
|
||||
min_original_size = float(min((w, h)))
|
||||
max_original_size = float(max((w, h)))
|
||||
if max_original_size / min_original_size * size > max_size:
|
||||
size = int(
|
||||
round(max_size * min_original_size / max_original_size))
|
||||
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return (h, w)
|
||||
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
|
||||
return (oh, ow)
|
||||
|
||||
def get_size(image_size, size, max_size=None):
|
||||
if isinstance(size, (list, tuple)):
|
||||
return size[::-1]
|
||||
else:
|
||||
return get_size_with_aspect_ratio(image_size, size, max_size)
|
||||
|
||||
size = get_size(image.size, size, max_size)
|
||||
rescaled_image = F.resize(image, size)
|
||||
|
||||
if target is None:
|
||||
return rescaled_image, None
|
||||
|
||||
ratios = tuple(
|
||||
float(s) / float(s_orig)
|
||||
for s, s_orig in zip(rescaled_image.size, image.size))
|
||||
ratio_width, ratio_height = ratios
|
||||
|
||||
target = target.copy()
|
||||
if 'boxes' in target:
|
||||
boxes = target['boxes']
|
||||
scaled_boxes = boxes * torch.as_tensor(
|
||||
[ratio_width, ratio_height, ratio_width, ratio_height])
|
||||
target['boxes'] = scaled_boxes
|
||||
|
||||
if 'area' in target:
|
||||
area = target['area']
|
||||
scaled_area = area * (ratio_width * ratio_height)
|
||||
target['area'] = scaled_area
|
||||
|
||||
h, w = size
|
||||
target['size'] = torch.tensor([h, w])
|
||||
|
||||
if 'masks' in target:
|
||||
target['masks'] = interpolate(
|
||||
target['masks'][:, None].float(), size, mode='nearest')[:, 0] > 0.5
|
||||
|
||||
return rescaled_image, target
|
||||
|
||||
|
||||
def pad(image, target, padding):
|
||||
# assumes that we only pad on the bottom right corners
|
||||
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
|
||||
if target is None:
|
||||
return padded_image, None
|
||||
target = target.copy()
|
||||
# should we do something wrt the original size?
|
||||
target['size'] = torch.tensor(padded_image.size[::-1])
|
||||
if 'masks' in target:
|
||||
target['masks'] = torch.nn.functional.pad(
|
||||
target['masks'], (0, padding[0], 0, padding[1]))
|
||||
return padded_image, target
|
||||
|
||||
|
||||
class RandomCrop(object):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img, target):
|
||||
region = T.RandomCrop.get_params(img, self.size)
|
||||
return crop(img, target, region)
|
||||
|
||||
|
||||
class RandomSizeCrop(object):
|
||||
|
||||
def __init__(self, min_size: int, max_size: int):
|
||||
self.min_size = min_size
|
||||
self.max_size = max_size
|
||||
|
||||
def __call__(self, img: PIL.Image.Image, target: dict):
|
||||
w = random.randint(self.min_size, min(img.width, self.max_size))
|
||||
h = random.randint(self.min_size, min(img.height, self.max_size))
|
||||
region = T.RandomCrop.get_params(img, [h, w])
|
||||
return crop(img, target, region)
|
||||
|
||||
|
||||
class CenterCrop(object):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img, target):
|
||||
image_width, image_height = img.size
|
||||
crop_height, crop_width = self.size
|
||||
crop_top = int(round((image_height - crop_height) / 2.))
|
||||
crop_left = int(round((image_width - crop_width) / 2.))
|
||||
return crop(img, target,
|
||||
(crop_top, crop_left, crop_height, crop_width))
|
||||
|
||||
|
||||
class RandomHorizontalFlip(object):
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img, target):
|
||||
if random.random() < self.p:
|
||||
return hflip(img, target)
|
||||
return img, target
|
||||
|
||||
|
||||
class RandomResize(object):
|
||||
|
||||
def __init__(self, sizes, max_size=None):
|
||||
assert isinstance(sizes, (list, tuple))
|
||||
self.sizes = sizes
|
||||
self.max_size = max_size
|
||||
|
||||
def __call__(self, img, target=None):
|
||||
size = random.choice(self.sizes)
|
||||
return resize(img, target, size, self.max_size)
|
||||
|
||||
|
||||
class RandomPad(object):
|
||||
|
||||
def __init__(self, max_pad):
|
||||
self.max_pad = max_pad
|
||||
|
||||
def __call__(self, img, target):
|
||||
pad_x = random.randint(0, self.max_pad)
|
||||
pad_y = random.randint(0, self.max_pad)
|
||||
return pad(img, target, (pad_x, pad_y))
|
||||
|
||||
|
||||
class RandomSelect(object):
|
||||
"""
|
||||
Randomly selects between transforms1 and transforms2,
|
||||
with probability p for transforms1 and (1 - p) for transforms2
|
||||
"""
|
||||
|
||||
def __init__(self, transforms1, transforms2, p=0.5):
|
||||
self.transforms1 = transforms1
|
||||
self.transforms2 = transforms2
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img, target):
|
||||
if random.random() < self.p:
|
||||
return self.transforms1(img, target)
|
||||
return self.transforms2(img, target)
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
|
||||
def __call__(self, img, target):
|
||||
return F.to_tensor(img), target
|
||||
|
||||
|
||||
class RandomErasing(object):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.eraser = T.RandomErasing(*args, **kwargs)
|
||||
|
||||
def __call__(self, img, target):
|
||||
return self.eraser(img), target
|
||||
|
||||
|
||||
class Normalize(object):
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, image, target=None):
|
||||
image = F.normalize(image, mean=self.mean, std=self.std)
|
||||
if target is None:
|
||||
return image, None
|
||||
target = target.copy()
|
||||
h, w = image.shape[-2:]
|
||||
if 'boxes' in target:
|
||||
boxes = target['boxes']
|
||||
boxes = box_xyxy_to_cxcywh(boxes)
|
||||
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
|
||||
target['boxes'] = boxes
|
||||
return image, target
|
||||
|
||||
|
||||
class Compose(object):
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, image, target):
|
||||
for t in self.transforms:
|
||||
image, target = t(image, target)
|
||||
return image, target
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + '('
|
||||
for t in self.transforms:
|
||||
format_string += '\n'
|
||||
format_string += ' {0}'.format(t)
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
@@ -157,7 +157,13 @@ class ReferringVideoObjectSegmentationPipeline(Pipeline):
|
||||
* text_border_height_per_query, 0, 0))
|
||||
W, H = vid_frame.size
|
||||
draw = ImageDraw.Draw(vid_frame)
|
||||
font = ImageFont.truetype(font='DejaVuSansMono.ttf', size=30)
|
||||
|
||||
if self.model.cfg.pipeline.output_font:
|
||||
font = ImageFont.truetype(
|
||||
font=self.model.cfg.pipeline.output_font,
|
||||
size=self.model.cfg.pipeline.output_font_size)
|
||||
else:
|
||||
font = ImageFont.load_default()
|
||||
for i, (text_query, color) in enumerate(
|
||||
zip(self.text_queries, colors), start=1):
|
||||
w, h = draw.textsize(text_query, font=font)
|
||||
|
||||
@@ -9,7 +9,8 @@ if TYPE_CHECKING:
|
||||
from .builder import build_trainer
|
||||
from .cv import (ImageInstanceSegmentationTrainer,
|
||||
ImagePortraitEnhancementTrainer,
|
||||
MovieSceneSegmentationTrainer, ImageInpaintingTrainer)
|
||||
MovieSceneSegmentationTrainer, ImageInpaintingTrainer,
|
||||
ReferringVideoObjectSegmentationTrainer)
|
||||
from .multi_modal import CLIPTrainer
|
||||
from .nlp import SequenceClassificationTrainer, TextRankingTrainer
|
||||
from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer, NlpTrainerArguments
|
||||
|
||||
@@ -9,6 +9,7 @@ if TYPE_CHECKING:
|
||||
from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer
|
||||
from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer
|
||||
from .image_inpainting_trainer import ImageInpaintingTrainer
|
||||
from .referring_video_object_segmentation_trainer import ReferringVideoObjectSegmentationTrainer
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -17,7 +18,9 @@ else:
|
||||
'image_portrait_enhancement_trainer':
|
||||
['ImagePortraitEnhancementTrainer'],
|
||||
'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'],
|
||||
'image_inpainting_trainer': ['ImageInpaintingTrainer']
|
||||
'image_inpainting_trainer': ['ImageInpaintingTrainer'],
|
||||
'referring_video_object_segmentation_trainer':
|
||||
['ReferringVideoObjectSegmentationTrainer']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.trainers.trainer import EpochBasedTrainer
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
|
||||
|
||||
@TRAINERS.register_module(
|
||||
module_name=Trainers.referring_video_object_segmentation)
|
||||
class ReferringVideoObjectSegmentationTrainer(EpochBasedTrainer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model.set_postprocessor(self.cfg.dataset.name)
|
||||
self.train_data_collator = self.train_dataset.collator
|
||||
self.eval_data_collator = self.eval_dataset.collator
|
||||
|
||||
device_name = kwargs.get('device', 'gpu')
|
||||
self.model.set_device(self.device, device_name)
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
self.model.criterion.train()
|
||||
super().train(*args, **kwargs)
|
||||
|
||||
def evaluate(self, checkpoint_path=None):
|
||||
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
|
||||
from modelscope.trainers.hooks import CheckpointHook
|
||||
CheckpointHook.load_checkpoint(checkpoint_path, self)
|
||||
self.model.eval()
|
||||
self._mode = ModeKeys.EVAL
|
||||
if self.eval_dataset is None:
|
||||
self.eval_dataloader = self.get_eval_data_loader()
|
||||
else:
|
||||
self.eval_dataloader = self._build_dataloader_with_dataset(
|
||||
self.eval_dataset,
|
||||
dist=self._dist,
|
||||
seed=self._seed,
|
||||
collate_fn=self.eval_data_collator,
|
||||
**self.cfg.evaluation.get('dataloader', {}))
|
||||
self.data_loader = self.eval_dataloader
|
||||
|
||||
from modelscope.metrics import build_metric
|
||||
ann_file = self.eval_dataset.ann_file
|
||||
metric_classes = []
|
||||
for metric in self.metrics:
|
||||
metric.update({'ann_file': ann_file})
|
||||
metric_classes.append(build_metric(metric))
|
||||
|
||||
for m in metric_classes:
|
||||
m.trainer = self
|
||||
|
||||
metric_values = self.evaluation_loop(self.eval_dataloader,
|
||||
metric_classes)
|
||||
|
||||
self._metric_values = metric_values
|
||||
return metric_values
|
||||
|
||||
def prediction_step(self, model, inputs):
|
||||
pass
|
||||
@@ -62,7 +62,10 @@ def single_gpu_test(trainer,
|
||||
if 'nsentences' in data:
|
||||
batch_size = data['nsentences']
|
||||
else:
|
||||
batch_size = len(next(iter(data.values())))
|
||||
try:
|
||||
batch_size = len(next(iter(data.values())))
|
||||
except Exception:
|
||||
batch_size = data_loader.batch_size
|
||||
else:
|
||||
batch_size = len(data)
|
||||
for _ in range(batch_size):
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
import zipfile
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.cv.movie_scene_segmentation import \
|
||||
MovieSceneSegmentationModel
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestImageInstanceSegmentationTrainer(unittest.TestCase):
|
||||
|
||||
model_id = 'damo/cv_swin-t_referring_video-object-segmentation'
|
||||
dataset_name = 'referring_vos_toydata'
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
|
||||
cfg = Config.from_file(config_path)
|
||||
|
||||
max_epochs = cfg.train.max_epochs
|
||||
|
||||
train_data_cfg = ConfigDict(
|
||||
name=self.dataset_name,
|
||||
split='train',
|
||||
test_mode=False,
|
||||
cfg=cfg.dataset)
|
||||
|
||||
test_data_cfg = ConfigDict(
|
||||
name=self.dataset_name,
|
||||
split='test',
|
||||
test_mode=True,
|
||||
cfg=cfg.dataset)
|
||||
|
||||
self.train_dataset = MsDataset.load(
|
||||
dataset_name=train_data_cfg.name,
|
||||
split=train_data_cfg.split,
|
||||
cfg=train_data_cfg.cfg,
|
||||
namespace='damo',
|
||||
test_mode=train_data_cfg.test_mode)
|
||||
assert next(
|
||||
iter(self.train_dataset.config_kwargs['split_config'].values()))
|
||||
|
||||
self.test_dataset = MsDataset.load(
|
||||
dataset_name=test_data_cfg.name,
|
||||
split=test_data_cfg.split,
|
||||
cfg=test_data_cfg.cfg,
|
||||
namespace='damo',
|
||||
test_mode=test_data_cfg.test_mode)
|
||||
assert next(
|
||||
iter(self.test_dataset.config_kwargs['split_config'].values()))
|
||||
|
||||
self.max_epochs = max_epochs
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer(self):
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.test_dataset,
|
||||
work_dir='./work_dir')
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.referring_video_object_segmentation,
|
||||
default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(trainer.work_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_trainer_with_model_and_args(self):
|
||||
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
model = MovieSceneSegmentationModel.from_pretrained(cache_path)
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
|
||||
model=model,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.test_dataset,
|
||||
work_dir='./work_dir')
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.referring_video_object_segmentation,
|
||||
default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(trainer.work_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user