mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2026-02-24 04:19:41 +01:00
Add multi thread!
This commit is contained in:
@@ -6,6 +6,8 @@ import numpy as np
|
||||
from tqdm import tqdm
|
||||
from torch.nn import functional as F
|
||||
import warnings
|
||||
import _thread
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@@ -48,11 +50,10 @@ else:
|
||||
|
||||
cnt = 0
|
||||
skip_frame = 1
|
||||
|
||||
buffer = []
|
||||
|
||||
def write_frame(i0, infs, i1, p, user_args):
|
||||
global skip_frame, cnt
|
||||
|
||||
global skip_frame, cnt
|
||||
for i in range(i0.shape[0]):
|
||||
# A video transition occurs.
|
||||
if p[i] > 0.2:
|
||||
@@ -70,18 +71,19 @@ def write_frame(i0, infs, i1, p, user_args):
|
||||
skip_frame += 1
|
||||
continue
|
||||
|
||||
# Write results.
|
||||
if user_args.png:
|
||||
cv2.imwrite('output/{:0>7d}.png'.format(cnt), i0[i])
|
||||
cnt += 1
|
||||
for inf in infs:
|
||||
cv2.imwrite('output/{:0>7d}.png'.format(cnt), inf[i])
|
||||
cnt += 1
|
||||
else:
|
||||
vid_out.write(i0[i])
|
||||
for inf in infs:
|
||||
vid_out.write(inf[i])
|
||||
# Write results.
|
||||
buffer.append(i0[i])
|
||||
for inf in infs:
|
||||
buffer.append(inf[i])
|
||||
|
||||
def clear_buffer(user_args, buffer):
|
||||
global cnt
|
||||
for i in buffer:
|
||||
if user_args.png:
|
||||
cv2.imwrite('output/{:0>7d}.png'.format(cnt), i)
|
||||
cnt += i
|
||||
else:
|
||||
vid_out.write(i)
|
||||
|
||||
def make_inference(model, I0, I1, exp):
|
||||
middle = model.inference(I0, I1)
|
||||
@@ -120,5 +122,9 @@ while success:
|
||||
write_frame(I0, inferences, I1, p.mean(3).mean(2).mean(1), args)
|
||||
pbar.update(4)
|
||||
img_list = img_list[-1:]
|
||||
if len(buffer) > 100:
|
||||
_thread.start_new_thread(clear_buffer, (args, buffer))
|
||||
buffer = []
|
||||
_thread.start_new_thread(clear_buffer, (args, buffer))
|
||||
pbar.close()
|
||||
vid_out.release()
|
||||
|
||||
Reference in New Issue
Block a user