mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-12-21 14:09:41 +01:00
@@ -35,12 +35,12 @@ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
|
||||
if saved_state_dict[k].shape != state_dict[k].shape:
|
||||
logger.warn(
|
||||
"shape-%s-mismatch. need: %s, get: %s"
|
||||
% (k, state_dict[k].shape, saved_state_dict[k].shape)
|
||||
, k, state_dict[k].shape, saved_state_dict[k].shape
|
||||
) #
|
||||
raise KeyError
|
||||
except:
|
||||
# logger.info(traceback.format_exc())
|
||||
logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
|
||||
logger.info("%s is not in the checkpoint", k) # pretrain缺失的
|
||||
new_state_dict[k] = v # 模型自带的随机值
|
||||
if hasattr(model, "module"):
|
||||
model.module.load_state_dict(new_state_dict, strict=False)
|
||||
@@ -111,12 +111,12 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
||||
if saved_state_dict[k].shape != state_dict[k].shape:
|
||||
logger.warn(
|
||||
"shape-%s-mismatch|need-%s|get-%s"
|
||||
% (k, state_dict[k].shape, saved_state_dict[k].shape)
|
||||
, k, state_dict[k].shape, saved_state_dict[k].shape
|
||||
) #
|
||||
raise KeyError
|
||||
except:
|
||||
# logger.info(traceback.format_exc())
|
||||
logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
|
||||
logger.info("%s is not in the checkpoint", k) # pretrain缺失的
|
||||
new_state_dict[k] = v # 模型自带的随机值
|
||||
if hasattr(model, "module"):
|
||||
model.module.load_state_dict(new_state_dict, strict=False)
|
||||
|
||||
Reference in New Issue
Block a user