diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0fe34b2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ + +*.pyc +*.py~ diff --git a/Flownet.py b/model/IFNet.py similarity index 92% rename from Flownet.py rename to model/IFNet.py index c2b2ee7..6172870 100644 --- a/Flownet.py +++ b/model/IFNet.py @@ -49,9 +49,9 @@ class ResBlock(nn.Module): x = self.relu2(x * w + y) return x -class Flownet(nn.Module): +class IFBlock(nn.Module): def __init__(self, in_planes, scale=1, c=64): - super(Flownet, self).__init__() + super(IFBlock, self).__init__() self.scale = scale self.conv0 = conv(in_planes, c, 3, 2, 1) self.res0 = ResBlock(c, c) @@ -79,12 +79,12 @@ class Flownet(nn.Module): flow = F.interpolate(flow, scale_factor= self.scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) return flow -class FlownetCas(nn.Module): +class IFNet(nn.Module): def __init__(self): - super(FlownetCas, self).__init__() - self.block0 = Flownet(6, scale=4, c=192) - self.block1 = Flownet(8, scale=2, c=128) - self.block2 = Flownet(8, scale=1, c=64) + super(IFNet, self).__init__() + self.block0 = IFBlock(6, scale=4, c=192) + self.block1 = IFBlock(8, scale=2, c=128) + self.block2 = IFBlock(8, scale=1, c=64) def forward(self, x): x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) diff --git a/model.py b/model/RIFE.py similarity index 91% rename from model.py rename to model/RIFE.py index 8c622d7..9bf1416 100644 --- a/model.py +++ b/model/RIFE.py @@ -6,7 +6,7 @@ import torch.optim as optim import itertools from warplayer import warp from torch.nn.parallel import DistributedDataParallel as DDP -from Flownet import * +from IFNet import * import torch.nn.functional as F from loss import * @@ -55,9 +55,9 @@ class ResBlock(nn.Module): x = self.relu2(x * w + y) return x c = 16 -class Contextnet(nn.Module): +class ContextNet(nn.Module): def __init__(self): - super(Contextnet, self).__init__() + super(ContextNet, self).__init__() self.conv1 = ResBlock(3, c) self.conv2 = ResBlock(c, 2*c) self.conv3 = ResBlock(2*c, 4*c) @@ -77,9 +77,9 @@ class Contextnet(nn.Module): f4 = warp(x, flow) return [f1, f2, f3, f4] -class Unet(nn.Module): +class FusionNet(nn.Module): def __init__(self): - super(Unet, self).__init__() + super(FusionNet, self).__init__() self.down0 = ResBlock(8, 2*c) self.down1 = ResBlock(4*c, 4*c) self.down2 = ResBlock(8*c, 8*c) @@ -111,14 +111,14 @@ class Unet(nn.Module): class Model: def __init__(self, local_rank=-1): - self.flownet = FlownetCas() - self.contextnet = Contextnet() - self.unet = Unet() + self.flownet = IFNet() + self.contextnet = ContextNet() + self.fusionnet = FusionNet() self.device() self.optimG = AdamW(itertools.chain( self.flownet.parameters(), self.contextnet.parameters(), - self.unet.parameters()), lr=1e-6, weight_decay=1e-5) + self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.schedulerG = optim.lr_scheduler.CyclicLR(self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.epe = EPE() self.ter = Ternary() @@ -126,34 +126,34 @@ class Model: if local_rank != -1: self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) self.contextnet = DDP(self.contextnet, device_ids=[local_rank], output_device=local_rank) - self.unet = DDP(self.unet, device_ids=[local_rank], output_device=local_rank) + self.fusionnet = DDP(self.fusionnet, device_ids=[local_rank], output_device=local_rank) def train(self): self.flownet.train() self.contextnet.train() - self.unet.train() + self.fusionnet.train() def eval(self): self.flownet.eval() self.contextnet.eval() - self.unet.eval() + self.fusionnet.eval() def device(self): self.flownet.to(device) self.contextnet.to(device) - self.unet.to(device) + self.fusionnet.to(device) def load_model(self, path, rank=0): if rank == 0: self.flownet.load_state_dict(torch.load('{}/flownet.pkl'.format(path))) self.contextnet.load_state_dict(torch.load('{}/contextnet.pkl'.format(path))) - self.unet.load_state_dict(torch.load('{}/unet.pkl'.format(path))) + self.fusionnet.load_state_dict(torch.load('{}/unet.pkl'.format(path))) def save_model(self, path, rank=0): if rank == 0: torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path)) torch.save(self.contextnet.state_dict(),'{}/contextnet.pkl'.format(path)) - torch.save(self.unet.state_dict(),'{}/unet.pkl'.format(path)) + torch.save(self.fusionnet.state_dict(),'{}/unet.pkl'.format(path)) def predict(self, imgs, flow, training=True, flow_gt=None): img0 = imgs[:, :3] @@ -161,7 +161,7 @@ class Model: c0 = self.contextnet(img0, flow) c1 = self.contextnet(img1, -flow) flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 - refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.unet(img0, img1, flow, c0, c1, flow_gt) + refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(img0, img1, flow, c0, c1, flow_gt) res = torch.sigmoid(refine_output[:, :3]) * 2 - 1 mask = torch.sigmoid(refine_output[:, 3:4]) merged_img = warped_img0 * mask + warped_img1 * (1 - mask) diff --git a/loss.py b/model/loss.py similarity index 100% rename from loss.py rename to model/loss.py diff --git a/warplayer.py b/model/warplayer.py similarity index 100% rename from warplayer.py rename to model/warplayer.py