From eb3923bfba38e4f6aaf621b8cd7b476278854d71 Mon Sep 17 00:00:00 2001 From: Andriy Toloshny Date: Mon, 11 Jan 2021 02:01:35 +0000 Subject: [PATCH 01/16] added openexr support --- inference_img.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/inference_img.py b/inference_img.py index 1e6ea25..76537b8 100644 --- a/inference_img.py +++ b/inference_img.py @@ -8,8 +8,8 @@ import warnings warnings.filterwarnings("ignore") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -torch.set_grad_enabled(False) if torch.cuda.is_available(): + torch.set_grad_enabled(False) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True @@ -22,12 +22,19 @@ model = Model() model.load_model('./train_log', -1) model.eval() model.device() - -img0 = cv2.imread(args.img[0]) -img1 = cv2.imread(args.img[1]) -img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) -img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) +if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0) + +else: + img0 = cv2.imread(args.img[0]) + img1 = cv2.imread(args.img[1]) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + n, c, h, w = img0.shape ph = ((h - 1) // 32 + 1) * 32 pw = ((w - 1) // 32 + 1) * 32 @@ -48,4 +55,7 @@ for i in range(args.exp): if not os.path.exists('output'): os.mkdir('output') for i in range(len(img_list)): - cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]) + if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + else: + cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]) From 399a675c161ee6ea10c96f9363886accd7118fa5 Mon Sep 17 00:00:00 2001 From: Andriy Toloshny Date: Mon, 11 Jan 2021 02:08:49 +0000 Subject: [PATCH 02/16] added openexr support --- inference_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference_img.py b/inference_img.py index 76537b8..24a9ab0 100644 --- a/inference_img.py +++ b/inference_img.py @@ -8,8 +8,8 @@ import warnings warnings.filterwarnings("ignore") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_grad_enabled(False) if torch.cuda.is_available(): - torch.set_grad_enabled(False) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True From 7948b22c05c217a576e980c9230152a0023ae425 Mon Sep 17 00:00:00 2001 From: N00MKRAD Date: Mon, 11 Jan 2021 12:16:24 +0100 Subject: [PATCH 03/16] Use AAC at 160k instead of MP3 at default bitrate for audio fallback --- inference_video.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/inference_video.py b/inference_video.py index c75d387..4976851 100644 --- a/inference_video.py +++ b/inference_video.py @@ -32,15 +32,15 @@ def transferAudio(sourceVideo, targetVideo): # combine audio file and new video file os.system("ffmpeg -y -i " + "noAudio_"+targetVideo + " -i " + tempAudioFileName + " -c copy " + targetVideo) - if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to mp3 - tempAudioFileName = "./temp/audio.mp3" - os.system("ffmpeg -y -i " + sourceVideo + " -c:a mp3 -vn " + tempAudioFileName) + if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac + tempAudioFileName = "./temp/audio.m4a" + os.system("ffmpeg -y -i " + sourceVideo + " -c:a aac -b:a 160k -vn " + tempAudioFileName) os.system("ffmpeg -y -i " + "noAudio_"+targetVideo + " -i " + tempAudioFileName + " -c copy " + targetVideo) - if (os.path.getsize(targetVideo) == 0): # if mp3 not supported by selected format + if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format os.rename("noAudio_"+targetVideo, targetVideo) print("Audio transfer failed. Interpolated video will have no audio") else: - print("Lossless audio transfer failed. Audio was transcoded to mp3 instead.") + print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.") # remove audio-less video os.remove("noAudio_"+targetVideo) From ea994bbe34941478644af6a70cc7ba0e9d5e45fa Mon Sep 17 00:00:00 2001 From: ko1N Date: Sun, 17 Jan 2021 15:24:57 +0100 Subject: [PATCH 04/16] Added Docker setup --- README.md | 13 +++++++++++++ benchmark/MiddleBury_Other.py | 2 +- benchmark/Vimeo90K.py | 2 +- docker/Dockerfile | 21 +++++++++++++++++++++ docker/rife.sh | 2 ++ inference_img.py | 2 +- inference_video.py | 15 ++++++++++----- 7 files changed, 49 insertions(+), 8 deletions(-) create mode 100644 docker/Dockerfile create mode 100644 docker/rife.sh diff --git a/README.md b/README.md index 7a3b820..26be3c6 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,19 @@ You can also use pngs to generate gif: ffmpeg -r 10 -f image2 -i output/img%d.png -s 448x256 -vf "split[s0][s1];[s0]palettegen=stats_mode=single[p];[s1][p]paletteuse=new=1" output/slomo.gif ``` +### Run in docker +Place the pre-trained models in the `./docker/pretrained_models directory` + +Building the container: +``` +docker build -t rife -f docker/Dockerfile . +``` + +Running the container: +``` +docker run --rm -it -v $PWD:/host rife:latest --exp=1 --video=untitled.mp4 --output=untitled_rife.mp4 +``` + ## Evaluation Download [RIFE model](https://drive.google.com/file/d/1c1R7iF-ypN6USo-D2YH_ORtaH3tukSlo/view?usp=sharing) or [RIFE2F1.5C model](https://drive.google.com/file/d/1ve9w-cRWotdvvbU1KcgtsSm12l-JUkeT/view?usp=sharing) reported by our paper. diff --git a/benchmark/MiddleBury_Other.py b/benchmark/MiddleBury_Other.py index 18b4e16..7c0f77f 100644 --- a/benchmark/MiddleBury_Other.py +++ b/benchmark/MiddleBury_Other.py @@ -12,7 +12,7 @@ from model.RIFE import Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Model() -model.load_model('./train_log') +model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log')) model.eval() model.device() diff --git a/benchmark/Vimeo90K.py b/benchmark/Vimeo90K.py index c984b4c..cd70e48 100644 --- a/benchmark/Vimeo90K.py +++ b/benchmark/Vimeo90K.py @@ -13,7 +13,7 @@ from model.RIFE import Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Model() -model.load_model('./train_log') +model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log')) model.eval() model.device() diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..84ad60c --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,21 @@ +FROM python:3.8-slim + +# install deps +RUN apt-get update && apt-get -y install \ + bash ffmpeg + +# setup RIFE +WORKDIR /rife +COPY . . +RUN pip3 install -r requirements.txt + +ADD docker/rife.sh /usr/local/bin/rife +RUN chmod +x /usr/local/bin/rife + +# add pre-trained models +COPY docker/pretrained_models /rife/train_log + +WORKDIR /host +ENTRYPOINT ["rife"] + +ENV NVIDIA_DRIVER_CAPABILITIES all \ No newline at end of file diff --git a/docker/rife.sh b/docker/rife.sh new file mode 100644 index 0000000..d718c5c --- /dev/null +++ b/docker/rife.sh @@ -0,0 +1,2 @@ +#!/bin/sh +python3 /rife/inference_video.py $@ diff --git a/inference_img.py b/inference_img.py index 24a9ab0..633bd55 100644 --- a/inference_img.py +++ b/inference_img.py @@ -19,7 +19,7 @@ parser.add_argument('--exp', default=4, type=int) args = parser.parse_args() model = Model() -model.load_model('./train_log', -1) +model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'), -1) model.eval() model.device() diff --git a/inference_video.py b/inference_video.py index 4976851..28fd756 100644 --- a/inference_video.py +++ b/inference_video.py @@ -59,6 +59,7 @@ if torch.cuda.is_available(): parser = argparse.ArgumentParser(description='Interpolation for a pair of images') parser.add_argument('--video', dest='video', type=str, default=None) +parser.add_argument('--output', dest='output', type=str, default=None) parser.add_argument('--img', dest='img', type=str, default=None) parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video') parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video') @@ -74,7 +75,7 @@ if not args.img is None: from model.RIFE_HD import Model model = Model() -model.load_model('./train_log', -1) +model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'), -1) model.eval() model.device() @@ -107,12 +108,17 @@ else: lastframe = cv2.imread(os.path.join(args.img, videogen[0]))[:, :, ::-1].copy() videogen = videogen[1:] h, w, _ = lastframe.shape +vid_out_name = None vid_out = None if args.png: if not os.path.exists('vid_out'): os.mkdir('vid_out') else: - vid_out = cv2.VideoWriter('{}_{}X_{}fps.{}'.format(video_path_wo_ext, (2 ** args.exp), int(np.round(args.fps)), args.ext), fourcc, args.fps, (w, h)) + if args.output is not None: + vid_out_name = args.output + else: + vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, (2 ** args.exp), int(np.round(args.fps)), args.ext) + vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h)) def clear_write_buffer(user_args, write_buffer): cnt = 0 @@ -211,9 +217,8 @@ if not vid_out is None: # move audio to new video file if appropriate if args.png == False and fpsNotAssigned == True and not args.skip and not args.video is None: - outputVideoFileName = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, 2 ** args.exp, int(np.round(args.fps)), args.ext) try: - transferAudio(args.video, outputVideoFileName) + transferAudio(args.video, vid_out_name) except: print("Audio transfer failed. Interpolated video will have no audio") - os.rename("noAudio_"+outputVideoFileName, outputVideoFileName) + os.rename("noAudio_"+vid_out_name, vid_out_name) From 0f980597a104cbd4907d8cda2ce5aa42c6af87ac Mon Sep 17 00:00:00 2001 From: ko1N Date: Sun, 17 Jan 2021 16:05:48 +0100 Subject: [PATCH 05/16] Added ability to run both inference scripts from docker --- README.md | 5 ++++- docker/Dockerfile | 8 +++++--- docker/inference_img | 2 ++ docker/{rife.sh => inference_video} | 0 4 files changed, 11 insertions(+), 4 deletions(-) create mode 100644 docker/inference_img rename docker/{rife.sh => inference_video} (100%) diff --git a/README.md b/README.md index 26be3c6..6e377f1 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,10 @@ docker build -t rife -f docker/Dockerfile . Running the container: ``` -docker run --rm -it -v $PWD:/host rife:latest --exp=1 --video=untitled.mp4 --output=untitled_rife.mp4 +docker run --rm -it -v $PWD:/host rife:latest inference_video --exp=1 --video=untitled.mp4 --output=untitled_rife.mp4 +``` +``` +docker run --rm -it -v $PWD:/host rife:latest inference_img --img img0.png img1.png --exp=4 ``` ## Evaluation diff --git a/docker/Dockerfile b/docker/Dockerfile index 84ad60c..eca0589 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -9,13 +9,15 @@ WORKDIR /rife COPY . . RUN pip3 install -r requirements.txt -ADD docker/rife.sh /usr/local/bin/rife -RUN chmod +x /usr/local/bin/rife +ADD docker/inference_img /usr/local/bin/inference_img +RUN chmod +x /usr/local/bin/inference_img +ADD docker/inference_video /usr/local/bin/inference_video +RUN chmod +x /usr/local/bin/inference_video # add pre-trained models COPY docker/pretrained_models /rife/train_log WORKDIR /host -ENTRYPOINT ["rife"] +ENTRYPOINT ["/bin/bash"] ENV NVIDIA_DRIVER_CAPABILITIES all \ No newline at end of file diff --git a/docker/inference_img b/docker/inference_img new file mode 100644 index 0000000..5557be4 --- /dev/null +++ b/docker/inference_img @@ -0,0 +1,2 @@ +#!/bin/sh +python3 /rife/inference_img.py $@ diff --git a/docker/rife.sh b/docker/inference_video similarity index 100% rename from docker/rife.sh rename to docker/inference_video From f455a0573656fcf95da54a538341f1672c556126 Mon Sep 17 00:00:00 2001 From: ko1N Date: Sun, 17 Jan 2021 16:07:54 +0100 Subject: [PATCH 06/16] Dockerfile now uses the same train_log folder as the scripts --- README.md | 2 +- docker/Dockerfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6e377f1..65e008e 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ ffmpeg -r 10 -f image2 -i output/img%d.png -s 448x256 -vf "split[s0][s1];[s0]pal ``` ### Run in docker -Place the pre-trained models in the `./docker/pretrained_models directory` +Place the pre-trained models in `train_log/\*.pkl` (as above) Building the container: ``` diff --git a/docker/Dockerfile b/docker/Dockerfile index eca0589..801dbb7 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -15,7 +15,7 @@ ADD docker/inference_video /usr/local/bin/inference_video RUN chmod +x /usr/local/bin/inference_video # add pre-trained models -COPY docker/pretrained_models /rife/train_log +COPY train_log /rife/train_log WORKDIR /host ENTRYPOINT ["/bin/bash"] From dd21a4bc9f1c14d1d5a676042fd7184dae2f0a0b Mon Sep 17 00:00:00 2001 From: ko1N Date: Sun, 17 Jan 2021 16:11:24 +0100 Subject: [PATCH 07/16] Updated docker gpu instructions --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 65e008e..aaf2847 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,11 @@ docker run --rm -it -v $PWD:/host rife:latest inference_video --exp=1 --video=un docker run --rm -it -v $PWD:/host rife:latest inference_img --img img0.png img1.png --exp=4 ``` +Using gpu acceleration (requires proper gpu drivers for docker): +``` +docker run --rm -it --gpus all -v /dev/dri:/dev/dri -v $PWD:/host rife:latest inference_video --exp=1 --video=untitled.mp4 --output=untitled_rife.mp4 +``` + ## Evaluation Download [RIFE model](https://drive.google.com/file/d/1c1R7iF-ypN6USo-D2YH_ORtaH3tukSlo/view?usp=sharing) or [RIFE2F1.5C model](https://drive.google.com/file/d/1ve9w-cRWotdvvbU1KcgtsSm12l-JUkeT/view?usp=sharing) reported by our paper. From 166d10e9a723584a15f27b6b88a5f2f24f1a50d5 Mon Sep 17 00:00:00 2001 From: ko1N Date: Sun, 17 Jan 2021 17:34:48 +0100 Subject: [PATCH 08/16] Fixed audio merge when using abs paths --- inference_video.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/inference_video.py b/inference_video.py index 28fd756..7b5d56c 100644 --- a/inference_video.py +++ b/inference_video.py @@ -27,25 +27,26 @@ def transferAudio(sourceVideo, targetVideo): os.makedirs("temp") # extract audio from video os.system("ffmpeg -y -i " + sourceVideo + " -c:a copy -vn " + tempAudioFileName) - - os.rename(targetVideo, "noAudio_"+targetVideo) + + targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1] + os.rename(targetVideo, targetNoAudio) # combine audio file and new video file - os.system("ffmpeg -y -i " + "noAudio_"+targetVideo + " -i " + tempAudioFileName + " -c copy " + targetVideo) + os.system("ffmpeg -y -i " + targetNoAudio + " -i " + tempAudioFileName + " -c copy " + targetVideo) if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac tempAudioFileName = "./temp/audio.m4a" os.system("ffmpeg -y -i " + sourceVideo + " -c:a aac -b:a 160k -vn " + tempAudioFileName) - os.system("ffmpeg -y -i " + "noAudio_"+targetVideo + " -i " + tempAudioFileName + " -c copy " + targetVideo) + os.system("ffmpeg -y -i " + targetNoAudio + " -i " + tempAudioFileName + " -c copy " + targetVideo) if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format - os.rename("noAudio_"+targetVideo, targetVideo) + os.rename(targetNoAudio, targetVideo) print("Audio transfer failed. Interpolated video will have no audio") else: print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.") # remove audio-less video - os.remove("noAudio_"+targetVideo) + os.remove(targetNoAudio) else: - os.remove("noAudio_"+targetVideo) + os.remove(targetNoAudio) # remove temp directory shutil.rmtree("temp") @@ -221,4 +222,5 @@ if args.png == False and fpsNotAssigned == True and not args.skip and not args.v transferAudio(args.video, vid_out_name) except: print("Audio transfer failed. Interpolated video will have no audio") - os.rename("noAudio_"+vid_out_name, vid_out_name) + targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1] + os.rename(targetNoAudio, vid_out_name) From 1dc2dba7d1b564ac6ca5cc8f3a0c37b24b57441d Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Mon, 18 Jan 2021 17:40:44 +0800 Subject: [PATCH 09/16] Use ssim to skip --- inference_video.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/inference_video.py b/inference_video.py index 7b5d56c..5e221fd 100644 --- a/inference_video.py +++ b/inference_video.py @@ -9,9 +9,11 @@ import warnings import _thread import skvideo.io from queue import Queue, Empty +from benchmark.pytorch_msssim import ssim_matlab + warnings.filterwarnings("ignore") -def transferAudio(sourceVideo, targetVideo): +def transferAudio(sourceVideo, targetVideo): import shutil import moviepy.editor tempAudioFileName = "./temp/audio.mkv" @@ -107,7 +109,7 @@ else: tot_frame = len(videogen) videogen.sort(key= lambda x:int(x[:-4])) lastframe = cv2.imread(os.path.join(args.img, videogen[0]))[:, :, ::-1].copy() - videogen = videogen[1:] + videogen = videogen[1:] h, w, _ = lastframe.shape vid_out_name = None vid_out = None @@ -179,15 +181,16 @@ while True: I0 = I1 I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. I1 = F.pad(I1, padding) - diff = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False) - - F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs() - if diff.max() < 2e-3 and args.skip: + I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) + I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) + ssim = ssim_matlab(I0_small, I1_small) + if ssim > 0.995 and args.skip: if skip_frame % 100 == 0: print("Warning: Your video has {} static frames, skipping them may change the duration of the generated video.".format(skip_frame)) skip_frame += 1 pbar.update(1) continue - if diff.mean() > 0.15: + if ssim < 0.5: output = [] for i in range((2 ** args.exp) - 1): output.append(I0) From e515b1d3647b3839c45d8b2317596e0aa5eea29f Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Mon, 18 Jan 2021 17:41:27 +0800 Subject: [PATCH 10/16] Support cpu ssim --- benchmark/pytorch_msssim/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmark/pytorch_msssim/__init__.py b/benchmark/pytorch_msssim/__init__.py index 118d265..a4d3032 100644 --- a/benchmark/pytorch_msssim/__init__.py +++ b/benchmark/pytorch_msssim/__init__.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from math import exp import numpy as np +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) @@ -11,7 +12,7 @@ def gaussian(window_size, sigma): def create_window(window_size, channel=1): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) - _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).cuda() + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device) window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() return window @@ -19,7 +20,7 @@ def create_window_3d(window_size, channel=1): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()) _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) - window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().cuda() + window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) return window From 1ad0890ebc1b07a3ff4219cb7c52c5f3df14d7cc Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Wed, 20 Jan 2021 16:54:47 +0800 Subject: [PATCH 11/16] Update README.md --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index aaf2847..32d9ef8 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # RIFE Video Frame Interpolation v1.8 +**Our paper has not been officially published yet, and our method and experimental results are under improvement.** ## [arXiv](https://arxiv.org/abs/2011.06294) | [Project Page](https://rife-vfi.github.io) | [Reddit](https://www.reddit.com/r/linux/comments/jy4jjl/opensourced_realtime_video_frame_interpolation/) | [YouTube](https://www.youtube.com/watch?v=60DX2T3zyVo&feature=youtu.be) | [Bilibili](https://www.bilibili.com/video/BV1K541157te?from=search&seid=5131698847373645765) **1.4 News: We have updated the v1.8 model optimized for 2D animation.** @@ -33,6 +34,8 @@ We are optimizing the visual effects and will support animation in the future. ( * Unzip and move the pretrained parameters to train_log/\*.pkl +**This model is designed to provide better visual effects for users and should not be used for benchmarking.** + ### Run **Video Frame Interpolation** From ae8bdb3a8151fa01903f52f7fb3529b2d42ff753 Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Wed, 20 Jan 2021 17:27:06 +0800 Subject: [PATCH 12/16] Add time testing script --- benchmark/testtime.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 benchmark/testtime.py diff --git a/benchmark/testtime.py b/benchmark/testtime.py new file mode 100644 index 0000000..c695d5f --- /dev/null +++ b/benchmark/testtime.py @@ -0,0 +1,29 @@ +import cv2 +import sys +sys.path.append('.') +import time +import torch +import torch.nn as nn +from model.RIFE import Model + +model = Model() +model.eval() +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_grad_enabled(False) +if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + +I0 = torch.rand(1, 3, 480, 640).to(device) +I1 = torch.rand(1, 3, 480, 640).to(device) +with torch.no_grad(): + for i in range(100): + pred = model.inference(I0, I1) + if torch.cuda.is_available(): + torch.cuda.synchronize() + time_stamp = time.time() + for i in range(100): + pred = model.inference(I0, I1) + if torch.cuda.is_available(): + torch.cuda.synchronize() + print((time.time() - time_stamp) / 100) From 83d60cbb3feeaee056fe2e73c365174389b6aad3 Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Mon, 25 Jan 2021 16:53:21 +0800 Subject: [PATCH 13/16] Update README.md --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 32d9ef8..061bd77 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # RIFE Video Frame Interpolation v1.8 -**Our paper has not been officially published yet, and our method and experimental results are under improvement.** +**Our paper has not been officially published yet, and our method and experimental results are under improvement. Due to the incorrect data reference, the latency measurement of Sepconv and TOFlow in our arxiv paper needs to be modified.** ## [arXiv](https://arxiv.org/abs/2011.06294) | [Project Page](https://rife-vfi.github.io) | [Reddit](https://www.reddit.com/r/linux/comments/jy4jjl/opensourced_realtime_video_frame_interpolation/) | [YouTube](https://www.youtube.com/watch?v=60DX2T3zyVo&feature=youtu.be) | [Bilibili](https://www.bilibili.com/video/BV1K541157te?from=search&seid=5131698847373645765) **1.4 News: We have updated the v1.8 model optimized for 2D animation.** @@ -143,7 +143,6 @@ python3 -m torch.distributed.launch --nproc_per_node=4 train.py --world_size=4 ``` ## Reference -img Optical Flow: [ARFlow](https://github.com/lliuz/ARFlow) [pytorch-liteflownet](https://github.com/sniklaus/pytorch-liteflownet) [RAFT](https://github.com/princeton-vl/RAFT) [pytorch-PWCNet](https://github.com/sniklaus/pytorch-pwc) From 2a1eafe27d5ff12eb31df96e47352fe30c18ac46 Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Mon, 25 Jan 2021 16:56:43 +0800 Subject: [PATCH 14/16] Update README.md --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 061bd77..9043dd0 100644 --- a/README.md +++ b/README.md @@ -2,15 +2,15 @@ **Our paper has not been officially published yet, and our method and experimental results are under improvement. Due to the incorrect data reference, the latency measurement of Sepconv and TOFlow in our arxiv paper needs to be modified.** ## [arXiv](https://arxiv.org/abs/2011.06294) | [Project Page](https://rife-vfi.github.io) | [Reddit](https://www.reddit.com/r/linux/comments/jy4jjl/opensourced_realtime_video_frame_interpolation/) | [YouTube](https://www.youtube.com/watch?v=60DX2T3zyVo&feature=youtu.be) | [Bilibili](https://www.bilibili.com/video/BV1K541157te?from=search&seid=5131698847373645765) -**1.4 News: We have updated the v1.8 model optimized for 2D animation.** +1.4 News: We have updated the v1.8 model optimized for 2D animation. -**12.13 News: We have updated the v1.6 model and support UHD mode. Please check our [update log](https://github.com/hzwer/arXiv2020-RIFE/issues/41#issuecomment-737651979).** +12.13 News: We have updated the v1.6 model and support UHD mode. Please check our [update log](https://github.com/hzwer/arXiv2020-RIFE/issues/41#issuecomment-737651979). -**11.22 News: We notice a new windows app is trying to integrate RIFE, we hope everyone to try and help them improve. You can download [Flowframes](https://nmkd.itch.io/flowframes) for free.** +11.22 News: We notice a new windows app is trying to integrate RIFE, we hope everyone to try and help them improve. You can download [Flowframes](https://nmkd.itch.io/flowframes) for free. -**There is [a tutorial of RIFE](https://www.youtube.com/watch?v=gf_on-dbwyU&feature=emb_title) on Youtube.** +There is [a tutorial of RIFE](https://www.youtube.com/watch?v=gf_on-dbwyU&feature=emb_title) on Youtube. -**You can easily use [colaboratory](https://colab.research.google.com/github/hzwer/arXiv2020-RIFE/blob/main/Colab_demo.ipynb) to have a try and generate the [our youtube demo](https://www.youtube.com/watch?v=LE2Dzl0oMHI).** +You can easily use [colaboratory](https://colab.research.google.com/github/hzwer/arXiv2020-RIFE/blob/main/Colab_demo.ipynb) to have a try and generate the [our youtube demo](https://www.youtube.com/watch?v=LE2Dzl0oMHI). Our model can run 30+FPS for 2X 720p interpolation on a 2080Ti GPU. Currently, our method supports 2X,4X,8X... interpolation for 1080p video, and multi-frame interpolation between a pair of images. Everyone is welcome to use our alpha version and make suggestions! From bf790c2ffda92f598ee65ad5022129473a9f129f Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Mon, 1 Feb 2021 11:13:05 +0800 Subject: [PATCH 15/16] Add HDv2 --- RIFE_HDv2.py | 247 ++++++++++++++++++++++++++++++++++++++++++++ model/IFNet_HDv2.py | 92 +++++++++++++++++ 2 files changed, 339 insertions(+) create mode 100644 RIFE_HDv2.py create mode 100644 model/IFNet_HDv2.py diff --git a/RIFE_HDv2.py b/RIFE_HDv2.py new file mode 100644 index 0000000..9f19ae2 --- /dev/null +++ b/RIFE_HDv2.py @@ -0,0 +1,247 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.optim import AdamW +import torch.optim as optim +import itertools +from model.warplayer import warp +from torch.nn.parallel import DistributedDataParallel as DDP +from model.IFNet_HDv2 import * +import torch.nn.functional as F +from model.loss import * + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=4, stride=2, padding=1, bias=True), + nn.PReLU(out_planes) + ) + +def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + ) + +class Conv2(nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + +c = 32 + +class ContextNet(nn.Module): + def __init__(self): + super(ContextNet, self).__init__() + self.conv0 = Conv2(3, c) + self.conv1 = Conv2(c, c) + self.conv2 = Conv2(c, 2*c) + self.conv3 = Conv2(2*c, 4*c) + self.conv4 = Conv2(4*c, 8*c) + + def forward(self, x, flow): + x = self.conv0(x) + x = self.conv1(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + f1 = warp(x, flow) + x = self.conv2(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", + align_corners=False) * 0.5 + f2 = warp(x, flow) + x = self.conv3(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", + align_corners=False) * 0.5 + f3 = warp(x, flow) + x = self.conv4(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", + align_corners=False) * 0.5 + f4 = warp(x, flow) + return [f1, f2, f3, f4] + + +class FusionNet(nn.Module): + def __init__(self): + super(FusionNet, self).__init__() + self.conv0 = Conv2(10, c) + self.down0 = Conv2(c, 2*c) + self.down1 = Conv2(4*c, 4*c) + self.down2 = Conv2(8*c, 8*c) + self.down3 = Conv2(16*c, 16*c) + self.up0 = deconv(32*c, 8*c) + self.up1 = deconv(16*c, 4*c) + self.up2 = deconv(8*c, 2*c) + self.up3 = deconv(4*c, c) + self.conv = nn.ConvTranspose2d(c, 4, 4, 2, 1) + + def forward(self, img0, img1, flow, c0, c1, flow_gt): + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + if flow_gt == None: + warped_img0_gt, warped_img1_gt = None, None + else: + warped_img0_gt = warp(img0, flow_gt[:, :2]) + warped_img1_gt = warp(img1, flow_gt[:, 2:4]) + x = self.conv0(torch.cat((warped_img0, warped_img1, flow), 1)) + s0 = self.down0(x) + s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) + s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) + s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) + x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) + x = self.up1(torch.cat((x, s2), 1)) + x = self.up2(torch.cat((x, s1), 1)) + x = self.up3(torch.cat((x, s0), 1)) + x = self.conv(x) + return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt + + +class Model: + def __init__(self, local_rank=-1): + self.flownet = IFNet() + self.contextnet = ContextNet() + self.fusionnet = FusionNet() + self.device() + self.optimG = AdamW(itertools.chain( + self.flownet.parameters(), + self.contextnet.parameters(), + self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) + self.schedulerG = optim.lr_scheduler.CyclicLR( + self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) + self.epe = EPE() + self.ter = Ternary() + self.sobel = SOBEL() + if local_rank != -1: + self.flownet = DDP(self.flownet, device_ids=[ + local_rank], output_device=local_rank) + self.contextnet = DDP(self.contextnet, device_ids=[ + local_rank], output_device=local_rank) + self.fusionnet = DDP(self.fusionnet, device_ids=[ + local_rank], output_device=local_rank) + + def train(self): + self.flownet.train() + self.contextnet.train() + self.fusionnet.train() + + def eval(self): + self.flownet.eval() + self.contextnet.eval() + self.fusionnet.eval() + + def device(self): + self.flownet.to(device) + self.contextnet.to(device) + self.fusionnet.to(device) + + def load_model(self, path, rank): + def convert(param): + if rank == -1: + return { + k.replace("module.", ""): v + for k, v in param.items() + if "module." in k + } + else: + return param + if rank <= 0: + self.flownet.load_state_dict( + convert(torch.load('{}/flownet.pkl'.format(path), map_location=device))) + self.contextnet.load_state_dict( + convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device))) + self.fusionnet.load_state_dict( + convert(torch.load('{}/unet.pkl'.format(path), map_location=device))) + + def save_model(self, path, rank): + if rank == 0: + torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path)) + torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path)) + torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) + + def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False): + img0 = imgs[:, :3] + img1 = imgs[:, 3:] + if UHD: + flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", + align_corners=False) * 2.0 + refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet( + img0, img1, flow, c0, c1, flow_gt) + res = torch.sigmoid(refine_output[:, :3]) * 2 - 1 + mask = torch.sigmoid(refine_output[:, 3:4]) + merged_img = warped_img0 * mask + warped_img1 * (1 - mask) + pred = merged_img + res + pred = torch.clamp(pred, 0, 1) + if training: + return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt + else: + return pred + + def inference(self, img0, img1, UHD=False): + imgs = torch.cat((img0, img1), 1) + flow, _ = self.flownet(imgs, UHD) + return self.predict(imgs, flow, training=False, UHD=UHD) + + def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): + for param_group in self.optimG.param_groups: + param_group['lr'] = learning_rate + if training: + self.train() + else: + self.eval() + flow, flow_list = self.flownet(imgs) + pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict( + imgs, flow, flow_gt=flow_gt) + loss_ter = self.ter(pred, gt).mean() + if training: + with torch.no_grad(): + loss_flow = torch.abs(warped_img0_gt - gt).mean() + loss_mask = torch.abs( + merged_img - gt).sum(1, True).float().detach() + loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear", + align_corners=False).detach() + flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear", + align_corners=False) * 0.5).detach() + loss_cons = 0 + for i in range(4): + loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1) + loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1) + loss_cons = loss_cons.mean() * 0.01 + else: + loss_cons = torch.tensor([0]) + loss_flow = torch.abs(warped_img0 - gt).mean() + loss_mask = 1 + loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean() + if training: + self.optimG.zero_grad() + loss_G = loss_l1 + loss_cons + loss_ter + loss_G.backward() + self.optimG.step() + return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask + + +if __name__ == '__main__': + img0 = torch.zeros(3, 3, 256, 256).float().to(device) + img1 = torch.tensor(np.random.normal( + 0, 1, (3, 3, 256, 256))).float().to(device) + imgs = torch.cat((img0, img1), 1) + model = Model() + model.eval() + print(model.inference(imgs).shape) diff --git a/model/IFNet_HDv2.py b/model/IFNet_HDv2.py new file mode 100644 index 0000000..6ed293a --- /dev/null +++ b/model/IFNet_HDv2.py @@ -0,0 +1,92 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from model.warplayer import warp + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + ) + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + +class IFBlock(nn.Module): + def __init__(self, in_planes, scale=1, c=64): + super(IFBlock, self).__init__() + self.scale = scale + self.conv0 = conv(in_planes, c, 5, 2, 2) + self.convblock = nn.Sequential( + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + ) + self.conv1 = nn.Conv2d(c, 4, 3, 1, 1) + + def forward(self, x): + if self.scale != 1: + x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", + align_corners=False) + x = self.conv0(x) + x = self.convblock(x) + x = self.conv1(x) + flow = x + if self.scale != 1: + flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear", + align_corners=False) * self.scale + return flow + + +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(6, scale=8, c=192) + self.block1 = IFBlock(10, scale=4, c=128) + self.block2 = IFBlock(10, scale=2, c=96) + self.block3 = IFBlock(10, scale=1, c=48) + + def forward(self, x, UHD=False): + if UHD: + x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False) + flow0 = self.block0(x) + F1 = flow0 + F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 + warped_img0 = warp(x[:, :3], F1_large[:, :2]) + warped_img1 = warp(x[:, 3:], F1_large[:, 2:4]) + flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1)) + F2 = (flow0 + flow1) + F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 + warped_img0 = warp(x[:, :3], F2_large[:, :2]) + warped_img1 = warp(x[:, 3:], F2_large[:, 2:4]) + flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1)) + F3 = (flow0 + flow1 + flow2) + F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 + warped_img0 = warp(x[:, :3], F3_large[:, :2]) + warped_img1 = warp(x[:, 3:], F3_large[:, 2:4]) + flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1)) + F4 = (flow0 + flow1 + flow2 + flow3) + return F4, [F1, F2, F3, F4] + +if __name__ == '__main__': + img0 = torch.zeros(3, 3, 256, 256).float().to(device) + img1 = torch.tensor(np.random.normal( + 0, 1, (3, 3, 256, 256))).float().to(device) + imgs = torch.cat((img0, img1), 1) + flownet = IFNet() + flow, _ = flownet(imgs) + print(flow.shape) From 0d33876da2906eea8ea57454a60a3c4ecdc079a6 Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Tue, 2 Feb 2021 18:09:52 +0800 Subject: [PATCH 16/16] Update IFNet_HDv2.py --- model/IFNet_HDv2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/IFNet_HDv2.py b/model/IFNet_HDv2.py index 6ed293a..6ee5be5 100644 --- a/model/IFNet_HDv2.py +++ b/model/IFNet_HDv2.py @@ -48,7 +48,7 @@ class IFBlock(nn.Module): flow = x if self.scale != 1: flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear", - align_corners=False) * self.scale + align_corners=False) return flow