diff --git a/inpainter/base_inpainter.py b/inpainter/base_inpainter.py index 18fdfce..35d6d4e 100644 --- a/inpainter/base_inpainter.py +++ b/inpainter/base_inpainter.py @@ -69,10 +69,11 @@ class BaseInpainter: size = None else: size = [int(W*ratio), int(H*ratio)] - if size[0] % 2 > 0: - size[0] += 1 - if size[1] % 2 > 0: - size[1] += 1 + size = [si+1 if si%2>0 else si for si in size] # only consider even values + # shortest side should be larger than 50 + if min(size) < 50: + ratio = 50. / min(H, W) + size = [int(W*ratio), int(H*ratio)] masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1 binary_masks = resize_masks(masks, tuple(size)) @@ -156,7 +157,7 @@ if __name__ == '__main__': base_inpainter = BaseInpainter(checkpoint, device) # 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W) # ratio: (0, 1], ratio for down sample, default value is 1 - inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=1) # numpy array, T, H, W, 3 + inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=0.01) # numpy array, T, H, W, 3 # ---------------------------------------------- # end # ----------------------------------------------