mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2026-02-24 04:19:41 +01:00
Add inference script
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,3 +1,5 @@
|
||||
*.pyc
|
||||
*.py~
|
||||
*.py#
|
||||
|
||||
*.pkl
|
||||
|
||||
24
inference.py
Normal file
24
inference.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import cv2
|
||||
import torch
|
||||
from model.RIFE import Model
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = Model()
|
||||
model.load_model('./train_log')
|
||||
model.eval()
|
||||
model.device()
|
||||
|
||||
img0 = cv2.imread('0.png')
|
||||
img1 = cv2.imread('1.png')
|
||||
|
||||
h, w, _ = img0.shape
|
||||
|
||||
img0 = torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.
|
||||
img1 = torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.
|
||||
|
||||
imgs = torch.cat((img0, img1), 0).float()
|
||||
|
||||
with torch.no_grad():
|
||||
res = model.inference(imgs.unsqueeze(0)) * 255
|
||||
cv2.imwrite('out.png', res[0].numpy().transpose(1, 2, 0))
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from warplayer import warp
|
||||
from model.warplayer import warp
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@@ -4,11 +4,11 @@ import numpy as np
|
||||
from torch.optim import AdamW
|
||||
import torch.optim as optim
|
||||
import itertools
|
||||
from warplayer import warp
|
||||
from model.warplayer import warp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from IFNet import *
|
||||
from model.IFNet import *
|
||||
import torch.nn.functional as F
|
||||
from loss import *
|
||||
from model.loss import *
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@@ -21,13 +21,6 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
)
|
||||
|
||||
|
||||
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
)
|
||||
|
||||
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
|
||||
@@ -35,6 +28,11 @@ def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
)
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride=2):
|
||||
@@ -61,10 +59,8 @@ class ResBlock(nn.Module):
|
||||
x = self.relu2(x * w + y)
|
||||
return x
|
||||
|
||||
|
||||
c = 16
|
||||
|
||||
|
||||
class ContextNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(ContextNet, self).__init__()
|
||||
@@ -163,13 +159,19 @@ class Model:
|
||||
self.fusionnet.to(device)
|
||||
|
||||
def load_model(self, path, rank=0):
|
||||
def convert(param):
|
||||
return {
|
||||
k.replace("module.", ""): v
|
||||
for k, v in param.items()
|
||||
if "module." in k
|
||||
}
|
||||
if rank == 0:
|
||||
self.flownet.load_state_dict(
|
||||
torch.load('{}/flownet.pkl'.format(path)))
|
||||
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
|
||||
self.contextnet.load_state_dict(
|
||||
torch.load('{}/contextnet.pkl'.format(path)))
|
||||
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device)))
|
||||
self.fusionnet.load_state_dict(
|
||||
torch.load('{}/unet.pkl'.format(path)))
|
||||
convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
|
||||
|
||||
def save_model(self, path, rank=0):
|
||||
if rank == 0:
|
||||
@@ -208,8 +210,6 @@ class Model:
|
||||
param_group['lr'] = learning_rate
|
||||
if training:
|
||||
self.train()
|
||||
# with torch.no_grad():
|
||||
# flow_gt = estimate(gt, img0)
|
||||
else:
|
||||
self.eval()
|
||||
flow, flow_list = self.flownet(imgs)
|
||||
|
||||
Reference in New Issue
Block a user