Update README.md
Browse filesyou have 7 days from this commit to backup old models before they get culled
README.md
CHANGED
@@ -6,60 +6,73 @@ license: mit
|
|
6 |
## Training Details
|
7 |
Training details can be found here: https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx
|
8 |
## Source Code
|
9 |
-
Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch
|
|
|
10 |
## Community: LAION
|
11 |
Join Us!: https://discord.gg/uPMftTmrvS
|
12 |
|
13 |
---
|
14 |
|
15 |
# Models
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
clip = OpenAIClipAdapter(clip_choice=["ViT-L/14" | "ViT-B/32"])
|
20 |
-
```
|
21 |
|
22 |
### Loading the models might look something like this:
|
|
|
|
|
|
|
23 |
```python
|
24 |
-
|
|
|
|
|
25 |
|
26 |
-
|
27 |
|
28 |
-
|
29 |
-
dim = 512
|
30 |
-
else:
|
31 |
-
dim = 768
|
32 |
-
|
33 |
prior_network = DiffusionPriorNetwork(
|
34 |
-
dim=
|
35 |
-
depth=
|
36 |
dim_head=64,
|
37 |
-
heads=
|
38 |
-
normformer=True
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
diffusion_prior = DiffusionPrior(
|
42 |
net=prior_network,
|
43 |
-
clip=OpenAIClipAdapter(
|
44 |
-
image_embed_dim=
|
45 |
timesteps=1000,
|
46 |
cond_drop_prob=0.1,
|
47 |
loss_type="l2",
|
48 |
-
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
65 |
```
|
|
|
6 |
## Training Details
|
7 |
Training details can be found here: https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx
|
8 |
## Source Code
|
9 |
+
Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch
|
10 |
+
|
11 |
## Community: LAION
|
12 |
Join Us!: https://discord.gg/uPMftTmrvS
|
13 |
|
14 |
---
|
15 |
|
16 |
# Models
|
17 |
+
The repo currently has many models (most of which are actually pretty bad). I recommend using the latest ema checkpoints for now.
|
18 |
+
|
19 |
+
> **_DISCLAIMER_**: **I will be removing many of the older models**. They were trained on older versions of the repo and massively under perform recent models. **If for whatever reason you want an old model please make a backup** (you have 7 days from this README commit timestamp).
|
|
|
|
|
20 |
|
21 |
### Loading the models might look something like this:
|
22 |
+
|
23 |
+
> Note: This repo's documentation will get an overhaul \~soon\~. If you're reading this, and having issues loading checkpoints, please reach out on LAION.
|
24 |
+
|
25 |
```python
|
26 |
+
import torch
|
27 |
+
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
28 |
+
from dalle2_pytorch.trainer import DiffusionPriorTrainer
|
29 |
|
30 |
+
def load_diffusion_model(dprior_path, device):
|
31 |
|
32 |
+
# If you are getting issues with size mismatches, it's likely this configuration
|
|
|
|
|
|
|
|
|
33 |
prior_network = DiffusionPriorNetwork(
|
34 |
+
dim=768,
|
35 |
+
depth=24,
|
36 |
dim_head=64,
|
37 |
+
heads=32,
|
38 |
+
normformer=True,
|
39 |
+
attn_dropout=5e-2,
|
40 |
+
ff_dropout=5e-2,
|
41 |
+
num_time_embeds=1,
|
42 |
+
num_image_embeds=1,
|
43 |
+
num_text_embeds=1,
|
44 |
+
num_timesteps=1000,
|
45 |
+
ff_mult=4
|
46 |
+
)
|
47 |
+
|
48 |
+
# currently, only ViT-L/14 models are being trained
|
49 |
diffusion_prior = DiffusionPrior(
|
50 |
net=prior_network,
|
51 |
+
clip=OpenAIClipAdapter("ViT-L/14"),
|
52 |
+
image_embed_dim=768,
|
53 |
timesteps=1000,
|
54 |
cond_drop_prob=0.1,
|
55 |
loss_type="l2",
|
56 |
+
condition_on_text_encodings=True,
|
57 |
|
58 |
+
)
|
59 |
+
|
60 |
+
# this will load the entire trainer
|
61 |
+
# If you only want EMA weights for inference you will need to extract them yourself for now
|
62 |
+
# (if you beat me to writing a nice function for that please make a PR on Github!)
|
63 |
+
trainer = DiffusionPriorTrainer(
|
64 |
+
diffusion_prior=diffusion_prior,
|
65 |
+
lr=1.1e-4,
|
66 |
+
wd=6.02e-2,
|
67 |
+
max_grad_norm=0.5,
|
68 |
+
amp=False,
|
69 |
+
group_wd_params=True,
|
70 |
+
use_ema=True,
|
71 |
+
device=device,
|
72 |
+
accelerator=None,
|
73 |
+
)
|
74 |
+
|
75 |
+
trainer.load(dprior_path)
|
76 |
+
|
77 |
+
return trainer
|
78 |
```
|