Update README.md
Browse files
README.md
CHANGED
@@ -1,13 +1,65 @@
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
|
|
4 |
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
|
8 |
|
|
|
9 |
```
|
10 |
depth=12
|
11 |
d_model=768
|
12 |
-
clip = OpenAIClipAdapter(clip_choice=["ViT-L/14" | "ViT-B/32"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
```
|
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
4 |
+
# A Text-Conditioned Diffusion-Prior
|
5 |
|
6 |
+
## Training Details
|
7 |
+
Training details can be found here: https://wandb.ai/nousr_laion/conditioned-prior/reports/Updated-Text-Conditioned-Prior--VmlldzoyMDI2OTIx
|
8 |
+
## Source Code
|
9 |
+
Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch with defaults specified in the train_diffusion_prior.py script
|
10 |
+
## Community: LAION
|
11 |
+
Join Us!: https://discord.gg/uPMftTmrvS
|
12 |
|
13 |
+
---
|
14 |
|
15 |
+
# Models
|
16 |
```
|
17 |
depth=12
|
18 |
d_model=768
|
19 |
+
clip = OpenAIClipAdapter(clip_choice=["ViT-L/14" | "ViT-B/32"])
|
20 |
+
```
|
21 |
+
|
22 |
+
### Loading the models might look something like this:
|
23 |
+
```python
|
24 |
+
def load_diffusion_model(dprior_path, device, clip_choice):
|
25 |
+
|
26 |
+
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
27 |
+
|
28 |
+
if clip_choice == "ViT-B/32":
|
29 |
+
dim = 512
|
30 |
+
else:
|
31 |
+
dim = 768
|
32 |
+
|
33 |
+
prior_network = DiffusionPriorNetwork(
|
34 |
+
dim=dim,
|
35 |
+
depth=12,
|
36 |
+
dim_head=64,
|
37 |
+
heads=12,
|
38 |
+
normformer=True
|
39 |
+
).to(device)
|
40 |
+
|
41 |
+
diffusion_prior = DiffusionPrior(
|
42 |
+
net=prior_network,
|
43 |
+
clip=OpenAIClipAdapter(clip_choice),
|
44 |
+
image_embed_dim=dim,
|
45 |
+
timesteps=1000,
|
46 |
+
cond_drop_prob=0.1,
|
47 |
+
loss_type="l2",
|
48 |
+
).to(device)
|
49 |
+
|
50 |
+
|
51 |
+
diffusion_prior.load_state_dict(loaded_obj["model"], strict=True)
|
52 |
+
|
53 |
+
diffusion_prior = DiffusionPriorTrainer(
|
54 |
+
diffusion_prior = diffusion_prior,
|
55 |
+
lr = 1.1e-4,
|
56 |
+
wd = 6.02e-2,
|
57 |
+
max_grad_norm = 0.5,
|
58 |
+
amp = False,
|
59 |
+
).to(device)
|
60 |
+
|
61 |
+
diffusion_prior.optimizer.load_state_dict(loaded_obj['optimizer'])
|
62 |
+
diffusion_prior.scaler.load_state_dict(loaded_obj['scaler'])
|
63 |
+
|
64 |
+
return diffusion_prior
|
65 |
```
|