mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2026-02-24 04:19:41 +01:00
Use thread safe queue
This commit is contained in:
@@ -7,6 +7,7 @@ from tqdm import tqdm
|
||||
from torch.nn import functional as F
|
||||
import warnings
|
||||
import _thread
|
||||
from queue import Queue
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@@ -40,7 +41,6 @@ if args.fps is None:
|
||||
success, frame = videoCapture.read()
|
||||
h, w, _ = frame.shape
|
||||
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
|
||||
buffer = []
|
||||
if args.png:
|
||||
if not os.path.exists('vid_out'):
|
||||
os.mkdir('vid_out')
|
||||
@@ -51,12 +51,15 @@ else:
|
||||
cnt = 0
|
||||
def clear_buffer(user_args, buffer):
|
||||
global cnt
|
||||
for i in buffer:
|
||||
while True:
|
||||
item = buffer.get()
|
||||
if item is None:
|
||||
break
|
||||
if user_args.png:
|
||||
cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), i)
|
||||
cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item)
|
||||
cnt += 1
|
||||
else:
|
||||
vid_out.write(i)
|
||||
vid_out.write(item)
|
||||
|
||||
if args.montage:
|
||||
left = w // 4
|
||||
@@ -70,6 +73,7 @@ pbar = tqdm(total=tot_frame)
|
||||
skip_frame = 1
|
||||
if args.montage:
|
||||
frame = frame[:, left: left + w]
|
||||
buffer = Queue()
|
||||
while success:
|
||||
lastframe = frame
|
||||
success, frame = videoCapture.read()
|
||||
@@ -101,27 +105,27 @@ while success:
|
||||
mid0 = (((mid[0] * 255.).byte().cpu().detach().numpy().transpose(1, 2, 0)))
|
||||
mid2 = (((mid[1] * 255.).byte().cpu().detach().numpy().transpose(1, 2, 0)))
|
||||
if args.montage:
|
||||
buffer.append(np.concatenate((lastframe, lastframe), 1))
|
||||
buffer.put(np.concatenate((lastframe, lastframe), 1))
|
||||
if args.exp == 4:
|
||||
buffer.append(np.concatenate((lastframe, mid0[:h, :w]), 1))
|
||||
buffer.append(np.concatenate((lastframe, mid1[:h, :w]), 1))
|
||||
buffer.put(np.concatenate((lastframe, mid0[:h, :w]), 1))
|
||||
buffer.put(np.concatenate((lastframe, mid1[:h, :w]), 1))
|
||||
if args.exp == 4:
|
||||
buffer.append(np.concatenate((lastframe, mid2[:h, :w]), 1))
|
||||
buffer.put(np.concatenate((lastframe, mid2[:h, :w]), 1))
|
||||
else:
|
||||
buffer.append(lastframe)
|
||||
buffer.put(lastframe)
|
||||
if args.exp == 4:
|
||||
buffer.append(mid0[:h, :w])
|
||||
buffer.append(mid1[:h, :w])
|
||||
buffer.put(mid0[:h, :w])
|
||||
buffer.put(mid1[:h, :w])
|
||||
if args.exp == 4:
|
||||
buffer.append(mid2[:h, :w])
|
||||
buffer.put(mid2[:h, :w])
|
||||
pbar.update(1)
|
||||
if len(buffer) > 100:
|
||||
if buffer.qsize() > 100:
|
||||
_thread.start_new_thread(clear_buffer, (args, buffer))
|
||||
buffer = []
|
||||
buffer.clear()
|
||||
if args.montage:
|
||||
buffer.append(np.concatenate((lastframe, lastframe), 1))
|
||||
buffer.put(np.concatenate((lastframe, lastframe), 1))
|
||||
else:
|
||||
buffer.append(lastframe)
|
||||
buffer.put(lastframe)
|
||||
_thread.start_new_thread(clear_buffer, (args, buffer))
|
||||
pbar.close()
|
||||
if not vid_out is None:
|
||||
|
||||
@@ -7,6 +7,7 @@ from tqdm import tqdm
|
||||
from torch.nn import functional as F
|
||||
import warnings
|
||||
import _thread
|
||||
from queue import Queue
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
@@ -50,7 +51,7 @@ else:
|
||||
|
||||
cnt = 0
|
||||
skip_frame = 1
|
||||
buffer = []
|
||||
buffer = Queue()
|
||||
|
||||
def write_frame(i0, infs, i1, p, user_args):
|
||||
global skip_frame, cnt
|
||||
@@ -72,18 +73,21 @@ def write_frame(i0, infs, i1, p, user_args):
|
||||
continue
|
||||
|
||||
# Write results.
|
||||
buffer.append(i0[i])
|
||||
buffer.put(i0[i])
|
||||
for inf in infs:
|
||||
buffer.append(inf[i])
|
||||
buffer.put(inf[i])
|
||||
|
||||
def clear_buffer(user_args, buffer):
|
||||
def clear_buffer(user_args, buffer):
|
||||
global cnt
|
||||
for i in buffer:
|
||||
while True:
|
||||
item = buffer.get()
|
||||
if item is None:
|
||||
break
|
||||
if user_args.png:
|
||||
cv2.imwrite('output/{:0>7d}.png'.format(cnt), i)
|
||||
cv2.imwrite('output/{:0>7d}.png'.format(cnt), item)
|
||||
cnt += 1
|
||||
else:
|
||||
vid_out.write(i)
|
||||
vid_out.write(item)
|
||||
|
||||
def make_inference(model, I0, I1, exp):
|
||||
middle = model.inference(I0, I1)
|
||||
@@ -101,6 +105,7 @@ tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps))
|
||||
pbar = tqdm(total=tot_frame)
|
||||
img_list = [frame]
|
||||
_thread.start_new_thread(clear_buffer, (args, buffer))
|
||||
while success:
|
||||
success, frame = videoCapture.read()
|
||||
if success:
|
||||
@@ -122,10 +127,6 @@ while success:
|
||||
write_frame(I0, inferences, I1, p.mean(3).mean(2).mean(1), args)
|
||||
pbar.update(4)
|
||||
img_list = img_list[-1:]
|
||||
if len(buffer) > 100:
|
||||
_thread.start_new_thread(clear_buffer, (args, buffer))
|
||||
buffer = []
|
||||
_thread.start_new_thread(clear_buffer, (args, buffer))
|
||||
pbar.close()
|
||||
if not vid_out is None:
|
||||
vid_out.release()
|
||||
|
||||
Reference in New Issue
Block a user