Files
ECCV2022-RIFE/model/RIFE.py

98 lines
3.5 KiB
Python
Raw Normal View History

2020-11-12 19:27:57 +08:00
import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
2020-11-12 21:32:21 +08:00
from model.warplayer import warp
2020-11-12 19:27:57 +08:00
from torch.nn.parallel import DistributedDataParallel as DDP
2020-11-12 21:32:21 +08:00
from model.IFNet import *
2021-11-15 23:32:37 +08:00
from model.IFNet_m import *
2020-11-12 19:27:57 +08:00
import torch.nn.functional as F
2020-11-12 21:32:21 +08:00
from model.loss import *
2021-08-13 16:34:41 +08:00
from model.laplacian import *
from model.refine import *
2020-11-12 19:27:57 +08:00
2022-11-16 11:59:13 +08:00
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2021-08-13 16:34:41 +08:00
2020-11-12 19:27:57 +08:00
class Model:
2021-11-15 23:32:37 +08:00
def __init__(self, local_rank=-1, arbitrary=False):
if arbitrary == True:
self.flownet = IFNet_m()
else:
self.flownet = IFNet()
2020-11-12 19:27:57 +08:00
self.device()
2021-10-15 17:32:45 +08:00
self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-3) # use large weight decay may avoid NaN loss
2020-11-12 19:27:57 +08:00
self.epe = EPE()
2021-08-13 16:34:41 +08:00
self.lap = LapLoss()
2020-11-12 19:27:57 +08:00
self.sobel = SOBEL()
if local_rank != -1:
2021-08-13 16:34:41 +08:00
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
2020-11-12 19:27:57 +08:00
def train(self):
self.flownet.train()
def eval(self):
self.flownet.eval()
def device(self):
self.flownet.to(device)
2020-11-12 19:57:29 +08:00
2021-08-13 16:34:41 +08:00
def load_model(self, path, rank=0):
2021-08-13 16:49:25 +08:00
def convert(param):
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
if rank <= 0:
self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))))
2021-08-13 16:34:41 +08:00
def save_model(self, path, rank=0):
2020-11-12 19:27:57 +08:00
if rank == 0:
2021-08-13 16:34:41 +08:00
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
2020-11-12 19:57:29 +08:00
2025-07-25 17:02:10 +08:00
def inference(self, img0, img1, scale=1, scale_list=None, TTA=False, timestep=0.5):
if scale_list is None:
scale_list = [4, 2, 1]
2022-07-21 14:25:24 +08:00
for i in range(3):
scale_list[i] = scale_list[i] * 1.0 / scale
2020-11-15 17:10:46 +08:00
imgs = torch.cat((img0, img1), 1)
2021-11-15 23:32:37 +08:00
flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep)
2021-08-13 17:13:53 +08:00
if TTA == False:
return merged[2]
else:
2021-11-15 23:32:37 +08:00
flow2, mask2, merged2, flow_teacher2, merged_teacher2, loss_distill2 = self.flownet(imgs.flip(2).flip(3), scale_list, timestep=timestep)
2021-08-13 17:13:53 +08:00
return (merged[2] + merged2[2].flip(2).flip(3)) / 2
2021-08-13 16:49:25 +08:00
2020-11-12 19:27:57 +08:00
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
2021-08-13 16:34:41 +08:00
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
2020-11-12 19:27:57 +08:00
if training:
self.train()
else:
self.eval()
2021-08-13 16:34:41 +08:00
flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(torch.cat((imgs, gt), 1), scale=[4, 2, 1])
loss_l1 = (self.lap(merged[2], gt)).mean()
loss_tea = (self.lap(merged_teacher, gt)).mean()
2020-11-12 19:27:57 +08:00
if training:
self.optimG.zero_grad()
2022-04-27 11:47:55 +08:00
loss_G = loss_l1 + loss_tea + loss_distill * 0.01 # when training RIFEm, the weight of loss_distill should be 0.005 or 0.002
2020-11-12 19:27:57 +08:00
loss_G.backward()
self.optimG.step()
2021-08-13 16:34:41 +08:00
else:
flow_teacher = flow[2]
return merged[2], {
'merged_tea': merged_teacher,
'mask': mask,
'mask_tea': mask,
'flow': flow[2][:, :2],
'flow_tea': flow_teacher,
'loss_l1': loss_l1,
'loss_tea': loss_tea,
'loss_distill': loss_distill,
}