Support cpu ssim

This commit is contained in:
hzwer
2021-01-18 17:41:27 +08:00
committed by GitHub
parent 1dc2dba7d1
commit e515b1d364

View File

@@ -3,6 +3,7 @@ import torch.nn.functional as F
from math import exp from math import exp
import numpy as np import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def gaussian(window_size, sigma): def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
@@ -11,7 +12,7 @@ def gaussian(window_size, sigma):
def create_window(window_size, channel=1): def create_window(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1) _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).cuda() _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window return window
@@ -19,7 +20,7 @@ def create_window_3d(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1) _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()) _2D_window = _1D_window.mm(_1D_window.t())
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().cuda() window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
return window return window