From 50fe036ed554c86ff7f570fc3b842f683920f482 Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Thu, 12 Nov 2020 21:32:21 +0800 Subject: [PATCH] Add inference script --- .gitignore | 2 ++ inference.py | 24 ++++++++++++++++++++++++ model/IFNet.py | 2 +- model/RIFE.py | 34 +++++++++++++++++----------------- 4 files changed, 44 insertions(+), 18 deletions(-) create mode 100644 inference.py diff --git a/.gitignore b/.gitignore index 0f7b2ca..39362d6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ *.pyc *.py~ *.py# + +*.pkl diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..0dc7d5d --- /dev/null +++ b/inference.py @@ -0,0 +1,24 @@ +import cv2 +import torch +from model.RIFE import Model + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +model = Model() +model.load_model('./train_log') +model.eval() +model.device() + +img0 = cv2.imread('0.png') +img1 = cv2.imread('1.png') + +h, w, _ = img0.shape + +img0 = torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255. +img1 = torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255. + +imgs = torch.cat((img0, img1), 0).float() + +with torch.no_grad(): + res = model.inference(imgs.unsqueeze(0)) * 255 +cv2.imwrite('out.png', res[0].numpy().transpose(1, 2, 0)) diff --git a/model/IFNet.py b/model/IFNet.py index 852b69c..ec74aeb 100644 --- a/model/IFNet.py +++ b/model/IFNet.py @@ -2,7 +2,7 @@ import torch import numpy as np import torch.nn as nn import torch.nn.functional as F -from warplayer import warp +from model.warplayer import warp device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/model/RIFE.py b/model/RIFE.py index 3234130..45021ed 100644 --- a/model/RIFE.py +++ b/model/RIFE.py @@ -4,11 +4,11 @@ import numpy as np from torch.optim import AdamW import torch.optim as optim import itertools -from warplayer import warp +from model.warplayer import warp from torch.nn.parallel import DistributedDataParallel as DDP -from IFNet import * +from model.IFNet import * import torch.nn.functional as F -from loss import * +from model.loss import * device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -21,13 +21,6 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): ) -def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): - return nn.Sequential( - nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation, bias=True), - ) - - def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): return nn.Sequential( torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, @@ -35,6 +28,11 @@ def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): nn.PReLU(out_planes) ) +def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + ) class ResBlock(nn.Module): def __init__(self, in_planes, out_planes, stride=2): @@ -61,10 +59,8 @@ class ResBlock(nn.Module): x = self.relu2(x * w + y) return x - c = 16 - class ContextNet(nn.Module): def __init__(self): super(ContextNet, self).__init__() @@ -163,13 +159,19 @@ class Model: self.fusionnet.to(device) def load_model(self, path, rank=0): + def convert(param): + return { + k.replace("module.", ""): v + for k, v in param.items() + if "module." in k + } if rank == 0: self.flownet.load_state_dict( - torch.load('{}/flownet.pkl'.format(path))) + convert(torch.load('{}/flownet.pkl'.format(path), map_location=device))) self.contextnet.load_state_dict( - torch.load('{}/contextnet.pkl'.format(path))) + convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device))) self.fusionnet.load_state_dict( - torch.load('{}/unet.pkl'.format(path))) + convert(torch.load('{}/unet.pkl'.format(path), map_location=device))) def save_model(self, path, rank=0): if rank == 0: @@ -208,8 +210,6 @@ class Model: param_group['lr'] = learning_rate if training: self.train() -# with torch.no_grad(): -# flow_gt = estimate(gt, img0) else: self.eval() flow, flow_list = self.flownet(imgs)