mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 16:37:51 +01:00
Add static frame interpolation
This commit is contained in:
@@ -161,14 +161,17 @@ def build_read_buffer(user_args, read_buffer, videogen):
|
||||
pass
|
||||
read_buffer.put(None)
|
||||
|
||||
def make_inference(I0, I1, exp):
|
||||
def make_inference(I0, I1, n):
|
||||
global model
|
||||
middle = model.inference(I0, I1, args.scale)
|
||||
if exp == 1:
|
||||
if n == 1:
|
||||
return [middle]
|
||||
first_half = make_inference(I0, middle, exp=exp - 1)
|
||||
second_half = make_inference(middle, I1, exp=exp - 1)
|
||||
return [*first_half, middle, *second_half]
|
||||
first_half = make_inference(I0, middle, n=n//2)
|
||||
second_half = make_inference(middle, I1, n=n//2)
|
||||
if n%2:
|
||||
return [*first_half, middle, *second_half]
|
||||
else:
|
||||
return [*first_half, *second_half]
|
||||
|
||||
def pad_image(img):
|
||||
if(args.fp16):
|
||||
@@ -194,6 +197,11 @@ _thread.start_new_thread(clear_write_buffer, (args, write_buffer))
|
||||
|
||||
I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
||||
I1 = pad_image(I1)
|
||||
# number of frames to interpolate including duplicate frames to replace
|
||||
duplicate_count = 0
|
||||
# last valid frame (non-duplicate)
|
||||
duplicate_frame = None
|
||||
|
||||
while True:
|
||||
frame = read_buffer.get()
|
||||
if frame is None:
|
||||
@@ -204,11 +212,20 @@ while True:
|
||||
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
|
||||
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
||||
ssim = ssim_matlab(I0_small, I1_small)
|
||||
if ssim > 0.995 and args.skip:
|
||||
|
||||
if ssim > 0.995:
|
||||
if skip_frame % 100 == 0:
|
||||
print("Warning: Your video has {} static frames, skipping them may change the duration of the generated video.".format(skip_frame))
|
||||
print("\nWarning: Your video has {} static frames, skipping them may change the duration of the generated video.".format(skip_frame))
|
||||
skip_frame += 1
|
||||
pbar.update(1)
|
||||
if args.skip:
|
||||
continue
|
||||
|
||||
if duplicate_count:
|
||||
duplicate_count += 2**args.exp # 2^exp-1+1: number of frames to interpolate + duplicate frame
|
||||
else:
|
||||
duplicate_count = 2**args.exp # 2^exp-1+1: number of frames to interpolate + duplicate frame
|
||||
duplicate_frame = I0
|
||||
continue
|
||||
if ssim < 0.5:
|
||||
output = []
|
||||
@@ -219,7 +236,12 @@ while True:
|
||||
beta = 1-alpha
|
||||
output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
|
||||
else:
|
||||
output = make_inference(I0, I1, args.exp)
|
||||
if duplicate_count:
|
||||
duplicate_count += 2**args.exp - 1 # number of frames to interpolate
|
||||
output = make_inference(duplicate_frame, I1, duplicate_count)
|
||||
else:
|
||||
output = make_inference(I0, I1, 2**args.exp-1) if args.exp else []
|
||||
|
||||
if args.montage:
|
||||
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
|
||||
for mid in output:
|
||||
@@ -232,6 +254,9 @@ while True:
|
||||
write_buffer.put(mid[:h, :w])
|
||||
pbar.update(1)
|
||||
lastframe = frame
|
||||
# reset if and only if not duplicate
|
||||
duplicate_count=0
|
||||
duplicate_frame=None
|
||||
if args.montage:
|
||||
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user