diff --git a/model/RIFE.py b/model/RIFE.py index 6abf090..9af628a 100644 --- a/model/RIFE.py +++ b/model/RIFE.py @@ -13,7 +13,7 @@ from model.loss import * from model.laplacian import * from model.refine import * -device = torch.device("cuda") +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class Model: def __init__(self, local_rank=-1, arbitrary=False):