Fix device mismatch for text to video (#207)

This commit is contained in:
hysts
2023-03-19 21:11:04 +09:00
committed by GitHub
parent b9bfbb70bc
commit fe673953b1

View File

@@ -10,6 +10,7 @@ __all__ = ['GaussianDiffusion', 'beta_schedule']
def _i(tensor, t, x):
r"""Index tensor using t and format the output according to x.
"""
tensor = tensor.to(x.device)
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
return tensor[t].view(shape).to(x)