Merge pull request #1 from hzwer/main

sync with upstream
This commit is contained in:
Andriy Toloshny
2021-02-02 17:01:11 +00:00
committed by GitHub
12 changed files with 480 additions and 41 deletions

View File

@@ -1,15 +1,16 @@
# RIFE Video Frame Interpolation v1.8
**Our paper has not been officially published yet, and our method and experimental results are under improvement. Due to the incorrect data reference, the latency measurement of Sepconv and TOFlow in our arxiv paper needs to be modified.**
## [arXiv](https://arxiv.org/abs/2011.06294) | [Project Page](https://rife-vfi.github.io) | [Reddit](https://www.reddit.com/r/linux/comments/jy4jjl/opensourced_realtime_video_frame_interpolation/) | [YouTube](https://www.youtube.com/watch?v=60DX2T3zyVo&feature=youtu.be) | [Bilibili](https://www.bilibili.com/video/BV1K541157te?from=search&seid=5131698847373645765)
**1.4 News: We have updated the v1.8 model optimized for 2D animation.**
1.4 News: We have updated the v1.8 model optimized for 2D animation.
**12.13 News: We have updated the v1.6 model and support UHD mode. Please check our [update log](https://github.com/hzwer/arXiv2020-RIFE/issues/41#issuecomment-737651979).**
12.13 News: We have updated the v1.6 model and support UHD mode. Please check our [update log](https://github.com/hzwer/arXiv2020-RIFE/issues/41#issuecomment-737651979).
**11.22 News: We notice a new windows app is trying to integrate RIFE, we hope everyone to try and help them improve. You can download [Flowframes](https://nmkd.itch.io/flowframes) for free.**
11.22 News: We notice a new windows app is trying to integrate RIFE, we hope everyone to try and help them improve. You can download [Flowframes](https://nmkd.itch.io/flowframes) for free.
**There is [a tutorial of RIFE](https://www.youtube.com/watch?v=gf_on-dbwyU&feature=emb_title) on Youtube.**
There is [a tutorial of RIFE](https://www.youtube.com/watch?v=gf_on-dbwyU&feature=emb_title) on Youtube.
**You can easily use [colaboratory](https://colab.research.google.com/github/hzwer/arXiv2020-RIFE/blob/main/Colab_demo.ipynb) to have a try and generate the [our youtube demo](https://www.youtube.com/watch?v=LE2Dzl0oMHI).**
You can easily use [colaboratory](https://colab.research.google.com/github/hzwer/arXiv2020-RIFE/blob/main/Colab_demo.ipynb) to have a try and generate the [our youtube demo](https://www.youtube.com/watch?v=LE2Dzl0oMHI).
Our model can run 30+FPS for 2X 720p interpolation on a 2080Ti GPU. Currently, our method supports 2X,4X,8X... interpolation for 1080p video, and multi-frame interpolation between a pair of images. Everyone is welcome to use our alpha version and make suggestions!
@@ -33,6 +34,8 @@ We are optimizing the visual effects and will support animation in the future. (
* Unzip and move the pretrained parameters to train_log/\*.pkl
**This model is designed to provide better visual effects for users and should not be used for benchmarking.**
### Run
**Video Frame Interpolation**
@@ -80,6 +83,27 @@ You can also use pngs to generate gif:
ffmpeg -r 10 -f image2 -i output/img%d.png -s 448x256 -vf "split[s0][s1];[s0]palettegen=stats_mode=single[p];[s1][p]paletteuse=new=1" output/slomo.gif
```
### Run in docker
Place the pre-trained models in `train_log/\*.pkl` (as above)
Building the container:
```
docker build -t rife -f docker/Dockerfile .
```
Running the container:
```
docker run --rm -it -v $PWD:/host rife:latest inference_video --exp=1 --video=untitled.mp4 --output=untitled_rife.mp4
```
```
docker run --rm -it -v $PWD:/host rife:latest inference_img --img img0.png img1.png --exp=4
```
Using gpu acceleration (requires proper gpu drivers for docker):
```
docker run --rm -it --gpus all -v /dev/dri:/dev/dri -v $PWD:/host rife:latest inference_video --exp=1 --video=untitled.mp4 --output=untitled_rife.mp4
```
## Evaluation
Download [RIFE model](https://drive.google.com/file/d/1c1R7iF-ypN6USo-D2YH_ORtaH3tukSlo/view?usp=sharing) or [RIFE2F1.5C model](https://drive.google.com/file/d/1ve9w-cRWotdvvbU1KcgtsSm12l-JUkeT/view?usp=sharing) reported by our paper.
@@ -119,7 +143,6 @@ python3 -m torch.distributed.launch --nproc_per_node=4 train.py --world_size=4
```
## Reference
<img src="demo/intro.png" alt="img" width=350 />
Optical Flow:
[ARFlow](https://github.com/lliuz/ARFlow) [pytorch-liteflownet](https://github.com/sniklaus/pytorch-liteflownet) [RAFT](https://github.com/princeton-vl/RAFT) [pytorch-PWCNet](https://github.com/sniklaus/pytorch-pwc)

247
RIFE_HDv2.py Normal file
View File

@@ -0,0 +1,247 @@
import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
from model.warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from model.IFNet_HDv2 import *
import torch.nn.functional as F
from model.loss import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=4, stride=2, padding=1, bias=True),
nn.PReLU(out_planes)
)
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
super(Conv2, self).__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
c = 32
class ContextNet(nn.Module):
def __init__(self):
super(ContextNet, self).__init__()
self.conv0 = Conv2(3, c)
self.conv1 = Conv2(c, c)
self.conv2 = Conv2(c, 2*c)
self.conv3 = Conv2(2*c, 4*c)
self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow):
x = self.conv0(x)
x = self.conv1(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
class FusionNet(nn.Module):
def __init__(self):
super(FusionNet, self).__init__()
self.conv0 = Conv2(10, c)
self.down0 = Conv2(c, 2*c)
self.down1 = Conv2(4*c, 4*c)
self.down2 = Conv2(8*c, 8*c)
self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c)
self.up3 = deconv(4*c, c)
self.conv = nn.ConvTranspose2d(c, 4, 4, 2, 1)
def forward(self, img0, img1, flow, c0, c1, flow_gt):
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
if flow_gt == None:
warped_img0_gt, warped_img1_gt = None, None
else:
warped_img0_gt = warp(img0, flow_gt[:, :2])
warped_img1_gt = warp(img1, flow_gt[:, 2:4])
x = self.conv0(torch.cat((warped_img0, warped_img1, flow), 1))
s0 = self.down0(x)
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x)
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
class Model:
def __init__(self, local_rank=-1):
self.flownet = IFNet()
self.contextnet = ContextNet()
self.fusionnet = FusionNet()
self.device()
self.optimG = AdamW(itertools.chain(
self.flownet.parameters(),
self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5)
self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE()
self.ter = Ternary()
self.sobel = SOBEL()
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[
local_rank], output_device=local_rank)
self.contextnet = DDP(self.contextnet, device_ids=[
local_rank], output_device=local_rank)
self.fusionnet = DDP(self.fusionnet, device_ids=[
local_rank], output_device=local_rank)
def train(self):
self.flownet.train()
self.contextnet.train()
self.fusionnet.train()
def eval(self):
self.flownet.eval()
self.contextnet.eval()
self.fusionnet.eval()
def device(self):
self.flownet.to(device)
self.contextnet.to(device)
self.fusionnet.to(device)
def load_model(self, path, rank):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param
if rank <= 0:
self.flownet.load_state_dict(
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
self.contextnet.load_state_dict(
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device)))
self.fusionnet.load_state_dict(
convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
def save_model(self, path, rank):
if rank == 0:
torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path))
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False):
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if UHD:
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False) * 2.0
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
img0, img1, flow, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
mask = torch.sigmoid(refine_output[:, 3:4])
merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
pred = merged_img + res
pred = torch.clamp(pred, 0, 1)
if training:
return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
else:
return pred
def inference(self, img0, img1, UHD=False):
imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(imgs, UHD)
return self.predict(imgs, flow, training=False, UHD=UHD)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
if training:
self.train()
else:
self.eval()
flow, flow_list = self.flownet(imgs)
pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
imgs, flow, flow_gt=flow_gt)
loss_ter = self.ter(pred, gt).mean()
if training:
with torch.no_grad():
loss_flow = torch.abs(warped_img0_gt - gt).mean()
loss_mask = torch.abs(
merged_img - gt).sum(1, True).float().detach()
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
align_corners=False).detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5).detach()
loss_cons = 0
for i in range(4):
loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
loss_cons = loss_cons.mean() * 0.01
else:
loss_cons = torch.tensor([0])
loss_flow = torch.abs(warped_img0 - gt).mean()
loss_mask = 1
loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean()
if training:
self.optimG.zero_grad()
loss_G = loss_l1 + loss_cons + loss_ter
loss_G.backward()
self.optimG.step()
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1)
model = Model()
model.eval()
print(model.inference(imgs).shape)

View File

@@ -12,7 +12,7 @@ from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model()
model.load_model('./train_log')
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'))
model.eval()
model.device()

View File

@@ -13,7 +13,7 @@ from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model()
model.load_model('./train_log')
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'))
model.eval()
model.device()

View File

@@ -3,6 +3,7 @@ import torch.nn.functional as F
from math import exp
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
@@ -11,7 +12,7 @@ def gaussian(window_size, sigma):
def create_window(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).cuda()
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window
@@ -19,7 +20,7 @@ def create_window_3d(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t())
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().cuda()
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
return window

29
benchmark/testtime.py Normal file
View File

@@ -0,0 +1,29 @@
import cv2
import sys
sys.path.append('.')
import time
import torch
import torch.nn as nn
from model.RIFE import Model
model = Model()
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
I0 = torch.rand(1, 3, 480, 640).to(device)
I1 = torch.rand(1, 3, 480, 640).to(device)
with torch.no_grad():
for i in range(100):
pred = model.inference(I0, I1)
if torch.cuda.is_available():
torch.cuda.synchronize()
time_stamp = time.time()
for i in range(100):
pred = model.inference(I0, I1)
if torch.cuda.is_available():
torch.cuda.synchronize()
print((time.time() - time_stamp) / 100)

23
docker/Dockerfile Normal file
View File

