mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2026-02-24 04:19:41 +01:00
Much better parallel!
This commit is contained in:
@@ -51,7 +51,7 @@ cnt = 0
|
||||
skip_frame = 1
|
||||
|
||||
|
||||
def write_frame(vid_out, i0, infs, i1, p, user_args):
|
||||
def write_frame(i0, infs, i1, p, user_args):
|
||||
global skip_frame, cnt
|
||||
|
||||
for i in range(i0.shape[0]):
|
||||
@@ -104,7 +104,7 @@ while success:
|
||||
success, frame = videoCapture.read()
|
||||
if success:
|
||||
img_list.append(frame)
|
||||
if len(img_list) == 5 or (not success and len(img_list) > 1):
|
||||
if len(img_list) == 3 or (not success and len(img_list) > 1):
|
||||
I0 = torch.from_numpy(np.transpose(img_list[:-1], (0, 3, 1, 2)).astype('float32') / 255.).to(device, non_blocking=True)
|
||||
I1 = torch.from_numpy(np.transpose(img_list[1:], (0, 3, 1, 2)).astype('float32') / 255.).to(device, non_blocking=True)
|
||||
p = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False)
|
||||
@@ -117,7 +117,7 @@ while success:
|
||||
I1 = np.array(img_list[1:])
|
||||
inferences = list(map(lambda x: ((x[:, :, :h, :w] * 255.).byte().cpu().detach().numpy().transpose(0, 2, 3, 1)), inferences))
|
||||
|
||||
write_frame(vid_out, I0, inferences, I1, p.mean(3).mean(2).mean(1), args)
|
||||
write_frame(I0, inferences, I1, p.mean(3).mean(2).mean(1), args)
|
||||
pbar.update(4)
|
||||
img_list = img_list[-1:]
|
||||
pbar.close()
|
||||
|
||||
Reference in New Issue
Block a user