Add inference script

This commit is contained in:
hzwer
2020-11-12 21:32:21 +08:00
parent 6d88c4e154
commit 50fe036ed5
4 changed files with 44 additions and 18 deletions

2
.gitignore vendored
View File

@@ -1,3 +1,5 @@
*.pyc
*.py~
*.py#
*.pkl

24
inference.py Normal file
View File

@@ -0,0 +1,24 @@
import cv2
import torch
from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model()
model.load_model('./train_log')
model.eval()
model.device()
img0 = cv2.imread('0.png')
img1 = cv2.imread('1.png')
h, w, _ = img0.shape
img0 = torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.
img1 = torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.
imgs = torch.cat((img0, img1), 0).float()
with torch.no_grad():
res = model.inference(imgs.unsqueeze(0)) * 255
cv2.imwrite('out.png', res[0].numpy().transpose(1, 2, 0))

View File

@@ -2,7 +2,7 @@ import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from warplayer import warp
from model.warplayer import warp
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

View File

@@ -4,11 +4,11 @@ import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
from warplayer import warp
from model.warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from IFNet import *
from model.IFNet import *
import torch.nn.functional as F
from loss import *
from model.loss import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -21,13 +21,6 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
)
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
@@ -35,6 +28,11 @@ def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
nn.PReLU(out_planes)
)
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
class ResBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
@@ -61,10 +59,8 @@ class ResBlock(nn.Module):
x = self.relu2(x * w + y)
return x
c = 16
class ContextNet(nn.Module):
def __init__(self):
super(ContextNet, self).__init__()
@@ -163,13 +159,19 @@ class Model:
self.fusionnet.to(device)
def load_model(self, path, rank=0):
def convert(param):
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
if rank == 0:
self.flownet.load_state_dict(
torch.load('{}/flownet.pkl'.format(path)))
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
self.contextnet.load_state_dict(
torch.load('{}/contextnet.pkl'.format(path)))
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device)))
self.fusionnet.load_state_dict(
torch.load('{}/unet.pkl'.format(path)))
convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
def save_model(self, path, rank=0):
if rank == 0:
@@ -208,8 +210,6 @@ class Model:
param_group['lr'] = learning_rate
if training:
self.train()
# with torch.no_grad():
# flow_gt = estimate(gt, img0)
else:
self.eval()
flow, flow_list = self.flownet(imgs)