[to #42322933]add fine-tune code for referring video object segmentation

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10539423
This commit is contained in:
shuying.shu
2022-10-27 19:43:54 +08:00
committed by yingda.chen
parent 3b75623be4
commit ddcb57440d
21 changed files with 1414 additions and 19 deletions

View File

@@ -305,6 +305,7 @@ class Trainers(object):
face_detection_scrfd = 'face-detection-scrfd'
card_detection_scrfd = 'card-detection-scrfd'
image_inpainting = 'image-inpainting'
referring_video_object_segmentation = 'referring-video-object-segmentation'
image_classification_team = 'image-classification-team'
# nlp trainers
@@ -423,6 +424,8 @@ class Metrics(object):
image_inpainting_metric = 'image-inpainting-metric'
# metric for ocr
NED = 'ned'
# metric for referring-video-object-segmentation task
referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric'
class Optimizers(object):

View File

@@ -20,6 +20,7 @@ if TYPE_CHECKING:
from .accuracy_metric import AccuracyMetric
from .bleu_metric import BleuMetric
from .image_inpainting_metric import ImageInpaintingMetric
from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric
else:
_import_structure = {
@@ -40,6 +41,8 @@ else:
'image_inpainting_metric': ['ImageInpaintingMetric'],
'accuracy_metric': ['AccuracyMetric'],
'bleu_metric': ['BleuMetric'],
'referring_video_object_segmentation_metric':
['ReferringVideoObjectSegmentationMetric'],
}
import sys

View File

@@ -43,6 +43,8 @@ task_default_metrics = {
Tasks.visual_question_answering: [Metrics.text_gen_metric],
Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric],
Tasks.image_inpainting: [Metrics.image_inpainting_metric],
Tasks.referring_video_object_segmentation:
[Metrics.referring_video_object_segmentation_metric],
}

View File

@@ -0,0 +1,108 @@
# Part of the implementation is borrowed and modified from MTTR,
# publicly available at https://github.com/mttr2021/MTTR
from typing import Dict
import numpy as np
import torch
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from pycocotools.mask import decode
from tqdm import tqdm
from modelscope.metainfo import Metrics
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys
@METRICS.register_module(
group_key=default_group,
module_name=Metrics.referring_video_object_segmentation_metric)
class ReferringVideoObjectSegmentationMetric(Metric):
"""The metric computation class for movie scene segmentation classes.
"""
def __init__(self,
ann_file=None,
calculate_precision_and_iou_metrics=True):
self.ann_file = ann_file
self.calculate_precision_and_iou_metrics = calculate_precision_and_iou_metrics
self.preds = []
def add(self, outputs: Dict, inputs: Dict):
preds_batch = outputs['pred']
self.preds.extend(preds_batch)
def evaluate(self):
coco_gt = COCO(self.ann_file)
coco_pred = coco_gt.loadRes(self.preds)
coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm')
coco_eval.params.useCats = 0
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
ap_labels = [
'mAP 0.5:0.95', 'AP 0.5', 'AP 0.75', 'AP 0.5:0.95 S',
'AP 0.5:0.95 M', 'AP 0.5:0.95 L'
]
ap_metrics = coco_eval.stats[:6]
eval_metrics = {la: m for la, m in zip(ap_labels, ap_metrics)}
if self.calculate_precision_and_iou_metrics:
precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics(
coco_gt, coco_pred)
eval_metrics.update({
f'P@{k}': m
for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k)
})
eval_metrics.update({
'overall_iou': overall_iou,
'mean_iou': mean_iou
})
return eval_metrics
def compute_iou(outputs: torch.Tensor, labels: torch.Tensor, EPS=1e-6):
outputs = outputs.int()
intersection = (outputs & labels).float().sum(
(1, 2)) # Will be zero if Truth=0 or Prediction=0
union = (outputs | labels).float().sum(
(1, 2)) # Will be zero if both are 0
iou = (intersection + EPS) / (union + EPS
) # EPS is used to avoid division by zero
return iou, intersection, union
def calculate_precision_at_k_and_iou_metrics(coco_gt: COCO, coco_pred: COCO):
print('evaluating precision@k & iou metrics...')
counters_by_iou = {iou: 0 for iou in [0.5, 0.6, 0.7, 0.8, 0.9]}
total_intersection_area = 0
total_union_area = 0
ious_list = []
for instance in tqdm(coco_gt.imgs.keys()
): # each image_id contains exactly one instance
gt_annot = coco_gt.imgToAnns[instance][0]
gt_mask = decode(gt_annot['segmentation'])
pred_annots = coco_pred.imgToAnns[instance]
pred_annot = sorted(
pred_annots,
key=lambda a: a['score'])[-1] # choose pred with highest score
pred_mask = decode(pred_annot['segmentation'])
iou, intersection, union = compute_iou(
torch.tensor(pred_mask).unsqueeze(0),
torch.tensor(gt_mask).unsqueeze(0))
iou, intersection, union = iou.item(), intersection.item(), union.item(
)
for iou_threshold in counters_by_iou.keys():
if iou > iou_threshold:
counters_by_iou[iou_threshold] += 1
total_intersection_area += intersection
total_union_area += union
ious_list.append(iou)
num_samples = len(ious_list)
precision_at_k = np.array(list(counters_by_iou.values())) / num_samples
overall_iou = total_intersection_area / total_union_area
mean_iou = np.mean(ious_list)
return precision_at_k, overall_iou, mean_iou

View File

@@ -5,11 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .model import MovieSceneSegmentation
from .model import ReferringVideoObjectSegmentation
else:
_import_structure = {
'model': ['MovieSceneSegmentation'],
'model': ['ReferringVideoObjectSegmentation'],
}
import sys

View File

