Files
ECCV2022-RIFE/model/RIFE.py
hzwer 1c28fbde2d Merge remote-tracking branch 'origin/main' into main
# Conflicts:
#	model/RIFE.py
2021-08-13 16:34:55 +08:00

100 lines
3.1 KiB
Python

import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
from model.warplayer import warp
from torchstat import stat
from torch.nn.parallel import DistributedDataParallel as DDP
from model.IFNet import *
import torch.nn.functional as F
from model.loss import *
from model.laplacian import *
from model.refine import *
device = torch.device("cuda")
class Model:
def __init__(self, local_rank=-1):
self.flownet = IFNet()
self.device()
self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
self.epe = EPE()
self.lap = LapLoss()
self.sobel = SOBEL()
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
def train(self):
self.flownet.train()
def eval(self):
self.flownet.eval()
def device(self):
self.flownet.to(device)
def load_model(self, path, rank=0):
if rank == 0:
self.flownet.load_state_dict(torch.load('{}/flownet.pkl'.format(path)))
def save_model(self, path, rank=0):
if rank == 0:
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
'''
def predict(self, imgs, flow, merged, training=True, flow_gt=None):
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
refine_output = self.unet(img0, img1, flow, merged, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
pred = merged + res
pred = torch.clamp(pred, 0, 1)
if training:
return pred, merged
else:
return pred
<<<<<<< HEAD
'''
=======
def inference(self, img0, img1, scale=None):
imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(torch.cat((img0, img1), 1))
return self.predict(imgs, flow, training=False)
>>>>>>> origin/main
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if training:
self.train()
else:
self.eval()
flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(torch.cat((imgs, gt), 1), scale=[4, 2, 1])
loss_l1 = (self.lap(merged[2], gt)).mean()
loss_tea = (self.lap(merged_teacher, gt)).mean()
if training:
self.optimG.zero_grad()
loss_G = loss_l1 + loss_tea + loss_distill * 0.01
loss_G.backward()
self.optimG.step()
else:
flow_teacher = flow[2]
merged_teacher = merged[2]
return merged[2], {
'merged_tea': merged_teacher,
'mask': mask,
'mask_tea': mask,
'flow': flow[2][:, :2],
'flow_tea': flow_teacher,
'loss_l1': loss_l1,
'loss_tea': loss_tea,
'loss_distill': loss_distill,
}