mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 08:27:45 +01:00
Support cpu ssim
This commit is contained in:
@@ -3,6 +3,7 @@ import torch.nn.functional as F
|
||||
from math import exp
|
||||
import numpy as np
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
||||
@@ -11,7 +12,7 @@ def gaussian(window_size, sigma):
|
||||
|
||||
def create_window(window_size, channel=1):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).cuda()
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
|
||||
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
||||
return window
|
||||
|
||||
@@ -19,7 +20,7 @@ def create_window_3d(window_size, channel=1):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t())
|
||||
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
|
||||
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().cuda()
|
||||
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
|
||||
return window
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user