mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2026-05-18 05:04:43 +02:00
Update UHD mode
This commit is contained in:
@@ -61,6 +61,7 @@ parser.add_argument('--montage', dest='montage', action='store_true', help='mont
|
|||||||
parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
|
parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
|
||||||
parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
|
parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
|
||||||
parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
|
parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
|
||||||
|
parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video')
|
||||||
parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
|
parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
|
||||||
parser.add_argument('--fps', dest='fps', type=int, default=None)
|
parser.add_argument('--fps', dest='fps', type=int, default=None)
|
||||||
parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs')
|
parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs')
|
||||||
@@ -68,6 +69,7 @@ parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out
|
|||||||
parser.add_argument('--exp', dest='exp', type=int, default=1)
|
parser.add_argument('--exp', dest='exp', type=int, default=1)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert (not args.video is None or not args.img is None)
|
assert (not args.video is None or not args.img is None)
|
||||||
|
assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
|
||||||
if not args.img is None:
|
if not args.img is None:
|
||||||
args.png = True
|
args.png = True
|
||||||
|
|
||||||
@@ -159,7 +161,7 @@ def build_read_buffer(user_args, read_buffer, videogen):
|
|||||||
|
|
||||||
def make_inference(I0, I1, exp):
|
def make_inference(I0, I1, exp):
|
||||||
global model
|
global model
|
||||||
middle = model.inference(I0, I1, args.UHD)
|
middle = model.inference(I0, I1, args.scale)
|
||||||
if exp == 1:
|
if exp == 1:
|
||||||
return [middle]
|
return [middle]
|
||||||
first_half = make_inference(I0, middle, exp=exp - 1)
|
first_half = make_inference(I0, middle, exp=exp - 1)
|
||||||
@@ -175,12 +177,9 @@ def pad_image(img):
|
|||||||
if args.montage:
|
if args.montage:
|
||||||
left = w // 4
|
left = w // 4
|
||||||
w = w // 2
|
w = w // 2
|
||||||
if args.UHD:
|
tmp = max(32, int(32 / args.scale))
|
||||||
ph = ((h - 1) // 64 + 1) * 64
|
ph = ((h - 1) // tmp + 1) * tmp
|
||||||
pw = ((w - 1) // 64 + 1) * 64
|
pw = ((w - 1) // tmp + 1) * tmp
|
||||||
else:
|
|
||||||
ph = ((h - 1) // 32 + 1) * 32
|
|
||||||
pw = ((w - 1) // 32 + 1) * 32
|
|
||||||
padding = (0, pw - w, 0, ph - h)
|
padding = (0, pw - w, 0, ph - h)
|
||||||
pbar = tqdm(total=tot_frame)
|
pbar = tqdm(total=tot_frame)
|
||||||
skip_frame = 1
|
skip_frame = 1
|
||||||
|
|||||||
@@ -91,12 +91,9 @@ class IFNet(nn.Module):
|
|||||||
self.block2 = IFBlock(8, scale=2, c=96)
|
self.block2 = IFBlock(8, scale=2, c=96)
|
||||||
self.block3 = IFBlock(8, scale=1, c=48)
|
self.block3 = IFBlock(8, scale=1, c=48)
|
||||||
|
|
||||||
def forward(self, x, UHD=False):
|
def forward(self, x, scale=1.0):
|
||||||
if UHD:
|
x = F.interpolate(x, scale_factor=0.5 * scale, mode="bilinear",
|
||||||
x = F.interpolate(x, scale_factor=0.25, mode="bilinear", align_corners=False)
|
align_corners=False)
|
||||||
else:
|
|
||||||
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
|
|
||||||
align_corners=False)
|
|
||||||
flow0 = self.block0(x)
|
flow0 = self.block0(x)
|
||||||
F1 = flow0
|
F1 = flow0
|
||||||
warped_img0 = warp(x[:, :3], F1)
|
warped_img0 = warp(x[:, :3], F1)
|
||||||
@@ -111,6 +108,8 @@ class IFNet(nn.Module):
|
|||||||
warped_img1 = warp(x[:, 3:], -F3)
|
warped_img1 = warp(x[:, 3:], -F3)
|
||||||
flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3), 1))
|
flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3), 1))
|
||||||
F4 = (flow0 + flow1 + flow2 + flow3)
|
F4 = (flow0 + flow1 + flow2 + flow3)
|
||||||
|
F4 = F.interpolate(F4, scale_factor=1 / scale, mode="bilinear",
|
||||||
|
align_corners=False) / scale
|
||||||
return F4, [F1, F2, F3, F4]
|
return F4, [F1, F2, F3, F4]
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -188,11 +188,9 @@ class Model:
|
|||||||
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
|
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
|
||||||
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
|
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
|
||||||
|
|
||||||
def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False):
|
def predict(self, imgs, flow, training=True, flow_gt=None):
|
||||||
img0 = imgs[:, :3]
|
img0 = imgs[:, :3]
|
||||||
img1 = imgs[:, 3:]
|
img1 = imgs[:, 3:]
|
||||||
if UHD:
|
|
||||||
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
|
|
||||||
c0 = self.contextnet(img0, flow)
|
c0 = self.contextnet(img0, flow)
|
||||||
c1 = self.contextnet(img1, -flow)
|
c1 = self.contextnet(img1, -flow)
|
||||||
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
|
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
|
||||||
@@ -209,10 +207,10 @@ class Model:
|
|||||||
else:
|
else:
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
def inference(self, img0, img1, UHD=False):
|
def inference(self, img0, img1, scale=1.0):
|
||||||
imgs = torch.cat((img0, img1), 1)
|
imgs = torch.cat((img0, img1), 1)
|
||||||
flow, _ = self.flownet(imgs, UHD)
|
flow, _ = self.flownet(imgs, scale)
|
||||||
return self.predict(imgs, flow, training=False, UHD=UHD)
|
return self.predict(imgs, flow, training=False)
|
||||||
|
|
||||||
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
|
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
|
||||||
for param_group in self.optimG.param_groups:
|
for param_group in self.optimG.param_groups:
|
||||||
|
|||||||
@@ -173,11 +173,9 @@ class Model:
|
|||||||
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
|
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
|
||||||
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
|
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
|
||||||
|
|
||||||
def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False):
|
def predict(self, imgs, flow, training=True, flow_gt=None):
|
||||||
img0 = imgs[:, :3]
|
img0 = imgs[:, :3]
|
||||||
img1 = imgs[:, 3:]
|
img1 = imgs[:, 3:]
|
||||||
if UHD:
|
|
||||||
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
|
|
||||||
c0 = self.contextnet(img0, flow[:, :2])
|
c0 = self.contextnet(img0, flow[:, :2])
|
||||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||||
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
|
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
|
||||||
@@ -194,10 +192,10 @@ class Model:
|
|||||||
else:
|
else:
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
def inference(self, img0, img1, UHD=False):
|
def inference(self, img0, img1, scale=1.0):
|
||||||
imgs = torch.cat((img0, img1), 1)
|
imgs = torch.cat((img0, img1), 1)
|
||||||
flow, _ = self.flownet(imgs, UHD)
|
flow, _ = self.flownet(imgs, scale)
|
||||||
return self.predict(imgs, flow, training=False, UHD=UHD)
|
return self.predict(imgs, flow, training=False)
|
||||||
|
|
||||||
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
|
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
|
||||||
for param_group in self.optimG.param_groups:
|
for param_group in self.optimG.param_groups:
|
||||||
|
|||||||
Reference in New Issue
Block a user