Update it properly

This commit is contained in:
kalomaze
2023-07-22 22:26:40 -05:00
parent 6d12ba6015
commit 45293e07a6

View File

@@ -1,37 +1,62 @@
import argparse
import sys
import torch
import json
from multiprocessing import cpu_count
global IsFp16Support
IsFp16Support = False
global usefp16
usefp16 = False
def use_fp32_config():
usefp16 = False
device_capability = 0
if torch.cuda.is_available():
device = torch.device("cuda:0") # Assuming you have only one GPU (index 0).
if torch.cuda.get_device_capability(device)[0] >= 7:
print("Fp16 supported!")
IsFp16Support = True
device_capability = torch.cuda.get_device_capability(device)[0]
if device_capability >= 7:
usefp16 = True
for config_file in ["32k.json", "40k.json", "48k.json"]:
with open(f"configs/{config_file}", "r+") as f:
strr = f.read().replace("false", "true")
f.write(strr)
with open("trainset_preprocess_pipeline_print.py", "r+") as f:
strr = f.read().replace("3.0", "3.7")
with open(f"configs/{config_file}", "r") as d:
data = json.load(d)
if "train" in data and "fp16_run" in data["train"]:
data["train"]["fp16_run"] = True
with open(f"configs/{config_file}", "w") as d:
json.dump(data, d, indent=4)
print(f"Set fp16_run to true in {config_file}")
with open("trainset_preprocess_pipeline_print.py", "r", encoding="utf-8") as f:
strr = f.read()
strr = strr.replace("3.0", "3.7")
with open("trainset_preprocess_pipeline_print.py", "w", encoding="utf-8") as f:
f.write(strr)
else:
print("fp16 unavailable! Using fp32.....")
for config_file in ["32k.json", "40k.json", "48k.json"]:
with open(f"configs/{config_file}", "r+") as f:
strr = f.read().replace("true", "false")
f.write(strr)
with open("trainset_preprocess_pipeline_print.py", "r+") as f:
strr = f.read().replace("3.7", "3.0")
with open(f"configs/{config_file}", "r") as f:
data = json.load(f)
if "train" in data and "fp16_run" in data["train"]:
data["train"]["fp16_run"] = False
with open(f"configs/{config_file}", "w") as d:
json.dump(data, d, indent=4)
print(f"Set fp16_run to false in {config_file}")
with open("trainset_preprocess_pipeline_print.py", "r", encoding="utf-8") as f:
strr = f.read()
strr = strr.replace("3.7", "3.0")
with open("trainset_preprocess_pipeline_print.py", "w", encoding="utf-8") as f:
f.write(strr)
else:
print("CUDA is not available. Make sure you have an NVIDIA GPU and CUDA installed.")
return (IsFp16Support, torch.cuda.get_device_capability(device)[0])
return (usefp16, device_capability)
class Config:
def __init__(self):
@@ -114,7 +139,7 @@ class Config:
self.is_half = False
else:
print("Found GPU", self.gpu_name)
print(use_fp32_config())
use_fp32_config()
self.gpu_mem = int(
torch.cuda.get_device_properties(i_device).total_memory
/ 1024