Esmail-AGumaan commited on
Commit
dd7cfee
1 Parent(s): 01342fe

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +27 -27
model_loader.py CHANGED
@@ -1,28 +1,28 @@
1
- from nanograd.models.stable_diffusion.clip import CLIP
2
- from nanograd.models.stable_diffusion.encoder import VAE_Encoder
3
- from nanograd.models.stable_diffusion.decoder import VAE_Decoder
4
- from nanograd.models.stable_diffusion.diffusion import Diffusion
5
-
6
- from nanograd.models.stable_diffusion import model_converter
7
-
8
- def preload_models_from_standard_weights(ckpt_path, device):
9
- state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
10
-
11
- encoder = VAE_Encoder().to(device)
12
- encoder.load_state_dict(state_dict['encoder'], strict=True)
13
-
14
- decoder = VAE_Decoder().to(device)
15
- decoder.load_state_dict(state_dict['decoder'], strict=True)
16
-
17
- diffusion = Diffusion().to(device)
18
- diffusion.load_state_dict(state_dict['diffusion'], strict=True)
19
-
20
- clip = CLIP().to(device)
21
- clip.load_state_dict(state_dict['clip'], strict=True)
22
-
23
- return {
24
- 'clip': clip,
25
- 'encoder': encoder,
26
- 'decoder': decoder,
27
- 'diffusion': diffusion,
28
  }
 
1
+ from clip import CLIP
2
+ from encoder import VAE_Encoder
3
+ from decoder import VAE_Decoder
4
+ from diffusion import Diffusion
5
+
6
+ import model_converter
7
+
8
+ def preload_models_from_standard_weights(ckpt_path, device):
9
+ state_dict = model_converter.load_from_standard_weights(ckpt_path, device)
10
+
11
+ encoder = VAE_Encoder().to(device)
12
+ encoder.load_state_dict(state_dict['encoder'], strict=True)
13
+
14
+ decoder = VAE_Decoder().to(device)
15
+ decoder.load_state_dict(state_dict['decoder'], strict=True)
16
+
17
+ diffusion = Diffusion().to(device)
18
+ diffusion.load_state_dict(state_dict['diffusion'], strict=True)
19
+
20
+ clip = CLIP().to(device)
21
+ clip.load_state_dict(state_dict['clip'], strict=True)
22
+
23
+ return {
24
+ 'clip': clip,
25
+ 'encoder': encoder,
26
+ 'decoder': decoder,
27
+ 'diffusion': diffusion,
28
  }