mirror of
https://github.com/coqui-ai/TTS.git
synced 2025-12-25 20:59:48 +01:00
load_checkpoint for hifigan and no_grad for inference
This commit is contained in:
@@ -159,6 +159,7 @@ class HifiganGenerator(torch.nn.Module):
|
||||
x = torch.tanh(x)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
c = c.to(self.conv_pre.weight.device)
|
||||
c = torch.nn.functional.pad(
|
||||
@@ -173,3 +174,11 @@ class HifiganGenerator(torch.nn.Module):
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
|
||||
Reference in New Issue
Block a user