@@ -1,4 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed and modified from MTTR,
# publicly available at https://github.com/mttr2021/MTTR
import os.path as osp
from typing import Any, Dict
@@ -10,7 +12,9 @@ from modelscope.models.builder import MODELS
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .utils import (MTTR, A2DSentencesPostProcess, ReferYoutubeVOSPostProcess,
from .utils import (MTTR, A2DSentencesPostProcess, HungarianMatcher,
ReferYoutubeVOSPostProcess, SetCriterion,
flatten_temporal_batch_dims,
nested_tensor_from_videos_list)
logger = get_logger()
@@ -35,16 +39,66 @@ class ReferringVideoObjectSegmentation(TorchModel):
params_dict = params_dict['model_state_dict']
self.model.load_state_dict(params_dict, strict=True)
dataset_name = self.cfg.pipeline.dataset_name
if dataset_name == 'a2d_sentences' or dataset_name == 'jhmdb_sentences':
self.postprocessor = A2DSentencesPostProcess()
elif dataset_name == 'ref_youtube_vos':
self.postprocessor = ReferYoutubeVOSPostProcess()
self.set_postprocessor(self.cfg.pipeline.dataset_name)
self.set_criterion()
def set_device(self, device, name):
self.device = device
self._device_name = name
def set_postprocessor(self, dataset_name):
if 'a2d_sentences' in dataset_name or 'jhmdb_sentences' in dataset_name:
self.postprocessor = A2DSentencesPostProcess() # fine-tune
elif 'ref_youtube_vos' in dataset_name:
self.postprocessor = ReferYoutubeVOSPostProcess() # inference
else:
assert False, f'postprocessing for dataset: {dataset_name} is not supported'
def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
return inputs
def forward(self, inputs: Dict[str, Any]):
samples = inputs['samples']
targets = inputs['targets']
text_queries = inputs['text_queries']
valid_indices = torch.tensor(
[i for i, t in enumerate(targets) if None not in t])
targets = [targets[i] for i in valid_indices.tolist()]
if self._device_name == 'gpu':
samples = samples.to(self.device)
valid_indices = valid_indices.to(self.device)
if isinstance(text_queries, tuple):
text_queries = list(text_queries)
outputs = self.model(samples, valid_indices, text_queries)
losses = -1
if self.training:
loss_dict = self.criterion(outputs, targets)
weight_dict = self.criterion.weight_dict
losses = sum(loss_dict[k] * weight_dict[k]
for k in loss_dict.keys() if k in weight_dict)
predictions = []
if not self.training:
outputs.pop('aux_outputs', None)
outputs, targets = flatten_temporal_batch_dims(outputs, targets)
processed_outputs = self.postprocessor(
outputs,
resized_padded_sample_size=samples.tensors.shape[-2:],
resized_sample_sizes=[t['size'] for t in targets],
orig_sample_sizes=[t['orig_size'] for t in targets])
image_ids = [t['image_id'] for t in targets]
predictions = []
for p, image_id in zip(processed_outputs, image_ids):
for s, m in zip(p['scores'], p['rle_masks']):
predictions.append({
'image_id': image_id,
'category_id':
1, # dummy label, as categories are not predicted in ref-vos
'segmentation': m,
'score': s.item()
})
re = dict(pred=predictions, loss=losses)
return re
def inference(self, **kwargs):
window = kwargs['window']
@@ -63,3 +117,26 @@ class ReferringVideoObjectSegmentation(TorchModel):
def postprocess(self, inputs: Dict[str, Any], **kwargs):
return inputs
def set_criterion(self):
matcher = HungarianMatcher(
cost_is_referred=self.cfg.matcher.set_cost_is_referred,
cost_dice=self.cfg.matcher.set_cost_dice)
weight_dict = {
'loss_is_referred': self.cfg.loss.is_referred_loss_coef,
'loss_dice': self.cfg.loss.dice_loss_coef,
'loss_sigmoid_focal': self.cfg.loss.sigmoid_focal_loss_coef
}
if self.cfg.loss.aux_loss:
aux_weight_dict = {}
for i in range(self.cfg.model.num_decoder_layers - 1):
aux_weight_dict.update(
{k + f'_{i}': v
for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
self.criterion = SetCriterion(
matcher=matcher,
weight_dict=weight_dict,
eos_coef=self.cfg.loss.eos_coef)

View File

@@ -1,4 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .misc import nested_tensor_from_videos_list
from .criterion import SetCriterion, flatten_temporal_batch_dims
from .matcher import HungarianMatcher
from .misc import interpolate, nested_tensor_from_videos_list
from .mttr import MTTR
from .postprocessing import A2DSentencesPostProcess, ReferYoutubeVOSPostProcess

View File

@@ -0,0 +1,198 @@
# The implementation is adopted from MTTR,
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR
# Modified from DETR https://github.com/facebookresearch/detr
import torch
from torch import nn
from .misc import (get_world_size, interpolate, is_dist_avail_and_initialized,
nested_tensor_from_tensor_list)
from .segmentation import dice_loss, sigmoid_focal_loss
class SetCriterion(nn.Module):
""" This class computes the loss for MTTR.
The process happens in two steps:
1) we compute the hungarian assignment between the ground-truth and predicted sequences.
2) we supervise each pair of matched ground-truth / prediction sequences (mask + reference prediction)
"""
def __init__(self, matcher, weight_dict, eos_coef):
""" Create the criterion.
Parameters:
matcher: module able to compute a matching between targets and proposals
weight_dict: dict containing as key the names of the losses and as values their relative weight.
eos_coef: relative classification weight applied to the un-referred category
"""
super().__init__()
self.matcher = matcher
self.weight_dict = weight_dict
self.eos_coef = eos_coef
# make sure that only loss functions with non-zero weights are computed:
losses_to_compute = []
if weight_dict['loss_dice'] > 0 or weight_dict[
'loss_sigmoid_focal'] > 0:
losses_to_compute.append('masks')
if weight_dict['loss_is_referred'] > 0:
losses_to_compute.append('is_referred')
self.losses = losses_to_compute
def forward(self, outputs, targets):
aux_outputs_list = outputs.pop('aux_outputs', None)
# compute the losses for the output of the last decoder layer:
losses = self.compute_criterion(
outputs, targets, losses_to_compute=self.losses)
# In case of auxiliary losses, we repeat this process with the output of each intermediate decoder layer.
if aux_outputs_list is not None:
aux_losses_to_compute = self.losses.copy()
for i, aux_outputs in enumerate(aux_outputs_list):
losses_dict = self.compute_criterion(aux_outputs, targets,
aux_losses_to_compute)
losses_dict = {k + f'_{i}': v for k, v in losses_dict.items()}
losses.update(losses_dict)
return losses
def compute_criterion(self, outputs, targets, losses_to_compute):
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs, targets)
# T & B dims are flattened so loss functions can be computed per frame (but with same indices per video).
# also, indices are repeated so so the same indices can be used for frames of the same video.
T = len(targets)
outputs, targets = flatten_temporal_batch_dims(outputs, targets)
# repeat the indices list T times so the same indices can be used for each video frame
indices = T * indices
# Compute the average number of target masks across all nodes, for normalization purposes
num_masks = sum(len(t['masks']) for t in targets)
num_masks = torch.as_tensor([num_masks],
dtype=torch.float,
device=indices[0][0].device)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_masks)
num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
# Compute all the requested losses
losses = {}
for loss in losses_to_compute:
losses.update(
self.get_loss(
loss, outputs, targets, indices, num_masks=num_masks))
return losses
def loss_is_referred(self, outputs, targets, indices, **kwargs):
device = outputs['pred_is_referred'].device
bs = outputs['pred_is_referred'].shape[0]
pred_is_referred = outputs['pred_is_referred'].log_softmax(
dim=-1) # note that log-softmax is used here
target_is_referred = torch.zeros_like(pred_is_referred)
# extract indices of object queries that where matched with text-referred target objects
query_referred_indices = self._get_query_referred_indices(
indices, targets)
# by default penalize compared to the no-object class (last token)
target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device)
if 'is_ref_inst_visible' in targets[
0]: # visibility labels are available per-frame for the referred object:
is_ref_inst_visible_per_frame = torch.stack(
[t['is_ref_inst_visible'] for t in targets])
ref_inst_visible_frame_indices = is_ref_inst_visible_per_frame.nonzero(
).squeeze()
# keep only the matched query indices of the frames in which the referred object is visible:
visible_query_referred_indices = query_referred_indices[
ref_inst_visible_frame_indices]
target_is_referred[ref_inst_visible_frame_indices,
visible_query_referred_indices] = torch.tensor(
[1.0, 0.0], device=device)
else: # assume that the referred object is visible in every frame:
target_is_referred[torch.arange(bs),
query_referred_indices] = torch.tensor(
[1.0, 0.0], device=device)
loss = -(pred_is_referred * target_is_referred).sum(-1)
# apply no-object class weights:
eos_coef = torch.full(loss.shape, self.eos_coef, device=loss.device)
eos_coef[torch.arange(bs), query_referred_indices] = 1.0
loss = loss * eos_coef
bs = len(indices)
loss = loss.sum() / bs # sum and normalize the loss by the batch size
losses = {'loss_is_referred': loss}
return losses
def loss_masks(self, outputs, targets, indices, num_masks, **kwargs):
assert 'pred_masks' in outputs
src_idx = self._get_src_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices)
src_masks = outputs['pred_masks']
src_masks = src_masks[src_idx]
masks = [t['masks'] for t in targets]
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks)
target_masks = target_masks[tgt_idx]
# upsample predictions to the target size
src_masks = interpolate(
src_masks[:, None],
size=target_masks.shape[-2:],
mode='bilinear',
align_corners=False)
src_masks = src_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape)
losses = {
'loss_sigmoid_focal':
sigmoid_focal_loss(src_masks, target_masks, num_masks),
'loss_dice':
dice_loss(src_masks, target_masks, num_masks),
}
return losses
@staticmethod
def _get_src_permutation_idx(indices):
# permute predictions following indices
batch_idx = torch.cat(
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
@staticmethod
def _get_tgt_permutation_idx(indices):
# permute targets following indices
batch_idx = torch.cat(
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
@staticmethod
def _get_query_referred_indices(indices, targets):
"""
extract indices of object queries that where matched with text-referred target objects
"""
query_referred_indices = []
for (query_idxs, target_idxs), target in zip(indices, targets):
ref_query_idx = query_idxs[torch.where(
target_idxs == target['referred_instance_idx'])[0]]
query_referred_indices.append(ref_query_idx)
query_referred_indices = torch.cat(query_referred_indices)
return query_referred_indices
def get_loss(self, loss, outputs, targets, indices, **kwargs):
loss_map = {
'masks': self.loss_masks,
'is_referred': self.loss_is_referred,
}
assert loss in loss_map, f'do you really want to compute {loss} loss?'
return loss_map[loss](outputs, targets, indices, **kwargs)
def flatten_temporal_batch_dims(outputs, targets):
for k in outputs.keys():
if isinstance(outputs[k], torch.Tensor):
outputs[k] = outputs[k].flatten(0, 1)
else: # list
outputs[k] = [i for step_t in outputs[k] for i in step_t]
targets = [
frame_t_target for step_t in targets for frame_t_target in step_t
]
return outputs, targets

View File

@@ -0,0 +1,163 @@
# The implementation is adopted from MTTR,
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR
# Modified from DETR https://github.com/facebookresearch/detr
# Module to compute the matching cost and solve the corresponding LSAP.
import torch
from scipy.optimize import linear_sum_assignment
from torch import nn
from .misc import interpolate, nested_tensor_from_tensor_list
class HungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def __init__(self, cost_is_referred: float = 1, cost_dice: float = 1):
"""Creates the matcher
Params:
cost_is_referred: This is the relative weight of the reference cost in the total matching cost
cost_dice: This is the relative weight of the dice cost in the total matching cost
"""
super().__init__()
self.cost_is_referred = cost_is_referred
self.cost_dice = cost_dice
assert cost_is_referred != 0 or cost_dice != 0, 'all costs cant be 0'
@torch.inference_mode()
def forward(self, outputs, targets):
""" Performs the matching
Params:
outputs: A dict that contains at least these entries:
"pred_is_referred": Tensor of dim [time, batch_size, num_queries, 2] with the reference logits
"pred_masks": Tensor of dim [time, batch_size, num_queries, H, W] with the predicted masks logits
targets: A list of lists of targets (outer - time steps, inner - batch samples). each target is a dict
which contain mask and reference ground truth information for a single frame.
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_masks)
"""
t, bs, num_queries = outputs['pred_masks'].shape[:3]
# We flatten to compute the cost matrices in a batch
out_masks = outputs['pred_masks'].flatten(
1, 2) # [t, batch_size * num_queries, mask_h, mask_w]
# preprocess and concat the target masks
tgt_masks = [[
m for v in t_step_batch for m in v['masks'].unsqueeze(1)
] for t_step_batch in targets]
# pad the target masks to a uniform shape
tgt_masks, valid = list(
zip(*[
nested_tensor_from_tensor_list(t).decompose()
for t in tgt_masks
]))
tgt_masks = torch.stack(tgt_masks).squeeze(2)
# upsample predicted masks to target mask size
out_masks = interpolate(
out_masks,
size=tgt_masks.shape[-2:],
mode='bilinear',
align_corners=False)
# Compute the soft-tokens cost:
if self.cost_is_referred > 0:
cost_is_referred = compute_is_referred_cost(outputs, targets)
else:
cost_is_referred = 0
# Compute the DICE coefficient between the masks:
if self.cost_dice > 0:
cost_dice = -dice_coef(out_masks, tgt_masks)
else:
cost_dice = 0
# Final cost matrix
C = self.cost_is_referred * cost_is_referred + self.cost_dice * cost_dice
C = C.view(bs, num_queries, -1).cpu()
num_traj_per_batch = [
len(v['masks']) for v in targets[0]
] # number of instance trajectories in each batch
indices = [
linear_sum_assignment(c[i])
for i, c in enumerate(C.split(num_traj_per_batch, -1))
]
device = out_masks.device
return [(torch.as_tensor(i, dtype=torch.int64, device=device),
torch.as_tensor(j, dtype=torch.int64, device=device))
for i, j in indices]
def dice_coef(inputs, targets, smooth=1.0):
"""
Compute the DICE coefficient, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
"""
inputs = inputs.sigmoid().flatten(2).unsqueeze(2)
targets = targets.flatten(2).unsqueeze(1)
numerator = 2 * (inputs * targets).sum(-1)
denominator = inputs.sum(-1) + targets.sum(-1)
coef = (numerator + smooth) / (denominator + smooth)
coef = coef.mean(
0) # average on the temporal dim to get instance trajectory scores
return coef
def compute_is_referred_cost(outputs, targets):
pred_is_referred = outputs['pred_is_referred'].flatten(1, 2).softmax(
dim=-1) # [t, b*nq, 2]
device = pred_is_referred.device
t = pred_is_referred.shape[0]
# number of instance trajectories in each batch
num_traj_per_batch = torch.tensor([len(v['masks']) for v in targets[0]],
device=device)
total_trajectories = num_traj_per_batch.sum()
# note that ref_indices are shared across time steps:
ref_indices = torch.tensor(
[v['referred_instance_idx'] for v in targets[0]], device=device)
# convert ref_indices to fit flattened batch targets:
ref_indices += torch.cat(
(torch.zeros(1, dtype=torch.long,
device=device), num_traj_per_batch.cumsum(0)[:-1]))
# number of instance trajectories in each batch
target_is_referred = torch.zeros((t, total_trajectories, 2), device=device)
# 'no object' class by default (for un-referred objects)
target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device)
if 'is_ref_inst_visible' in targets[0][
0]: # visibility labels are available per-frame for the referred object:
is_ref_inst_visible = torch.stack([
torch.stack([t['is_ref_inst_visible'] for t in t_step])
for t_step in targets
]).permute(1, 0)
for ref_idx, is_visible in zip(ref_indices, is_ref_inst_visible):
is_visible = is_visible.nonzero().squeeze()
target_is_referred[is_visible,
ref_idx, :] = torch.tensor([1.0, 0.0],
device=device)
else: # assume that the referred object is visible in every frame:
target_is_referred[:, ref_indices, :] = torch.tensor([1.0, 0.0],
device=device)
cost_is_referred = -(pred_is_referred.unsqueeze(2)
* target_is_referred.unsqueeze(1)).sum(dim=-1).mean(
dim=0)
return cost_is_referred

View File

@@ -122,8 +122,8 @@ class MultimodalTransformer(nn.Module):
with torch.inference_mode(mode=self.freeze_text_encoder):
encoded_text = self.text_encoder(**tokenized_queries)
# Transpose memory because pytorch's attention expects sequence first
txt_memory = rearrange(encoded_text.last_hidden_state,
'b s c -> s b c')
tmp_last_hidden_state = encoded_text.last_hidden_state.clone()
txt_memory = rearrange(tmp_last_hidden_state, 'b s c -> s b c')
txt_memory = self.txt_proj(
txt_memory) # change text embeddings dim to model dim
# Invert attention mask that we get from huggingface because its the opposite in pytorch transformer

View File

@@ -123,7 +123,8 @@ class WindowAttention3D(nn.Module):
# define a parameter table of relative position bias
wd, wh, ww = window_size
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), num_heads))
torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1),
num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_d = torch.arange(self.window_size[0])

