Merge pull request #65 from Mangio621/fix-cuda-req

Undo torch requirement change for compatibility
This commit is contained in:
kalomaze
2023-07-28 18:47:13 -05:00
committed by GitHub
5 changed files with 596 additions and 353 deletions

View File

@@ -22,6 +22,7 @@ DoFormant = False
Quefrency = 0.0 Quefrency = 0.0
Timbre = 0.0 Timbre = 0.0
def printt(strr): def printt(strr):
print(strr) print(strr)
f.write("%s\n" % strr) f.write("%s\n" % strr)

File diff suppressed because it is too large Load Diff

View File

@@ -11,22 +11,25 @@ import random
import csv import csv
platform_stft_mapping = { platform_stft_mapping = {
'linux': 'stftpitchshift', "linux": "stftpitchshift",
'darwin': 'stftpitchshift', "darwin": "stftpitchshift",
'win32': 'stftpitchshift.exe', "win32": "stftpitchshift.exe",
} }
stft = platform_stft_mapping.get(sys.platform) stft = platform_stft_mapping.get(sys.platform)
# praatEXE = join('.',os.path.abspath(os.getcwd()) + r"\Praat.exe") # praatEXE = join('.',os.path.abspath(os.getcwd()) + r"\Praat.exe")
def CSVutil(file, rw, type, *args): def CSVutil(file, rw, type, *args):
if type == 'formanting': if type == "formanting":
if rw == 'r': if rw == "r":
with open(file) as fileCSVread: with open(file) as fileCSVread:
csv_reader = list(csv.reader(fileCSVread)) csv_reader = list(csv.reader(fileCSVread))
return ( return (
csv_reader[0][0], csv_reader[0][1], csv_reader[0][2] (csv_reader[0][0], csv_reader[0][1], csv_reader[0][2])
) if csv_reader is not None else (lambda: exec('raise ValueError("No data")'))() if csv_reader is not None
else (lambda: exec('raise ValueError("No data")'))()
)
else: else:
if args: if args:
doformnt = args[0] doformnt = args[0]
@@ -34,18 +37,19 @@ def CSVutil(file, rw, type, *args):
doformnt = False doformnt = False
qfr = args[1] if len(args) > 1 else 1.0 qfr = args[1] if len(args) > 1 else 1.0
tmb = args[2] if len(args) > 2 else 1.0 tmb = args[2] if len(args) > 2 else 1.0
with open(file, rw, newline='') as fileCSVwrite: with open(file, rw, newline="") as fileCSVwrite:
csv_writer = csv.writer(fileCSVwrite, delimiter=',') csv_writer = csv.writer(fileCSVwrite, delimiter=",")
csv_writer.writerow([doformnt, qfr, tmb]) csv_writer.writerow([doformnt, qfr, tmb])
elif type == 'stop': elif type == "stop":
stop = args[0] if args else False stop = args[0] if args else False
with open(file, rw, newline='') as fileCSVwrite: with open(file, rw, newline="") as fileCSVwrite:
csv_writer = csv.writer(fileCSVwrite, delimiter=',') csv_writer = csv.writer(fileCSVwrite, delimiter=",")
csv_writer.writerow([stop]) csv_writer.writerow([stop])
def load_audio(file, sr, DoFormant, Quefrency, Timbre): def load_audio(file, sr, DoFormant, Quefrency, Timbre):
converted = False converted = False
DoFormant, Quefrency, Timbre = CSVutil('csvdb/formanting.csv', 'r', 'formanting') DoFormant, Quefrency, Timbre = CSVutil("csvdb/formanting.csv", "r", "formanting")
try: try:
# https://github.com/openai/whisper/blob/main/whisper/audio.py#L26 # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
# This launches a subprocess to decode audio while down-mixing and resampling as necessary. # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
@@ -54,66 +58,77 @@ def load_audio(file, sr, DoFormant, Quefrency, Timbre):
file.strip(" ").strip('"').strip("\n").strip('"').strip(" ") file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
) # 防止小白拷路径头尾带了空格和"和回车 ) # 防止小白拷路径头尾带了空格和"和回车
file_formanted = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ") file_formanted = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
#print(f"dofor={bool(DoFormant)} timbr={Timbre} quef={Quefrency}\n") # print(f"dofor={bool(DoFormant)} timbr={Timbre} quef={Quefrency}\n")
if (lambda DoFormant: True if DoFormant.lower() == 'true' else (False if DoFormant.lower() == 'false' else DoFormant))(DoFormant): if (
numerator = round(random.uniform(1,4), 4) lambda DoFormant: True
if DoFormant.lower() == "true"
else (False if DoFormant.lower() == "false" else DoFormant)
)(DoFormant):
numerator = round(random.uniform(1, 4), 4)
# os.system(f"stftpitchshift -i {file} -q {Quefrency} -t {Timbre} -o {file_formanted}") # os.system(f"stftpitchshift -i {file} -q {Quefrency} -t {Timbre} -o {file_formanted}")
# print('stftpitchshift -i "%s" -p 1.0 --rms -w 128 -v 8 -q %s -t %s -o "%s"' % (file, Quefrency, Timbre, file_formanted)) # print('stftpitchshift -i "%s" -p 1.0 --rms -w 128 -v 8 -q %s -t %s -o "%s"' % (file, Quefrency, Timbre, file_formanted))
if not file.endswith(".wav"): if not file.endswith(".wav"):
if not os.path.isfile(f"{file_formanted}.wav"): if not os.path.isfile(f"{file_formanted}.wav"):
converted = True converted = True
#print(f"\nfile = {file}\n") # print(f"\nfile = {file}\n")
#print(f"\nfile_formanted = {file_formanted}\n") # print(f"\nfile_formanted = {file_formanted}\n")
converting = ( converting = (
ffmpeg.input(file_formanted, threads = 0) ffmpeg.input(file_formanted, threads=0)
.output(f"{file_formanted}.wav") .output(f"{file_formanted}.wav")
.run( .run(
cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True cmd=["ffmpeg", "-nostdin"],
capture_stdout=True,
capture_stderr=True,
) )
) )
else: else:
pass pass
file_formanted = (
f"{file_formanted}.wav"
file_formanted = f"{file_formanted}.wav" if not file_formanted.endswith(".wav") else file_formanted if not file_formanted.endswith(".wav")
else file_formanted
)
print(f" · Formanting {file_formanted}...\n") print(f" · Formanting {file_formanted}...\n")
os.system( os.system(
'%s -i "%s" -q "%s" -t "%s" -o "%sFORMANTED_%s.wav"' '%s -i "%s" -q "%s" -t "%s" -o "%sFORMANTED_%s.wav"'
% (stft, file_formanted, Quefrency, Timbre, file_formanted, str(numerator)) % (
stft,
file_formanted,
Quefrency,
Timbre,
file_formanted,
str(numerator),
)
) )
print(f" · Formanted {file_formanted}!\n") print(f" · Formanted {file_formanted}!\n")
# filepraat = (os.path.abspath(os.getcwd()) + '\\' + file).replace('/','\\') # filepraat = (os.path.abspath(os.getcwd()) + '\\' + file).replace('/','\\')
# file_formantedpraat = ('"' + os.path.abspath(os.getcwd()) + '/' + 'formanted'.join(file_formanted) + '"').replace('/','\\') # file_formantedpraat = ('"' + os.path.abspath(os.getcwd()) + '/' + 'formanted'.join(file_formanted) + '"').replace('/','\\')
#print("%sFORMANTED_%s.wav" % (file_formanted, str(numerator))) # print("%sFORMANTED_%s.wav" % (file_formanted, str(numerator)))
out, _ = ( out, _ = (
ffmpeg.input("%sFORMANTED_%s.wav" % (file_formanted, str(numerator)), threads=0) ffmpeg.input(
"%sFORMANTED_%s.wav" % (file_formanted, str(numerator)), threads=0
)
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
.run( .run(
cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True
) )
) )
try: os.remove("%sFORMANTED_%s.wav" % (file_formanted, str(numerator))) try:
except Exception: pass; print("couldn't remove formanted type of file") os.remove("%sFORMANTED_%s.wav" % (file_formanted, str(numerator)))
except Exception:
pass
print("couldn't remove formanted type of file")
else: else:
out, _ = ( out, _ = (
ffmpeg.input(file, threads=0) ffmpeg.input(file, threads=0)
@@ -124,10 +139,13 @@ def load_audio(file, sr, DoFormant, Quefrency, Timbre):
) )
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load audio: {e}") raise RuntimeError(f"Failed to load audio: {e}")
if converted: if converted:
try: os.remove(file_formanted) try:
except Exception: pass; print("couldn't remove converted type of file") os.remove(file_formanted)
except Exception:
pass
print("couldn't remove converted type of file")
converted = False converted = False
return np.frombuffer(out, np.float32).flatten() return np.frombuffer(out, np.float32).flatten()

