diff --git a/inference_mp4_2x.py b/inference_mp4_2x.py index 15442ec..e47d6ea 100644 --- a/inference_mp4_2x.py +++ b/inference_mp4_2x.py @@ -5,7 +5,6 @@ import argparse import numpy as np from tqdm import tqdm from torch.nn import functional as F -from model.RIFE import Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): @@ -15,9 +14,14 @@ if torch.cuda.is_available(): parser = argparse.ArgumentParser(description='Interpolation for a pair of images') parser.add_argument('--video', dest='video', required=True) +parser.add_argument('--model', dest='model', type=str, default='RIFE') parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video') args = parser.parse_args() +if args.model == '2F': + from model.RIFE2F import Model +else: + from model.RIFE import Model model = Model() model.load_model('./train_log') model.eval() diff --git a/inference_mp4_4x.py b/inference_mp4_4x.py index ea52f75..80d3b3a 100644 --- a/inference_mp4_4x.py +++ b/inference_mp4_4x.py @@ -5,7 +5,6 @@ import argparse import numpy as np from tqdm import tqdm from torch.nn import functional as F -from model.RIFE import Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): @@ -16,8 +15,13 @@ if torch.cuda.is_available(): parser = argparse.ArgumentParser(description='Interpolation for a pair of images') parser.add_argument('--video', dest='video', required=True) parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video') +parser.add_argument('--model', dest='model', type=str, default='RIFE') args = parser.parse_args() +if args.model == '2F': + from model.RIFE2F import Model +else: + from model.RIFE import Model model = Model() model.load_model('./train_log') model.eval() diff --git a/model/IFNet_Large.py b/model/IFNet2F.py similarity index 96% rename from model/IFNet_Large.py rename to model/IFNet2F.py index b9079e8..802999b 100644 --- a/model/IFNet_Large.py +++ b/model/IFNet2F.py @@ -86,9 +86,9 @@ class IFBlock(nn.Module): class IFNet(nn.Module): def __init__(self): super(IFNet, self).__init__() - self.block0 = IFBlock(6, scale=4, c=288) - self.block1 = IFBlock(8, scale=2, c=192) - self.block2 = IFBlock(8, scale=1, c=96) + self.block0 = IFBlock(6, scale=4, c=192) + self.block1 = IFBlock(8, scale=2, c=128) + self.block2 = IFBlock(8, scale=1, c=64) def forward(self, x): x = F.interpolate(x, scale_factor=0.5, mode="bilinear", diff --git a/model/RIFE_Large.py b/model/RIFE2F.py similarity index 99% rename from model/RIFE_Large.py rename to model/RIFE2F.py index 4e10a18..fca5661 100644 --- a/model/RIFE_Large.py +++ b/model/RIFE2F.py @@ -6,7 +6,7 @@ import torch.optim as optim import itertools from model.warplayer import warp from torch.nn.parallel import DistributedDataParallel as DDP -from model.IFNet_Large import * +from model.IFNet2F import * import torch.nn.functional as F from model.loss import * @@ -59,7 +59,7 @@ class ResBlock(nn.Module): x = self.relu2(x * w + y) return x -c = 24 +c = 16 class ContextNet(nn.Module): def __init__(self):