From 183f5014e84ff5fb3f36e42ca546dfc6943a3962 Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Tue, 17 Nov 2020 19:00:39 +0800 Subject: [PATCH] Fixing parallel --- inference_mp4_4x_parallel.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/inference_mp4_4x_parallel.py b/inference_mp4_4x_parallel.py index 76df6cf..e30a3a0 100644 --- a/inference_mp4_4x_parallel.py +++ b/inference_mp4_4x_parallel.py @@ -42,9 +42,9 @@ def writeframe(I0, mid0, mid1, mid2, I1, p): global cnt, skip_frame for i in range(I0.shape[0]): if p[i] > 0.2: - mid0[i] = I0 - mid1[i] = I0 - mid2[i] = I1 + mid0[i] = I0[i] + mid1[i] = I0[i] + mid2[i] = I1[i] if p[i] < 1e-3 and args.skip: if skip_frame % 100 == 0: print("Warning: Your video has {} static frames, skipping them may change the duration of the generated video.".format(skip_frame)) @@ -70,26 +70,28 @@ tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) print('{}.mp4, {} frames in total, {}FPS to {}FPS'.format(args.video[:-4], tot_frame, fps, args.fps)) pbar = tqdm(total=tot_frame) img_list = [] +img_list.append(frame) while success: - img_list.append(frame) success, frame = videoCapture.read() if success: img_list.append(frame) - if img_list == 5 or not success: + if len(img_list) == 5 or not success: I0 = torch.from_numpy(np.transpose(img_list[:-1], (0, 3, 1, 2)).astype("float32") / 255.).to(device) I1 = torch.from_numpy(np.transpose(img_list[1:], (0, 3, 1, 2)).astype("float32") / 255.).to(device) p = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False) - - F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs().mean(3).mean(2).mean(1) + - F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs() I0 = F.pad(I0, padding) I1 = F.pad(I1, padding) mid1 = model.inference(I0, I1) mid0 = model.inference(I0, mid1) mid2 = model.inference(mid1, I1) + I0 = (((I0 * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1))).astype('uint8') + I1 = (((I1 * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1))).astype('uint8') mid0 = (((mid0 * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1))).astype('uint8') mid1 = (((mid1 * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1))).astype('uint8') mid2 = (((mid2 * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1))).astype('uint8') - writeframe(p, mid0, mid1, mid2) + writeframe(I0, mid0, mid1, mid2, I1, p.mean(3).mean(2).mean(1)) pbar.update(4) - img_list = img_list[-1] + img_list = img_list[-1:] pbar.close() output.release()