mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2026-02-24 04:19:41 +01:00
35
README.md
35
README.md
@@ -1,15 +1,16 @@
|
||||
# RIFE Video Frame Interpolation v1.8
|
||||
**Our paper has not been officially published yet, and our method and experimental results are under improvement. Due to the incorrect data reference, the latency measurement of Sepconv and TOFlow in our arxiv paper needs to be modified.**
|
||||
## [arXiv](https://arxiv.org/abs/2011.06294) | [Project Page](https://rife-vfi.github.io) | [Reddit](https://www.reddit.com/r/linux/comments/jy4jjl/opensourced_realtime_video_frame_interpolation/) | [YouTube](https://www.youtube.com/watch?v=60DX2T3zyVo&feature=youtu.be) | [Bilibili](https://www.bilibili.com/video/BV1K541157te?from=search&seid=5131698847373645765)
|
||||
|
||||
**1.4 News: We have updated the v1.8 model optimized for 2D animation.**
|
||||
1.4 News: We have updated the v1.8 model optimized for 2D animation.
|
||||
|
||||
**12.13 News: We have updated the v1.6 model and support UHD mode. Please check our [update log](https://github.com/hzwer/arXiv2020-RIFE/issues/41#issuecomment-737651979).**
|
||||
12.13 News: We have updated the v1.6 model and support UHD mode. Please check our [update log](https://github.com/hzwer/arXiv2020-RIFE/issues/41#issuecomment-737651979).
|
||||
|
||||
**11.22 News: We notice a new windows app is trying to integrate RIFE, we hope everyone to try and help them improve. You can download [Flowframes](https://nmkd.itch.io/flowframes) for free.**
|
||||
11.22 News: We notice a new windows app is trying to integrate RIFE, we hope everyone to try and help them improve. You can download [Flowframes](https://nmkd.itch.io/flowframes) for free.
|
||||
|
||||
**There is [a tutorial of RIFE](https://www.youtube.com/watch?v=gf_on-dbwyU&feature=emb_title) on Youtube.**
|
||||
There is [a tutorial of RIFE](https://www.youtube.com/watch?v=gf_on-dbwyU&feature=emb_title) on Youtube.
|
||||
|
||||
**You can easily use [colaboratory](https://colab.research.google.com/github/hzwer/arXiv2020-RIFE/blob/main/Colab_demo.ipynb) to have a try and generate the [our youtube demo](https://www.youtube.com/watch?v=LE2Dzl0oMHI).**
|
||||
You can easily use [colaboratory](https://colab.research.google.com/github/hzwer/arXiv2020-RIFE/blob/main/Colab_demo.ipynb) to have a try and generate the [our youtube demo](https://www.youtube.com/watch?v=LE2Dzl0oMHI).
|
||||
|
||||
Our model can run 30+FPS for 2X 720p interpolation on a 2080Ti GPU. Currently, our method supports 2X,4X,8X... interpolation for 1080p video, and multi-frame interpolation between a pair of images. Everyone is welcome to use our alpha version and make suggestions!
|
||||
|
||||
@@ -33,6 +34,8 @@ We are optimizing the visual effects and will support animation in the future. (
|
||||
|
||||
* Unzip and move the pretrained parameters to train_log/\*.pkl
|
||||
|
||||
**This model is designed to provide better visual effects for users and should not be used for benchmarking.**
|
||||
|
||||
### Run
|
||||
|
||||
**Video Frame Interpolation**
|
||||
@@ -80,6 +83,27 @@ You can also use pngs to generate gif:
|
||||
ffmpeg -r 10 -f image2 -i output/img%d.png -s 448x256 -vf "split[s0][s1];[s0]palettegen=stats_mode=single[p];[s1][p]paletteuse=new=1" output/slomo.gif
|
||||
```
|
||||
|
||||
### Run in docker
|
||||
Place the pre-trained models in `train_log/\*.pkl` (as above)
|
||||
|
||||
Building the container:
|
||||
```
|
||||
docker build -t rife -f docker/Dockerfile .
|
||||
```
|
||||
|
||||
Running the container:
|
||||
```
|
||||
docker run --rm -it -v $PWD:/host rife:latest inference_video --exp=1 --video=untitled.mp4 --output=untitled_rife.mp4
|
||||
```
|
||||
```
|
||||
docker run --rm -it -v $PWD:/host rife:latest inference_img --img img0.png img1.png --exp=4
|
||||
```
|
||||
|
||||
Using gpu acceleration (requires proper gpu drivers for docker):
|
||||
```
|
||||
docker run --rm -it --gpus all -v /dev/dri:/dev/dri -v $PWD:/host rife:latest inference_video --exp=1 --video=untitled.mp4 --output=untitled_rife.mp4
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
Download [RIFE model](https://drive.google.com/file/d/1c1R7iF-ypN6USo-D2YH_ORtaH3tukSlo/view?usp=sharing) or [RIFE2F1.5C model](https://drive.google.com/file/d/1ve9w-cRWotdvvbU1KcgtsSm12l-JUkeT/view?usp=sharing) reported by our paper.
|
||||
|
||||
@@ -119,7 +143,6 @@ python3 -m torch.distributed.launch --nproc_per_node=4 train.py --world_size=4
|
||||
```
|
||||
|
||||
## Reference
|
||||
<img src="demo/intro.png" alt="img" width=350 />
|
||||
|
||||
Optical Flow:
|
||||
[ARFlow](https://github.com/lliuz/ARFlow) [pytorch-liteflownet](https://github.com/sniklaus/pytorch-liteflownet) [RAFT](https://github.com/princeton-vl/RAFT) [pytorch-PWCNet](https://github.com/sniklaus/pytorch-pwc)
|
||||
|
||||
247
RIFE_HDv2.py
Normal file
247
RIFE_HDv2.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from torch.optim import AdamW
|
||||
import torch.optim as optim
|
||||
import itertools
|
||||
from model.warplayer import warp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from model.IFNet_HDv2 import *
|
||||
import torch.nn.functional as F
|
||||
from model.loss import *
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
|
||||
kernel_size=4, stride=2, padding=1, bias=True),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
)
|
||||
|
||||
class Conv2(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride=2):
|
||||
super(Conv2, self).__init__()
|
||||
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
|
||||
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
c = 32
|
||||
|
||||
class ContextNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(ContextNet, self).__init__()
|
||||
self.conv0 = Conv2(3, c)
|
||||
self.conv1 = Conv2(c, c)
|
||||
self.conv2 = Conv2(c, 2*c)
|
||||
self.conv3 = Conv2(2*c, 4*c)
|
||||
self.conv4 = Conv2(4*c, 8*c)
|
||||
|
||||
def forward(self, x, flow):
|
||||
x = self.conv0(x)
|
||||
x = self.conv1(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
|
||||
f1 = warp(x, flow)
|
||||
x = self.conv2(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False) * 0.5
|
||||
f2 = warp(x, flow)
|
||||
x = self.conv3(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False) * 0.5
|
||||
f3 = warp(x, flow)
|
||||
x = self.conv4(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False) * 0.5
|
||||
f4 = warp(x, flow)
|
||||
return [f1, f2, f3, f4]
|
||||
|
||||
|
||||
class FusionNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(FusionNet, self).__init__()
|
||||
self.conv0 = Conv2(10, c)
|
||||
self.down0 = Conv2(c, 2*c)
|
||||
self.down1 = Conv2(4*c, 4*c)
|
||||
self.down2 = Conv2(8*c, 8*c)
|
||||
self.down3 = Conv2(16*c, 16*c)
|
||||
self.up0 = deconv(32*c, 8*c)
|
||||
self.up1 = deconv(16*c, 4*c)
|
||||
self.up2 = deconv(8*c, 2*c)
|
||||
self.up3 = deconv(4*c, c)
|
||||
self.conv = nn.ConvTranspose2d(c, 4, 4, 2, 1)
|
||||
|
||||
def forward(self, img0, img1, flow, c0, c1, flow_gt):
|
||||
warped_img0 = warp(img0, flow[:, :2])
|
||||
warped_img1 = warp(img1, flow[:, 2:4])
|
||||
if flow_gt == None:
|
||||
warped_img0_gt, warped_img1_gt = None, None
|
||||
else:
|
||||
warped_img0_gt = warp(img0, flow_gt[:, :2])
|
||||
warped_img1_gt = warp(img1, flow_gt[:, 2:4])
|
||||
x = self.conv0(torch.cat((warped_img0, warped_img1, flow), 1))
|
||||
s0 = self.down0(x)
|
||||
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
|
||||
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
|
||||
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
|
||||
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
|
||||
x = self.up1(torch.cat((x, s2), 1))
|
||||
x = self.up2(torch.cat((x, s1), 1))
|
||||
x = self.up3(torch.cat((x, s0), 1))
|
||||
x = self.conv(x)
|
||||
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, local_rank=-1):
|
||||
self.flownet = IFNet()
|
||||
self.contextnet = ContextNet()
|
||||
self.fusionnet = FusionNet()
|
||||
self.device()
|
||||
self.optimG = AdamW(itertools.chain(
|
||||
self.flownet.parameters(),
|
||||
self.contextnet.parameters(),
|
||||
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5)
|
||||
self.schedulerG = optim.lr_scheduler.CyclicLR(
|
||||
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
|
||||
self.epe = EPE()
|
||||
self.ter = Ternary()
|
||||
self.sobel = SOBEL()
|
||||
if local_rank != -1:
|
||||
self.flownet = DDP(self.flownet, device_ids=[
|
||||
local_rank], output_device=local_rank)
|
||||
self.contextnet = DDP(self.contextnet, device_ids=[
|
||||
local_rank], output_device=local_rank)
|
||||
self.fusionnet = DDP(self.fusionnet, device_ids=[
|
||||
local_rank], output_device=local_rank)
|
||||
|
||||
def train(self):
|
||||
self.flownet.train()
|
||||
self.contextnet.train()
|
||||
self.fusionnet.train()
|
||||
|
||||
def eval(self):
|
||||
self.flownet.eval()
|
||||
self.contextnet.eval()
|
||||
self.fusionnet.eval()
|
||||
|
||||
def device(self):
|
||||
self.flownet.to(device)
|
||||
self.contextnet.to(device)
|
||||
self.fusionnet.to(device)
|
||||
|
||||
def load_model(self, path, rank):
|
||||
def convert(param):
|
||||
if rank == -1:
|
||||
return {
|
||||
k.replace("module.", ""): v
|
||||
for k, v in param.items()
|
||||
if "module." in k
|
||||
}
|
||||
else:
|
||||
return param
|
||||
if rank <= 0:
|
||||
self.flownet.load_state_dict(
|
||||
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
|
||||
self.contextnet.load_state_dict(
|
||||
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device)))
|
||||
self.fusionnet.load_state_dict(
|
||||
convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
|
||||
|
||||
def save_model(self, path, rank):
|
||||
if rank == 0:
|
||||
torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path))
|
||||
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
|
||||
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
|
||||
|
||||
def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False):
|
||||
img0 = imgs[:, :3]
|
||||
img1 = imgs[:, 3:]
|
||||
if UHD:
|
||||
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
|
||||
c0 = self.contextnet(img0, flow[:, :2])
|
||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
|
||||
align_corners=False) * 2.0
|
||||
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
|
||||
img0, img1, flow, c0, c1, flow_gt)
|
||||
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
|
||||
mask = torch.sigmoid(refine_output[:, 3:4])
|
||||
merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
|
||||
pred = merged_img + res
|
||||
pred = torch.clamp(pred, 0, 1)
|
||||
if training:
|
||||
return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
|
||||
else:
|
||||
return pred
|
||||
|
||||
def inference(self, img0, img1, UHD=False):
|
||||
imgs = torch.cat((img0, img1), 1)
|
||||
flow, _ = self.flownet(imgs, UHD)
|
||||
return self.predict(imgs, flow, training=False, UHD=UHD)
|
||||
|
||||
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
|
||||
for param_group in self.optimG.param_groups:
|
||||
param_group['lr'] = learning_rate
|
||||
if training:
|
||||
self.train()
|
||||
else:
|
||||
self.eval()
|
||||
flow, flow_list = self.flownet(imgs)
|
||||
pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
|
||||
imgs, flow, flow_gt=flow_gt)
|
||||
loss_ter = self.ter(pred, gt).mean()
|
||||
if training:
|
||||
with torch.no_grad():
|
||||
loss_flow = torch.abs(warped_img0_gt - gt).mean()
|
||||
loss_mask = torch.abs(
|
||||
merged_img - gt).sum(1, True).float().detach()
|
||||
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False).detach()
|
||||
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False) * 0.5).detach()
|
||||
loss_cons = 0
|
||||
for i in range(4):
|
||||
loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
|
||||
loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
|
||||
loss_cons = loss_cons.mean() * 0.01
|
||||
else:
|
||||
loss_cons = torch.tensor([0])
|
||||
loss_flow = torch.abs(warped_img0 - gt).mean()
|
||||
loss_mask = 1
|
||||
loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean()
|
||||
if training:
|
||||
self.optimG.zero_grad()
|
||||
loss_G = loss_l1 + loss_cons + loss_ter
|
||||
loss_G.backward()
|
||||
self.optimG.step()
|
||||
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
|
||||
img1 = torch.tensor(np.random.normal(
|
||||
0, 1, (3, 3, 256, 256))).float().to(device)
|
||||
imgs = torch.cat((img0, img1), 1)
|
||||
model = Model()
|
||||
model.eval()
|
||||
print(model.inference(imgs).shape)
|
||||
@@ -12,7 +12,7 @@ from model.RIFE import Model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = Model()
|
||||
model.load_model('./train_log')
|
||||
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'))
|
||||
model.eval()
|
||||
model.device()
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from model.RIFE import Model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = Model()
|
||||
model.load_model('./train_log')
|
||||
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'))
|
||||
model.eval()
|
||||
model.device()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import torch.nn.functional as F
|
||||
from math import exp
|
||||
import numpy as np
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
||||
@@ -11,7 +12,7 @@ def gaussian(window_size, sigma):
|
||||
|
||||
def create_window(window_size, channel=1):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).cuda()
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
|
||||
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
||||
return window
|
||||
|
||||
@@ -19,7 +20,7 @@ def create_window_3d(window_size, channel=1):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t())
|
||||
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
|
||||
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().cuda()
|
||||
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
|
||||
return window
|
||||
|
||||
|
||||
|
||||
29
benchmark/testtime.py
Normal file
29
benchmark/testtime.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import cv2
|
||||
import sys
|
||||
sys.path.append('.')
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model.RIFE import Model
|
||||
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
torch.set_grad_enabled(False)
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
I0 = torch.rand(1, 3, 480, 640).to(device)
|
||||
I1 = torch.rand(1, 3, 480, 640).to(device)
|
||||
with torch.no_grad():
|
||||
for i in range(100):
|
||||
pred = model.inference(I0, I1)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
time_stamp = time.time()
|
||||
for i in range(100):
|
||||
pred = model.inference(I0, I1)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
print((time.time() - time_stamp) / 100)
|
||||
23
docker/Dockerfile
Normal file
23
docker/Dockerfile
Normal file
@@ -0,0 +1,23 @@
|
||||
FROM python:3.8-slim
|
||||
|
||||
# install deps
|
||||
RUN apt-get update && apt-get -y install \
|
||||
bash ffmpeg
|
||||
|
||||
# setup RIFE
|
||||
WORKDIR /rife
|
||||
COPY . .
|
||||
RUN pip3 install -r requirements.txt
|
||||
|
||||
ADD docker/inference_img /usr/local/bin/inference_img
|
||||
RUN chmod +x /usr/local/bin/inference_img
|
||||
ADD docker/inference_video /usr/local/bin/inference_video
|
||||
RUN chmod +x /usr/local/bin/inference_video
|
||||
|
||||
# add pre-trained models
|
||||
COPY train_log /rife/train_log
|
||||
|
||||
WORKDIR /host
|
||||
ENTRYPOINT ["/bin/bash"]
|
||||
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES all
|
||||
2
docker/inference_img
Normal file
2
docker/inference_img
Normal file
@@ -0,0 +1,2 @@
|
||||
#!/bin/sh
|
||||
python3 /rife/inference_img.py $@
|
||||
2
docker/inference_video
Normal file
2
docker/inference_video
Normal file
@@ -0,0 +1,2 @@
|
||||
#!/bin/sh
|
||||
python3 /rife/inference_video.py $@
|
||||
@@ -19,15 +19,22 @@ parser.add_argument('--exp', default=4, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
model = Model()
|
||||
model.load_model('./train_log', -1)
|
||||
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'), -1)
|
||||
model.eval()
|
||||
model.device()
|
||||
|
||||
img0 = cv2.imread(args.img[0])
|
||||
img1 = cv2.imread(args.img[1])
|
||||
|
||||
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
|
||||
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
|
||||
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
|
||||
img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
|
||||
img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
|
||||
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
|
||||
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
|
||||
|
||||
else:
|
||||
img0 = cv2.imread(args.img[0])
|
||||
img1 = cv2.imread(args.img[1])
|
||||
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
|
||||
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
|
||||
|
||||
n, c, h, w = img0.shape
|
||||
ph = ((h - 1) // 32 + 1) * 32
|
||||
pw = ((w - 1) // 32 + 1) * 32
|
||||
@@ -48,4 +55,7 @@ for i in range(args.exp):
|
||||
if not os.path.exists('output'):
|
||||
os.mkdir('output')
|
||||
for i in range(len(img_list)):
|
||||
cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
|
||||
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
|
||||
cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
|
||||
else:
|
||||
cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
|
||||
|
||||
@@ -9,9 +9,11 @@ import warnings
|
||||
import _thread
|
||||
import skvideo.io
|
||||
from queue import Queue, Empty
|
||||
from benchmark.pytorch_msssim import ssim_matlab
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
def transferAudio(sourceVideo, targetVideo):
|
||||
def transferAudio(sourceVideo, targetVideo):
|
||||
import shutil
|
||||
import moviepy.editor
|
||||
tempAudioFileName = "./temp/audio.mkv"
|
||||
@@ -27,25 +29,26 @@ def transferAudio(sourceVideo, targetVideo):
|
||||
os.makedirs("temp")
|
||||
# extract audio from video
|
||||
os.system("ffmpeg -y -i " + sourceVideo + " -c:a copy -vn " + tempAudioFileName)
|
||||
|
||||
os.rename(targetVideo, "noAudio_"+targetVideo)
|
||||
# combine audio file and new video file
|
||||
os.system("ffmpeg -y -i " + "noAudio_"+targetVideo + " -i " + tempAudioFileName + " -c copy " + targetVideo)
|
||||
|
||||
if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to mp3
|
||||
tempAudioFileName = "./temp/audio.mp3"
|
||||
os.system("ffmpeg -y -i " + sourceVideo + " -c:a mp3 -vn " + tempAudioFileName)
|
||||
os.system("ffmpeg -y -i " + "noAudio_"+targetVideo + " -i " + tempAudioFileName + " -c copy " + targetVideo)
|
||||
if (os.path.getsize(targetVideo) == 0): # if mp3 not supported by selected format
|
||||
os.rename("noAudio_"+targetVideo, targetVideo)
|
||||
targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
|
||||
os.rename(targetVideo, targetNoAudio)
|
||||
# combine audio file and new video file
|
||||
os.system("ffmpeg -y -i " + targetNoAudio + " -i " + tempAudioFileName + " -c copy " + targetVideo)
|
||||
|
||||
if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac
|
||||
tempAudioFileName = "./temp/audio.m4a"
|
||||
os.system("ffmpeg -y -i " + sourceVideo + " -c:a aac -b:a 160k -vn " + tempAudioFileName)
|
||||
os.system("ffmpeg -y -i " + targetNoAudio + " -i " + tempAudioFileName + " -c copy " + targetVideo)
|
||||
if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format
|
||||
os.rename(targetNoAudio, targetVideo)
|
||||
print("Audio transfer failed. Interpolated video will have no audio")
|
||||
else:
|
||||
print("Lossless audio transfer failed. Audio was transcoded to mp3 instead.")
|
||||
print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.")
|
||||
|
||||
# remove audio-less video
|
||||
os.remove("noAudio_"+targetVideo)
|
||||
os.remove(targetNoAudio)
|
||||
else:
|
||||
os.remove("noAudio_"+targetVideo)
|
||||
os.remove(targetNoAudio)
|
||||
|
||||
# remove temp directory
|
||||
shutil.rmtree("temp")
|
||||
@@ -59,6 +62,7 @@ if torch.cuda.is_available():
|
||||
|
||||
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
|
||||
parser.add_argument('--video', dest='video', type=str, default=None)
|
||||
parser.add_argument('--output', dest='output', type=str, default=None)
|
||||
parser.add_argument('--img', dest='img', type=str, default=None)
|
||||
parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
|
||||
parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
|
||||
@@ -74,7 +78,7 @@ if not args.img is None:
|
||||
|
||||
from model.RIFE_HD import Model
|
||||
model = Model()
|
||||
model.load_model('./train_log', -1)
|
||||
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'), -1)
|
||||
model.eval()
|
||||
model.device()
|
||||
|
||||
@@ -105,14 +109,19 @@ else:
|
||||
tot_frame = len(videogen)
|
||||
videogen.sort(key= lambda x:int(x[:-4]))
|
||||
lastframe = cv2.imread(os.path.join(args.img, videogen[0]))[:, :, ::-1].copy()
|
||||
videogen = videogen[1:]
|
||||
videogen = videogen[1:]
|
||||
h, w, _ = lastframe.shape
|
||||
vid_out_name = None
|
||||
vid_out = None
|
||||
if args.png:
|
||||
if not os.path.exists('vid_out'):
|
||||
os.mkdir('vid_out')
|
||||
else:
|
||||
vid_out = cv2.VideoWriter('{}_{}X_{}fps.{}'.format(video_path_wo_ext, (2 ** args.exp), int(np.round(args.fps)), args.ext), fourcc, args.fps, (w, h))
|
||||
if args.output is not None:
|
||||
vid_out_name = args.output
|
||||
else:
|
||||
vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, (2 ** args.exp), int(np.round(args.fps)), args.ext)
|
||||
vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h))
|
||||
|
||||
def clear_write_buffer(user_args, write_buffer):
|
||||
cnt = 0
|
||||
@@ -172,15 +181,16 @@ while True:
|
||||
I0 = I1
|
||||
I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
||||
I1 = F.pad(I1, padding)
|
||||
diff = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False)
|
||||
- F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs()
|
||||
if diff.max() < 2e-3 and args.skip:
|
||||
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
|
||||
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
||||
ssim = ssim_matlab(I0_small, I1_small)
|
||||
if ssim > 0.995 and args.skip:
|
||||
if skip_frame % 100 == 0:
|
||||
print("Warning: Your video has {} static frames, skipping them may change the duration of the generated video.".format(skip_frame))
|
||||
skip_frame += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
if diff.mean() > 0.15:
|
||||
if ssim < 0.5:
|
||||
output = []
|
||||
for i in range((2 ** args.exp) - 1):
|
||||
output.append(I0)
|
||||
@@ -211,9 +221,9 @@ if not vid_out is None:
|
||||
|
||||
# move audio to new video file if appropriate
|
||||
if args.png == False and fpsNotAssigned == True and not args.skip and not args.video is None:
|
||||
outputVideoFileName = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, 2 ** args.exp, int(np.round(args.fps)), args.ext)
|
||||
try:
|
||||
transferAudio(args.video, outputVideoFileName)
|
||||
transferAudio(args.video, vid_out_name)
|
||||
except:
|
||||
print("Audio transfer failed. Interpolated video will have no audio")
|
||||
os.rename("noAudio_"+outputVideoFileName, outputVideoFileName)
|
||||
targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1]
|
||||
os.rename(targetNoAudio, vid_out_name)
|
||||
|
||||
92
model/IFNet_HDv2.py
Normal file
92
model/IFNet_HDv2.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from model.warplayer import warp
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
)
|
||||
|
||||
|
||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
class IFBlock(nn.Module):
|
||||
def __init__(self, in_planes, scale=1, c=64):
|
||||
super(IFBlock, self).__init__()
|
||||
self.scale = scale
|
||||
self.conv0 = conv(in_planes, c, 5, 2, 2)
|
||||
self.convblock = nn.Sequential(
|
||||
conv(c, c),
|
||||
conv(c, c),
|
||||
conv(c, c),
|
||||
conv(c, c),
|
||||
conv(c, c),
|
||||
conv(c, c),
|
||||
conv(c, c),
|
||||
conv(c, c),
|
||||
)
|
||||
self.conv1 = nn.Conv2d(c, 4, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
if self.scale != 1:
|
||||
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear",
|
||||
align_corners=False)
|
||||
x = self.conv0(x)
|
||||
x = self.convblock(x)
|
||||
x = self.conv1(x)
|
||||
flow = x
|
||||
if self.scale != 1:
|
||||
flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear",
|
||||
align_corners=False)
|
||||
return flow
|
||||
|
||||
|
||||
class IFNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(IFNet, self).__init__()
|
||||
self.block0 = IFBlock(6, scale=8, c=192)
|
||||
self.block1 = IFBlock(10, scale=4, c=128)
|
||||
self.block2 = IFBlock(10, scale=2, c=96)
|
||||
self.block3 = IFBlock(10, scale=1, c=48)
|
||||
|
||||
def forward(self, x, UHD=False):
|
||||
if UHD:
|
||||
x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
|
||||
flow0 = self.block0(x)
|
||||
F1 = flow0
|
||||
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0
|
||||
warped_img0 = warp(x[:, :3], F1_large[:, :2])
|
||||
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
|
||||
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
|
||||
F2 = (flow0 + flow1)
|
||||
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0
|
||||
warped_img0 = warp(x[:, :3], F2_large[:, :2])
|
||||
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
|
||||
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
|
||||
F3 = (flow0 + flow1 + flow2)
|
||||
F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0
|
||||
warped_img0 = warp(x[:, :3], F3_large[:, :2])
|
||||
warped_img1 = warp(x[:, 3:], F3_large[:, 2:4])
|
||||
flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1))
|
||||
F4 = (flow0 + flow1 + flow2 + flow3)
|
||||
return F4, [F1, F2, F3, F4]
|
||||
|
||||
if __name__ == '__main__':
|
||||
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
|
||||
img1 = torch.tensor(np.random.normal(
|
||||
0, 1, (3, 3, 256, 256))).float().to(device)
|
||||
imgs = torch.cat((img0, img1), 1)
|
||||
flownet = IFNet()
|
||||
flow, _ = flownet(imgs)
|
||||
print(flow.shape)
|
||||
Reference in New Issue
Block a user