@@ -0,0 +1,23 @@
FROM python:3.8-slim
# install deps
RUN apt-get update && apt-get -y install \
bash ffmpeg
# setup RIFE
WORKDIR /rife
COPY . .
RUN pip3 install -r requirements.txt
ADD docker/inference_img /usr/local/bin/inference_img
RUN chmod +x /usr/local/bin/inference_img
ADD docker/inference_video /usr/local/bin/inference_video
RUN chmod +x /usr/local/bin/inference_video
# add pre-trained models
COPY train_log /rife/train_log
WORKDIR /host
ENTRYPOINT ["/bin/bash"]
ENV NVIDIA_DRIVER_CAPABILITIES all

2
docker/inference_img Normal file
View File

@@ -0,0 +1,2 @@
#!/bin/sh
python3 /rife/inference_img.py $@

2
docker/inference_video Normal file
View File

@@ -0,0 +1,2 @@
#!/bin/sh
python3 /rife/inference_video.py $@

View File

@@ -19,15 +19,22 @@ parser.add_argument('--exp', default=4, type=int)
args = parser.parse_args()
model = Model()
model.load_model('./train_log', -1)
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'), -1)
model.eval()
model.device()
img0 = cv2.imread(args.img[0])
img1 = cv2.imread(args.img[1])
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
else:
img0 = cv2.imread(args.img[0])
img1 = cv2.imread(args.img[1])
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
n, c, h, w = img0.shape
ph = ((h - 1) // 32 + 1) * 32
pw = ((w - 1) // 32 + 1) * 32
@@ -48,4 +55,7 @@ for i in range(args.exp):
if not os.path.exists('output'):
os.mkdir('output')
for i in range(len(img_list)):
cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
else:
cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])

View File

