2023-05-02 12:07:03 +00:00
import argparse
import torch
from multiprocessing import cpu_count
class Config :
2023-05-11 03:32:09 +10:00
def __init__ ( self , is_gui = True ) :
2023-05-02 12:07:03 +00:00
self . device = " cuda:0 "
self . is_half = True
self . n_cpu = 0
self . gpu_name = None
self . gpu_mem = None
2023-05-11 03:32:09 +10:00
if ( is_gui ) :
(
self . python_cmd ,
self . listen_port ,
self . iscolab ,
self . noparallel ,
self . noautoopen ,
self . paperspace ,
) = self . arg_parse ( )
2023-05-03 19:09:37 +10:00
2023-05-02 12:07:03 +00:00
self . x_pad , self . x_query , self . x_center , self . x_max = self . device_config ( )
2023-05-06 01:13:27 +09:00
@staticmethod
def arg_parse ( ) - > tuple :
2023-05-02 12:07:03 +00:00
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --port " , type = int , default = 7865 , help = " Listen port " )
parser . add_argument (
" --pycmd " , type = str , default = " python " , help = " Python command "
)
parser . add_argument ( " --colab " , action = " store_true " , help = " Launch in colab " )
parser . add_argument (
" --noparallel " , action = " store_true " , help = " Disable parallel processing "
)
parser . add_argument (
" --noautoopen " ,
action = " store_true " ,
help = " Do not open in browser automatically " ,
)
2023-05-04 04:59:10 +10:00
parser . add_argument ( # Fork Feature. Paperspace integration for web UI
" --paperspace " , action = " store_true " , help = " Note that this argument just shares a gradio link for the web UI. Thus can be used on other non-local CLI systems. "
)
2023-05-02 12:07:03 +00:00
cmd_opts = parser . parse_args ( )
cmd_opts . port = cmd_opts . port if 0 < = cmd_opts . port < = 65535 else 7865
return (
cmd_opts . pycmd ,
cmd_opts . port ,
cmd_opts . colab ,
cmd_opts . noparallel ,
cmd_opts . noautoopen ,
2023-05-04 04:59:10 +10:00
cmd_opts . paperspace ,
2023-05-02 12:07:03 +00:00
)
def device_config ( self ) - > tuple :
if torch . cuda . is_available ( ) :
i_device = int ( self . device . split ( " : " ) [ - 1 ] )
self . gpu_name = torch . cuda . get_device_name ( i_device )
if (
2023-05-02 12:31:05 +00:00
( " 16 " in self . gpu_name and " V100 " not in self . gpu_name . upper ( ) )
2023-05-02 12:07:03 +00:00
or " P40 " in self . gpu_name . upper ( )
2023-05-04 22:22:46 +08:00
or " 1060 " in self . gpu_name
2023-05-02 12:07:03 +00:00
or " 1070 " in self . gpu_name
or " 1080 " in self . gpu_name
) :
2023-05-04 22:22:46 +08:00
print ( " 16系/10系显卡和P40强制单精度 " )
2023-05-02 12:07:03 +00:00
self . is_half = False
for config_file in [ " 32k.json " , " 40k.json " , " 48k.json " ] :
with open ( f " configs/ { config_file } " , " r " ) as f :
strr = f . read ( ) . replace ( " true " , " false " )
with open ( f " configs/ { config_file } " , " w " ) as f :
f . write ( strr )
with open ( " trainset_preprocess_pipeline_print.py " , " r " ) as f :
strr = f . read ( ) . replace ( " 3.7 " , " 3.0 " )
with open ( " trainset_preprocess_pipeline_print.py " , " w " ) as f :
f . write ( strr )
else :
self . gpu_name = None
self . gpu_mem = int (
2023-05-02 20:22:08 +08:00
torch . cuda . get_device_properties ( i_device ) . total_memory
/ 1024
/ 1024
/ 1024
+ 0.4
)
2023-05-02 12:07:03 +00:00
if self . gpu_mem < = 4 :
with open ( " trainset_preprocess_pipeline_print.py " , " r " ) as f :
strr = f . read ( ) . replace ( " 3.7 " , " 3.0 " )
with open ( " trainset_preprocess_pipeline_print.py " , " w " ) as f :
f . write ( strr )
elif torch . backends . mps . is_available ( ) :
print ( " 没有发现支持的N卡, 使用MPS进行推理 " )
self . device = " mps "
2023-05-10 22:17:13 +09:00
self . is_half = False
2023-05-02 12:07:03 +00:00
else :
print ( " 没有发现支持的N卡, 使用CPU进行推理 " )
self . device = " cpu "
2023-05-10 13:19:09 +00:00
self . is_half = False
2023-05-02 20:22:08 +08:00
2023-05-02 12:07:03 +00:00
if self . n_cpu == 0 :
self . n_cpu = cpu_count ( )
if self . is_half :
# 6G显存配置
x_pad = 3
x_query = 10
x_center = 60
x_max = 65
else :
# 5G显存配置
x_pad = 1
x_query = 6
x_center = 38
x_max = 41
if self . gpu_mem != None and self . gpu_mem < = 4 :
x_pad = 1
x_query = 5
x_center = 30
x_max = 32
return x_pad , x_query , x_center , x_max