Add 2F model

This commit is contained in:
hzwer
2020-11-15 19:23:08 +08:00
parent 117f016192
commit 56d2bb3362
4 changed files with 15 additions and 7 deletions

View File

@@ -5,7 +5,6 @@ import argparse
import numpy as np
from tqdm import tqdm
from torch.nn import functional as F
from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
@@ -15,9 +14,14 @@ if torch.cuda.is_available():
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
parser.add_argument('--video', dest='video', required=True)
parser.add_argument('--model', dest='model', type=str, default='RIFE')
parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
args = parser.parse_args()
if args.model == '2F':
from model.RIFE2F import Model
else:
from model.RIFE import Model
model = Model()
model.load_model('./train_log')
model.eval()

View File

@@ -5,7 +5,6 @@ import argparse
import numpy as np
from tqdm import tqdm
from torch.nn import functional as F
from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
@@ -16,8 +15,13 @@ if torch.cuda.is_available():
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
parser.add_argument('--video', dest='video', required=True)
parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
parser.add_argument('--model', dest='model', type=str, default='RIFE')
args = parser.parse_args()
if args.model == '2F':
from model.RIFE2F import Model
else:
from model.RIFE import Model
model = Model()
model.load_model('./train_log')
model.eval()

View File

@@ -86,9 +86,9 @@ class IFBlock(nn.Module):
class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(6, scale=4, c=288)
self.block1 = IFBlock(8, scale=2, c=192)
self.block2 = IFBlock(8, scale=1, c=96)
self.block0 = IFBlock(6, scale=4, c=192)
self.block1 = IFBlock(8, scale=2, c=128)
self.block2 = IFBlock(8, scale=1, c=64)
def forward(self, x):
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",

View File

@@ -6,7 +6,7 @@ import torch.optim as optim
import itertools
from model.warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from model.IFNet_Large import *
from model.IFNet2F import *
import torch.nn.functional as F
from model.loss import *
@@ -59,7 +59,7 @@ class ResBlock(nn.Module):
x = self.relu2(x * w + y)
return x
c = 24
c = 16
class ContextNet(nn.Module):
def __init__(self):