burakcanbiner commited on
Commit
fafcf34
1 Parent(s): 6ce1706

Update pnp.py

Browse files
Files changed (1) hide show
  1. pnp.py +11 -5
pnp.py CHANGED
@@ -67,11 +67,17 @@ class PNP(nn.Module):
67
  ).to("cuda")
68
 
69
 
70
- # gate_dict = torch.load(adapter_ckpt_path)
 
 
 
 
 
71
 
72
- # for name, param in self.unet.named_parameters():
73
- # if "adapter" in name:
74
- # param.data = gate_dict[name]
 
75
  #unet.to(self.device);
76
 
77
  #pipe.unet = unet.to(self.device);
@@ -102,7 +108,7 @@ class PNP(nn.Module):
102
  # self.set_audio_projector(adapter_ckpt_path, audio_projector_ckpt_path)
103
  self.text_encoder = self.text_encoder.cuda()
104
 
105
- # self.audio_projector.load_state_dict(torch.load(audio_projector_ckpt_path))
106
 
107
  self.audio_projector_ckpt_path = audio_projector_ckpt_path
108
  self.adapter_ckpt_path = adapter_ckpt_path
 
67
  ).to("cuda")
68
 
69
 
70
+
71
+ audio_projector_path = "ckpts/audio_projector_landscape.pth"
72
+ adapter_ckpt_path = "ckpts/landscape.pt"
73
+ #self.pnp.set_audio_projector(gate_dict_path, audio_projector_path)
74
+
75
+ gate_dict = torch.load(adapter_ckpt_path)
76
 
77
+ for name, param in self.unet.named_parameters():
78
+ if "adapter" in name:
79
+ param.data = gate_dict[name]
80
+
81
  #unet.to(self.device);
82
 
83
  #pipe.unet = unet.to(self.device);
 
108
  # self.set_audio_projector(adapter_ckpt_path, audio_projector_ckpt_path)
109
  self.text_encoder = self.text_encoder.cuda()
110
 
111
+ self.audio_projector.load_state_dict(torch.load(audio_projector_path))
112
 
113
  self.audio_projector_ckpt_path = audio_projector_ckpt_path
114
  self.adapter_ckpt_path = adapter_ckpt_path