mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 16:37:51 +01:00
Merge pull request #113 from n00mkrad/main
Add optional fp16 mode for faster/lightweight inference on RTX cards
This commit is contained in:
@@ -53,18 +53,12 @@ def transferAudio(sourceVideo, targetVideo):
|
||||
# remove temp directory
|
||||
shutil.rmtree("temp")
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
torch.set_grad_enabled(False)
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
|
||||
parser.add_argument('--video', dest='video', type=str, default=None)
|
||||
parser.add_argument('--output', dest='output', type=str, default=None)
|
||||
parser.add_argument('--img', dest='img', type=str, default=None)
|
||||
parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
|
||||
parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
|
||||
parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
|
||||
parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
|
||||
parser.add_argument('--fps', dest='fps', type=int, default=None)
|
||||
@@ -75,6 +69,14 @@ args = parser.parse_args()
|
||||
assert (not args.video is None or not args.img is None)
|
||||
if not args.img is None:
|
||||
args.png = True
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
torch.set_grad_enabled(False)
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if(args.fp16):
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
from model.RIFE_HDv2 import Model
|
||||
model = Model()
|
||||
@@ -155,7 +157,13 @@ def make_inference(I0, I1, exp):
|
||||
first_half = make_inference(I0, middle, exp=exp - 1)
|
||||
second_half = make_inference(middle, I1, exp=exp - 1)
|
||||
return [*first_half, middle, *second_half]
|
||||
|
||||
|
||||
def pad_image(img):
|
||||
if(args.fp16):
|
||||
return F.pad(img, padding).half()
|
||||
else:
|
||||
return F.pad(img, padding)
|
||||
|
||||
if args.montage:
|
||||
left = w // 4
|
||||
w = w // 2
|
||||
@@ -176,14 +184,14 @@ _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
|
||||
_thread.start_new_thread(clear_write_buffer, (args, write_buffer))
|
||||
|
||||
I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
||||
I1 = F.pad(I1, padding)
|
||||
I1 = pad_image(I1)
|
||||
while True:
|
||||
frame = read_buffer.get()
|
||||
if frame is None:
|
||||
break
|
||||
I0 = I1
|
||||
I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
||||
I1 = F.pad(I1, padding)
|
||||
I1 = pad_image(I1)
|
||||
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
|
||||
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
||||
ssim = ssim_matlab(I0_small, I1_small)
|
||||
|
||||
Reference in New Issue
Block a user