View File

@@ -13,6 +13,7 @@ if TYPE_CHECKING:
from .video_summarization_dataset import VideoSummarizationDataset
from .image_inpainting import ImageInpaintingDataset
from .text_ranking_dataset import TextRankingDataset
from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset
else:
_import_structure = {
@@ -29,6 +30,8 @@ else:
'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'],
'image_portrait_enhancement_dataset':
['ImagePortraitEnhancementDataset'],
'referring_video_object_segmentation':
['ReferringVideoObjectSegmentationDataset'],
}
import sys

View File

@@ -0,0 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .referring_video_object_segmentation_dataset import \
ReferringVideoObjectSegmentationDataset

View File

@@ -0,0 +1,361 @@
# Part of the implementation is borrowed and modified from MTTR,
# publicly available at https://github.com/mttr2021/MTTR
from glob import glob
from os import path as osp
import h5py
import json
import numpy as np
import pandas
import torch
import torch.distributed as dist
import torchvision.transforms.functional as F
from pycocotools.mask import area, encode
from torchvision.io import read_video
from tqdm import tqdm
from modelscope.metainfo import Models
from modelscope.models.cv.referring_video_object_segmentation.utils import \
nested_tensor_from_videos_list
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
TorchTaskDataset
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from . import transformers as T
LOGGER = get_logger()
def get_image_id(video_id, frame_idx, ref_instance_a2d_id):
image_id = f'v_{video_id}_f_{frame_idx}_i_{ref_instance_a2d_id}'
return image_id
@TASK_DATASETS.register_module(
Tasks.referring_video_object_segmentation,
module_name=Models.referring_video_object_segmentation)
class ReferringVideoObjectSegmentationDataset(TorchTaskDataset):
def __init__(self, **kwargs):
split_config = kwargs['split_config']
LOGGER.info(kwargs)
data_cfg = kwargs.get('cfg').data_kwargs
trans_cfg = kwargs.get('cfg').transformers_kwargs
distributed = data_cfg.get('distributed', False)
self.data_root = next(iter(split_config.values()))
if not osp.exists(self.data_root):
self.data_root = osp.dirname(self.data_root)
assert osp.exists(self.data_root)
self.window_size = data_cfg.get('window_size', 8)
self.mask_annotations_dir = osp.join(
self.data_root, 'text_annotations/annotation_with_instances')
self.videos_dir = osp.join(self.data_root, 'Release/CLIPS320')
self.subset_type = next(iter(split_config.keys()))
self.text_annotations = self.get_text_annotations(
self.data_root, self.subset_type, distributed)
self.transforms = A2dSentencesTransforms(self.subset_type, **trans_cfg)
self.collator = Collator()
self.ann_file = osp.join(
self.data_root,
data_cfg.get('ann_file',
'a2d_sentences_test_annotations_in_coco_format.json'))
# create ground-truth test annotations for the evaluation process if necessary:
if self.subset_type == 'test' and not osp.exists(self.ann_file):
if (distributed and dist.get_rank() == 0) or not distributed:
create_a2d_sentences_ground_truth_test_annotations(
self.data_root, self.subset_type,
self.mask_annotations_dir, self.ann_file)
if distributed:
dist.barrier()
def __len__(self):
return len(self.text_annotations)
def __getitem__(self, idx):
text_query, video_id, frame_idx, instance_id = self.text_annotations[
idx]
text_query = ' '.join(
text_query.lower().split()) # clean up the text query
# read the source window frames:
video_frames, _, _ = read_video(
osp.join(self.videos_dir, f'{video_id}.mp4'),
pts_unit='sec') # (T, H, W, C)
# get a window of window_size frames with frame frame_idx in the middle.
# note that the original a2d dataset is 1 indexed, so we have to subtract 1 from frame_idx
start_idx, end_idx = frame_idx - 1 - self.window_size // 2, frame_idx - 1 + (
self.window_size + 1) // 2
# extract the window source frames:
source_frames = []
for i in range(start_idx, end_idx):
i = min(max(i, 0),
len(video_frames)
- 1) # pad out of range indices with edge frames
source_frames.append(
F.to_pil_image(video_frames[i].permute(2, 0, 1)))
# read the instance mask:
frame_annot_path = osp.join(self.mask_annotations_dir, video_id,
f'{frame_idx:05d}.h5')
f = h5py.File(frame_annot_path, 'r')
instances = list(f['instance'])
instance_idx = instances.index(
instance_id) # existence was already validated during init
instance_masks = np.array(f['reMask'])
if len(instances) == 1:
instance_masks = instance_masks[np.newaxis, ...]
instance_masks = torch.tensor(instance_masks).transpose(1, 2)
mask_rles = [encode(mask) for mask in instance_masks.numpy()]
mask_areas = area(mask_rles).astype(np.float)
f.close()
# create the target dict for the center frame:
target = {
'masks': instance_masks,
'orig_size': instance_masks.
shape[-2:], # original frame shape without any augmentations
# size with augmentations, will be changed inside transforms if necessary
'size': instance_masks.shape[-2:],
'referred_instance_idx': torch.tensor(
instance_idx), # idx in 'masks' of the text referred instance
'area': torch.tensor(mask_areas),
'iscrowd':
torch.zeros(len(instance_masks)
), # for compatibility with DETR COCO transforms
'image_id': get_image_id(video_id, frame_idx, instance_id)
}
# create dummy targets for adjacent frames:
targets = self.window_size * [None]
center_frame_idx = self.window_size // 2
targets[center_frame_idx] = target
source_frames, targets, text_query = self.transforms(
source_frames, targets, text_query)
return source_frames, targets, text_query
@staticmethod
def get_text_annotations(root_path, subset, distributed):
saved_annotations_file_path = osp.join(
root_path, f'sentences_single_frame_{subset}_annotations.json')
if osp.exists(saved_annotations_file_path):
with open(saved_annotations_file_path, 'r') as f:
text_annotations_by_frame = [tuple(a) for a in json.load(f)]
return text_annotations_by_frame
elif (distributed and dist.get_rank() == 0) or not distributed:
print(f'building a2d sentences {subset} text annotations...')
# without 'header == None' pandas will ignore the first sample...
a2d_data_info = pandas.read_csv(
osp.join(root_path, 'Release/videoset.csv'), header=None)
# 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset'
a2d_data_info.columns = [
'vid', '', '', '', '', '', '', '', 'subset'
]
with open(
osp.join(root_path, 'text_annotations/missed_videos.txt'),
'r') as f:
unused_videos = f.read().splitlines()
subsets = {'train': 0, 'test': 1}
# filter unused videos and videos which do not belong to our train/test subset:
used_videos = a2d_data_info[
~a2d_data_info.vid.isin(unused_videos)
& (a2d_data_info.subset == subsets[subset])]
used_videos_ids = list(used_videos['vid'])
text_annotations = pandas.read_csv(
osp.join(root_path, 'text_annotations/annotation.txt'))
# filter the text annotations based on the used videos:
used_text_annotations = text_annotations[
text_annotations.video_id.isin(used_videos_ids)]
# remove a single dataset annotation mistake in video: T6bNPuKV-wY
used_text_annotations = used_text_annotations[
used_text_annotations['instance_id'] != '1 (copy)']
# convert data-frame to list of tuples:
used_text_annotations = list(
used_text_annotations.to_records(index=False))
text_annotations_by_frame = []
mask_annotations_dir = osp.join(
root_path, 'text_annotations/annotation_with_instances')
for video_id, instance_id, text_query in tqdm(
used_text_annotations):
frame_annot_paths = sorted(
glob(osp.join(mask_annotations_dir, video_id, '*.h5')))
instance_id = int(instance_id)
for p in frame_annot_paths:
f = h5py.File(p)
instances = list(f['instance'])
if instance_id in instances:
# in case this instance does not appear in this frame it has no ground-truth mask, and thus this
# frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out:
# https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98
frame_idx = int(p.split('/')[-1].split('.')[0])
text_query = text_query.lower(
) # lower the text query prior to augmentation & tokenization
text_annotations_by_frame.append(
(text_query, video_id, frame_idx, instance_id))
with open(saved_annotations_file_path, 'w') as f:
json.dump(text_annotations_by_frame, f)
if distributed:
dist.barrier()
with open(saved_annotations_file_path, 'r') as f:
text_annotations_by_frame = [tuple(a) for a in json.load(f)]
return text_annotations_by_frame
class A2dSentencesTransforms:
def __init__(self, subset_type, horizontal_flip_augmentations,
resize_and_crop_augmentations, train_short_size,
train_max_size, eval_short_size, eval_max_size, **kwargs):
self.h_flip_augmentation = subset_type == 'train' and horizontal_flip_augmentations
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
scales = [
train_short_size
] # no more scales for now due to GPU memory constraints. might be changed later
transforms = []
if resize_and_crop_augmentations:
if subset_type == 'train':
transforms.append(
T.RandomResize(scales, max_size=train_max_size))
elif subset_type == 'test':
transforms.append(
T.RandomResize([eval_short_size], max_size=eval_max_size)),
transforms.extend([T.ToTensor(), normalize])
self.size_transforms = T.Compose(transforms)
def __call__(self, source_frames, targets, text_query):
if self.h_flip_augmentation and torch.rand(1) > 0.5:
source_frames = [F.hflip(f) for f in source_frames]
targets[len(targets) // 2]['masks'] = F.hflip(
targets[len(targets) // 2]['masks'])
# Note - is it possible for both 'right' and 'left' to appear together in the same query. hence this fix:
text_query = text_query.replace('left', '@').replace(
'right', 'left').replace('@', 'right')
source_frames, targets = list(
zip(*[
self.size_transforms(f, t)
for f, t in zip(source_frames, targets)
]))
source_frames = torch.stack(source_frames) # [T, 3, H, W]
return source_frames, targets, text_query
class Collator:
def __call__(self, batch):
samples, targets, text_queries = list(zip(*batch))
samples = nested_tensor_from_videos_list(samples) # [T, B, C, H, W]
# convert targets to a list of tuples. outer list - time steps, inner tuples - time step batch
targets = list(zip(*targets))
batch_dict = {
'samples': samples,
'targets': targets,
'text_queries': text_queries
}
return batch_dict
def get_text_annotations_gt(root_path, subset):
# without 'header == None' pandas will ignore the first sample...
a2d_data_info = pandas.read_csv(
osp.join(root_path, 'Release/videoset.csv'), header=None)
# 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset'
a2d_data_info.columns = ['vid', '', '', '', '', '', '', '', 'subset']
with open(osp.join(root_path, 'text_annotations/missed_videos.txt'),
'r') as f:
unused_videos = f.read().splitlines()
subsets = {'train': 0, 'test': 1}
# filter unused videos and videos which do not belong to our train/test subset:
used_videos = a2d_data_info[~a2d_data_info.vid.isin(unused_videos)
& (a2d_data_info.subset == subsets[subset])]
used_videos_ids = list(used_videos['vid'])
text_annotations = pandas.read_csv(
osp.join(root_path, 'text_annotations/annotation.txt'))
# filter the text annotations based on the used videos:
used_text_annotations = text_annotations[text_annotations.video_id.isin(
used_videos_ids)]
# convert data-frame to list of tuples:
used_text_annotations = list(used_text_annotations.to_records(index=False))
return used_text_annotations
def create_a2d_sentences_ground_truth_test_annotations(dataset_path,
subset_type,
mask_annotations_dir,
output_path):
text_annotations = get_text_annotations_gt(dataset_path, subset_type)
# Note - it is very important to start counting the instance and category ids from 1 (not 0). This is implicitly
# expected by pycocotools as it is the convention of the original coco dataset annotations.
categories_dict = [{
'id': 1,
'name': 'dummy_class'
}] # dummy class, as categories are not used/predicted in RVOS
images_dict = []
annotations_dict = []
images_set = set()
instance_id_counter = 1
for annot in tqdm(text_annotations):
video_id, instance_id, text_query = annot
annot_paths = sorted(
glob(osp.join(mask_annotations_dir, video_id, '*.h5')))
for p in annot_paths:
f = h5py.File(p)
instances = list(f['instance'])
try:
instance_idx = instances.index(int(instance_id))
# in case this instance does not appear in this frame it has no ground-truth mask, and thus this
# frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out:
# https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98
except ValueError:
continue # instance_id does not appear in current frame
mask = f['reMask'][instance_idx] if len(
instances) > 1 else np.array(f['reMask'])
mask = mask.transpose()
frame_idx = int(p.split('/')[-1].split('.')[0])
image_id = get_image_id(video_id, frame_idx, instance_id)
assert image_id not in images_set, f'error: image id: {image_id} appeared twice'
images_set.add(image_id)
images_dict.append({
'id': image_id,
'height': mask.shape[0],
'width': mask.shape[1]
})
mask_rle = encode(mask)
mask_rle['counts'] = mask_rle['counts'].decode('ascii')
mask_area = float(area(mask_rle))
bbox = f['reBBox'][:, instance_idx] if len(
instances) > 1 else np.array(
f['reBBox']).squeeze() # x1y1x2y2 form
bbox_xywh = [
bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]
]
instance_annot = {
'id': instance_id_counter,
'image_id': image_id,
'category_id':
1, # dummy class, as categories are not used/predicted in ref-vos
'segmentation': mask_rle,
'area': mask_area,
'bbox': bbox_xywh,
'iscrowd': 0,
}
annotations_dict.append(instance_annot)
instance_id_counter += 1
dataset_dict = {
'categories': categories_dict,
'images': images_dict,
'annotations': annotations_dict
}
with open(output_path, 'w') as f:
json.dump(dataset_dict, f)

View File

@@ -0,0 +1,294 @@
# The implementation is adopted from MTTR,
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR
# Modified from DETR https://github.com/facebookresearch/detr
import random
import PIL
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from modelscope.models.cv.referring_video_object_segmentation.utils import \
interpolate
def crop(image, target, region):
cropped_image = F.crop(image, *region)
target = target.copy()
i, j, h, w = region
# should we do something wrt the original size?
target['size'] = torch.tensor([h, w])
fields = ['labels', 'area', 'iscrowd']
if 'boxes' in target:
boxes = target['boxes']
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
target['boxes'] = cropped_boxes.reshape(-1, 4)
target['area'] = area
fields.append('boxes')
if 'masks' in target:
# FIXME should we update the area here if there are no boxes?
target['masks'] = target['masks'][:, i:i + h, j:j + w]
fields.append('masks')
# remove elements for which the boxes or masks that have zero area
if 'boxes' in target or 'masks' in target:
# favor boxes selection when defining which elements to keep
# this is compatible with previous implementation
if 'boxes' in target:
cropped_boxes = target['boxes'].reshape(-1, 2, 2)
keep = torch.all(
cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
else:
keep = target['masks'].flatten(1).any(1)
for field in fields:
target[field] = target[field][keep]
return cropped_image, target
def hflip(image, target):
flipped_image = F.hflip(image)
w, h = image.size
target = target.copy()
if 'boxes' in target:
boxes = target['boxes']
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
[-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
target['boxes'] = boxes
if 'masks' in target:
target['masks'] = target['masks'].flip(-1)
return flipped_image, target
def resize(image, target, size, max_size=None):
# size can be min_size (scalar) or (w, h) tuple
def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(
round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
return (oh, ow)
def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)
size = get_size(image.size, size, max_size)
rescaled_image = F.resize(image, size)
if target is None:
return rescaled_image, None
ratios = tuple(
float(s) / float(s_orig)
for s, s_orig in zip(rescaled_image.size, image.size))
ratio_width, ratio_height = ratios
target = target.copy()
if 'boxes' in target:
boxes = target['boxes']
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height])
target['boxes'] = scaled_boxes
if 'area' in target:
area = target['area']
scaled_area = area * (ratio_width * ratio_height)
target['area'] = scaled_area
h, w = size
target['size'] = torch.tensor([h, w])
if 'masks' in target:
target['masks'] = interpolate(
target['masks'][:, None].float(), size, mode='nearest')[:, 0] > 0.5
return rescaled_image, target
def pad(image, target, padding):
# assumes that we only pad on the bottom right corners
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
if target is None:
return padded_image, None
target = target.copy()
# should we do something wrt the original size?
target['size'] = torch.tensor(padded_image.size[::-1])
if 'masks' in target:
target['masks'] = torch.nn.functional.pad(
target['masks'], (0, padding[0], 0, padding[1]))
return padded_image, target
class RandomCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, img, target):
region = T.RandomCrop.get_params(img, self.size)
return crop(img, target, region)
class RandomSizeCrop(object):
def __init__(self, min_size: int, max_size: int):
self.min_size = min_size
self.max_size = max_size
def __call__(self, img: PIL.Image.Image, target: dict):
w = random.randint(self.min_size, min(img.width, self.max_size))
h = random.randint(self.min_size, min(img.height, self.max_size))
region = T.RandomCrop.get_params(img, [h, w])
return crop(img, target, region)
class CenterCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, img, target):
image_width, image_height = img.size
crop_height, crop_width = self.size
crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.))
return crop(img, target,
(crop_top, crop_left, crop_height, crop_width))
class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, img, target):
if random.random() < self.p:
return hflip(img, target)
return img, target
class RandomResize(object):
def __init__(self, sizes, max_size=None):
assert isinstance(sizes, (list, tuple))
self.sizes = sizes
self.max_size = max_size
def __call__(self, img, target=None):
size = random.choice(self.sizes)
return resize(img, target, size, self.max_size)
class RandomPad(object):
def __init__(self, max_pad):
self.max_pad = max_pad
def __call__(self, img, target):
pad_x = random.randint(0, self.max_pad)
pad_y = random.randint(0, self.max_pad)
return pad(img, target, (pad_x, pad_y))
class RandomSelect(object):
"""
Randomly selects between transforms1 and transforms2,
with probability p for transforms1 and (1 - p) for transforms2
"""
def __init__(self, transforms1, transforms2, p=0.5):
self.transforms1 = transforms1
self.transforms2 = transforms2
self.p = p
def __call__(self, img, target):
if random.random() < self.p:
return self.transforms1(img, target)
return self.transforms2(img, target)
class ToTensor(object):
def __call__(self, img, target):
return F.to_tensor(img), target
class RandomErasing(object):
def __init__(self, *args, **kwargs):
self.eraser = T.RandomErasing(*args, **kwargs)
def __call__(self, img, target):
return self.eraser(img), target
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target=None):
image = F.normalize(image, mean=self.mean, std=self.std)
if target is None:
return image, None
target = target.copy()
h, w = image.shape[-2:]
if 'boxes' in target:
boxes = target['boxes']
boxes = box_xyxy_to_cxcywh(boxes)
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
target['boxes'] = boxes
return image, target
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string

