Better fp16 detection

This commit is contained in:
kalomaze
2023-07-22 22:19:22 -05:00
parent add253b476
commit 6d12ba6015

View File

@@ -3,17 +3,34 @@ import sys
import torch
from multiprocessing import cpu_count
global IsFp16Support
IsFp16Support = False
def use_fp32_config():
for config_file in ["32k.json", "40k.json", "48k.json"]:
with open(f"configs/{config_file}", "r") as f:
strr = f.read().replace("true", "false")
with open(f"configs/{config_file}", "w") as f:
f.write(strr)
with open("trainset_preprocess_pipeline_print.py", "r") as f:
strr = f.read().replace("3.7", "3.0")
with open("trainset_preprocess_pipeline_print.py", "w") as f:
f.write(strr)
if torch.cuda.is_available():
device = torch.device("cuda:0") # Assuming you have only one GPU (index 0).
if torch.cuda.get_device_capability(device)[0] >= 7:
print("Fp16 supported!")
IsFp16Support = True
for config_file in ["32k.json", "40k.json", "48k.json"]:
with open(f"configs/{config_file}", "r+") as f:
strr = f.read().replace("false", "true")
f.write(strr)
with open("trainset_preprocess_pipeline_print.py", "r+") as f:
strr = f.read().replace("3.0", "3.7")
f.write(strr)
else:
print("fp16 unavailable! Using fp32.....")
for config_file in ["32k.json", "40k.json", "48k.json"]:
with open(f"configs/{config_file}", "r+") as f:
strr = f.read().replace("true", "false")
f.write(strr)
with open("trainset_preprocess_pipeline_print.py", "r+") as f:
strr = f.read().replace("3.7", "3.0")
f.write(strr)
else:
print("CUDA is not available. Make sure you have an NVIDIA GPU and CUDA installed.")
return (IsFp16Support, torch.cuda.get_device_capability(device)[0])
class Config:
@@ -29,7 +46,10 @@ class Config:
self.iscolab,
self.noparallel,
self.noautoopen,
self.paperspace,
self.is_cli,
) = self.arg_parse()
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
@staticmethod
@@ -47,6 +67,12 @@ class Config:
action="store_true",
help="Do not open in browser automatically",
)
parser.add_argument( # Fork Feature. Paperspace integration for web UI
"--paperspace", action="store_true", help="Note that this argument just shares a gradio link for the web UI. Thus can be used on other non-local CLI systems."
)
parser.add_argument( # Fork Feature. Embed a CLI into the infer-web.py
"--is_cli", action="store_true", help="Use the CLI instead of setting up a gradio UI. This flag will launch an RVC text interface where you can execute functions from infer-web.py!"
)
cmd_opts = parser.parse_args()
cmd_opts.port = cmd_opts.port if 0 <= cmd_opts.port <= 65535 else 7865
@@ -57,6 +83,8 @@ class Config:
cmd_opts.colab,
cmd_opts.noparallel,
cmd_opts.noautoopen,
cmd_opts.paperspace,
cmd_opts.is_cli,
)
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
@@ -84,9 +112,9 @@ class Config:
):
print("Found GPU", self.gpu_name, ", force to fp32")
self.is_half = False
use_fp32_config()
else:
print("Found GPU", self.gpu_name)
print(use_fp32_config())
self.gpu_mem = int(
torch.cuda.get_device_properties(i_device).total_memory
/ 1024