View File

@@ -146,7 +146,7 @@ tensorboard-plugin-wit==1.8.1
tensorboardX==2.6.1 tensorboardX==2.6.1
threadpoolctl==3.1.0 threadpoolctl==3.1.0
toolz==0.12.0 toolz==0.12.0
torch==2.0.1 torch @ https://download.pytorch.org/whl/cu118/torch-2.0.0%2Bcu118-cp39-cp39-win_amd64.whl
torchaudio==2.0.1 torchaudio==2.0.1
torchcrepe==0.0.19 torchcrepe==0.0.19
torchgen==0.0.1 torchgen==0.0.1

View File

@@ -256,7 +256,6 @@ def run(rank, n_gpus, hps):
def train_and_evaluate( def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache
): ):
net_g, net_d = nets net_g, net_d = nets
optim_g, optim_d = optims optim_g, optim_d = optims
train_loader, eval_loader = loaders train_loader, eval_loader = loaders
@@ -353,7 +352,7 @@ def train_and_evaluate(
# Run steps # Run steps
epoch_recorder = EpochRecorder() epoch_recorder = EpochRecorder()
for batch_idx, info in data_iterator: for batch_idx, info in data_iterator:
# Data # Data
## Unpack ## Unpack
@@ -572,15 +571,23 @@ def train_and_evaluate(
), ),
) )
) )
try: try:
with open('csvdb/stop.csv') as CSVStop: with open("csvdb/stop.csv") as CSVStop:
csv_reader = list(csv.reader(CSVStop)) csv_reader = list(csv.reader(CSVStop))
stopbtn = csv_reader[0][0] if csv_reader is not None else (lambda: exec('raise ValueError("No data")'))() stopbtn = (
stopbtn = (lambda stopbtn: True if stopbtn.lower() == 'true' else (False if stopbtn.lower() == 'false' else stopbtn))(stopbtn) csv_reader[0][0]
if csv_reader is not None
else (lambda: exec('raise ValueError("No data")'))()
)
stopbtn = (
lambda stopbtn: True
if stopbtn.lower() == "true"
else (False if stopbtn.lower() == "false" else stopbtn)
)(stopbtn)
except (ValueError, TypeError, IndexError): except (ValueError, TypeError, IndexError):
stopbtn = False stopbtn = False
if stopbtn: if stopbtn:
logger.info("Stop Button was pressed. The program is closed.") logger.info("Stop Button was pressed. The program is closed.")
if hasattr(net_g, "module"): if hasattr(net_g, "module"):
@@ -602,9 +609,9 @@ def train_and_evaluate(
) )
) )
sleep(1) sleep(1)
with open('csvdb/stop.csv', 'w+', newline='') as STOPCSVwrite: with open("csvdb/stop.csv", "w+", newline="") as STOPCSVwrite:
csv_writer = csv.writer(STOPCSVwrite, delimiter=',') csv_writer = csv.writer(STOPCSVwrite, delimiter=",")
csv_writer.writerow(['False']) csv_writer.writerow(["False"])
os._exit(2333333) os._exit(2333333)
if rank == 0: if rank == 0:
@@ -625,11 +632,10 @@ def train_and_evaluate(
) )
) )
sleep(1) sleep(1)
with open('csvdb/stop.csv', 'w+', newline='') as STOPCSVwrite: with open("csvdb/stop.csv", "w+", newline="") as STOPCSVwrite:
csv_writer = csv.writer(STOPCSVwrite, delimiter=',') csv_writer = csv.writer(STOPCSVwrite, delimiter=",")
csv_writer.writerow(['False']) csv_writer.writerow(["False"])
os._exit(2333333) os._exit(2333333)
if __name__ == "__main__": if __name__ == "__main__":