View File

@@ -157,7 +157,13 @@ class ReferringVideoObjectSegmentationPipeline(Pipeline):
* text_border_height_per_query, 0, 0))
W, H = vid_frame.size
draw = ImageDraw.Draw(vid_frame)
font = ImageFont.truetype(font='DejaVuSansMono.ttf', size=30)
if self.model.cfg.pipeline.output_font:
font = ImageFont.truetype(
font=self.model.cfg.pipeline.output_font,
size=self.model.cfg.pipeline.output_font_size)
else:
font = ImageFont.load_default()
for i, (text_query, color) in enumerate(
zip(self.text_queries, colors), start=1):
w, h = draw.textsize(text_query, font=font)

View File

@@ -9,7 +9,8 @@ if TYPE_CHECKING:
from .builder import build_trainer
from .cv import (ImageInstanceSegmentationTrainer,
ImagePortraitEnhancementTrainer,
MovieSceneSegmentationTrainer, ImageInpaintingTrainer)
MovieSceneSegmentationTrainer, ImageInpaintingTrainer,
ReferringVideoObjectSegmentationTrainer)
from .multi_modal import CLIPTrainer
from .nlp import SequenceClassificationTrainer, TextRankingTrainer
from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer, NlpTrainerArguments

View File

@@ -9,6 +9,7 @@ if TYPE_CHECKING:
from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer
from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer
from .image_inpainting_trainer import ImageInpaintingTrainer
from .referring_video_object_segmentation_trainer import ReferringVideoObjectSegmentationTrainer
else:
_import_structure = {
@@ -17,7 +18,9 @@ else:
'image_portrait_enhancement_trainer':
['ImagePortraitEnhancementTrainer'],
'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'],
'image_inpainting_trainer': ['ImageInpaintingTrainer']
'image_inpainting_trainer': ['ImageInpaintingTrainer'],
'referring_video_object_segmentation_trainer':
['ReferringVideoObjectSegmentationTrainer']
}
import sys

