mirror of
https://github.com/n00mkrad/flowframes.git
synced 2025-12-16 16:37:48 +01:00
flavr.py: removed unnecessary functions, added jpeg input support
This commit is contained in:
@@ -33,7 +33,6 @@ parser.add_argument('--fp16', dest='fp16', action='store_true', help='half-preci
|
|||||||
parser.add_argument('--imgformat', default="png")
|
parser.add_argument('--imgformat', default="png")
|
||||||
parser.add_argument("--output_ext", type=str, help="Output video format", default=".avi")
|
parser.add_argument("--output_ext", type=str, help="Output video format", default=".avi")
|
||||||
parser.add_argument("--input_ext", type=str, help="Input video format", default=".mp4")
|
parser.add_argument("--input_ext", type=str, help="Input video format", default=".mp4")
|
||||||
parser.add_argument("--downscale", type=float, help="Downscale input res. for memory", default=1)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
input_ext = args.input_ext
|
input_ext = args.input_ext
|
||||||
@@ -82,7 +81,7 @@ def make_image(img):
|
|||||||
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
||||||
return im
|
return im
|
||||||
|
|
||||||
def files_to_videoTensor(path, downscale=1.):
|
def files_to_videoTensor(path):
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
global in_files
|
global in_files
|
||||||
in_files_fixed = in_files
|
in_files_fixed = in_files
|
||||||
@@ -93,22 +92,18 @@ def files_to_videoTensor(path, downscale=1.):
|
|||||||
videoTensor = torch.stack(images)
|
videoTensor = torch.stack(images)
|
||||||
return videoTensor
|
return videoTensor
|
||||||
|
|
||||||
def video_transform(videoTensor, downscale=1):
|
def video_transform(videoTensor):
|
||||||
T, H, W = videoTensor.size(0), videoTensor.size(1), videoTensor.size(2)
|
T, H, W = videoTensor.size(0), videoTensor.size(1), videoTensor.size(2)
|
||||||
downscale = int(downscale * 8)
|
transforms = torchvision.transforms.Compose([ToTensorVideo()])
|
||||||
resizes = 8*(H//downscale), 8*(W//downscale)
|
|
||||||
transforms = torchvision.transforms.Compose([ToTensorVideo(), Resize(resizes)])
|
|
||||||
videoTensor = transforms(videoTensor)
|
videoTensor = transforms(videoTensor)
|
||||||
|
return videoTensor
|
||||||
print("Resizing to %dx%d"%(resizes[0], resizes[1]) )
|
|
||||||
return videoTensor, resizes
|
|
||||||
|
|
||||||
videoTensor = files_to_videoTensor(interp_input_path, args.downscale)
|
videoTensor = files_to_videoTensor(interp_input_path)
|
||||||
|
|
||||||
print(f"Video Tensor len: {len(videoTensor)}")
|
print(f"Video Tensor len: {len(videoTensor)}")
|
||||||
idxs = torch.Tensor(range(len(videoTensor))).type(torch.long).view(1, -1).unfold(1,size=nbr_frame,step=1).squeeze(0)
|
idxs = torch.Tensor(range(len(videoTensor))).type(torch.long).view(1, -1).unfold(1,size=nbr_frame,step=1).squeeze(0)
|
||||||
print(f"len(idxs): {len(idxs)}")
|
print(f"len(idxs): {len(idxs)}")
|
||||||
videoTensor, resizes = video_transform(videoTensor, args.downscale)
|
videoTensor = video_transform(videoTensor)
|
||||||
print("Video tensor shape is ", videoTensor.shape)
|
print("Video tensor shape is ", videoTensor.shape)
|
||||||
|
|
||||||
frames = torch.unbind(videoTensor, 1)
|
frames = torch.unbind(videoTensor, 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user