Execute linspace in warplayer.py on GPU when available

This commit is contained in:
Heylon
2021-02-20 17:53:22 +10:00
parent 556c6ca990
commit 34b91a13f2

View File

@@ -8,9 +8,9 @@ backwarp_tenGrid = {}
def warp(tenInput, tenFlow):
k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3],device=device).view(
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2],device=device).view(
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
backwarp_tenGrid[k] = torch.cat(
[tenHorizontal, tenVertical], 1).to(device)