Merge pull request #113 from n00mkrad/main

Add optional fp16 mode for faster/lightweight inference on RTX cards
This commit is contained in:
hzwer
2021-02-25 11:10:17 +08:00
committed by GitHub

View File

@@ -53,18 +53,12 @@ def transferAudio(sourceVideo, targetVideo):
# remove temp directory
shutil.rmtree("temp")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
parser.add_argument('--video', dest='video', type=str, default=None)
parser.add_argument('--output', dest='output', type=str, default=None)
parser.add_argument('--img', dest='img', type=str, default=None)
parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
parser.add_argument('--fps', dest='fps', type=int, default=None)
@@ -75,6 +69,14 @@ args = parser.parse_args()
assert (not args.video is None or not args.img is None)
if not args.img is None:
args.png = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
if(args.fp16):
torch.set_default_tensor_type(torch.cuda.HalfTensor)
from model.RIFE_HDv2 import Model
model = Model()
@@ -155,7 +157,13 @@ def make_inference(I0, I1, exp):
first_half = make_inference(I0, middle, exp=exp - 1)
second_half = make_inference(middle, I1, exp=exp - 1)
return [*first_half, middle, *second_half]
def pad_image(img):
if(args.fp16):
return F.pad(img, padding).half()
else:
return F.pad(img, padding)
if args.montage:
left = w // 4
w = w // 2
@@ -176,14 +184,14 @@ _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
_thread.start_new_thread(clear_write_buffer, (args, write_buffer))
I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
I1 = F.pad(I1, padding)
I1 = pad_image(I1)
while True:
frame = read_buffer.get()
if frame is None:
break
I0 = I1
I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
I1 = F.pad(I1, padding)
I1 = pad_image(I1)
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
ssim = ssim_matlab(I0_small, I1_small)