From ddcb57440d798b44fd90192bd58a60b69811ae43 Mon Sep 17 00:00:00 2001 From: "shuying.shu" Date: Thu, 27 Oct 2022 19:43:54 +0800 Subject: [PATCH] [to #42322933]add fine-tune code for referring video object segmentation Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10539423 --- modelscope/metainfo.py | 3 + modelscope/metrics/__init__.py | 3 + modelscope/metrics/builder.py | 2 + ...erring_video_object_segmentation_metric.py | 108 ++++++ .../__init__.py | 4 +- .../model.py | 95 ++++- .../utils/__init__.py | 4 +- .../utils/criterion.py | 198 ++++++++++ .../utils/matcher.py | 163 ++++++++ .../utils/multimodal_transformer.py | 4 +- .../utils/swin_transformer.py | 3 +- .../msdatasets/task_datasets/__init__.py | 3 + .../__init__.py | 3 + ...rring_video_object_segmentation_dataset.py | 361 ++++++++++++++++++ .../transformers.py | 294 ++++++++++++++ ...ring_video_object_segmentation_pipeline.py | 8 +- modelscope/trainers/__init__.py | 3 +- modelscope/trainers/cv/__init__.py | 5 +- ...rring_video_object_segmentation_trainer.py | 63 +++ modelscope/trainers/utils/inference.py | 5 +- ...rring_video_object_segmentation_trainer.py | 101 +++++ 21 files changed, 1414 insertions(+), 19 deletions(-) create mode 100644 modelscope/metrics/referring_video_object_segmentation_metric.py create mode 100644 modelscope/models/cv/referring_video_object_segmentation/utils/criterion.py create mode 100644 modelscope/models/cv/referring_video_object_segmentation/utils/matcher.py create mode 100644 modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py create mode 100644 modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py create mode 100644 modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py create mode 100644 modelscope/trainers/cv/referring_video_object_segmentation_trainer.py create mode 100644 tests/trainers/test_referring_video_object_segmentation_trainer.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index af60f072..2aeb86da 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -305,6 +305,7 @@ class Trainers(object): face_detection_scrfd = 'face-detection-scrfd' card_detection_scrfd = 'card-detection-scrfd' image_inpainting = 'image-inpainting' + referring_video_object_segmentation = 'referring-video-object-segmentation' image_classification_team = 'image-classification-team' # nlp trainers @@ -423,6 +424,8 @@ class Metrics(object): image_inpainting_metric = 'image-inpainting-metric' # metric for ocr NED = 'ned' + # metric for referring-video-object-segmentation task + referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric' class Optimizers(object): diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index c022eaf4..f106f054 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .accuracy_metric import AccuracyMetric from .bleu_metric import BleuMetric from .image_inpainting_metric import ImageInpaintingMetric + from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric else: _import_structure = { @@ -40,6 +41,8 @@ else: 'image_inpainting_metric': ['ImageInpaintingMetric'], 'accuracy_metric': ['AccuracyMetric'], 'bleu_metric': ['BleuMetric'], + 'referring_video_object_segmentation_metric': + ['ReferringVideoObjectSegmentationMetric'], } import sys diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index da3b64c7..2b61c1ae 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -43,6 +43,8 @@ task_default_metrics = { Tasks.visual_question_answering: [Metrics.text_gen_metric], Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], Tasks.image_inpainting: [Metrics.image_inpainting_metric], + Tasks.referring_video_object_segmentation: + [Metrics.referring_video_object_segmentation_metric], } diff --git a/modelscope/metrics/referring_video_object_segmentation_metric.py b/modelscope/metrics/referring_video_object_segmentation_metric.py new file mode 100644 index 00000000..5a0af30b --- /dev/null +++ b/modelscope/metrics/referring_video_object_segmentation_metric.py @@ -0,0 +1,108 @@ +# Part of the implementation is borrowed and modified from MTTR, +# publicly available at https://github.com/mttr2021/MTTR +from typing import Dict + +import numpy as np +import torch +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from pycocotools.mask import decode +from tqdm import tqdm + +from modelscope.metainfo import Metrics +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, + module_name=Metrics.referring_video_object_segmentation_metric) +class ReferringVideoObjectSegmentationMetric(Metric): + """The metric computation class for movie scene segmentation classes. + """ + + def __init__(self, + ann_file=None, + calculate_precision_and_iou_metrics=True): + self.ann_file = ann_file + self.calculate_precision_and_iou_metrics = calculate_precision_and_iou_metrics + self.preds = [] + + def add(self, outputs: Dict, inputs: Dict): + preds_batch = outputs['pred'] + self.preds.extend(preds_batch) + + def evaluate(self): + coco_gt = COCO(self.ann_file) + coco_pred = coco_gt.loadRes(self.preds) + coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm') + coco_eval.params.useCats = 0 + + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + ap_labels = [ + 'mAP 0.5:0.95', 'AP 0.5', 'AP 0.75', 'AP 0.5:0.95 S', + 'AP 0.5:0.95 M', 'AP 0.5:0.95 L' + ] + ap_metrics = coco_eval.stats[:6] + eval_metrics = {la: m for la, m in zip(ap_labels, ap_metrics)} + if self.calculate_precision_and_iou_metrics: + precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics( + coco_gt, coco_pred) + eval_metrics.update({ + f'P@{k}': m + for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k) + }) + eval_metrics.update({ + 'overall_iou': overall_iou, + 'mean_iou': mean_iou + }) + + return eval_metrics + + +def compute_iou(outputs: torch.Tensor, labels: torch.Tensor, EPS=1e-6): + outputs = outputs.int() + intersection = (outputs & labels).float().sum( + (1, 2)) # Will be zero if Truth=0 or Prediction=0 + union = (outputs | labels).float().sum( + (1, 2)) # Will be zero if both are 0 + iou = (intersection + EPS) / (union + EPS + ) # EPS is used to avoid division by zero + return iou, intersection, union + + +def calculate_precision_at_k_and_iou_metrics(coco_gt: COCO, coco_pred: COCO): + print('evaluating precision@k & iou metrics...') + counters_by_iou = {iou: 0 for iou in [0.5, 0.6, 0.7, 0.8, 0.9]} + total_intersection_area = 0 + total_union_area = 0 + ious_list = [] + for instance in tqdm(coco_gt.imgs.keys() + ): # each image_id contains exactly one instance + gt_annot = coco_gt.imgToAnns[instance][0] + gt_mask = decode(gt_annot['segmentation']) + pred_annots = coco_pred.imgToAnns[instance] + pred_annot = sorted( + pred_annots, + key=lambda a: a['score'])[-1] # choose pred with highest score + pred_mask = decode(pred_annot['segmentation']) + iou, intersection, union = compute_iou( + torch.tensor(pred_mask).unsqueeze(0), + torch.tensor(gt_mask).unsqueeze(0)) + iou, intersection, union = iou.item(), intersection.item(), union.item( + ) + for iou_threshold in counters_by_iou.keys(): + if iou > iou_threshold: + counters_by_iou[iou_threshold] += 1 + total_intersection_area += intersection + total_union_area += union + ious_list.append(iou) + num_samples = len(ious_list) + precision_at_k = np.array(list(counters_by_iou.values())) / num_samples + overall_iou = total_intersection_area / total_union_area + mean_iou = np.mean(ious_list) + return precision_at_k, overall_iou, mean_iou diff --git a/modelscope/models/cv/referring_video_object_segmentation/__init__.py b/modelscope/models/cv/referring_video_object_segmentation/__init__.py index 58dbf7b0..4c97bd7b 100644 --- a/modelscope/models/cv/referring_video_object_segmentation/__init__.py +++ b/modelscope/models/cv/referring_video_object_segmentation/__init__.py @@ -5,11 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .model import MovieSceneSegmentation + from .model import ReferringVideoObjectSegmentation else: _import_structure = { - 'model': ['MovieSceneSegmentation'], + 'model': ['ReferringVideoObjectSegmentation'], } import sys diff --git a/modelscope/models/cv/referring_video_object_segmentation/model.py b/modelscope/models/cv/referring_video_object_segmentation/model.py index 902a3416..91f7ea91 100644 --- a/modelscope/models/cv/referring_video_object_segmentation/model.py +++ b/modelscope/models/cv/referring_video_object_segmentation/model.py @@ -1,4 +1,6 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed and modified from MTTR, +# publicly available at https://github.com/mttr2021/MTTR + import os.path as osp from typing import Any, Dict @@ -10,7 +12,9 @@ from modelscope.models.builder import MODELS from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger -from .utils import (MTTR, A2DSentencesPostProcess, ReferYoutubeVOSPostProcess, +from .utils import (MTTR, A2DSentencesPostProcess, HungarianMatcher, + ReferYoutubeVOSPostProcess, SetCriterion, + flatten_temporal_batch_dims, nested_tensor_from_videos_list) logger = get_logger() @@ -35,16 +39,66 @@ class ReferringVideoObjectSegmentation(TorchModel): params_dict = params_dict['model_state_dict'] self.model.load_state_dict(params_dict, strict=True) - dataset_name = self.cfg.pipeline.dataset_name - if dataset_name == 'a2d_sentences' or dataset_name == 'jhmdb_sentences': - self.postprocessor = A2DSentencesPostProcess() - elif dataset_name == 'ref_youtube_vos': - self.postprocessor = ReferYoutubeVOSPostProcess() + self.set_postprocessor(self.cfg.pipeline.dataset_name) + self.set_criterion() + + def set_device(self, device, name): + self.device = device + self._device_name = name + + def set_postprocessor(self, dataset_name): + if 'a2d_sentences' in dataset_name or 'jhmdb_sentences' in dataset_name: + self.postprocessor = A2DSentencesPostProcess() # fine-tune + elif 'ref_youtube_vos' in dataset_name: + self.postprocessor = ReferYoutubeVOSPostProcess() # inference else: assert False, f'postprocessing for dataset: {dataset_name} is not supported' - def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: - return inputs + def forward(self, inputs: Dict[str, Any]): + samples = inputs['samples'] + targets = inputs['targets'] + text_queries = inputs['text_queries'] + + valid_indices = torch.tensor( + [i for i, t in enumerate(targets) if None not in t]) + targets = [targets[i] for i in valid_indices.tolist()] + if self._device_name == 'gpu': + samples = samples.to(self.device) + valid_indices = valid_indices.to(self.device) + if isinstance(text_queries, tuple): + text_queries = list(text_queries) + + outputs = self.model(samples, valid_indices, text_queries) + losses = -1 + if self.training: + loss_dict = self.criterion(outputs, targets) + weight_dict = self.criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] + for k in loss_dict.keys() if k in weight_dict) + + predictions = [] + if not self.training: + outputs.pop('aux_outputs', None) + outputs, targets = flatten_temporal_batch_dims(outputs, targets) + processed_outputs = self.postprocessor( + outputs, + resized_padded_sample_size=samples.tensors.shape[-2:], + resized_sample_sizes=[t['size'] for t in targets], + orig_sample_sizes=[t['orig_size'] for t in targets]) + image_ids = [t['image_id'] for t in targets] + predictions = [] + for p, image_id in zip(processed_outputs, image_ids): + for s, m in zip(p['scores'], p['rle_masks']): + predictions.append({ + 'image_id': image_id, + 'category_id': + 1, # dummy label, as categories are not predicted in ref-vos + 'segmentation': m, + 'score': s.item() + }) + + re = dict(pred=predictions, loss=losses) + return re def inference(self, **kwargs): window = kwargs['window'] @@ -63,3 +117,26 @@ class ReferringVideoObjectSegmentation(TorchModel): def postprocess(self, inputs: Dict[str, Any], **kwargs): return inputs + + def set_criterion(self): + matcher = HungarianMatcher( + cost_is_referred=self.cfg.matcher.set_cost_is_referred, + cost_dice=self.cfg.matcher.set_cost_dice) + weight_dict = { + 'loss_is_referred': self.cfg.loss.is_referred_loss_coef, + 'loss_dice': self.cfg.loss.dice_loss_coef, + 'loss_sigmoid_focal': self.cfg.loss.sigmoid_focal_loss_coef + } + + if self.cfg.loss.aux_loss: + aux_weight_dict = {} + for i in range(self.cfg.model.num_decoder_layers - 1): + aux_weight_dict.update( + {k + f'_{i}': v + for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + self.criterion = SetCriterion( + matcher=matcher, + weight_dict=weight_dict, + eos_coef=self.cfg.loss.eos_coef) diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py b/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py index 796bd6f4..fbb75b00 100644 --- a/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py @@ -1,4 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .misc import nested_tensor_from_videos_list +from .criterion import SetCriterion, flatten_temporal_batch_dims +from .matcher import HungarianMatcher +from .misc import interpolate, nested_tensor_from_videos_list from .mttr import MTTR from .postprocessing import A2DSentencesPostProcess, ReferYoutubeVOSPostProcess diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/criterion.py b/modelscope/models/cv/referring_video_object_segmentation/utils/criterion.py new file mode 100644 index 00000000..a4d2f0ff --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/criterion.py @@ -0,0 +1,198 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr +import torch +from torch import nn + +from .misc import (get_world_size, interpolate, is_dist_avail_and_initialized, + nested_tensor_from_tensor_list) +from .segmentation import dice_loss, sigmoid_focal_loss + + +class SetCriterion(nn.Module): + """ This class computes the loss for MTTR. + The process happens in two steps: + 1) we compute the hungarian assignment between the ground-truth and predicted sequences. + 2) we supervise each pair of matched ground-truth / prediction sequences (mask + reference prediction) + """ + + def __init__(self, matcher, weight_dict, eos_coef): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the un-referred category + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + # make sure that only loss functions with non-zero weights are computed: + losses_to_compute = [] + if weight_dict['loss_dice'] > 0 or weight_dict[ + 'loss_sigmoid_focal'] > 0: + losses_to_compute.append('masks') + if weight_dict['loss_is_referred'] > 0: + losses_to_compute.append('is_referred') + self.losses = losses_to_compute + + def forward(self, outputs, targets): + aux_outputs_list = outputs.pop('aux_outputs', None) + # compute the losses for the output of the last decoder layer: + losses = self.compute_criterion( + outputs, targets, losses_to_compute=self.losses) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate decoder layer. + if aux_outputs_list is not None: + aux_losses_to_compute = self.losses.copy() + for i, aux_outputs in enumerate(aux_outputs_list): + losses_dict = self.compute_criterion(aux_outputs, targets, + aux_losses_to_compute) + losses_dict = {k + f'_{i}': v for k, v in losses_dict.items()} + losses.update(losses_dict) + + return losses + + def compute_criterion(self, outputs, targets, losses_to_compute): + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs, targets) + + # T & B dims are flattened so loss functions can be computed per frame (but with same indices per video). + # also, indices are repeated so so the same indices can be used for frames of the same video. + T = len(targets) + outputs, targets = flatten_temporal_batch_dims(outputs, targets) + # repeat the indices list T times so the same indices can be used for each video frame + indices = T * indices + + # Compute the average number of target masks across all nodes, for normalization purposes + num_masks = sum(len(t['masks']) for t in targets) + num_masks = torch.as_tensor([num_masks], + dtype=torch.float, + device=indices[0][0].device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_masks) + num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in losses_to_compute: + losses.update( + self.get_loss( + loss, outputs, targets, indices, num_masks=num_masks)) + return losses + + def loss_is_referred(self, outputs, targets, indices, **kwargs): + device = outputs['pred_is_referred'].device + bs = outputs['pred_is_referred'].shape[0] + pred_is_referred = outputs['pred_is_referred'].log_softmax( + dim=-1) # note that log-softmax is used here + target_is_referred = torch.zeros_like(pred_is_referred) + # extract indices of object queries that where matched with text-referred target objects + query_referred_indices = self._get_query_referred_indices( + indices, targets) + # by default penalize compared to the no-object class (last token) + target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device) + if 'is_ref_inst_visible' in targets[ + 0]: # visibility labels are available per-frame for the referred object: + is_ref_inst_visible_per_frame = torch.stack( + [t['is_ref_inst_visible'] for t in targets]) + ref_inst_visible_frame_indices = is_ref_inst_visible_per_frame.nonzero( + ).squeeze() + # keep only the matched query indices of the frames in which the referred object is visible: + visible_query_referred_indices = query_referred_indices[ + ref_inst_visible_frame_indices] + target_is_referred[ref_inst_visible_frame_indices, + visible_query_referred_indices] = torch.tensor( + [1.0, 0.0], device=device) + else: # assume that the referred object is visible in every frame: + target_is_referred[torch.arange(bs), + query_referred_indices] = torch.tensor( + [1.0, 0.0], device=device) + loss = -(pred_is_referred * target_is_referred).sum(-1) + # apply no-object class weights: + eos_coef = torch.full(loss.shape, self.eos_coef, device=loss.device) + eos_coef[torch.arange(bs), query_referred_indices] = 1.0 + loss = loss * eos_coef + bs = len(indices) + loss = loss.sum() / bs # sum and normalize the loss by the batch size + losses = {'loss_is_referred': loss} + return losses + + def loss_masks(self, outputs, targets, indices, num_masks, **kwargs): + assert 'pred_masks' in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + src_masks = outputs['pred_masks'] + src_masks = src_masks[src_idx] + masks = [t['masks'] for t in targets] + target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() + target_masks = target_masks.to(src_masks) + target_masks = target_masks[tgt_idx] + + # upsample predictions to the target size + src_masks = interpolate( + src_masks[:, None], + size=target_masks.shape[-2:], + mode='bilinear', + align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(src_masks.shape) + losses = { + 'loss_sigmoid_focal': + sigmoid_focal_loss(src_masks, target_masks, num_masks), + 'loss_dice': + dice_loss(src_masks, target_masks, num_masks), + } + return losses + + @staticmethod + def _get_src_permutation_idx(indices): + # permute predictions following indices + batch_idx = torch.cat( + [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + @staticmethod + def _get_tgt_permutation_idx(indices): + # permute targets following indices + batch_idx = torch.cat( + [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + @staticmethod + def _get_query_referred_indices(indices, targets): + """ + extract indices of object queries that where matched with text-referred target objects + """ + query_referred_indices = [] + for (query_idxs, target_idxs), target in zip(indices, targets): + ref_query_idx = query_idxs[torch.where( + target_idxs == target['referred_instance_idx'])[0]] + query_referred_indices.append(ref_query_idx) + query_referred_indices = torch.cat(query_referred_indices) + return query_referred_indices + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + 'masks': self.loss_masks, + 'is_referred': self.loss_is_referred, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + +def flatten_temporal_batch_dims(outputs, targets): + for k in outputs.keys(): + if isinstance(outputs[k], torch.Tensor): + outputs[k] = outputs[k].flatten(0, 1) + else: # list + outputs[k] = [i for step_t in outputs[k] for i in step_t] + targets = [ + frame_t_target for step_t in targets for frame_t_target in step_t + ] + return outputs, targets diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/matcher.py b/modelscope/models/cv/referring_video_object_segmentation/utils/matcher.py new file mode 100644 index 00000000..4f9b88e5 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/matcher.py @@ -0,0 +1,163 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr +# Module to compute the matching cost and solve the corresponding LSAP. + +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn + +from .misc import interpolate, nested_tensor_from_tensor_list + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_is_referred: float = 1, cost_dice: float = 1): + """Creates the matcher + + Params: + cost_is_referred: This is the relative weight of the reference cost in the total matching cost + cost_dice: This is the relative weight of the dice cost in the total matching cost + """ + super().__init__() + self.cost_is_referred = cost_is_referred + self.cost_dice = cost_dice + assert cost_is_referred != 0 or cost_dice != 0, 'all costs cant be 0' + + @torch.inference_mode() + def forward(self, outputs, targets): + """ Performs the matching + + Params: + outputs: A dict that contains at least these entries: + "pred_is_referred": Tensor of dim [time, batch_size, num_queries, 2] with the reference logits + "pred_masks": Tensor of dim [time, batch_size, num_queries, H, W] with the predicted masks logits + + targets: A list of lists of targets (outer - time steps, inner - batch samples). each target is a dict + which contain mask and reference ground truth information for a single frame. + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_masks) + """ + t, bs, num_queries = outputs['pred_masks'].shape[:3] + + # We flatten to compute the cost matrices in a batch + out_masks = outputs['pred_masks'].flatten( + 1, 2) # [t, batch_size * num_queries, mask_h, mask_w] + + # preprocess and concat the target masks + tgt_masks = [[ + m for v in t_step_batch for m in v['masks'].unsqueeze(1) + ] for t_step_batch in targets] + # pad the target masks to a uniform shape + tgt_masks, valid = list( + zip(*[ + nested_tensor_from_tensor_list(t).decompose() + for t in tgt_masks + ])) + tgt_masks = torch.stack(tgt_masks).squeeze(2) + + # upsample predicted masks to target mask size + out_masks = interpolate( + out_masks, + size=tgt_masks.shape[-2:], + mode='bilinear', + align_corners=False) + + # Compute the soft-tokens cost: + if self.cost_is_referred > 0: + cost_is_referred = compute_is_referred_cost(outputs, targets) + else: + cost_is_referred = 0 + + # Compute the DICE coefficient between the masks: + if self.cost_dice > 0: + cost_dice = -dice_coef(out_masks, tgt_masks) + else: + cost_dice = 0 + + # Final cost matrix + C = self.cost_is_referred * cost_is_referred + self.cost_dice * cost_dice + C = C.view(bs, num_queries, -1).cpu() + + num_traj_per_batch = [ + len(v['masks']) for v in targets[0] + ] # number of instance trajectories in each batch + indices = [ + linear_sum_assignment(c[i]) + for i, c in enumerate(C.split(num_traj_per_batch, -1)) + ] + device = out_masks.device + return [(torch.as_tensor(i, dtype=torch.int64, device=device), + torch.as_tensor(j, dtype=torch.int64, device=device)) + for i, j in indices] + + +def dice_coef(inputs, targets, smooth=1.0): + """ + Compute the DICE coefficient, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid().flatten(2).unsqueeze(2) + targets = targets.flatten(2).unsqueeze(1) + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + coef = (numerator + smooth) / (denominator + smooth) + coef = coef.mean( + 0) # average on the temporal dim to get instance trajectory scores + return coef + + +def compute_is_referred_cost(outputs, targets): + pred_is_referred = outputs['pred_is_referred'].flatten(1, 2).softmax( + dim=-1) # [t, b*nq, 2] + device = pred_is_referred.device + t = pred_is_referred.shape[0] + # number of instance trajectories in each batch + num_traj_per_batch = torch.tensor([len(v['masks']) for v in targets[0]], + device=device) + total_trajectories = num_traj_per_batch.sum() + # note that ref_indices are shared across time steps: + ref_indices = torch.tensor( + [v['referred_instance_idx'] for v in targets[0]], device=device) + # convert ref_indices to fit flattened batch targets: + ref_indices += torch.cat( + (torch.zeros(1, dtype=torch.long, + device=device), num_traj_per_batch.cumsum(0)[:-1])) + # number of instance trajectories in each batch + target_is_referred = torch.zeros((t, total_trajectories, 2), device=device) + # 'no object' class by default (for un-referred objects) + target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device) + if 'is_ref_inst_visible' in targets[0][ + 0]: # visibility labels are available per-frame for the referred object: + is_ref_inst_visible = torch.stack([ + torch.stack([t['is_ref_inst_visible'] for t in t_step]) + for t_step in targets + ]).permute(1, 0) + for ref_idx, is_visible in zip(ref_indices, is_ref_inst_visible): + is_visible = is_visible.nonzero().squeeze() + target_is_referred[is_visible, + ref_idx, :] = torch.tensor([1.0, 0.0], + device=device) + else: # assume that the referred object is visible in every frame: + target_is_referred[:, ref_indices, :] = torch.tensor([1.0, 0.0], + device=device) + cost_is_referred = -(pred_is_referred.unsqueeze(2) + * target_is_referred.unsqueeze(1)).sum(dim=-1).mean( + dim=0) + return cost_is_referred diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py b/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py index 8c24e397..39962715 100644 --- a/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py @@ -122,8 +122,8 @@ class MultimodalTransformer(nn.Module): with torch.inference_mode(mode=self.freeze_text_encoder): encoded_text = self.text_encoder(**tokenized_queries) # Transpose memory because pytorch's attention expects sequence first - txt_memory = rearrange(encoded_text.last_hidden_state, - 'b s c -> s b c') + tmp_last_hidden_state = encoded_text.last_hidden_state.clone() + txt_memory = rearrange(tmp_last_hidden_state, 'b s c -> s b c') txt_memory = self.txt_proj( txt_memory) # change text embeddings dim to model dim # Invert attention mask that we get from huggingface because its the opposite in pytorch transformer diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py b/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py index 9a08ef48..faaf6e10 100644 --- a/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py @@ -123,7 +123,8 @@ class WindowAttention3D(nn.Module): # define a parameter table of relative position bias wd, wh, ww = window_size self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), num_heads)) + torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), + num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_d = torch.arange(self.window_size[0]) diff --git a/modelscope/msdatasets/task_datasets/__init__.py b/modelscope/msdatasets/task_datasets/__init__.py index 92764155..043010bf 100644 --- a/modelscope/msdatasets/task_datasets/__init__.py +++ b/modelscope/msdatasets/task_datasets/__init__.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from .video_summarization_dataset import VideoSummarizationDataset from .image_inpainting import ImageInpaintingDataset from .text_ranking_dataset import TextRankingDataset + from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset else: _import_structure = { @@ -29,6 +30,8 @@ else: 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], 'image_portrait_enhancement_dataset': ['ImagePortraitEnhancementDataset'], + 'referring_video_object_segmentation': + ['ReferringVideoObjectSegmentationDataset'], } import sys diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py new file mode 100644 index 00000000..7c1b724e --- /dev/null +++ b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .referring_video_object_segmentation_dataset import \ + ReferringVideoObjectSegmentationDataset diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py new file mode 100644 index 00000000..c90351e9 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py @@ -0,0 +1,361 @@ +# Part of the implementation is borrowed and modified from MTTR, +# publicly available at https://github.com/mttr2021/MTTR + +from glob import glob +from os import path as osp + +import h5py +import json +import numpy as np +import pandas +import torch +import torch.distributed as dist +import torchvision.transforms.functional as F +from pycocotools.mask import area, encode +from torchvision.io import read_video +from tqdm import tqdm + +from modelscope.metainfo import Models +from modelscope.models.cv.referring_video_object_segmentation.utils import \ + nested_tensor_from_videos_list +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from . import transformers as T + +LOGGER = get_logger() + + +def get_image_id(video_id, frame_idx, ref_instance_a2d_id): + image_id = f'v_{video_id}_f_{frame_idx}_i_{ref_instance_a2d_id}' + return image_id + + +@TASK_DATASETS.register_module( + Tasks.referring_video_object_segmentation, + module_name=Models.referring_video_object_segmentation) +class ReferringVideoObjectSegmentationDataset(TorchTaskDataset): + + def __init__(self, **kwargs): + split_config = kwargs['split_config'] + LOGGER.info(kwargs) + data_cfg = kwargs.get('cfg').data_kwargs + trans_cfg = kwargs.get('cfg').transformers_kwargs + distributed = data_cfg.get('distributed', False) + + self.data_root = next(iter(split_config.values())) + if not osp.exists(self.data_root): + self.data_root = osp.dirname(self.data_root) + assert osp.exists(self.data_root) + + self.window_size = data_cfg.get('window_size', 8) + self.mask_annotations_dir = osp.join( + self.data_root, 'text_annotations/annotation_with_instances') + self.videos_dir = osp.join(self.data_root, 'Release/CLIPS320') + self.subset_type = next(iter(split_config.keys())) + self.text_annotations = self.get_text_annotations( + self.data_root, self.subset_type, distributed) + self.transforms = A2dSentencesTransforms(self.subset_type, **trans_cfg) + self.collator = Collator() + self.ann_file = osp.join( + self.data_root, + data_cfg.get('ann_file', + 'a2d_sentences_test_annotations_in_coco_format.json')) + + # create ground-truth test annotations for the evaluation process if necessary: + if self.subset_type == 'test' and not osp.exists(self.ann_file): + if (distributed and dist.get_rank() == 0) or not distributed: + create_a2d_sentences_ground_truth_test_annotations( + self.data_root, self.subset_type, + self.mask_annotations_dir, self.ann_file) + if distributed: + dist.barrier() + + def __len__(self): + return len(self.text_annotations) + + def __getitem__(self, idx): + text_query, video_id, frame_idx, instance_id = self.text_annotations[ + idx] + + text_query = ' '.join( + text_query.lower().split()) # clean up the text query + + # read the source window frames: + video_frames, _, _ = read_video( + osp.join(self.videos_dir, f'{video_id}.mp4'), + pts_unit='sec') # (T, H, W, C) + # get a window of window_size frames with frame frame_idx in the middle. + # note that the original a2d dataset is 1 indexed, so we have to subtract 1 from frame_idx + start_idx, end_idx = frame_idx - 1 - self.window_size // 2, frame_idx - 1 + ( + self.window_size + 1) // 2 + + # extract the window source frames: + source_frames = [] + for i in range(start_idx, end_idx): + i = min(max(i, 0), + len(video_frames) + - 1) # pad out of range indices with edge frames + source_frames.append( + F.to_pil_image(video_frames[i].permute(2, 0, 1))) + + # read the instance mask: + frame_annot_path = osp.join(self.mask_annotations_dir, video_id, + f'{frame_idx:05d}.h5') + f = h5py.File(frame_annot_path, 'r') + instances = list(f['instance']) + instance_idx = instances.index( + instance_id) # existence was already validated during init + + instance_masks = np.array(f['reMask']) + if len(instances) == 1: + instance_masks = instance_masks[np.newaxis, ...] + instance_masks = torch.tensor(instance_masks).transpose(1, 2) + mask_rles = [encode(mask) for mask in instance_masks.numpy()] + mask_areas = area(mask_rles).astype(np.float) + f.close() + + # create the target dict for the center frame: + target = { + 'masks': instance_masks, + 'orig_size': instance_masks. + shape[-2:], # original frame shape without any augmentations + # size with augmentations, will be changed inside transforms if necessary + 'size': instance_masks.shape[-2:], + 'referred_instance_idx': torch.tensor( + instance_idx), # idx in 'masks' of the text referred instance + 'area': torch.tensor(mask_areas), + 'iscrowd': + torch.zeros(len(instance_masks) + ), # for compatibility with DETR COCO transforms + 'image_id': get_image_id(video_id, frame_idx, instance_id) + } + + # create dummy targets for adjacent frames: + targets = self.window_size * [None] + center_frame_idx = self.window_size // 2 + targets[center_frame_idx] = target + source_frames, targets, text_query = self.transforms( + source_frames, targets, text_query) + return source_frames, targets, text_query + + @staticmethod + def get_text_annotations(root_path, subset, distributed): + saved_annotations_file_path = osp.join( + root_path, f'sentences_single_frame_{subset}_annotations.json') + if osp.exists(saved_annotations_file_path): + with open(saved_annotations_file_path, 'r') as f: + text_annotations_by_frame = [tuple(a) for a in json.load(f)] + return text_annotations_by_frame + elif (distributed and dist.get_rank() == 0) or not distributed: + print(f'building a2d sentences {subset} text annotations...') + # without 'header == None' pandas will ignore the first sample... + a2d_data_info = pandas.read_csv( + osp.join(root_path, 'Release/videoset.csv'), header=None) + # 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset' + a2d_data_info.columns = [ + 'vid', '', '', '', '', '', '', '', 'subset' + ] + with open( + osp.join(root_path, 'text_annotations/missed_videos.txt'), + 'r') as f: + unused_videos = f.read().splitlines() + subsets = {'train': 0, 'test': 1} + # filter unused videos and videos which do not belong to our train/test subset: + used_videos = a2d_data_info[ + ~a2d_data_info.vid.isin(unused_videos) + & (a2d_data_info.subset == subsets[subset])] + used_videos_ids = list(used_videos['vid']) + text_annotations = pandas.read_csv( + osp.join(root_path, 'text_annotations/annotation.txt')) + # filter the text annotations based on the used videos: + used_text_annotations = text_annotations[ + text_annotations.video_id.isin(used_videos_ids)] + # remove a single dataset annotation mistake in video: T6bNPuKV-wY + used_text_annotations = used_text_annotations[ + used_text_annotations['instance_id'] != '1 (copy)'] + # convert data-frame to list of tuples: + used_text_annotations = list( + used_text_annotations.to_records(index=False)) + text_annotations_by_frame = [] + mask_annotations_dir = osp.join( + root_path, 'text_annotations/annotation_with_instances') + for video_id, instance_id, text_query in tqdm( + used_text_annotations): + frame_annot_paths = sorted( + glob(osp.join(mask_annotations_dir, video_id, '*.h5'))) + instance_id = int(instance_id) + for p in frame_annot_paths: + f = h5py.File(p) + instances = list(f['instance']) + if instance_id in instances: + # in case this instance does not appear in this frame it has no ground-truth mask, and thus this + # frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out: + # https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98 + frame_idx = int(p.split('/')[-1].split('.')[0]) + text_query = text_query.lower( + ) # lower the text query prior to augmentation & tokenization + text_annotations_by_frame.append( + (text_query, video_id, frame_idx, instance_id)) + with open(saved_annotations_file_path, 'w') as f: + json.dump(text_annotations_by_frame, f) + if distributed: + dist.barrier() + with open(saved_annotations_file_path, 'r') as f: + text_annotations_by_frame = [tuple(a) for a in json.load(f)] + return text_annotations_by_frame + + +class A2dSentencesTransforms: + + def __init__(self, subset_type, horizontal_flip_augmentations, + resize_and_crop_augmentations, train_short_size, + train_max_size, eval_short_size, eval_max_size, **kwargs): + self.h_flip_augmentation = subset_type == 'train' and horizontal_flip_augmentations + normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + scales = [ + train_short_size + ] # no more scales for now due to GPU memory constraints. might be changed later + transforms = [] + if resize_and_crop_augmentations: + if subset_type == 'train': + transforms.append( + T.RandomResize(scales, max_size=train_max_size)) + elif subset_type == 'test': + transforms.append( + T.RandomResize([eval_short_size], max_size=eval_max_size)), + transforms.extend([T.ToTensor(), normalize]) + self.size_transforms = T.Compose(transforms) + + def __call__(self, source_frames, targets, text_query): + if self.h_flip_augmentation and torch.rand(1) > 0.5: + source_frames = [F.hflip(f) for f in source_frames] + targets[len(targets) // 2]['masks'] = F.hflip( + targets[len(targets) // 2]['masks']) + # Note - is it possible for both 'right' and 'left' to appear together in the same query. hence this fix: + text_query = text_query.replace('left', '@').replace( + 'right', 'left').replace('@', 'right') + source_frames, targets = list( + zip(*[ + self.size_transforms(f, t) + for f, t in zip(source_frames, targets) + ])) + source_frames = torch.stack(source_frames) # [T, 3, H, W] + return source_frames, targets, text_query + + +class Collator: + + def __call__(self, batch): + samples, targets, text_queries = list(zip(*batch)) + samples = nested_tensor_from_videos_list(samples) # [T, B, C, H, W] + # convert targets to a list of tuples. outer list - time steps, inner tuples - time step batch + targets = list(zip(*targets)) + batch_dict = { + 'samples': samples, + 'targets': targets, + 'text_queries': text_queries + } + return batch_dict + + +def get_text_annotations_gt(root_path, subset): + # without 'header == None' pandas will ignore the first sample... + a2d_data_info = pandas.read_csv( + osp.join(root_path, 'Release/videoset.csv'), header=None) + # 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset' + a2d_data_info.columns = ['vid', '', '', '', '', '', '', '', 'subset'] + with open(osp.join(root_path, 'text_annotations/missed_videos.txt'), + 'r') as f: + unused_videos = f.read().splitlines() + subsets = {'train': 0, 'test': 1} + # filter unused videos and videos which do not belong to our train/test subset: + used_videos = a2d_data_info[~a2d_data_info.vid.isin(unused_videos) + & (a2d_data_info.subset == subsets[subset])] + used_videos_ids = list(used_videos['vid']) + text_annotations = pandas.read_csv( + osp.join(root_path, 'text_annotations/annotation.txt')) + # filter the text annotations based on the used videos: + used_text_annotations = text_annotations[text_annotations.video_id.isin( + used_videos_ids)] + # convert data-frame to list of tuples: + used_text_annotations = list(used_text_annotations.to_records(index=False)) + return used_text_annotations + + +def create_a2d_sentences_ground_truth_test_annotations(dataset_path, + subset_type, + mask_annotations_dir, + output_path): + text_annotations = get_text_annotations_gt(dataset_path, subset_type) + + # Note - it is very important to start counting the instance and category ids from 1 (not 0). This is implicitly + # expected by pycocotools as it is the convention of the original coco dataset annotations. + + categories_dict = [{ + 'id': 1, + 'name': 'dummy_class' + }] # dummy class, as categories are not used/predicted in RVOS + + images_dict = [] + annotations_dict = [] + images_set = set() + instance_id_counter = 1 + for annot in tqdm(text_annotations): + video_id, instance_id, text_query = annot + annot_paths = sorted( + glob(osp.join(mask_annotations_dir, video_id, '*.h5'))) + for p in annot_paths: + f = h5py.File(p) + instances = list(f['instance']) + try: + instance_idx = instances.index(int(instance_id)) + # in case this instance does not appear in this frame it has no ground-truth mask, and thus this + # frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out: + # https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98 + except ValueError: + continue # instance_id does not appear in current frame + mask = f['reMask'][instance_idx] if len( + instances) > 1 else np.array(f['reMask']) + mask = mask.transpose() + + frame_idx = int(p.split('/')[-1].split('.')[0]) + image_id = get_image_id(video_id, frame_idx, instance_id) + assert image_id not in images_set, f'error: image id: {image_id} appeared twice' + images_set.add(image_id) + images_dict.append({ + 'id': image_id, + 'height': mask.shape[0], + 'width': mask.shape[1] + }) + + mask_rle = encode(mask) + mask_rle['counts'] = mask_rle['counts'].decode('ascii') + mask_area = float(area(mask_rle)) + bbox = f['reBBox'][:, instance_idx] if len( + instances) > 1 else np.array( + f['reBBox']).squeeze() # x1y1x2y2 form + bbox_xywh = [ + bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1] + ] + instance_annot = { + 'id': instance_id_counter, + 'image_id': image_id, + 'category_id': + 1, # dummy class, as categories are not used/predicted in ref-vos + 'segmentation': mask_rle, + 'area': mask_area, + 'bbox': bbox_xywh, + 'iscrowd': 0, + } + annotations_dict.append(instance_annot) + instance_id_counter += 1 + dataset_dict = { + 'categories': categories_dict, + 'images': images_dict, + 'annotations': annotations_dict + } + with open(output_path, 'w') as f: + json.dump(dataset_dict, f) diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py new file mode 100644 index 00000000..a5067b1b --- /dev/null +++ b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py @@ -0,0 +1,294 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr + +import random + +import PIL +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F + +from modelscope.models.cv.referring_video_object_segmentation.utils import \ + interpolate + + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target['size'] = torch.tensor([h, w]) + + fields = ['labels', 'area', 'iscrowd'] + + if 'boxes' in target: + boxes = target['boxes'] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target['boxes'] = cropped_boxes.reshape(-1, 4) + target['area'] = area + fields.append('boxes') + + if 'masks' in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append('masks') + + # remove elements for which the boxes or masks that have zero area + if 'boxes' in target or 'masks' in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if 'boxes' in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all( + cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if 'boxes' in target: + boxes = target['boxes'] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( + [-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) + target['boxes'] = boxes + + if 'masks' in target: + target['masks'] = target['masks'].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int( + round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple( + float(s) / float(s_orig) + for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if 'boxes' in target: + boxes = target['boxes'] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height]) + target['boxes'] = scaled_boxes + + if 'area' in target: + area = target['area'] + scaled_area = area * (ratio_width * ratio_height) + target['area'] = scaled_area + + h, w = size + target['size'] = torch.tensor([h, w]) + + if 'masks' in target: + target['masks'] = interpolate( + target['masks'][:, None].float(), size, mode='nearest')[:, 0] > 0.5 + + return rescaled_image, target + + +def pad(image, target, padding): + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + target['size'] = torch.tensor(padded_image.size[::-1]) + if 'masks' in target: + target['masks'] = torch.nn.functional.pad( + target['masks'], (0, padding[0], 0, padding[1])) + return padded_image, target + + +class RandomCrop(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + region = T.RandomCrop.get_params(img, self.size) + return crop(img, target, region) + + +class RandomSizeCrop(object): + + def __init__(self, min_size: int, max_size: int): + self.min_size = min_size + self.max_size = max_size + + def __call__(self, img: PIL.Image.Image, target: dict): + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + region = T.RandomCrop.get_params(img, [h, w]) + return crop(img, target, region) + + +class CenterCrop(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, target, + (crop_top, crop_left, crop_height, crop_width)) + + +class RandomHorizontalFlip(object): + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RandomResize(object): + + def __init__(self, sizes, max_size=None): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) + + +class RandomPad(object): + + def __init__(self, max_pad): + self.max_pad = max_pad + + def __call__(self, img, target): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + return pad(img, target, (pad_x, pad_y)) + + +class RandomSelect(object): + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + + def __init__(self, transforms1, transforms2, p=0.5): + self.transforms1 = transforms1 + self.transforms2 = transforms2 + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return self.transforms1(img, target) + return self.transforms2(img, target) + + +class ToTensor(object): + + def __call__(self, img, target): + return F.to_tensor(img), target + + +class RandomErasing(object): + + def __init__(self, *args, **kwargs): + self.eraser = T.RandomErasing(*args, **kwargs) + + def __call__(self, img, target): + return self.eraser(img), target + + +class Normalize(object): + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if 'boxes' in target: + boxes = target['boxes'] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target['boxes'] = boxes + return image, target + + +class Compose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string diff --git a/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py b/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py index d264b386..cfbf2607 100644 --- a/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py +++ b/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py @@ -157,7 +157,13 @@ class ReferringVideoObjectSegmentationPipeline(Pipeline): * text_border_height_per_query, 0, 0)) W, H = vid_frame.size draw = ImageDraw.Draw(vid_frame) - font = ImageFont.truetype(font='DejaVuSansMono.ttf', size=30) + + if self.model.cfg.pipeline.output_font: + font = ImageFont.truetype( + font=self.model.cfg.pipeline.output_font, + size=self.model.cfg.pipeline.output_font_size) + else: + font = ImageFont.load_default() for i, (text_query, color) in enumerate( zip(self.text_queries, colors), start=1): w, h = draw.textsize(text_query, font=font) diff --git a/modelscope/trainers/__init__.py b/modelscope/trainers/__init__.py index d914489c..37fdcc12 100644 --- a/modelscope/trainers/__init__.py +++ b/modelscope/trainers/__init__.py @@ -9,7 +9,8 @@ if TYPE_CHECKING: from .builder import build_trainer from .cv import (ImageInstanceSegmentationTrainer, ImagePortraitEnhancementTrainer, - MovieSceneSegmentationTrainer, ImageInpaintingTrainer) + MovieSceneSegmentationTrainer, ImageInpaintingTrainer, + ReferringVideoObjectSegmentationTrainer) from .multi_modal import CLIPTrainer from .nlp import SequenceClassificationTrainer, TextRankingTrainer from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer, NlpTrainerArguments diff --git a/modelscope/trainers/cv/__init__.py b/modelscope/trainers/cv/__init__.py index d09fd75c..32c38de2 100644 --- a/modelscope/trainers/cv/__init__.py +++ b/modelscope/trainers/cv/__init__.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer from .image_inpainting_trainer import ImageInpaintingTrainer + from .referring_video_object_segmentation_trainer import ReferringVideoObjectSegmentationTrainer else: _import_structure = { @@ -17,7 +18,9 @@ else: 'image_portrait_enhancement_trainer': ['ImagePortraitEnhancementTrainer'], 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'], - 'image_inpainting_trainer': ['ImageInpaintingTrainer'] + 'image_inpainting_trainer': ['ImageInpaintingTrainer'], + 'referring_video_object_segmentation_trainer': + ['ReferringVideoObjectSegmentationTrainer'] } import sys diff --git a/modelscope/trainers/cv/referring_video_object_segmentation_trainer.py b/modelscope/trainers/cv/referring_video_object_segmentation_trainer.py new file mode 100644 index 00000000..c15df3a5 --- /dev/null +++ b/modelscope/trainers/cv/referring_video_object_segmentation_trainer.py @@ -0,0 +1,63 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +import torch + +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.trainer import EpochBasedTrainer +from modelscope.utils.constant import ModeKeys + + +@TRAINERS.register_module( + module_name=Trainers.referring_video_object_segmentation) +class ReferringVideoObjectSegmentationTrainer(EpochBasedTrainer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model.set_postprocessor(self.cfg.dataset.name) + self.train_data_collator = self.train_dataset.collator + self.eval_data_collator = self.eval_dataset.collator + + device_name = kwargs.get('device', 'gpu') + self.model.set_device(self.device, device_name) + + def train(self, *args, **kwargs): + self.model.criterion.train() + super().train(*args, **kwargs) + + def evaluate(self, checkpoint_path=None): + if checkpoint_path is not None and os.path.isfile(checkpoint_path): + from modelscope.trainers.hooks import CheckpointHook + CheckpointHook.load_checkpoint(checkpoint_path, self) + self.model.eval() + self._mode = ModeKeys.EVAL + if self.eval_dataset is None: + self.eval_dataloader = self.get_eval_data_loader() + else: + self.eval_dataloader = self._build_dataloader_with_dataset( + self.eval_dataset, + dist=self._dist, + seed=self._seed, + collate_fn=self.eval_data_collator, + **self.cfg.evaluation.get('dataloader', {})) + self.data_loader = self.eval_dataloader + + from modelscope.metrics import build_metric + ann_file = self.eval_dataset.ann_file + metric_classes = [] + for metric in self.metrics: + metric.update({'ann_file': ann_file}) + metric_classes.append(build_metric(metric)) + + for m in metric_classes: + m.trainer = self + + metric_values = self.evaluation_loop(self.eval_dataloader, + metric_classes) + + self._metric_values = metric_values + return metric_values + + def prediction_step(self, model, inputs): + pass diff --git a/modelscope/trainers/utils/inference.py b/modelscope/trainers/utils/inference.py index d6187b5f..6e4e7a19 100644 --- a/modelscope/trainers/utils/inference.py +++ b/modelscope/trainers/utils/inference.py @@ -62,7 +62,10 @@ def single_gpu_test(trainer, if 'nsentences' in data: batch_size = data['nsentences'] else: - batch_size = len(next(iter(data.values()))) + try: + batch_size = len(next(iter(data.values()))) + except Exception: + batch_size = data_loader.batch_size else: batch_size = len(data) for _ in range(batch_size): diff --git a/tests/trainers/test_referring_video_object_segmentation_trainer.py b/tests/trainers/test_referring_video_object_segmentation_trainer.py new file mode 100644 index 00000000..c1dc040d --- /dev/null +++ b/tests/trainers/test_referring_video_object_segmentation_trainer.py @@ -0,0 +1,101 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import zipfile + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.cv.movie_scene_segmentation import \ + MovieSceneSegmentationModel +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestImageInstanceSegmentationTrainer(unittest.TestCase): + + model_id = 'damo/cv_swin-t_referring_video-object-segmentation' + dataset_name = 'referring_vos_toydata' + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + cache_path = snapshot_download(self.model_id) + config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + + max_epochs = cfg.train.max_epochs + + train_data_cfg = ConfigDict( + name=self.dataset_name, + split='train', + test_mode=False, + cfg=cfg.dataset) + + test_data_cfg = ConfigDict( + name=self.dataset_name, + split='test', + test_mode=True, + cfg=cfg.dataset) + + self.train_dataset = MsDataset.load( + dataset_name=train_data_cfg.name, + split=train_data_cfg.split, + cfg=train_data_cfg.cfg, + namespace='damo', + test_mode=train_data_cfg.test_mode) + assert next( + iter(self.train_dataset.config_kwargs['split_config'].values())) + + self.test_dataset = MsDataset.load( + dataset_name=test_data_cfg.name, + split=test_data_cfg.split, + cfg=test_data_cfg.cfg, + namespace='damo', + test_mode=test_data_cfg.test_mode) + assert next( + iter(self.test_dataset.config_kwargs['split_config'].values())) + + self.max_epochs = max_epochs + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir='./work_dir') + + trainer = build_trainer( + name=Trainers.referring_video_object_segmentation, + default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + + cache_path = snapshot_download(self.model_id) + model = MovieSceneSegmentationModel.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir='./work_dir') + + trainer = build_trainer( + name=Trainers.referring_video_object_segmentation, + default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + +if __name__ == '__main__': + unittest.main()