File size: 2,636 Bytes
afc0c95
 
 
614425e
3e6762d
614425e
db7c99f
614425e
ee6e0e8
 
614425e
 
45e02cf
614425e
45e02cf
614425e
651b685
ee6e0e8
6ed3a94
614425e
 
ee6e0e8
 
 
614425e
ee6e0e8
 
 
614425e
ee6e0e8
614425e
ee6e0e8
614425e
ee6e0e8
 
614425e
ee6e0e8
 
 
 
 
 
 
 
 
 
 
 
614425e
 
ee6e0e8
 
614425e
 
 
ee6e0e8
614425e
ee6e0e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2746f82
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
---
license: mit
---
# A Text-Conditioned Diffusion-Prior

## Training Details
Training details can be found [here](https://wandb.ai/nousr_laion/1B%20Prior/reports/Distributed-Training-of-the-Prior--VmlldzoyMDkxMDQ5?accessToken=md54qpjikfxhf366iv64rxv94d47z05iojh28335fz6qlov11vlq313z63z42h3m)
## Source Code
Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch

## Community: LAION
Join Us!: https://discord.gg/uPMftTmrvS

---

# Models
The repo currently has many models, I recommend using the latest EMA model checkpoints as they are the best performing models right now.  

> **_DISCLAIMER_**: **I will be removing many of the older models**. They were trained on older versions of *DALLE2 PyTorch* and massively under perform compared to recent models. **If for whatever reason you want an old model please make a backup** (you have 7 days from this README commit timestamp).

### Loading the models might look something like this:

> Note: This repo's documentation will get an overhaul \~soon\~. If you're reading this, and having issues loading checkpoints, please reach out on LAION.

```python
import torch
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer

def load_diffusion_model(dprior_path, device):
    
    # If you are getting issues with size mismatches, it's likely this configuration
    prior_network = DiffusionPriorNetwork(
        dim=768,
        depth=24,
        dim_head=64,
        heads=32,
        normformer=True,
        attn_dropout=5e-2,
        ff_dropout=5e-2,
        num_time_embeds=1,
        num_image_embeds=1,
        num_text_embeds=1,
        num_timesteps=1000,
        ff_mult=4
    )
    
    # currently, only ViT-L/14 models are being trained
    diffusion_prior = DiffusionPrior(
        net=prior_network,
        clip=OpenAIClipAdapter("ViT-L/14"),
        image_embed_dim=768,
        timesteps=1000,
        cond_drop_prob=0.1,
        loss_type="l2",
        condition_on_text_encodings=True,

    )
    
    # this will load the entire trainer
    # If you only want EMA weights for inference you will need to extract them yourself for now 
    # (if you beat me to writing a nice function for that please make a PR on Github!)
    trainer = DiffusionPriorTrainer(
        diffusion_prior=diffusion_prior,
        lr=1.1e-4,
        wd=6.02e-2,
        max_grad_norm=0.5,
        amp=False,
        group_wd_params=True,
        use_ema=True,
        device=device,
        accelerator=None,
    )
    
    trainer.load(dprior_path)
    
    return trainer
```