mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
[to #42322933] Add video-inpainting files
视频编辑的cr
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10026166
This commit is contained in:
3
data/test/videos/mask_dir/mask_00000_00320.png
Normal file
3
data/test/videos/mask_dir/mask_00000_00320.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b158f6029d9763d7f84042f7c5835f398c688fdbb6b3f4fe6431101d4118c66c
|
||||
size 2766
|
||||
3
data/test/videos/mask_dir/mask_00321_00633.png
Normal file
3
data/test/videos/mask_dir/mask_00321_00633.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0dcf46b93077e2229ab69cd6ddb80e2689546c575ee538bb2033fee1124ef3e3
|
||||
size 2761
|
||||
3
data/test/videos/video_inpainting_test.mp4
Normal file
3
data/test/videos/video_inpainting_test.mp4
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9c9870df5a86acaaec67063183dace795479cd0f05296f13058995f475149c56
|
||||
size 2957783
|
||||
@@ -38,6 +38,7 @@ class Models(object):
|
||||
mogface = 'mogface'
|
||||
mtcnn = 'mtcnn'
|
||||
ulfd = 'ulfd'
|
||||
video_inpainting = 'video-inpainting'
|
||||
|
||||
# EasyCV models
|
||||
yolox = 'YOLOX'
|
||||
@@ -169,6 +170,7 @@ class Pipelines(object):
|
||||
text_driven_segmentation = 'text-driven-segmentation'
|
||||
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation'
|
||||
shop_segmentation = 'shop-segmentation'
|
||||
video_inpainting = 'video-inpainting'
|
||||
|
||||
# nlp tasks
|
||||
sentence_similarity = 'sentence-similarity'
|
||||
|
||||
20
modelscope/models/cv/video_inpainting/__init__.py
Normal file
20
modelscope/models/cv/video_inpainting/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .inpainting_model import VideoInpainting
|
||||
|
||||
else:
|
||||
_import_structure = {'inpainting_model': ['VideoInpainting']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
298
modelscope/models/cv/video_inpainting/inpainting.py
Normal file
298
modelscope/models/cv/video_inpainting/inpainting.py
Normal file
@@ -0,0 +1,298 @@
|
||||
""" VideoInpaintingProcess
|
||||
Base modules are adapted from https://github.com/researchmm/STTN,
|
||||
originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab,
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
torch.backends.cudnn.enabled = False
|
||||
|
||||
w, h = 192, 96
|
||||
ref_length = 300
|
||||
neighbor_stride = 20
|
||||
default_fps = 24
|
||||
MAX_frame = 300
|
||||
|
||||
|
||||
def video_process(video_input_path):
|
||||
video_input = cv2.VideoCapture(video_input_path)
|
||||
success, frame = video_input.read()
|
||||
if success is False:
|
||||
decode_error = 'decode_error'
|
||||
w, h, fps = 0, 0, 0
|
||||
else:
|
||||
decode_error = None
|
||||
h, w = frame.shape[0:2]
|
||||
fps = video_input.get(cv2.CAP_PROP_FPS)
|
||||
video_input.release()
|
||||
|
||||
return decode_error, fps, w, h
|
||||
|
||||
|
||||
class Stack(object):
|
||||
|
||||
def __init__(self, roll=False):
|
||||
self.roll = roll
|
||||
|
||||
def __call__(self, img_group):
|
||||
mode = img_group[0].mode
|
||||
if mode == '1':
|
||||
img_group = [img.convert('L') for img in img_group]
|
||||
mode = 'L'
|
||||
if mode == 'L':
|
||||
return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
|
||||
elif mode == 'RGB':
|
||||
if self.roll:
|
||||
return np.stack([np.array(x)[:, :, ::-1] for x in img_group],
|
||||
axis=2)
|
||||
else:
|
||||
return np.stack(img_group, axis=2)
|
||||
else:
|
||||
raise NotImplementedError(f'Image mode {mode}')
|
||||
|
||||
|
||||
class ToTorchFormatTensor(object):
|
||||
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
|
||||
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
|
||||
|
||||
def __init__(self, div=True):
|
||||
self.div = div
|
||||
|
||||
def __call__(self, pic):
|
||||
if isinstance(pic, np.ndarray):
|
||||
img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
|
||||
else:
|
||||
img = torch.ByteTensor(
|
||||
torch.ByteStorage.from_buffer(pic.tobytes()))
|
||||
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
|
||||
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
img = img.float().div(255) if self.div else img.float()
|
||||
return img
|
||||
|
||||
|
||||
_to_tensors = transforms.Compose([Stack(), ToTorchFormatTensor()])
|
||||
|
||||
|
||||
def get_crop_mask_v1(mask):
|
||||
orig_h, orig_w, _ = mask.shape
|
||||
if (mask == 255).all():
|
||||
return mask, (0, int(orig_h), 0,
|
||||
int(orig_w)), [0, int(orig_h), 0,
|
||||
int(orig_w)
|
||||
], [0, int(orig_h), 0,
|
||||
int(orig_w)]
|
||||
|
||||
hs = np.min(np.where(mask == 0)[0])
|
||||
he = np.max(np.where(mask == 0)[0])
|
||||
ws = np.min(np.where(mask == 0)[1])
|
||||
we = np.max(np.where(mask == 0)[1])
|
||||
crop_box = [ws, hs, we, he]
|
||||
|
||||
mask_h = round(int(orig_h / 2) / 4) * 4
|
||||
mask_w = round(int(orig_w / 2) / 4) * 4
|
||||
|
||||
if (hs < mask_h) and (he < mask_h) and (ws < mask_w) and (we < mask_w):
|
||||
crop_mask = mask[:mask_h, :mask_w, :]
|
||||
res_pix = (0, mask_h, 0, mask_w)
|
||||
elif (hs < mask_h) and (he < mask_h) and (ws > mask_w) and (we > mask_w):
|
||||
crop_mask = mask[:mask_h, orig_w - mask_w:orig_w, :]
|
||||
res_pix = (0, mask_h, orig_w - mask_w, int(orig_w))
|
||||
elif (hs > mask_h) and (he > mask_h) and (ws < mask_w) and (we < mask_w):
|
||||
crop_mask = mask[orig_h - mask_h:orig_h, :mask_w, :]
|
||||
res_pix = (orig_h - mask_h, int(orig_h), 0, mask_w)
|
||||
elif (hs > mask_h) and (he > mask_h) and (ws > mask_w) and (we > mask_w):
|
||||
crop_mask = mask[orig_h - mask_h:orig_h, orig_w - mask_w:orig_w, :]
|
||||
res_pix = (orig_h - mask_h, int(orig_h), orig_w - mask_w, int(orig_w))
|
||||
|
||||
elif (hs < mask_h) and (he < mask_h) and (ws < mask_w) and (we > mask_w):
|
||||
crop_mask = mask[:mask_h, :, :]
|
||||
res_pix = (0, mask_h, 0, int(orig_w))
|
||||
elif (hs < mask_h) and (he > mask_h) and (ws < mask_w) and (we < mask_w):
|
||||
crop_mask = mask[:, :mask_w, :]
|
||||
res_pix = (0, int(orig_h), 0, mask_w)
|
||||
elif (hs > mask_h) and (he > mask_h) and (ws < mask_w) and (we > mask_w):
|
||||
crop_mask = mask[orig_h - mask_h:orig_h, :, :]
|
||||
res_pix = (orig_h - mask_h, int(orig_h), 0, int(orig_w))
|
||||
elif (hs < mask_h) and (he > mask_h) and (ws > mask_w) and (we > mask_w):
|
||||
crop_mask = mask[:, orig_w - mask_w:orig_w, :]
|
||||
res_pix = (0, int(orig_h), orig_w - mask_w, int(orig_w))
|
||||
else:
|
||||
crop_mask = mask
|
||||
res_pix = (0, int(orig_h), 0, int(orig_w))
|
||||
a = ws - res_pix[2]
|
||||
b = hs - res_pix[0]
|
||||
c = we - res_pix[2]
|
||||
d = he - res_pix[0]
|
||||
return crop_mask, res_pix, crop_box, [a, b, c, d]
|
||||
|
||||
|
||||
def get_ref_index(neighbor_ids, length):
|
||||
ref_index = []
|
||||
for i in range(0, length, ref_length):
|
||||
if i not in neighbor_ids:
|
||||
ref_index.append(i)
|
||||
return ref_index
|
||||
|
||||
|
||||
def read_mask_oneImage(mpath):
|
||||
masks = []
|
||||
print('mask_path: {}'.format(mpath))
|
||||
start = int(mpath.split('/')[-1].split('mask_')[1].split('_')[0])
|
||||
end = int(
|
||||
mpath.split('/')[-1].split('mask_')[1].split('_')[1].split('.')[0])
|
||||
m = Image.open(mpath)
|
||||
m = np.array(m.convert('L'))
|
||||
m = np.array(m > 0).astype(np.uint8)
|
||||
m = 1 - m
|
||||
for i in range(start - 1, end + 1):
|
||||
masks.append(Image.fromarray(m * 255))
|
||||
return masks
|
||||
|
||||
|
||||
def check_size(h, w):
|
||||
is_resize = False
|
||||
if h != 240:
|
||||
h = 240
|
||||
is_resize = True
|
||||
if w != 432:
|
||||
w = 432
|
||||
is_resize = True
|
||||
return is_resize
|
||||
|
||||
|
||||
def get_mask_list(mask_path):
|
||||
mask_names = os.listdir(mask_path)
|
||||
mask_names.sort()
|
||||
|
||||
abs_mask_path = []
|
||||
mask_list = []
|
||||
begin_list = []
|
||||
end_list = []
|
||||
|
||||
for mask_name in mask_names:
|
||||
mask_name_tmp = mask_name.split('mask_')[1]
|
||||
begin_list.append(int(mask_name_tmp.split('_')[0]))
|
||||
end_list.append(int(mask_name_tmp.split('_')[1].split('.')[0]))
|
||||
abs_mask_path.append(os.path.join(mask_path, mask_name))
|
||||
mask = cv2.imread(os.path.join(mask_path, mask_name))
|
||||
mask_list.append(mask)
|
||||
return mask_list, begin_list, end_list, abs_mask_path
|
||||
|
||||
|
||||
def inpainting_by_model_balance(model, video_inputPath, mask_path,
|
||||
video_savePath, fps, w_ori, h_ori):
|
||||
|
||||
video_ori = cv2.VideoCapture(video_inputPath)
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
video_save = cv2.VideoWriter(video_savePath, fourcc, fps, (w_ori, h_ori))
|
||||
|
||||
mask_list, begin_list, end_list, abs_mask_path = get_mask_list(mask_path)
|
||||
|
||||
img_npy = []
|
||||
|
||||
for index, mask in enumerate(mask_list):
|
||||
|
||||
masks = read_mask_oneImage(abs_mask_path[index])
|
||||
|
||||
mask, res_pix, crop_for_oriimg, crop_for_inpimg = get_crop_mask_v1(
|
||||
mask)
|
||||
mask_h, mask_w = mask.shape[0:2]
|
||||
is_resize = check_size(mask.shape[0], mask.shape[1])
|
||||
|
||||
begin = begin_list[index]
|
||||
end = end_list[index]
|
||||
print('begin: {}'.format(begin))
|
||||
print('end: {}'.format(end))
|
||||
|
||||
for i in range(begin, end + 1, MAX_frame):
|
||||
begin_time = time.time()
|
||||
if i + MAX_frame <= end:
|
||||
video_length = MAX_frame
|
||||
else:
|
||||
video_length = end - i + 1
|
||||
|
||||
for frame_count in range(video_length):
|
||||
_, frame = video_ori.read()
|
||||
img_npy.append(frame)
|
||||
frames_temp = []
|
||||
for f in img_npy:
|
||||
f = Image.fromarray(f)
|
||||
i_temp = f.crop(
|
||||
(res_pix[2], res_pix[0], res_pix[3], res_pix[1]))
|
||||
a = i_temp.resize((w, h), Image.NEAREST)
|
||||
frames_temp.append(a)
|
||||
feats_temp = _to_tensors(frames_temp).unsqueeze(0) * 2 - 1
|
||||
frames_temp = [np.array(f).astype(np.uint8) for f in frames_temp]
|
||||
masks_temp = []
|
||||
for m in masks[i - begin:i + video_length - begin]:
|
||||
|
||||
m_temp = m.crop(
|
||||
(res_pix[2], res_pix[0], res_pix[3], res_pix[1]))
|
||||
b = m_temp.resize((w, h), Image.NEAREST)
|
||||
masks_temp.append(b)
|
||||
binary_masks_temp = [
|
||||
np.expand_dims((np.array(m) != 0).astype(np.uint8), 2)
|
||||
for m in masks_temp
|
||||
]
|
||||
masks_temp = _to_tensors(masks_temp).unsqueeze(0)
|
||||
feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda()
|
||||
comp_frames = [None] * video_length
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
feats_out = feats_temp * (1 - masks_temp).float()
|
||||
feats_out = feats_out.view(video_length, 3, h, w)
|
||||
feats_out = model.model.encoder(feats_out)
|
||||
_, c, feat_h, feat_w = feats_out.size()
|
||||
feats_out = feats_out.view(1, video_length, c, feat_h, feat_w)
|
||||
|
||||
for f in range(0, video_length, neighbor_stride):
|
||||
neighbor_ids = [
|
||||
i for i in range(
|
||||
max(0, f - neighbor_stride),
|
||||
min(video_length, f + neighbor_stride + 1))
|
||||
]
|
||||
ref_ids = get_ref_index(neighbor_ids, video_length)
|
||||
with torch.no_grad():
|
||||
pred_feat = model.model.infer(
|
||||
feats_out[0, neighbor_ids + ref_ids, :, :, :],
|
||||
masks_temp[0, neighbor_ids + ref_ids, :, :, :])
|
||||
pred_img = torch.tanh(
|
||||
model.model.decoder(
|
||||
pred_feat[:len(neighbor_ids), :, :, :])).detach()
|
||||
pred_img = (pred_img + 1) / 2
|
||||
pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
|
||||
for j in range(len(neighbor_ids)):
|
||||
idx = neighbor_ids[j]
|
||||
img = np.array(pred_img[j]).astype(
|
||||
np.uint8) * binary_masks_temp[idx] + frames_temp[
|
||||
idx] * (1 - binary_masks_temp[idx])
|
||||
if comp_frames[idx] is None:
|
||||
comp_frames[idx] = img
|
||||
else:
|
||||
comp_frames[idx] = comp_frames[idx].astype(
|
||||
np.float32) * 0.5 + img.astype(
|
||||
np.float32) * 0.5
|
||||
print('inpainting time:', time.time() - begin_time)
|
||||
for f in range(video_length):
|
||||
comp = np.array(comp_frames[f]).astype(
|
||||
np.uint8) * binary_masks_temp[f] + frames_temp[f] * (
|
||||
1 - binary_masks_temp[f])
|
||||
if is_resize:
|
||||
comp = cv2.resize(comp, (mask_w, mask_h))
|
||||
complete_frame = img_npy[f]
|
||||
a1, b1, c1, d1 = crop_for_oriimg
|
||||
a2, b2, c2, d2 = crop_for_inpimg
|
||||
complete_frame[b1:d1, a1:c1] = comp[b2:d2, a2:c2]
|
||||
video_save.write(complete_frame)
|
||||
|
||||
img_npy = []
|
||||
|
||||
video_ori.release()
|
||||
373
modelscope/models/cv/video_inpainting/inpainting_model.py
Normal file
373
modelscope/models/cv/video_inpainting/inpainting_model.py
Normal file
@@ -0,0 +1,373 @@
|
||||
""" VideoInpaintingNetwork
|
||||
Base modules are adapted from https://github.com/researchmm/STTN,
|
||||
originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab,
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class BaseNetwork(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BaseNetwork, self).__init__()
|
||||
|
||||
def print_network(self):
|
||||
if isinstance(self, list):
|
||||
self = self[0]
|
||||
num_params = 0
|
||||
for param in self.parameters():
|
||||
num_params += param.numel()
|
||||
print(
|
||||
'Network [%s] was created. Total number of parameters: %.1f million. '
|
||||
'To see the architecture, do print(network).' %
|
||||
(type(self).__name__, num_params / 1000000))
|
||||
|
||||
def init_weights(self, init_type='normal', gain=0.02):
|
||||
'''
|
||||
initialize network's weights
|
||||
init_type: normal | xavier | kaiming | orthogonal
|
||||
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
|
||||
'''
|
||||
|
||||
def init_func(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('InstanceNorm2d') != -1:
|
||||
if hasattr(m, 'weight') and m.weight is not None:
|
||||
nn.init.constant_(m.weight.data, 1.0)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
elif hasattr(m, 'weight') and (classname.find('Conv') != -1
|
||||
or classname.find('Linear') != -1):
|
||||
if init_type == 'normal':
|
||||
nn.init.normal_(m.weight.data, 0.0, gain)
|
||||
elif init_type == 'xavier':
|
||||
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
||||
elif init_type == 'xavier_uniform':
|
||||
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
|
||||
elif init_type == 'kaiming':
|
||||
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
elif init_type == 'orthogonal':
|
||||
nn.init.orthogonal_(m.weight.data, gain=gain)
|
||||
elif init_type == 'none':
|
||||
m.reset_parameters()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'initialization method [%s] is not implemented'
|
||||
% init_type)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
|
||||
self.apply(init_func)
|
||||
|
||||
for m in self.children():
|
||||
if hasattr(m, 'init_weights'):
|
||||
m.init_weights(init_type, gain)
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.video_inpainting, module_name=Models.video_inpainting)
|
||||
class VideoInpainting(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, device_id=0, *args, **kwargs):
|
||||
super().__init__(
|
||||
model_dir=model_dir, device_id=device_id, *args, **kwargs)
|
||||
self.model = InpaintGenerator()
|
||||
pretrained_params = torch.load('{}/{}'.format(
|
||||
model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
|
||||
self.model.load_state_dict(pretrained_params['netG'])
|
||||
self.model.eval()
|
||||
self.device_id = device_id
|
||||
if self.device_id >= 0 and torch.cuda.is_available():
|
||||
self.model.to('cuda:{}'.format(self.device_id))
|
||||
logger.info('Use GPU: {}'.format(self.device_id))
|
||||
else:
|
||||
self.device_id = -1
|
||||
logger.info('Use CPU for inference')
|
||||
|
||||
|
||||
class InpaintGenerator(BaseNetwork):
|
||||
|
||||
def __init__(self, init_weights=True):
|
||||
super(InpaintGenerator, self).__init__()
|
||||
channel = 256
|
||||
stack_num = 6
|
||||
patchsize = [(48, 24), (16, 8), (8, 4), (4, 2)]
|
||||
blocks = []
|
||||
for _ in range(stack_num):
|
||||
blocks.append(TransformerBlock(patchsize, hidden=channel))
|
||||
self.transformer = nn.Sequential(*blocks)
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(128, channel, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
deconv(channel, 128, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
deconv(64, 64, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
if init_weights:
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, masked_frames, masks):
|
||||
b, t, c, h, w = masked_frames.size()
|
||||
masks = masks.view(b * t, 1, h, w)
|
||||
enc_feat = self.encoder(masked_frames.view(b * t, c, h, w))
|
||||
_, c, h, w = enc_feat.size()
|
||||
masks = F.interpolate(masks, scale_factor=1.0 / 4)
|
||||
enc_feat = self.transformer({
|
||||
'x': enc_feat,
|
||||
'm': masks,
|
||||
'b': b,
|
||||
'c': c
|
||||
})['x']
|
||||
output = self.decoder(enc_feat)
|
||||
output = torch.tanh(output)
|
||||
return output
|
||||
|
||||
def infer(self, feat, masks):
|
||||
t, c, h, w = masks.size()
|
||||
masks = masks.view(t, c, h, w)
|
||||
masks = F.interpolate(masks, scale_factor=1.0 / 4)
|
||||
t, c, _, _ = feat.size()
|
||||
enc_feat = self.transformer({
|
||||
'x': feat,
|
||||
'm': masks,
|
||||
'b': 1,
|
||||
'c': c
|
||||
})['x']
|
||||
return enc_feat
|
||||
|
||||
|
||||
class deconv(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_channel,
|
||||
output_channel,
|
||||
kernel_size=3,
|
||||
padding=0):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
input_channel,
|
||||
output_channel,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(
|
||||
x, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
Compute 'Scaled Dot Product Attention
|
||||
"""
|
||||
|
||||
def forward(self, query, key, value, m):
|
||||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
|
||||
query.size(-1))
|
||||
scores.masked_fill(m, -1e9)
|
||||
p_attn = F.softmax(scores, dim=-1)
|
||||
p_val = torch.matmul(p_attn, value)
|
||||
return p_val, p_attn
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
"""
|
||||
Take in model size and number of heads.
|
||||
"""
|
||||
|
||||
def __init__(self, patchsize, d_model):
|
||||
super().__init__()
|
||||
self.patchsize = patchsize
|
||||
self.query_embedding = nn.Conv2d(
|
||||
d_model, d_model, kernel_size=1, padding=0)
|
||||
self.value_embedding = nn.Conv2d(
|
||||
d_model, d_model, kernel_size=1, padding=0)
|
||||
self.key_embedding = nn.Conv2d(
|
||||
d_model, d_model, kernel_size=1, padding=0)
|
||||
self.output_linear = nn.Sequential(
|
||||
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True))
|
||||
self.attention = Attention()
|
||||
|
||||
def forward(self, x, m, b, c):
|
||||
bt, _, h, w = x.size()
|
||||
t = bt // b
|
||||
d_k = c // len(self.patchsize)
|
||||
output = []
|
||||
_query = self.query_embedding(x)
|
||||
_key = self.key_embedding(x)
|
||||
_value = self.value_embedding(x)
|
||||
for (width, height), query, key, value in zip(
|
||||
self.patchsize,
|
||||
torch.chunk(_query, len(self.patchsize), dim=1),
|
||||
torch.chunk(_key, len(self.patchsize), dim=1),
|
||||
torch.chunk(_value, len(self.patchsize), dim=1)):
|
||||
out_w, out_h = w // width, h // height
|
||||
mm = m.view(b, t, 1, out_h, height, out_w, width)
|
||||
mm = mm.permute(0, 1, 3, 5, 2, 4,
|
||||
6).contiguous().view(b, t * out_h * out_w,
|
||||
height * width)
|
||||
mm = (mm.mean(-1) > 0.5).unsqueeze(1).repeat(
|
||||
1, t * out_h * out_w, 1)
|
||||
query = query.view(b, t, d_k, out_h, height, out_w, width)
|
||||
query = query.permute(0, 1, 3, 5, 2, 4,
|
||||
6).contiguous().view(b, t * out_h * out_w,
|
||||
d_k * height * width)
|
||||
key = key.view(b, t, d_k, out_h, height, out_w, width)
|
||||
key = key.permute(0, 1, 3, 5, 2, 4,
|
||||
6).contiguous().view(b, t * out_h * out_w,
|
||||
d_k * height * width)
|
||||
value = value.view(b, t, d_k, out_h, height, out_w, width)
|
||||
value = value.permute(0, 1, 3, 5, 2, 4,
|
||||
6).contiguous().view(b, t * out_h * out_w,
|
||||
d_k * height * width)
|
||||
y, _ = self.attention(query, key, value, mm)
|
||||
y = y.view(b, t, out_h, out_w, d_k, height, width)
|
||||
y = y.permute(0, 1, 4, 2, 5, 3, 6).contiguous().view(bt, d_k, h, w)
|
||||
output.append(y)
|
||||
output = torch.cat(output, 1)
|
||||
x = self.output_linear(output)
|
||||
return x
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, d_model):
|
||||
super(FeedForward, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(d_model, d_model, kernel_size=3, padding=2, dilation=2),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""
|
||||
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
|
||||
"""
|
||||
|
||||
def __init__(self, patchsize, hidden=128): # hidden=128
|
||||
super().__init__()
|
||||
self.attention = MultiHeadedAttention(patchsize, d_model=hidden)
|
||||
self.feed_forward = FeedForward(hidden)
|
||||
|
||||
def forward(self, x):
|
||||
x, m, b, c = x['x'], x['m'], x['b'], x['c']
|
||||
x = x + self.attention(x, m, b, c)
|
||||
x = x + self.feed_forward(x)
|
||||
return {'x': x, 'm': m, 'b': b, 'c': c}
|
||||
|
||||
|
||||
class Discriminator(BaseNetwork):
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
use_sigmoid=False,
|
||||
use_spectral_norm=True,
|
||||
init_weights=True):
|
||||
super(Discriminator, self).__init__()
|
||||
self.use_sigmoid = use_sigmoid
|
||||
nf = 64
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
spectral_norm(
|
||||
nn.Conv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=nf * 1,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=1,
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
spectral_norm(
|
||||
nn.Conv3d(
|
||||
nf * 1,
|
||||
nf * 2,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2),
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
spectral_norm(
|
||||
nn.Conv3d(
|
||||
nf * 2,
|
||||
nf * 4,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2),
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
spectral_norm(
|
||||
nn.Conv3d(
|
||||
nf * 4,
|
||||
nf * 4,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2),
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
spectral_norm(
|
||||
nn.Conv3d(
|
||||
nf * 4,
|
||||
nf * 4,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2),
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv3d(
|
||||
nf * 4,
|
||||
nf * 4,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2)))
|
||||
|
||||
if init_weights:
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, xs):
|
||||
xs_t = torch.transpose(xs, 0, 1)
|
||||
xs_t = xs_t.unsqueeze(0)
|
||||
feat = self.conv(xs_t)
|
||||
if self.use_sigmoid:
|
||||
feat = torch.sigmoid(feat)
|
||||
out = torch.transpose(feat, 1, 2)
|
||||
return out
|
||||
|
||||
|
||||
def spectral_norm(module, mode=True):
|
||||
if mode:
|
||||
return _spectral_norm(module)
|
||||
return module
|
||||
@@ -610,4 +610,9 @@ TASK_OUTPUTS = {
|
||||
# "img_embedding": np.array with shape [1, D],
|
||||
# }
|
||||
Tasks.image_reid_person: [OutputKeys.IMG_EMBEDDING],
|
||||
|
||||
# {
|
||||
# 'output': ['Done' / 'Decode_Error']
|
||||
# }
|
||||
Tasks.video_inpainting: [OutputKeys.OUTPUT]
|
||||
}
|
||||
|
||||
@@ -168,6 +168,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_resnet50-bert_video-scene-segmentation_movienet'),
|
||||
Tasks.shop_segmentation: (Pipelines.shop_segmentation,
|
||||
'damo/cv_vitb16_segmentation_shop-seg'),
|
||||
Tasks.video_inpainting: (Pipelines.video_inpainting,
|
||||
'damo/cv_video-inpainting'),
|
||||
}
|
||||
|
||||
|
||||
|
||||
47
modelscope/pipelines/cv/video_inpainting_pipeline.py
Normal file
47
modelscope/pipelines/cv/video_inpainting_pipeline.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.video_inpainting import inpainting
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.video_inpainting, module_name=Pipelines.video_inpainting)
|
||||
class VideoInpaintingPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create video inpainting pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
|
||||
super().__init__(model=model, **kwargs)
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
return input
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
decode_error, fps, w, h = inpainting.video_process(
|
||||
input['video_input_path'])
|
||||
|
||||
if decode_error is not None:
|
||||
return {OutputKeys.OUTPUT: 'decode_error'}
|
||||
|
||||
inpainting.inpainting_by_model_balance(self.model,
|
||||
input['video_input_path'],
|
||||
input['mask_path'],
|
||||
input['video_output_path'], fps,
|
||||
w, h)
|
||||
|
||||
return {OutputKeys.OUTPUT: 'Done'}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
@@ -70,6 +70,9 @@ class CVTasks(object):
|
||||
crowd_counting = 'crowd-counting'
|
||||
movie_scene_segmentation = 'movie-scene-segmentation'
|
||||
|
||||
# video editing
|
||||
video_inpainting = 'video-inpainting'
|
||||
|
||||
# reid and tracking
|
||||
video_single_object_tracking = 'video-single-object-tracking'
|
||||
video_summarization = 'video-summarization'
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
|
||||
39
tests/pipelines/test_video_inpainting.py
Normal file
39
tests/pipelines/test_video_inpainting.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class VideoInpaintingTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model = 'damo/cv_video-inpainting'
|
||||
self.mask_dir = 'data/test/videos/mask_dir'
|
||||
self.video_in = 'data/test/videos/video_inpainting_test.mp4'
|
||||
self.video_out = 'out.mp4'
|
||||
self.input = {
|
||||
'video_input_path': self.video_in,
|
||||
'video_output_path': self.video_out,
|
||||
'mask_path': self.mask_dir
|
||||
}
|
||||
|
||||
def pipeline_inference(self, pipeline: Pipeline, input: str):
|
||||
result = pipeline(input)
|
||||
print(result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
video_inpainting = pipeline(Tasks.video_inpainting, model=self.model)
|
||||
self.pipeline_inference(video_inpainting, self.input)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_modelhub_default_model(self):
|
||||
video_inpainting = pipeline(Tasks.video_inpainting)
|
||||
self.pipeline_inference(video_inpainting, self.input)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user