Add static frame interpolation

This commit is contained in:
Mskycoder
2021-04-18 21:39:32 -04:00
parent 02e8d34a96
commit ee20c4edaf

View File

@@ -161,14 +161,17 @@ def build_read_buffer(user_args, read_buffer, videogen):
pass
read_buffer.put(None)
def make_inference(I0, I1, exp):
def make_inference(I0, I1, n):
global model
middle = model.inference(I0, I1, args.scale)
if exp == 1:
if n == 1:
return [middle]
first_half = make_inference(I0, middle, exp=exp - 1)
second_half = make_inference(middle, I1, exp=exp - 1)
return [*first_half, middle, *second_half]
first_half = make_inference(I0, middle, n=n//2)
second_half = make_inference(middle, I1, n=n//2)
if n%2:
return [*first_half, middle, *second_half]
else:
return [*first_half, *second_half]
def pad_image(img):
if(args.fp16):
@@ -194,6 +197,11 @@ _thread.start_new_thread(clear_write_buffer, (args, write_buffer))
I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
I1 = pad_image(I1)
# number of frames to interpolate including duplicate frames to replace
duplicate_count = 0
# last valid frame (non-duplicate)
duplicate_frame = None
while True:
frame = read_buffer.get()
if frame is None:
@@ -204,11 +212,20 @@ while True:
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
ssim = ssim_matlab(I0_small, I1_small)
if ssim > 0.995 and args.skip:
if ssim > 0.995:
if skip_frame % 100 == 0:
print("Warning: Your video has {} static frames, skipping them may change the duration of the generated video.".format(skip_frame))
print("\nWarning: Your video has {} static frames, skipping them may change the duration of the generated video.".format(skip_frame))
skip_frame += 1
pbar.update(1)
if args.skip:
continue
if duplicate_count:
duplicate_count += 2**args.exp # 2^exp-1+1: number of frames to interpolate + duplicate frame
else:
duplicate_count = 2**args.exp # 2^exp-1+1: number of frames to interpolate + duplicate frame
duplicate_frame = I0
continue
if ssim < 0.5:
output = []
@@ -219,7 +236,12 @@ while True:
beta = 1-alpha
output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
else:
output = make_inference(I0, I1, args.exp)
if duplicate_count:
duplicate_count += 2**args.exp - 1 # number of frames to interpolate
output = make_inference(duplicate_frame, I1, duplicate_count)
else:
output = make_inference(I0, I1, 2**args.exp-1) if args.exp else []
if args.montage:
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
for mid in output:
@@ -232,6 +254,9 @@ while True:
write_buffer.put(mid[:h, :w])
pbar.update(1)
lastframe = frame
# reset if and only if not duplicate
duplicate_count=0
duplicate_frame=None
if args.montage:
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
else: