diff --git a/benchmark/pytorch_msssim/__init__.py b/benchmark/pytorch_msssim/__init__.py index 118d265..a4d3032 100644 --- a/benchmark/pytorch_msssim/__init__.py +++ b/benchmark/pytorch_msssim/__init__.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from math import exp import numpy as np +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) @@ -11,7 +12,7 @@ def gaussian(window_size, sigma): def create_window(window_size, channel=1): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) - _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).cuda() + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device) window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() return window @@ -19,7 +20,7 @@ def create_window_3d(window_size, channel=1): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()) _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) - window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().cuda() + window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) return window