diff --git a/tests/__init__.py b/tests/__init__.py index f1445c92..c7930ef9 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,16 @@ import os +from TTS.utils.generic_utils import get_cuda + + +def get_device_id(): + use_cuda, _ = get_cuda() + if use_cuda: + GPU_ID = "0" + else: + GPU_ID = "" + return GPU_ID + def get_tests_path(): """Returns the path to the test directory."""