diff --git a/inference_video.py b/inference_video.py index c3f2a71..609bb04 100644 --- a/inference_video.py +++ b/inference_video.py @@ -13,6 +13,7 @@ if torch.cuda.is_available(): torch.set_grad_enabled(False) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True + torch.set_num_threads(4) parser = argparse.ArgumentParser(description='Interpolation for a pair of images') parser.add_argument('--video', dest='video', required=True) @@ -72,8 +73,8 @@ while success: if success: if args.montage: frame = frame[:, left: left + w] - I0 = torch.from_numpy(np.transpose(lastframe, (2,0,1)).astype("float32") / 255.).to(device).unsqueeze(0) - I1 = torch.from_numpy(np.transpose(frame, (2,0,1)).astype("float32") / 255.).to(device).unsqueeze(0) + I0 = torch.from_numpy(np.transpose(lastframe, (2,0,1)).astype('float32') / 255.).to(device, non_blocking=True).unsqueeze(0) + I1 = torch.from_numpy(np.transpose(frame, (2,0,1)).astype('float32') / 255.).to(device, non_blocking=True).unsqueeze(0) I0 = F.pad(I0, padding) I1 = F.pad(I1, padding) p = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False) diff --git a/inference_video_parallel.py b/inference_video_parallel.py index f06bcea..f22684c 100644 --- a/inference_video_parallel.py +++ b/inference_video_parallel.py @@ -13,6 +13,7 @@ if torch.cuda.is_available(): torch.set_grad_enabled(False) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True + torch.set_num_threads(4) parser = argparse.ArgumentParser(description='Interpolation for a pair of images') parser.add_argument('--video', dest='video', required=True) @@ -104,16 +105,16 @@ while success: if success: img_list.append(frame) if len(img_list) == 5 or (not success and len(img_list) > 1): - I0 = torch.from_numpy(np.transpose(img_list[:-1], (0, 3, 1, 2)).astype("float32") / 255.).to(device) - I1 = torch.from_numpy(np.transpose(img_list[1:], (0, 3, 1, 2)).astype("float32") / 255.).to(device) + I0 = torch.from_numpy(np.transpose(img_list[:-1], (0, 3, 1, 2)).astype('float32') / 255.).to(device, non_blocking=True) + I1 = torch.from_numpy(np.transpose(img_list[1:], (0, 3, 1, 2)).astype('float32') / 255.).to(device, non_blocking=True) p = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False) - F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs() I0 = F.pad(I0, padding) I1 = F.pad(I1, padding) inferences = make_inference(model, I0, I1, exp=args.exp) - - I0 = ((I0[:, :, :h, :w] * 255.).byte().cpu().detach().numpy().transpose(0, 2, 3, 1)) - I1 = ((I1[:, :, :h, :w] * 255.).byte().cpu().detach().numpy().transpose(0, 2, 3, 1)) + + I0 = np.array(img_list[:-1]) + I1 = np.array(img_list[1:]) inferences = list(map(lambda x: ((x[:, :, :h, :w] * 255.).byte().cpu().detach().numpy().transpose(0, 2, 3, 1)), inferences)) write_frame(vid_out, I0, inferences, I1, p.mean(3).mean(2).mean(1), args)