mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 08:27:45 +01:00
Execute linspace in warplayer.py on GPU when available
This commit is contained in:
@@ -8,9 +8,9 @@ backwarp_tenGrid = {}
|
||||
def warp(tenInput, tenFlow):
|
||||
k = (str(tenFlow.device), str(tenFlow.size()))
|
||||
if k not in backwarp_tenGrid:
|
||||
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(
|
||||
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3],device=device).view(
|
||||
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
||||
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(
|
||||
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2],device=device).view(
|
||||
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
||||
backwarp_tenGrid[k] = torch.cat(
|
||||
[tenHorizontal, tenVertical], 1).to(device)
|
||||
|
||||
Reference in New Issue
Block a user