mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 08:27:45 +01:00
100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from torch.optim import AdamW
|
|
import torch.optim as optim
|
|
import itertools
|
|
from model.warplayer import warp
|
|
from torchstat import stat
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from model.IFNet import *
|
|
import torch.nn.functional as F
|
|
from model.loss import *
|
|
from model.laplacian import *
|
|
from model.refine import *
|
|
|
|
device = torch.device("cuda")
|
|
|
|
class Model:
|
|
def __init__(self, local_rank=-1):
|
|
self.flownet = IFNet()
|
|
self.device()
|
|
self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
|
|
self.epe = EPE()
|
|
self.lap = LapLoss()
|
|
self.sobel = SOBEL()
|
|
if local_rank != -1:
|
|
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
|
|
|
|
def train(self):
|
|
self.flownet.train()
|
|
|
|
def eval(self):
|
|
self.flownet.eval()
|
|
|
|
def device(self):
|
|
self.flownet.to(device)
|
|
|
|
def load_model(self, path, rank=0):
|
|
if rank == 0:
|
|
self.flownet.load_state_dict(torch.load('{}/flownet.pkl'.format(path)))
|
|
|
|
def save_model(self, path, rank=0):
|
|
if rank == 0:
|
|
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
|
|
|
|
'''
|
|
def predict(self, imgs, flow, merged, training=True, flow_gt=None):
|
|
img0 = imgs[:, :3]
|
|
img1 = imgs[:, 3:]
|
|
c0 = self.contextnet(img0, flow[:, :2])
|
|
c1 = self.contextnet(img1, flow[:, 2:4])
|
|
refine_output = self.unet(img0, img1, flow, merged, c0, c1, flow_gt)
|
|
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
|
|
pred = merged + res
|
|
pred = torch.clamp(pred, 0, 1)
|
|
if training:
|
|
return pred, merged
|
|
else:
|
|
return pred
|
|
<<<<<<< HEAD
|
|
'''
|
|
=======
|
|
|
|
def inference(self, img0, img1, scale=None):
|
|
imgs = torch.cat((img0, img1), 1)
|
|
flow, _ = self.flownet(torch.cat((img0, img1), 1))
|
|
return self.predict(imgs, flow, training=False)
|
|
>>>>>>> origin/main
|
|
|
|
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
|
|
for param_group in self.optimG.param_groups:
|
|
param_group['lr'] = learning_rate
|
|
img0 = imgs[:, :3]
|
|
img1 = imgs[:, 3:]
|
|
if training:
|
|
self.train()
|
|
else:
|
|
self.eval()
|
|
flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(torch.cat((imgs, gt), 1), scale=[4, 2, 1])
|
|
loss_l1 = (self.lap(merged[2], gt)).mean()
|
|
loss_tea = (self.lap(merged_teacher, gt)).mean()
|
|
if training:
|
|
self.optimG.zero_grad()
|
|
loss_G = loss_l1 + loss_tea + loss_distill * 0.01
|
|
loss_G.backward()
|
|
self.optimG.step()
|
|
else:
|
|
flow_teacher = flow[2]
|
|
merged_teacher = merged[2]
|
|
return merged[2], {
|
|
'merged_tea': merged_teacher,
|
|
'mask': mask,
|
|
'mask_tea': mask,
|
|
'flow': flow[2][:, :2],
|
|
'flow_tea': flow_teacher,
|
|
'loss_l1': loss_l1,
|
|
'loss_tea': loss_tea,
|
|
'loss_distill': loss_distill,
|
|
}
|