diff --git a/README.md b/README.md index 7a3b820..9043dd0 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,16 @@ # RIFE Video Frame Interpolation v1.8 +**Our paper has not been officially published yet, and our method and experimental results are under improvement. Due to the incorrect data reference, the latency measurement of Sepconv and TOFlow in our arxiv paper needs to be modified.** ## [arXiv](https://arxiv.org/abs/2011.06294) | [Project Page](https://rife-vfi.github.io) | [Reddit](https://www.reddit.com/r/linux/comments/jy4jjl/opensourced_realtime_video_frame_interpolation/) | [YouTube](https://www.youtube.com/watch?v=60DX2T3zyVo&feature=youtu.be) | [Bilibili](https://www.bilibili.com/video/BV1K541157te?from=search&seid=5131698847373645765) -**1.4 News: We have updated the v1.8 model optimized for 2D animation.** +1.4 News: We have updated the v1.8 model optimized for 2D animation. -**12.13 News: We have updated the v1.6 model and support UHD mode. Please check our [update log](https://github.com/hzwer/arXiv2020-RIFE/issues/41#issuecomment-737651979).** +12.13 News: We have updated the v1.6 model and support UHD mode. Please check our [update log](https://github.com/hzwer/arXiv2020-RIFE/issues/41#issuecomment-737651979). -**11.22 News: We notice a new windows app is trying to integrate RIFE, we hope everyone to try and help them improve. You can download [Flowframes](https://nmkd.itch.io/flowframes) for free.** +11.22 News: We notice a new windows app is trying to integrate RIFE, we hope everyone to try and help them improve. You can download [Flowframes](https://nmkd.itch.io/flowframes) for free. -**There is [a tutorial of RIFE](https://www.youtube.com/watch?v=gf_on-dbwyU&feature=emb_title) on Youtube.** +There is [a tutorial of RIFE](https://www.youtube.com/watch?v=gf_on-dbwyU&feature=emb_title) on Youtube. -**You can easily use [colaboratory](https://colab.research.google.com/github/hzwer/arXiv2020-RIFE/blob/main/Colab_demo.ipynb) to have a try and generate the [our youtube demo](https://www.youtube.com/watch?v=LE2Dzl0oMHI).** +You can easily use [colaboratory](https://colab.research.google.com/github/hzwer/arXiv2020-RIFE/blob/main/Colab_demo.ipynb) to have a try and generate the [our youtube demo](https://www.youtube.com/watch?v=LE2Dzl0oMHI). Our model can run 30+FPS for 2X 720p interpolation on a 2080Ti GPU. Currently, our method supports 2X,4X,8X... interpolation for 1080p video, and multi-frame interpolation between a pair of images. Everyone is welcome to use our alpha version and make suggestions! @@ -33,6 +34,8 @@ We are optimizing the visual effects and will support animation in the future. ( * Unzip and move the pretrained parameters to train_log/\*.pkl +**This model is designed to provide better visual effects for users and should not be used for benchmarking.** + ### Run **Video Frame Interpolation** @@ -80,6 +83,27 @@ You can also use pngs to generate gif: ffmpeg -r 10 -f image2 -i output/img%d.png -s 448x256 -vf "split[s0][s1];[s0]palettegen=stats_mode=single[p];[s1][p]paletteuse=new=1" output/slomo.gif ``` +### Run in docker +Place the pre-trained models in `train_log/\*.pkl` (as above) + +Building the container: +``` +docker build -t rife -f docker/Dockerfile . +``` + +Running the container: +``` +docker run --rm -it -v $PWD:/host rife:latest inference_video --exp=1 --video=untitled.mp4 --output=untitled_rife.mp4 +``` +``` +docker run --rm -it -v $PWD:/host rife:latest inference_img --img img0.png img1.png --exp=4 +``` + +Using gpu acceleration (requires proper gpu drivers for docker): +``` +docker run --rm -it --gpus all -v /dev/dri:/dev/dri -v $PWD:/host rife:latest inference_video --exp=1 --video=untitled.mp4 --output=untitled_rife.mp4 +``` + ## Evaluation Download [RIFE model](https://drive.google.com/file/d/1c1R7iF-ypN6USo-D2YH_ORtaH3tukSlo/view?usp=sharing) or [RIFE2F1.5C model](https://drive.google.com/file/d/1ve9w-cRWotdvvbU1KcgtsSm12l-JUkeT/view?usp=sharing) reported by our paper. @@ -119,7 +143,6 @@ python3 -m torch.distributed.launch --nproc_per_node=4 train.py --world_size=4 ``` ## Reference -img Optical Flow: [ARFlow](https://github.com/lliuz/ARFlow) [pytorch-liteflownet](https://github.com/sniklaus/pytorch-liteflownet) [RAFT](https://github.com/princeton-vl/RAFT) [pytorch-PWCNet](https://github.com/sniklaus/pytorch-pwc) diff --git a/RIFE_HDv2.py b/RIFE_HDv2.py new file mode 100644 index 0000000..9f19ae2 --- /dev/null +++ b/RIFE_HDv2.py @@ -0,0 +1,247 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.optim import AdamW +import torch.optim as optim +import itertools +from model.warplayer import warp +from torch.nn.parallel import DistributedDataParallel as DDP +from model.IFNet_HDv2 import * +import torch.nn.functional as F +from model.loss import * + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=4, stride=2, padding=1, bias=True), + nn.PReLU(out_planes) + ) + +def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + ) + +class Conv2(nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + +c = 32 + +class ContextNet(nn.Module): + def __init__(self): + super(ContextNet, self).__init__() + self.conv0 = Conv2(3, c) + self.conv1 = Conv2(c, c) + self.conv2 = Conv2(c, 2*c) + self.conv3 = Conv2(2*c, 4*c) + self.conv4 = Conv2(4*c, 8*c) + + def forward(self, x, flow): + x = self.conv0(x) + x = self.conv1(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f1 = warp(x, flow) + x = self.conv2(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", + align_corners=False) * 0.5 + f2 = warp(x, flow) + x = self.conv3(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", + align_corners=False) * 0.5 + f3 = warp(x, flow) + x = self.conv4(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", + align_corners=False) * 0.5 + f4 = warp(x, flow) + return [f1, f2, f3, f4] + + +class FusionNet(nn.Module): + def __init__(self): + super(FusionNet, self).__init__() + self.conv0 = Conv2(10, c) + self.down0 = Conv2(c, 2*c) + self.down1 = Conv2(4*c, 4*c) + self.down2 = Conv2(8*c, 8*c) + self.down3 = Conv2(16*c, 16*c) + self.up0 = deconv(32*c, 8*c) + self.up1 = deconv(16*c, 4*c) + self.up2 = deconv(8*c, 2*c) + self.up3 = deconv(4*c, c) + self.conv = nn.ConvTranspose2d(c, 4, 4, 2, 1) + + def forward(self, img0, img1, flow, c0, c1, flow_gt): + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + if flow_gt == None: + warped_img0_gt, warped_img1_gt = None, None + else: + warped_img0_gt = warp(img0, flow_gt[:, :2]) + warped_img1_gt = warp(img1, flow_gt[:, 2:4]) + x = self.conv0(torch.cat((warped_img0, warped_img1, flow), 1)) + s0 = self.down0(x) + s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) + s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) + s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) + x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) + x = self.up1(torch.cat((x, s2), 1)) + x = self.up2(torch.cat((x, s1), 1)) + x = self.up3(torch.cat((x, s0), 1)) + x = self.conv(x) + return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt + + +class Model: + def __init__(self, local_rank=-1): + self.flownet = IFNet() + self.contextnet = ContextNet() + self.fusionnet = FusionNet() + self.device() + self.optimG = AdamW(itertools.chain( + self.flownet.parameters(), + self.contextnet.parameters(), + self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) + self.schedulerG = optim.lr_scheduler.CyclicLR( + self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) + self.epe = EPE() + self.ter = Ternary() + self.sobel = SOBEL() + if local_rank != -1: + self.flownet = DDP(self.flownet, device_ids=[ + local_rank], output_device=local_rank) + self.contextnet = DDP(self.contextnet, device_ids=[ + local_rank], output_device=local_rank) + self.fusionnet = DDP(self.fusionnet, device_ids=[ + local_rank], output_device=local_rank) + + def train(self): + self.flownet.train() + self.contextnet.train() + self.fusionnet.train() + + def eval(self): + self.flownet.eval() + self.contextnet.eval() + self.fusionnet.eval() + + def device(self): + self.flownet.to(device) + self.contextnet.to(device) + self.fusionnet.to(device) + + def load_model(self, path, rank): + def convert(param): + if rank == -1: + return { + k.replace("module.", ""): v + for k, v in param.items() + if "module." in k + } + else: + return param + if rank <= 0: + self.flownet.load_state_dict( + convert(torch.load('{}/flownet.pkl'.format(path), map_location=device))) + self.contextnet.load_state_dict( + convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device))) + self.fusionnet.load_state_dict( + convert(torch.load('{}/unet.pkl'.format(path), map_location=device))) + + def save_model(self, path, rank): + if rank == 0: + torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path)) + torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path)) + torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) + + def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False): + img0 = imgs[:, :3] + img1 = imgs[:, 3:] + if UHD: + flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", + align_corners=False) * 2.0 + refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet( + img0, img1, flow, c0, c1, flow_gt) + res = torch.sigmoid(refine_output[:, :3]) * 2 - 1 + mask = torch.sigmoid(refine_output[:, 3:4]) + merged_img = warped_img0 * mask + warped_img1 * (1 - mask) + pred = merged_img + res + pred = torch.clamp(pred, 0, 1) + if training: + return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt + else: + return pred + + def inference(self, img0, img1, UHD=False): + imgs = torch.cat((img0, img1), 1) + flow, _ = self.flownet(imgs, UHD) + return self.predict(imgs, flow, training=False, UHD=UHD) + + def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): + for param_group in self.optimG.param_groups: + param_group['lr'] = learning_rate + if training: + self.train() + else: + self.eval() + flow, flow_list = self.flownet(imgs) + pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict( + imgs, flow, flow_gt=flow_gt) + loss_ter = self.ter(pred, gt).mean() + if training: + with torch.no_grad(): + loss_flow = torch.abs(warped_img0_gt - gt).mean() + loss_mask = torch.abs( + merged_img - gt).sum(1, True).float().detach() + loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear", + align_corners=False).detach() + flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear", + align_corners=False) * 0.5).detach() + loss_cons = 0 + for i in range(4): + loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1) + loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1) + loss_cons = loss_cons.mean() * 0.01 + else: + loss_cons = torch.tensor([0]) + loss_flow = torch.abs(warped_img0 - gt).mean() + loss_mask = 1 + loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean() + if training: + self.optimG.zero_grad() + loss_G = loss_l1 + loss_cons + loss_ter + loss_G.backward() + self.optimG.step() + return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask + + +if __name__ == '__main__': + img0 = torch.zeros(3, 3, 256, 256).float().to(device) + img1 = torch.tensor(np.random.normal( + 0, 1, (3, 3, 256, 256))).float().to(device) + imgs = torch.cat((img0, img1), 1) + model = Model() + model.eval() + print(model.inference(imgs).shape) diff --git a/benchmark/MiddleBury_Other.py b/benchmark/MiddleBury_Other.py index 18b4e16..7c0f77f 100644 --- a/benchmark/MiddleBury_Other.py +++ b/benchmark/MiddleBury_Other.py @@ -12,7 +12,7 @@ from model.RIFE import Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Model() -model.load_model('./train_log') +model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log')) model.eval() model.device() diff --git a/benchmark/Vimeo90K.py b/benchmark/Vimeo90K.py index c984b4c..cd70e48 100644 --- a/benchmark/Vimeo90K.py +++ b/benchmark/Vimeo90K.py @@ -13,7 +13,7 @@ from model.RIFE import Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Model() -model.load_model('./train_log') +model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log')) model.eval() model.device() diff --git a/benchmark/pytorch_msssim/__init__.py b/benchmark/pytorch_msssim/__init__.py index 118d265..a4d3032 100644 --- a/benchmark/pytorch_msssim/__init__.py +++ b/benchmark/pytorch_msssim/__init__.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from math import exp import numpy as np +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) @@ -11,7 +12,7 @@ def gaussian(window_size, sigma): def create_window(window_size, channel=1): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) - _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).cuda() + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device) window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() return window @@ -19,7 +20,7 @@ def create_window_3d(window_size, channel=1): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()) _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) - window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().cuda() + window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) return window diff --git a/benchmark/testtime.py b/benchmark/testtime.py new file mode 100644 index 0000000..c695d5f --- /dev/null +++ b/benchmark/testtime.py @@ -0,0 +1,29 @@ +import cv2 +import sys +sys.path.append('.') +import time +import torch +import torch.nn as nn +from model.RIFE import Model + +model = Model() +model.eval() +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_grad_enabled(False) +if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + +I0 = torch.rand(1, 3, 480, 640).to(device) +I1 = torch.rand(1, 3, 480, 640).to(device) +with torch.no_grad(): + for i in range(100): + pred = model.inference(I0, I1) + if torch.cuda.is_available(): + torch.cuda.synchronize() + time_stamp = time.time() + for i in range(100): + pred = model.inference(I0, I1) + if torch.cuda.is_available(): + torch.cuda.synchronize() + print((time.time() - time_stamp) / 100) diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..801dbb7 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,23 @@ +FROM python:3.8-slim + +# install deps +RUN apt-get update && apt-get -y install \ + bash ffmpeg + +# setup RIFE +WORKDIR /rife +COPY . . +RUN pip3 install -r requirements.txt + +ADD docker/inference_img /usr/local/bin/inference_img +RUN chmod +x /usr/local/bin/inference_img +ADD docker/inference_video /usr/local/bin/inference_video +RUN chmod +x /usr/local/bin/inference_video + +# add pre-trained models +COPY train_log /rife/train_log + +WORKDIR /host +ENTRYPOINT ["/bin/bash"] + +ENV NVIDIA_DRIVER_CAPABILITIES all \ No newline at end of file diff --git a/docker/inference_img b/docker/inference_img new file mode 100644 index 0000000..5557be4 --- /dev/null +++ b/docker/inference_img @@ -0,0 +1,2 @@ +#!/bin/sh +python3 /rife/inference_img.py $@ diff --git a/docker/inference_video b/docker/inference_video new file mode 100644 index 0000000..d718c5c --- /dev/null +++ b/docker/inference_video @@ -0,0 +1,2 @@ +#!/bin/sh +python3 /rife/inference_video.py $@ diff --git a/inference_img.py b/inference_img.py index 1e6ea25..633bd55 100644 --- a/inference_img.py +++ b/inference_img.py @@ -19,15 +19,22 @@ parser.add_argument('--exp', default=4, type=int) args = parser.parse_args() model = Model() -model.load_model('./train_log', -1) +model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'), -1) model.eval() model.device() - -img0 = cv2.imread(args.img[0]) -img1 = cv2.imread(args.img[1]) -img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) -img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) +if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0) + +else: + img0 = cv2.imread(args.img[0]) + img1 = cv2.imread(args.img[1]) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + n, c, h, w = img0.shape ph = ((h - 1) // 32 + 1) * 32 pw = ((w - 1) // 32 + 1) * 32 @@ -48,4 +55,7 @@ for i in range(args.exp): if not os.path.exists('output'): os.mkdir('output') for i in range(len(img_list)): - cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]) + if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + else: + cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]) diff --git a/inference_video.py b/inference_video.py index c75d387..5e221fd 100644 --- a/inference_video.py +++ b/inference_video.py @@ -9,9 +9,11 @@ import warnings import _thread import skvideo.io from queue import Queue, Empty +from benchmark.pytorch_msssim import ssim_matlab + warnings.filterwarnings("ignore") -def transferAudio(sourceVideo, targetVideo): +def transferAudio(sourceVideo, targetVideo): import shutil import moviepy.editor tempAudioFileName = "./temp/audio.mkv" @@ -27,25 +29,26 @@ def transferAudio(sourceVideo, targetVideo): os.makedirs("temp") # extract audio from video os.system("ffmpeg -y -i " + sourceVideo + " -c:a copy -vn " + tempAudioFileName) - - os.rename(targetVideo, "noAudio_"+targetVideo) - # combine audio file and new video file - os.system("ffmpeg -y -i " + "noAudio_"+targetVideo + " -i " + tempAudioFileName + " -c copy " + targetVideo) - if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to mp3 - tempAudioFileName = "./temp/audio.mp3" - os.system("ffmpeg -y -i " + sourceVideo + " -c:a mp3 -vn " + tempAudioFileName) - os.system("ffmpeg -y -i " + "noAudio_"+targetVideo + " -i " + tempAudioFileName + " -c copy " + targetVideo) - if (os.path.getsize(targetVideo) == 0): # if mp3 not supported by selected format - os.rename("noAudio_"+targetVideo, targetVideo) + targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1] + os.rename(targetVideo, targetNoAudio) + # combine audio file and new video file + os.system("ffmpeg -y -i " + targetNoAudio + " -i " + tempAudioFileName + " -c copy " + targetVideo) + + if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac + tempAudioFileName = "./temp/audio.m4a" + os.system("ffmpeg -y -i " + sourceVideo + " -c:a aac -b:a 160k -vn " + tempAudioFileName) + os.system("ffmpeg -y -i " + targetNoAudio + " -i " + tempAudioFileName + " -c copy " + targetVideo) + if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format + os.rename(targetNoAudio, targetVideo) print("Audio transfer failed. Interpolated video will have no audio") else: - print("Lossless audio transfer failed. Audio was transcoded to mp3 instead.") + print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.") # remove audio-less video - os.remove("noAudio_"+targetVideo) + os.remove(targetNoAudio) else: - os.remove("noAudio_"+targetVideo) + os.remove(targetNoAudio) # remove temp directory shutil.rmtree("temp") @@ -59,6 +62,7 @@ if torch.cuda.is_available(): parser = argparse.ArgumentParser(description='Interpolation for a pair of images') parser.add_argument('--video', dest='video', type=str, default=None) +parser.add_argument('--output', dest='output', type=str, default=None) parser.add_argument('--img', dest='img', type=str, default=None) parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video') parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video') @@ -74,7 +78,7 @@ if not args.img is None: from model.RIFE_HD import Model model = Model() -model.load_model('./train_log', -1) +model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'), -1) model.eval() model.device() @@ -105,14 +109,19 @@ else: tot_frame = len(videogen) videogen.sort(key= lambda x:int(x[:-4])) lastframe = cv2.imread(os.path.join(args.img, videogen[0]))[:, :, ::-1].copy() - videogen = videogen[1:] + videogen = videogen[1:] h, w, _ = lastframe.shape +vid_out_name = None vid_out = None if args.png: if not os.path.exists('vid_out'): os.mkdir('vid_out') else: - vid_out = cv2.VideoWriter('{}_{}X_{}fps.{}'.format(video_path_wo_ext, (2 ** args.exp), int(np.round(args.fps)), args.ext), fourcc, args.fps, (w, h)) + if args.output is not None: + vid_out_name = args.output + else: + vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, (2 ** args.exp), int(np.round(args.fps)), args.ext) + vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h)) def clear_write_buffer(user_args, write_buffer): cnt = 0 @@ -172,15 +181,16 @@ while True: I0 = I1 I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. I1 = F.pad(I1, padding) - diff = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False) - - F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs() - if diff.max() < 2e-3 and args.skip: + I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) + I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) + ssim = ssim_matlab(I0_small, I1_small) + if ssim > 0.995 and args.skip: if skip_frame % 100 == 0: print("Warning: Your video has {} static frames, skipping them may change the duration of the generated video.".format(skip_frame)) skip_frame += 1 pbar.update(1) continue - if diff.mean() > 0.15: + if ssim < 0.5: output = [] for i in range((2 ** args.exp) - 1): output.append(I0) @@ -211,9 +221,9 @@ if not vid_out is None: # move audio to new video file if appropriate if args.png == False and fpsNotAssigned == True and not args.skip and not args.video is None: - outputVideoFileName = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, 2 ** args.exp, int(np.round(args.fps)), args.ext) try: - transferAudio(args.video, outputVideoFileName) + transferAudio(args.video, vid_out_name) except: print("Audio transfer failed. Interpolated video will have no audio") - os.rename("noAudio_"+outputVideoFileName, outputVideoFileName) + targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1] + os.rename(targetNoAudio, vid_out_name) diff --git a/model/IFNet_HDv2.py b/model/IFNet_HDv2.py new file mode 100644 index 0000000..6ee5be5 --- /dev/null +++ b/model/IFNet_HDv2.py @@ -0,0 +1,92 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from model.warplayer import warp + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + ) + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + +class IFBlock(nn.Module): + def __init__(self, in_planes, scale=1, c=64): + super(IFBlock, self).__init__() + self.scale = scale + self.conv0 = conv(in_planes, c, 5, 2, 2) + self.convblock = nn.Sequential( + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + ) + self.conv1 = nn.Conv2d(c, 4, 3, 1, 1) + + def forward(self, x): + if self.scale != 1: + x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", + align_corners=False) + x = self.conv0(x) + x = self.convblock(x) + x = self.conv1(x) + flow = x + if self.scale != 1: + flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear", + align_corners=False) + return flow + + +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(6, scale=8, c=192) + self.block1 = IFBlock(10, scale=4, c=128) + self.block2 = IFBlock(10, scale=2, c=96) + self.block3 = IFBlock(10, scale=1, c=48) + + def forward(self, x, UHD=False): + if UHD: + x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False) + flow0 = self.block0(x) + F1 = flow0 + F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 + warped_img0 = warp(x[:, :3], F1_large[:, :2]) + warped_img1 = warp(x[:, 3:], F1_large[:, 2:4]) + flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1)) + F2 = (flow0 + flow1) + F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 + warped_img0 = warp(x[:, :3], F2_large[:, :2]) + warped_img1 = warp(x[:, 3:], F2_large[:, 2:4]) + flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1)) + F3 = (flow0 + flow1 + flow2) + F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 + warped_img0 = warp(x[:, :3], F3_large[:, :2]) + warped_img1 = warp(x[:, 3:], F3_large[:, 2:4]) + flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1)) + F4 = (flow0 + flow1 + flow2 + flow3) + return F4, [F1, F2, F3, F4] + +if __name__ == '__main__': + img0 = torch.zeros(3, 3, 256, 256).float().to(device) + img1 = torch.tensor(np.random.normal( + 0, 1, (3, 3, 256, 256))).float().to(device) + imgs = torch.cat((img0, img1), 1) + flownet = IFNet() + flow, _ = flownet(imgs) + print(flow.shape)