@@ -9,9 +9,11 @@ import warnings
import _thread
import skvideo.io
from queue import Queue, Empty
from benchmark.pytorch_msssim import ssim_matlab
warnings.filterwarnings("ignore")
def transferAudio(sourceVideo, targetVideo):
def transferAudio(sourceVideo, targetVideo):
import shutil
import moviepy.editor
tempAudioFileName = "./temp/audio.mkv"
@@ -27,25 +29,26 @@ def transferAudio(sourceVideo, targetVideo):
os.makedirs("temp")
# extract audio from video
os.system("ffmpeg -y -i " + sourceVideo + " -c:a copy -vn " + tempAudioFileName)
os.rename(targetVideo, "noAudio_"+targetVideo)
# combine audio file and new video file
os.system("ffmpeg -y -i " + "noAudio_"+targetVideo + " -i " + tempAudioFileName + " -c copy " + targetVideo)
if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to mp3
tempAudioFileName = "./temp/audio.mp3"
os.system("ffmpeg -y -i " + sourceVideo + " -c:a mp3 -vn " + tempAudioFileName)
os.system("ffmpeg -y -i " + "noAudio_"+targetVideo + " -i " + tempAudioFileName + " -c copy " + targetVideo)
if (os.path.getsize(targetVideo) == 0): # if mp3 not supported by selected format
os.rename("noAudio_"+targetVideo, targetVideo)
targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
os.rename(targetVideo, targetNoAudio)
# combine audio file and new video file
os.system("ffmpeg -y -i " + targetNoAudio + " -i " + tempAudioFileName + " -c copy " + targetVideo)
if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac
tempAudioFileName = "./temp/audio.m4a"
os.system("ffmpeg -y -i " + sourceVideo + " -c:a aac -b:a 160k -vn " + tempAudioFileName)
os.system("ffmpeg -y -i " + targetNoAudio + " -i " + tempAudioFileName + " -c copy " + targetVideo)
if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format
os.rename(targetNoAudio, targetVideo)
print("Audio transfer failed. Interpolated video will have no audio")
else:
print("Lossless audio transfer failed. Audio was transcoded to mp3 instead.")
print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.")
# remove audio-less video
os.remove("noAudio_"+targetVideo)
os.remove(targetNoAudio)
else:
os.remove("noAudio_"+targetVideo)
os.remove(targetNoAudio)
# remove temp directory
shutil.rmtree("temp")
@@ -59,6 +62,7 @@ if torch.cuda.is_available():
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
parser.add_argument('--video', dest='video', type=str, default=None)
parser.add_argument('--output', dest='output', type=str, default=None)
parser.add_argument('--img', dest='img', type=str, default=None)
parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
@@ -74,7 +78,7 @@ if not args.img is None:
from model.RIFE_HD import Model
model = Model()
model.load_model('./train_log', -1)
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'), -1)
model.eval()
model.device()
@@ -105,14 +109,19 @@ else:
tot_frame = len(videogen)
videogen.sort(key= lambda x:int(x[:-4]))
lastframe = cv2.imread(os.path.join(args.img, videogen[0]))[:, :, ::-1].copy()
videogen = videogen[1:]
videogen = videogen[1:]
h, w, _ = lastframe.shape
vid_out_name = None
vid_out = None
if args.png:
if not os.path.exists('vid_out'):
os.mkdir('vid_out')
else:
vid_out = cv2.VideoWriter('{}_{}X_{}fps.{}'.format(video_path_wo_ext, (2 ** args.exp), int(np.round(args.fps)), args.ext), fourcc, args.fps, (w, h))
if args.output is not None:
vid_out_name = args.output
else:
vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, (2 ** args.exp), int(np.round(args.fps)), args.ext)
vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h))
def clear_write_buffer(user_args, write_buffer):
cnt = 0
@@ -172,15 +181,16 @@ while True:
I0 = I1
I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
I1 = F.pad(I1, padding)
diff = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False)
- F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs()
if diff.max() < 2e-3 and args.skip:
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
ssim = ssim_matlab(I0_small, I1_small)
if ssim > 0.995 and args.skip:
if skip_frame % 100 == 0:
print("Warning: Your video has {} static frames, skipping them may change the duration of the generated video.".format(skip_frame))
skip_frame += 1
pbar.update(1)
continue
if diff.mean() > 0.15:
if ssim < 0.5:
output = []
for i in range((2 ** args.exp) - 1):
output.append(I0)
@@ -211,9 +221,9 @@ if not vid_out is None:
# move audio to new video file if appropriate
if args.png == False and fpsNotAssigned == True and not args.skip and not args.video is None:
outputVideoFileName = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, 2 ** args.exp, int(np.round(args.fps)), args.ext)
try:
transferAudio(args.video, outputVideoFileName)
transferAudio(args.video, vid_out_name)
except:
print("Audio transfer failed. Interpolated video will have no audio")
os.rename("noAudio_"+outputVideoFileName, outputVideoFileName)
targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1]
os.rename(targetNoAudio, vid_out_name)

92
model/IFNet_HDv2.py Normal file
View File

@@ -0,0 +1,92 @@
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from model.warplayer import warp
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
class IFBlock(nn.Module):
def __init__(self, in_planes, scale=1, c=64):
super(IFBlock, self).__init__()
self.scale = scale
self.conv0 = conv(in_planes, c, 5, 2, 2)
self.convblock = nn.Sequential(
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
)
self.conv1 = nn.Conv2d(c, 4, 3, 1, 1)
def forward(self, x):
if self.scale != 1:
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear",
align_corners=False)
x = self.conv0(x)
x = self.convblock(x)
x = self.conv1(x)
flow = x
if self.scale != 1:
flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear",
align_corners=False)
return flow
class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(6, scale=8, c=192)
self.block1 = IFBlock(10, scale=4, c=128)
self.block2 = IFBlock(10, scale=2, c=96)
self.block3 = IFBlock(10, scale=1, c=48)
def forward(self, x, UHD=False):
if UHD:
x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
flow0 = self.block0(x)
F1 = flow0
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0
warped_img0 = warp(x[:, :3], F1_large[:, :2])
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
F2 = (flow0 + flow1)
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0
warped_img0 = warp(x[:, :3], F2_large[:, :2])
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
F3 = (flow0 + flow1 + flow2)
F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0
warped_img0 = warp(x[:, :3], F3_large[:, :2])
warped_img1 = warp(x[:, 3:], F3_large[:, 2:4])
flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1))
F4 = (flow0 + flow1 + flow2 + flow3)
return F4, [F1, F2, F3, F4]
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1)
flownet = IFNet()
flow, _ = flownet(imgs)
print(flow.shape)