nousr commited on
Commit
ee6e0e8
1 Parent(s): 10edf38

Update README.md

Browse files

you have 7 days from this commit to backup old models before they get culled

Files changed (1) hide show
  1. README.md +50 -37
README.md CHANGED
@@ -6,60 +6,73 @@ license: mit
6
  ## Training Details
7
  Training details can be found here: https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx
8
  ## Source Code
9
- Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch with defaults specified in the train_diffusion_prior.py script
 
10
  ## Community: LAION
11
  Join Us!: https://discord.gg/uPMftTmrvS
12
 
13
  ---
14
 
15
  # Models
16
- ```
17
- depth=12
18
- d_model=768
19
- clip = OpenAIClipAdapter(clip_choice=["ViT-L/14" | "ViT-B/32"])
20
- ```
21
 
22
  ### Loading the models might look something like this:
 
 
 
23
  ```python
24
- def load_diffusion_model(dprior_path, device, clip_choice):
 
 
25
 
26
- loaded_obj = torch.load(str(dprior_path), map_location='cpu')
27
 
28
- if clip_choice == "ViT-B/32":
29
- dim = 512
30
- else:
31
- dim = 768
32
-
33
  prior_network = DiffusionPriorNetwork(
34
- dim=dim,
35
- depth=12,
36
  dim_head=64,
37
- heads=12,
38
- normformer=True
39
- ).to(device)
40
-
 
 
 
 
 
 
 
 
41
  diffusion_prior = DiffusionPrior(
42
  net=prior_network,
43
- clip=OpenAIClipAdapter(clip_choice),
44
- image_embed_dim=dim,
45
  timesteps=1000,
46
  cond_drop_prob=0.1,
47
  loss_type="l2",
48
- ).to(device)
49
 
50
-
51
- diffusion_prior.load_state_dict(loaded_obj["model"], strict=True)
52
-
53
- diffusion_prior = DiffusionPriorTrainer(
54
- diffusion_prior = diffusion_prior,
55
- lr = 1.1e-4,
56
- wd = 6.02e-2,
57
- max_grad_norm = 0.5,
58
- amp = False,
59
- ).to(device)
60
-
61
- diffusion_prior.optimizer.load_state_dict(loaded_obj['optimizer'])
62
- diffusion_prior.scaler.load_state_dict(loaded_obj['scaler'])
63
-
64
- return diffusion_prior
 
 
 
 
 
65
  ```
 
6
  ## Training Details
7
  Training details can be found here: https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx
8
  ## Source Code
9
+ Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch
10
+
11
  ## Community: LAION
12
  Join Us!: https://discord.gg/uPMftTmrvS
13
 
14
  ---
15
 
16
  # Models
17
+ The repo currently has many models (most of which are actually pretty bad). I recommend using the latest ema checkpoints for now.
18
+
19
+ > **_DISCLAIMER_**: **I will be removing many of the older models**. They were trained on older versions of the repo and massively under perform recent models. **If for whatever reason you want an old model please make a backup** (you have 7 days from this README commit timestamp).
 
 
20
 
21
  ### Loading the models might look something like this:
22
+
23
+ > Note: This repo's documentation will get an overhaul \~soon\~. If you're reading this, and having issues loading checkpoints, please reach out on LAION.
24
+
25
  ```python
26
+ import torch
27
+ from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
28
+ from dalle2_pytorch.trainer import DiffusionPriorTrainer
29
 
30
+ def load_diffusion_model(dprior_path, device):
31
 
32
+ # If you are getting issues with size mismatches, it's likely this configuration
 
 
 
 
33
  prior_network = DiffusionPriorNetwork(
34
+ dim=768,
35
+ depth=24,
36
  dim_head=64,
37
+ heads=32,
38
+ normformer=True,
39
+ attn_dropout=5e-2,
40
+ ff_dropout=5e-2,
41
+ num_time_embeds=1,
42
+ num_image_embeds=1,
43
+ num_text_embeds=1,
44
+ num_timesteps=1000,
45
+ ff_mult=4
46
+ )
47
+
48
+ # currently, only ViT-L/14 models are being trained
49
  diffusion_prior = DiffusionPrior(
50
  net=prior_network,
51
+ clip=OpenAIClipAdapter("ViT-L/14"),
52
+ image_embed_dim=768,
53
  timesteps=1000,
54
  cond_drop_prob=0.1,
55
  loss_type="l2",
56
+ condition_on_text_encodings=True,
57
 
58
+ )
59
+
60
+ # this will load the entire trainer
61
+ # If you only want EMA weights for inference you will need to extract them yourself for now
62
+ # (if you beat me to writing a nice function for that please make a PR on Github!)
63
+ trainer = DiffusionPriorTrainer(
64
+ diffusion_prior=diffusion_prior,
65
+ lr=1.1e-4,
66
+ wd=6.02e-2,
67
+ max_grad_norm=0.5,
68
+ amp=False,
69
+ group_wd_params=True,
70
+ use_ema=True,
71
+ device=device,
72
+ accelerator=None,
73
+ )
74
+
75
+ trainer.load(dprior_path)
76
+
77
+ return trainer
78
  ```