jacob-c commited on
Commit
aeea655
1 Parent(s): 2f7e110
audioldm/hifigan/utilities.py CHANGED
@@ -64,13 +64,13 @@ def get_param_num(model):
64
  return num_param
65
 
66
 
67
- def get_vocoder(config, device):
68
- config = hifigan.AttrDict(HIFIGAN_16K_64)
69
- vocoder = hifigan.Generator(config)
70
- vocoder.eval()
71
- vocoder.remove_weight_norm()
72
- vocoder.to(device)
73
- return vocoder
74
 
75
 
76
  def vocoder_infer(mels, vocoder, lengths=None):
 
64
  return num_param
65
 
66
 
67
+ # def get_vocoder(config, device):
68
+ # config = hifigan.AttrDict(HIFIGAN_16K_64)
69
+ # vocoder = hifigan.Generator(config)
70
+ # vocoder.eval()
71
+ # vocoder.remove_weight_norm()
72
+ # vocoder.to(device)
73
+ # return vocoder
74
 
75
 
76
  def vocoder_infer(mels, vocoder, lengths=None):
audioldm/pipeline.py CHANGED
@@ -87,7 +87,11 @@ def build_model(
87
  resume_from_checkpoint = ckpt_path
88
 
89
  checkpoint = torch.load(resume_from_checkpoint, map_location=device)
90
- latent_diffusion.load_state_dict(checkpoint["state_dict"])
 
 
 
 
91
 
92
  latent_diffusion.eval()
93
  latent_diffusion = latent_diffusion.to(device)
 
87
  resume_from_checkpoint = ckpt_path
88
 
89
  checkpoint = torch.load(resume_from_checkpoint, map_location=device)
90
+ state_dict = checkpoint["state_dict"]
91
+ # Filter out vocoder keys
92
+ filtered_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("vocoder.")}
93
+ latent_diffusion.load_state_dict(filtered_state_dict, strict=False)
94
+
95
 
96
  latent_diffusion.eval()
97
  latent_diffusion = latent_diffusion.to(device)