mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-20 21:59:35 +01:00
129 lines
5.1 KiB
Python
129 lines
5.1 KiB
Python
import argparse
|
|
import os
|
|
import yaml
|
|
|
|
global_print_hparams = True
|
|
hparams = {}
|
|
|
|
|
|
class Args:
|
|
def __init__(self, **kwargs):
|
|
for k, v in kwargs.items():
|
|
self.__setattr__(k, v)
|
|
|
|
|
|
def override_config(old_config: dict, new_config: dict):
|
|
for k, v in new_config.items():
|
|
if isinstance(v, dict) and k in old_config:
|
|
override_config(old_config[k], new_config[k])
|
|
else:
|
|
old_config[k] = v
|
|
|
|
|
|
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
|
|
if config == '' and exp_name == '':
|
|
parser = argparse.ArgumentParser(description='')
|
|
parser.add_argument('--config', type=str, default='',
|
|
help='location of the data corpus')
|
|
parser.add_argument('--exp_name', type=str, default='', help='exp_name')
|
|
parser.add_argument('-hp', '--hparams', type=str, default='',
|
|
help='location of the data corpus')
|
|
parser.add_argument('--infer', action='store_true', help='infer')
|
|
parser.add_argument('--validate', action='store_true', help='validate')
|
|
parser.add_argument('--reset', action='store_true', help='reset hparams')
|
|
parser.add_argument('--remove', action='store_true', help='remove old ckpt')
|
|
parser.add_argument('--debug', action='store_true', help='debug')
|
|
args, unknown = parser.parse_known_args()
|
|
print("| Unknow hparams: ", unknown)
|
|
else:
|
|
args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
|
|
infer=False, validate=False, reset=False, debug=False, remove=False)
|
|
global hparams
|
|
assert args.config != '' or args.exp_name != ''
|
|
if args.config != '':
|
|
assert os.path.exists(args.config)
|
|
|
|
config_chains = []
|
|
loaded_config = set()
|
|
|
|
def load_config(config_fn):
|
|
# deep first inheritance and avoid the second visit of one node
|
|
if not os.path.exists(config_fn):
|
|
return {}
|
|
with open(config_fn) as f:
|
|
hparams_ = yaml.safe_load(f)
|
|
loaded_config.add(config_fn)
|
|
if 'base_config' in hparams_:
|
|
ret_hparams = {}
|
|
if not isinstance(hparams_['base_config'], list):
|
|
hparams_['base_config'] = [hparams_['base_config']]
|
|
for c in hparams_['base_config']:
|
|
if c.startswith('.'):
|
|
c = f'{os.path.dirname(config_fn)}/{c}'
|
|
c = os.path.normpath(c)
|
|
if c not in loaded_config:
|
|
override_config(ret_hparams, load_config(c))
|
|
override_config(ret_hparams, hparams_)
|
|
else:
|
|
ret_hparams = hparams_
|
|
config_chains.append(config_fn)
|
|
return ret_hparams
|
|
|
|
saved_hparams = {}
|
|
args_work_dir = ''
|
|
if args.exp_name != '':
|
|
args_work_dir = f'{args.exp_name}' # modified
|
|
ckpt_config_path = f'{args_work_dir}/config.yaml'
|
|
if os.path.exists(ckpt_config_path):
|
|
with open(ckpt_config_path) as f:
|
|
saved_hparams_ = yaml.safe_load(f)
|
|
if saved_hparams_ is not None:
|
|
saved_hparams.update(saved_hparams_)
|
|
hparams_ = {}
|
|
if args.config != '':
|
|
hparams_.update(load_config(args.config))
|
|
if not args.reset:
|
|
hparams_.update(saved_hparams)
|
|
hparams_['work_dir'] = args_work_dir
|
|
|
|
# Support config overriding in command line. Support list type config overriding.
|
|
# Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
|
|
if args.hparams != "":
|
|
for new_hparam in args.hparams.split(","):
|
|
k, v = new_hparam.split("=")
|
|
v = v.strip("\'\" ")
|
|
config_node = hparams_
|
|
for k_ in k.split(".")[:-1]:
|
|
config_node = config_node[k_]
|
|
k = k.split(".")[-1]
|
|
if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
|
|
if type(config_node[k]) == list:
|
|
v = v.replace(" ", ",")
|
|
config_node[k] = eval(v)
|
|
else:
|
|
config_node[k] = type(config_node[k])(v)
|
|
if args_work_dir != '' and args.remove:
|
|
answer = input("REMOVE old checkpoint? Y/N [Default: N]: ")
|
|
if answer.lower() == "y":
|
|
remove_file(args_work_dir)
|
|
if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
|
|
os.makedirs(hparams_['work_dir'], exist_ok=True)
|
|
with open(ckpt_config_path, 'w') as f:
|
|
yaml.safe_dump(hparams_, f)
|
|
|
|
hparams_['infer'] = args.infer
|
|
hparams_['debug'] = args.debug
|
|
hparams_['validate'] = args.validate
|
|
hparams_['exp_name'] = args.exp_name
|
|
global global_print_hparams
|
|
if global_hparams:
|
|
hparams.clear()
|
|
hparams.update(hparams_)
|
|
if print_hparams and global_print_hparams and global_hparams:
|
|
print('| Hparams chains: ', config_chains)
|
|
print('| Hparams: ')
|
|
for i, (k, v) in enumerate(sorted(hparams_.items())):
|
|
print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
|
|
print("")
|
|
global_print_hparams = False
|
|
return hparams_ |