View File

@@ -0,0 +1,63 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import torch
from modelscope.metainfo import Trainers
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.trainer import EpochBasedTrainer
from modelscope.utils.constant import ModeKeys
@TRAINERS.register_module(
module_name=Trainers.referring_video_object_segmentation)
class ReferringVideoObjectSegmentationTrainer(EpochBasedTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model.set_postprocessor(self.cfg.dataset.name)
self.train_data_collator = self.train_dataset.collator
self.eval_data_collator = self.eval_dataset.collator
device_name = kwargs.get('device', 'gpu')
self.model.set_device(self.device, device_name)
def train(self, *args, **kwargs):
self.model.criterion.train()
super().train(*args, **kwargs)
def evaluate(self, checkpoint_path=None):
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
from modelscope.trainers.hooks import CheckpointHook
CheckpointHook.load_checkpoint(checkpoint_path, self)
self.model.eval()
self._mode = ModeKeys.EVAL
if self.eval_dataset is None:
self.eval_dataloader = self.get_eval_data_loader()
else:
self.eval_dataloader = self._build_dataloader_with_dataset(
self.eval_dataset,
dist=self._dist,
seed=self._seed,
collate_fn=self.eval_data_collator,
**self.cfg.evaluation.get('dataloader', {}))
self.data_loader = self.eval_dataloader
from modelscope.metrics import build_metric
ann_file = self.eval_dataset.ann_file
metric_classes = []
for metric in self.metrics:
metric.update({'ann_file': ann_file})
metric_classes.append(build_metric(metric))
for m in metric_classes:
m.trainer = self
metric_values = self.evaluation_loop(self.eval_dataloader,
metric_classes)
self._metric_values = metric_values
return metric_values
def prediction_step(self, model, inputs):
pass

View File

@@ -62,7 +62,10 @@ def single_gpu_test(trainer,
if 'nsentences' in data:
batch_size = data['nsentences']
else:
batch_size = len(next(iter(data.values())))
try:
batch_size = len(next(iter(data.values())))
except Exception:
batch_size = data_loader.batch_size
else:
batch_size = len(data)
for _ in range(batch_size):

View File

@@ -0,0 +1,101 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
import zipfile
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models.cv.movie_scene_segmentation import \
MovieSceneSegmentationModel
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level
class TestImageInstanceSegmentationTrainer(unittest.TestCase):
model_id = 'damo/cv_swin-t_referring_video-object-segmentation'
dataset_name = 'referring_vos_toydata'
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
cache_path = snapshot_download(self.model_id)
config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
cfg = Config.from_file(config_path)
max_epochs = cfg.train.max_epochs
train_data_cfg = ConfigDict(
name=self.dataset_name,
split='train',
test_mode=False,
cfg=cfg.dataset)
test_data_cfg = ConfigDict(
name=self.dataset_name,
split='test',
test_mode=True,
cfg=cfg.dataset)
self.train_dataset = MsDataset.load(
dataset_name=train_data_cfg.name,
split=train_data_cfg.split,
cfg=train_data_cfg.cfg,
namespace='damo',
test_mode=train_data_cfg.test_mode)
assert next(
iter(self.train_dataset.config_kwargs['split_config'].values()))
self.test_dataset = MsDataset.load(
dataset_name=test_data_cfg.name,
split=test_data_cfg.split,
cfg=test_data_cfg.cfg,
namespace='damo',
test_mode=test_data_cfg.test_mode)
assert next(
iter(self.test_dataset.config_kwargs['split_config'].values()))
self.max_epochs = max_epochs
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
kwargs = dict(
model=self.model_id,
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
work_dir='./work_dir')
trainer = build_trainer(
name=Trainers.referring_video_object_segmentation,
default_args=kwargs)
trainer.train()
results_files = os.listdir(trainer.work_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_model_and_args(self):
cache_path = snapshot_download(self.model_id)
model = MovieSceneSegmentationModel.from_pretrained(cache_path)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
work_dir='./work_dir')
trainer = build_trainer(
name=Trainers.referring_video_object_segmentation,
default_args=kwargs)
trainer.train()
results_files = os.listdir(trainer.work_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
if __name__ == '__main__':
unittest.main()