mirror of
https://github.com/Mangio621/Mangio-RVC-Fork.git
synced 2026-02-24 03:49:51 +01:00
Redo the PR on a properly updated branch
This commit is contained in:
73
config.py
73
config.py
@@ -1,20 +1,62 @@
|
||||
import argparse
|
||||
import sys
|
||||
import torch
|
||||
import json
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
global usefp16
|
||||
usefp16 = False
|
||||
|
||||
def use_fp32_config():
|
||||
for config_file in ["32k.json", "40k.json", "48k.json"]:
|
||||
with open(f"configs/{config_file}", "r") as f:
|
||||
strr = f.read().replace("true", "false")
|
||||
with open(f"configs/{config_file}", "w") as f:
|
||||
f.write(strr)
|
||||
with open("trainset_preprocess_pipeline_print.py", "r") as f:
|
||||
strr = f.read().replace("3.7", "3.0")
|
||||
with open("trainset_preprocess_pipeline_print.py", "w") as f:
|
||||
f.write(strr)
|
||||
usefp16 = False
|
||||
device_capability = 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0") # Assuming you have only one GPU (index 0).
|
||||
device_capability = torch.cuda.get_device_capability(device)[0]
|
||||
if device_capability >= 7:
|
||||
usefp16 = True
|
||||
for config_file in ["32k.json", "40k.json", "48k.json"]:
|
||||
with open(f"configs/{config_file}", "r") as d:
|
||||
data = json.load(d)
|
||||
|
||||
if "train" in data and "fp16_run" in data["train"]:
|
||||
data["train"]["fp16_run"] = True
|
||||
|
||||
with open(f"configs/{config_file}", "w") as d:
|
||||
json.dump(data, d, indent=4)
|
||||
|
||||
print(f"Set fp16_run to true in {config_file}")
|
||||
|
||||
with open("trainset_preprocess_pipeline_print.py", "r", encoding="utf-8") as f:
|
||||
strr = f.read()
|
||||
|
||||
strr = strr.replace("3.0", "3.7")
|
||||
|
||||
with open("trainset_preprocess_pipeline_print.py", "w", encoding="utf-8") as f:
|
||||
f.write(strr)
|
||||
else:
|
||||
for config_file in ["32k.json", "40k.json", "48k.json"]:
|
||||
with open(f"configs/{config_file}", "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if "train" in data and "fp16_run" in data["train"]:
|
||||
data["train"]["fp16_run"] = False
|
||||
|
||||
with open(f"configs/{config_file}", "w") as d:
|
||||
json.dump(data, d, indent=4)
|
||||
|
||||
print(f"Set fp16_run to false in {config_file}")
|
||||
|
||||
with open("trainset_preprocess_pipeline_print.py", "r", encoding="utf-8") as f:
|
||||
strr = f.read()
|
||||
|
||||
strr = strr.replace("3.7", "3.0")
|
||||
|
||||
with open("trainset_preprocess_pipeline_print.py", "w", encoding="utf-8") as f:
|
||||
f.write(strr)
|
||||
else:
|
||||
print("CUDA is not available. Make sure you have an NVIDIA GPU and CUDA installed.")
|
||||
return (usefp16, device_capability)
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
@@ -29,7 +71,10 @@ class Config:
|
||||
self.iscolab,
|
||||
self.noparallel,
|
||||
self.noautoopen,
|
||||
self.paperspace,
|
||||
self.is_cli,
|
||||
) = self.arg_parse()
|
||||
|
||||
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
|
||||
|
||||
@staticmethod
|
||||
@@ -47,6 +92,12 @@ class Config:
|
||||
action="store_true",
|
||||
help="Do not open in browser automatically",
|
||||
)
|
||||
parser.add_argument( # Fork Feature. Paperspace integration for web UI
|
||||
"--paperspace", action="store_true", help="Note that this argument just shares a gradio link for the web UI. Thus can be used on other non-local CLI systems."
|
||||
)
|
||||
parser.add_argument( # Fork Feature. Embed a CLI into the infer-web.py
|
||||
"--is_cli", action="store_true", help="Use the CLI instead of setting up a gradio UI. This flag will launch an RVC text interface where you can execute functions from infer-web.py!"
|
||||
)
|
||||
cmd_opts = parser.parse_args()
|
||||
|
||||
cmd_opts.port = cmd_opts.port if 0 <= cmd_opts.port <= 65535 else 7865
|
||||
@@ -57,6 +108,8 @@ class Config:
|
||||
cmd_opts.colab,
|
||||
cmd_opts.noparallel,
|
||||
cmd_opts.noautoopen,
|
||||
cmd_opts.paperspace,
|
||||
cmd_opts.is_cli,
|
||||
)
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
|
||||
@@ -84,9 +137,9 @@ class Config:
|
||||
):
|
||||
print("Found GPU", self.gpu_name, ", force to fp32")
|
||||
self.is_half = False
|
||||
use_fp32_config()
|
||||
else:
|
||||
print("Found GPU", self.gpu_name)
|
||||
use_fp32_config()
|
||||
self.gpu_mem = int(
|
||||
torch.cuda.get_device_properties(i_device).total_memory
|
||||
/ 1024
|
||||
|
||||
Reference in New Issue
Block a user