From 45293e07a6acb207bf66d2bdb2040e5aa319bd2e Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 22 Jul 2023 22:26:40 -0500 Subject: [PATCH] Update it properly --- config.py | 63 ++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/config.py b/config.py index 8342fb8..a3365e9 100644 --- a/config.py +++ b/config.py @@ -1,37 +1,62 @@ import argparse import sys import torch +import json from multiprocessing import cpu_count -global IsFp16Support -IsFp16Support = False +global usefp16 +usefp16 = False def use_fp32_config(): + usefp16 = False + device_capability = 0 if torch.cuda.is_available(): device = torch.device("cuda:0") # Assuming you have only one GPU (index 0). - if torch.cuda.get_device_capability(device)[0] >= 7: - print("Fp16 supported!") - IsFp16Support = True + device_capability = torch.cuda.get_device_capability(device)[0] + if device_capability >= 7: + usefp16 = True for config_file in ["32k.json", "40k.json", "48k.json"]: - with open(f"configs/{config_file}", "r+") as f: - strr = f.read().replace("false", "true") - f.write(strr) - with open("trainset_preprocess_pipeline_print.py", "r+") as f: - strr = f.read().replace("3.0", "3.7") + with open(f"configs/{config_file}", "r") as d: + data = json.load(d) + + if "train" in data and "fp16_run" in data["train"]: + data["train"]["fp16_run"] = True + + with open(f"configs/{config_file}", "w") as d: + json.dump(data, d, indent=4) + + print(f"Set fp16_run to true in {config_file}") + + with open("trainset_preprocess_pipeline_print.py", "r", encoding="utf-8") as f: + strr = f.read() + + strr = strr.replace("3.0", "3.7") + + with open("trainset_preprocess_pipeline_print.py", "w", encoding="utf-8") as f: f.write(strr) else: - print("fp16 unavailable! Using fp32.....") for config_file in ["32k.json", "40k.json", "48k.json"]: - with open(f"configs/{config_file}", "r+") as f: - strr = f.read().replace("true", "false") - f.write(strr) - with open("trainset_preprocess_pipeline_print.py", "r+") as f: - strr = f.read().replace("3.7", "3.0") + with open(f"configs/{config_file}", "r") as f: + data = json.load(f) + + if "train" in data and "fp16_run" in data["train"]: + data["train"]["fp16_run"] = False + + with open(f"configs/{config_file}", "w") as d: + json.dump(data, d, indent=4) + + print(f"Set fp16_run to false in {config_file}") + + with open("trainset_preprocess_pipeline_print.py", "r", encoding="utf-8") as f: + strr = f.read() + + strr = strr.replace("3.7", "3.0") + + with open("trainset_preprocess_pipeline_print.py", "w", encoding="utf-8") as f: f.write(strr) else: print("CUDA is not available. Make sure you have an NVIDIA GPU and CUDA installed.") - return (IsFp16Support, torch.cuda.get_device_capability(device)[0]) - + return (usefp16, device_capability) class Config: def __init__(self): @@ -114,7 +139,7 @@ class Config: self.is_half = False else: print("Found GPU", self.gpu_name) - print(use_fp32_config()) + use_fp32_config() self.gpu_mem = int( torch.cuda.get_device_properties(i_device).total_memory / 1024