Added Docker setup

This commit is contained in:
ko1N
2021-01-17 15:24:57 +01:00
parent 9239bf4df4
commit ea994bbe34
7 changed files with 49 additions and 8 deletions

View File

@@ -80,6 +80,19 @@ You can also use pngs to generate gif:
ffmpeg -r 10 -f image2 -i output/img%d.png -s 448x256 -vf "split[s0][s1];[s0]palettegen=stats_mode=single[p];[s1][p]paletteuse=new=1" output/slomo.gif
```
### Run in docker
Place the pre-trained models in the `./docker/pretrained_models directory`
Building the container:
```
docker build -t rife -f docker/Dockerfile .
```
Running the container:
```
docker run --rm -it -v $PWD:/host rife:latest --exp=1 --video=untitled.mp4 --output=untitled_rife.mp4
```
## Evaluation
Download [RIFE model](https://drive.google.com/file/d/1c1R7iF-ypN6USo-D2YH_ORtaH3tukSlo/view?usp=sharing) or [RIFE2F1.5C model](https://drive.google.com/file/d/1ve9w-cRWotdvvbU1KcgtsSm12l-JUkeT/view?usp=sharing) reported by our paper.

View File

@@ -12,7 +12,7 @@ from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model()
model.load_model('./train_log')
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'))
model.eval()
model.device()

View File

@@ -13,7 +13,7 @@ from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model()
model.load_model('./train_log')
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'))
model.eval()
model.device()

21
docker/Dockerfile Normal file
View File

@@ -0,0 +1,21 @@
FROM python:3.8-slim
# install deps
RUN apt-get update && apt-get -y install \
bash ffmpeg
# setup RIFE
WORKDIR /rife
COPY . .
RUN pip3 install -r requirements.txt
ADD docker/rife.sh /usr/local/bin/rife
RUN chmod +x /usr/local/bin/rife
# add pre-trained models
COPY docker/pretrained_models /rife/train_log
WORKDIR /host
ENTRYPOINT ["rife"]
ENV NVIDIA_DRIVER_CAPABILITIES all

2
docker/rife.sh Normal file
View File

@@ -0,0 +1,2 @@
#!/bin/sh
python3 /rife/inference_video.py $@

View File

@@ -19,7 +19,7 @@ parser.add_argument('--exp', default=4, type=int)
args = parser.parse_args()
model = Model()
model.load_model('./train_log', -1)
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'), -1)
model.eval()
model.device()

View File

@@ -59,6 +59,7 @@ if torch.cuda.is_available():
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
parser.add_argument('--video', dest='video', type=str, default=None)
parser.add_argument('--output', dest='output', type=str, default=None)
parser.add_argument('--img', dest='img', type=str, default=None)
parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
@@ -74,7 +75,7 @@ if not args.img is None:
from model.RIFE_HD import Model
model = Model()
model.load_model('./train_log', -1)
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'), -1)
model.eval()
model.device()
@@ -107,12 +108,17 @@ else:
lastframe = cv2.imread(os.path.join(args.img, videogen[0]))[:, :, ::-1].copy()
videogen = videogen[1:]
h, w, _ = lastframe.shape
vid_out_name = None
vid_out = None
if args.png:
if not os.path.exists('vid_out'):
os.mkdir('vid_out')
else:
vid_out = cv2.VideoWriter('{}_{}X_{}fps.{}'.format(video_path_wo_ext, (2 ** args.exp), int(np.round(args.fps)), args.ext), fourcc, args.fps, (w, h))
if args.output is not None:
vid_out_name = args.output
else:
vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, (2 ** args.exp), int(np.round(args.fps)), args.ext)
vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h))
def clear_write_buffer(user_args, write_buffer):
cnt = 0
@@ -211,9 +217,8 @@ if not vid_out is None:
# move audio to new video file if appropriate
if args.png == False and fpsNotAssigned == True and not args.skip and not args.video is None:
outputVideoFileName = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, 2 ** args.exp, int(np.round(args.fps)), args.ext)
try:
transferAudio(args.video, outputVideoFileName)
transferAudio(args.video, vid_out_name)
except:
print("Audio transfer failed. Interpolated video will have no audio")
os.rename("noAudio_"+outputVideoFileName, outputVideoFileName)
os.rename("noAudio_"+vid_out_name, vid_out_name)