diff --git a/inference_img.py b/inference_img.py index 3d44e2d..1e6ea25 100644 --- a/inference_img.py +++ b/inference_img.py @@ -8,8 +8,8 @@ import warnings warnings.filterwarnings("ignore") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_grad_enabled(False) if torch.cuda.is_available(): - torch.set_grad_enabled(False) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True