mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2026-02-24 12:29:43 +01:00
Add 2F model
This commit is contained in:
@@ -5,7 +5,6 @@ import argparse
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from torch.nn import functional as F
|
||||
from model.RIFE import Model
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if torch.cuda.is_available():
|
||||
@@ -15,9 +14,14 @@ if torch.cuda.is_available():
|
||||
|
||||
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
|
||||
parser.add_argument('--video', dest='video', required=True)
|
||||
parser.add_argument('--model', dest='model', type=str, default='RIFE')
|
||||
parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model == '2F':
|
||||
from model.RIFE2F import Model
|
||||
else:
|
||||
from model.RIFE import Model
|
||||
model = Model()
|
||||
model.load_model('./train_log')
|
||||
model.eval()
|
||||
|
||||
@@ -5,7 +5,6 @@ import argparse
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from torch.nn import functional as F
|
||||
from model.RIFE import Model
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if torch.cuda.is_available():
|
||||
@@ -16,8 +15,13 @@ if torch.cuda.is_available():
|
||||
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
|
||||
parser.add_argument('--video', dest='video', required=True)
|
||||
parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
|
||||
parser.add_argument('--model', dest='model', type=str, default='RIFE')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model == '2F':
|
||||
from model.RIFE2F import Model
|
||||
else:
|
||||
from model.RIFE import Model
|
||||
model = Model()
|
||||
model.load_model('./train_log')
|
||||
model.eval()
|
||||
|
||||
@@ -86,9 +86,9 @@ class IFBlock(nn.Module):
|
||||
class IFNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(IFNet, self).__init__()
|
||||
self.block0 = IFBlock(6, scale=4, c=288)
|
||||
self.block1 = IFBlock(8, scale=2, c=192)
|
||||
self.block2 = IFBlock(8, scale=1, c=96)
|
||||
self.block0 = IFBlock(6, scale=4, c=192)
|
||||
self.block1 = IFBlock(8, scale=2, c=128)
|
||||
self.block2 = IFBlock(8, scale=1, c=64)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
|
||||
@@ -6,7 +6,7 @@ import torch.optim as optim
|
||||
import itertools
|
||||
from model.warplayer import warp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from model.IFNet_Large import *
|
||||
from model.IFNet2F import *
|
||||
import torch.nn.functional as F
|
||||
from model.loss import *
|
||||
|
||||
@@ -59,7 +59,7 @@ class ResBlock(nn.Module):
|
||||
x = self.relu2(x * w + y)
|
||||
return x
|
||||
|
||||
c = 24
|
||||
c = 16
|
||||
|
||||
class ContextNet(nn.Module):
|
||||
def __init__(self):
|
||||
Reference in New Issue
Block a user