justinpinkney commited on
Commit
70e803f
1 Parent(s): 4184646
Files changed (2) hide show
  1. app.py +234 -0
  2. requirements.txt +14 -0
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from clip2latent import models
4
+ from PIL import Image
5
+
6
+ device = "cuda"
7
+ model_choices = {
8
+ "faces": {
9
+ "checkpoint": "https://huggingface.co/lambdalabs/clip2latent/resolve/main/ffhq-sg2-510.ckpt",
10
+ "config": "https://huggingface.co/lambdalabs/clip2latent/resolve/main/ffhq-sg2-510.yaml",
11
+ },
12
+ "landscape": {
13
+ "checkpoint": "https://huggingface.co/lambdalabs/clip2latent/resolve/main/lhq-sg3-410.ckpt",
14
+ "config": "https://huggingface.co/lambdalabs/clip2latent/resolve/main/lhq-sg3-410.yaml",
15
+ }
16
+ }
17
+
18
+ model_cache = {}
19
+ for k, v in model_choices.items():
20
+ checkpoint = v["checkpoint"]
21
+ cfg_file = v["config"]
22
+ # Moving to the cpu seems to break the model, so just put all on the gpu
23
+ model_cache[k] = models.Clip2StyleGAN(cfg_file, device, checkpoint)
24
+
25
+ @torch.no_grad()
26
+ def infer(prompt, model_select, n_samples, scale):
27
+ model = model_cache[model_select]
28
+ images, _ = model(prompt, n_samples_per_txt=n_samples, cond_scale=scale, skips=250, clip_sort=True)
29
+ images = images.cpu()
30
+ make_im = lambda x: (255*x.clamp(-1, 1)/2 + 127.5).to(torch.uint8).permute(1,2,0).numpy()
31
+ images = [Image.fromarray(make_im(x)) for x in images]
32
+ return images
33
+
34
+
35
+ css = """
36
+ a {
37
+ color: inherit;
38
+ text-decoration: underline;
39
+ }
40
+ .gradio-container {
41
+ font-family: 'IBM Plex Sans', sans-serif;
42
+ }
43
+ .gr-button {
44
+ color: white;
45
+ border-color: #9d66e5;
46
+ background: #9d66e5;
47
+ }
48
+ input[type='range'] {
49
+ accent-color: #9d66e5;
50
+ }
51
+ .dark input[type='range'] {
52
+ accent-color: #dfdfdf;
53
+ }
54
+ .container {
55
+ max-width: 730px;
56
+ margin: auto;
57
+ padding-top: 1.5rem;
58
+ }
59
+ #gallery {
60
+ min-height: 22rem;
61
+ margin-bottom: 15px;
62
+ margin-left: auto;
63
+ margin-right: auto;
64
+ border-bottom-right-radius: .5rem !important;
65
+ border-bottom-left-radius: .5rem !important;
66
+ }
67
+ #gallery>div>.h-full {
68
+ min-height: 20rem;
69
+ }
70
+ .details:hover {
71
+ text-decoration: underline;
72
+ }
73
+ .gr-button {
74
+ white-space: nowrap;
75
+ }
76
+ .gr-button:focus {
77
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
78
+ outline: none;
79
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
80
+ --tw-border-opacity: 1;
81
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
82
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
83
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
84
+ --tw-ring-opacity: .5;
85
+ }
86
+ #advanced-options {
87
+ margin-bottom: 20px;
88
+ }
89
+ .footer {
90
+ margin-bottom: 45px;
91
+ margin-top: 35px;
92
+ text-align: center;
93
+ border-bottom: 1px solid #e5e5e5;
94
+ }
95
+ .footer>p {
96
+ font-size: .8rem;
97
+ display: inline-block;
98
+ padding: 0 10px;
99
+ transform: translateY(10px);
100
+ background: white;
101
+ }
102
+ .dark .logo{ filter: invert(1); }
103
+ .dark .footer {
104
+ border-color: #303030;
105
+ }
106
+ .dark .footer>p {
107
+ background: #0b0f19;
108
+ }
109
+ .acknowledgments h4{
110
+ margin: 1.25em 0 .25em 0;
111
+ font-weight: bold;
112
+ font-size: 115%;
113
+ }
114
+ """
115
+
116
+ examples = [
117
+ [
118
+ 'a photograph of a happy person wearing sunglasses by the sea',
119
+ 'faces',
120
+ 2,
121
+ 2,
122
+ ],
123
+ [
124
+ 'a photograph of Captain Jean Luc Picard',
125
+ 'faces',
126
+ 2,
127
+ 2,
128
+ ],
129
+ [
130
+ 'a mountain in the middle of the sea',
131
+ 'landscape',
132
+ 2,
133
+ 2,
134
+ ],
135
+ [
136
+ 'The sun setting over the sea',
137
+ 'landscape',
138
+ 2,
139
+ 2,
140
+ ],
141
+ ]
142
+
143
+ def main():
144
+ block = gr.Blocks(css=css)
145
+
146
+ with block:
147
+ gr.HTML(
148
+ """
149
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
150
+ <div>
151
+ <img class="logo" src="https://lambdalabs.com/static/images/lambda-logo.svg" alt="Lambda Logo"
152
+ style="margin: auto; max-width: 7rem;">
153
+ <h1 style="font-weight: 900; font-size: 3rem;">
154
+ clip2latent
155
+ </h1>
156
+ </div>
157
+ <p style="font-size: 94%">
158
+ Official demo for <em>clip2latent: Text driven sampling of a pre-trained StyleGAN using denoising diffusion and CLIP</em>, accepted to BMVC 2022
159
+ </p>
160
+ <p style="margin-bottom: 10px; font-size: 94%">
161
+ Get the <a href="https://github.com/justinpinkney/clip2latent">code on GitHub</a>, see the <a href="#">paper on Arxiv</a>.
162
+ </p>
163
+ </div>
164
+ """
165
+ )
166
+ with gr.Group():
167
+ with gr.Box():
168
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
169
+ text = gr.Textbox(
170
+ label="Enter your prompt",
171
+ show_label=False,
172
+ max_lines=1,
173
+ placeholder="Enter your prompt",
174
+ ).style(
175
+ border=(True, False, True, True),
176
+ rounded=(True, False, False, True),
177
+ container=False,
178
+ )
179
+ btn = gr.Button("Generate image").style(
180
+ margin=False,
181
+ rounded=(False, True, True, False),
182
+ )
183
+
184
+ gallery = gr.Gallery(
185
+ label="Generated images", show_label=False, elem_id="gallery"
186
+ ).style(grid=[2], height="auto")
187
+
188
+
189
+ with gr.Row(elem_id="advanced-options"):
190
+ model_select = gr.Dropdown(label="Model", choices=["faces", "landscape"], value="faces",)
191
+ samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
192
+ scale = gr.Slider(
193
+ label="Guidance Scale", minimum=0, maximum=10, value=2, step=0.5
194
+ )
195
+
196
+
197
+ ex = gr.Examples(examples=examples, fn=infer, inputs=[text, model_select, samples, scale], outputs=gallery, cache_examples=False)
198
+ ex.dataset.headers = [""]
199
+
200
+ text.submit(infer, inputs=[text, model_select, samples, scale], outputs=gallery)
201
+ btn.click(infer, inputs=[text, model_select, samples, scale], outputs=gallery)
202
+ gr.HTML(
203
+ """
204
+ <div class="footer">
205
+ <p> Gradio Demo by Lambda Labs
206
+ </p>
207
+ </div>
208
+ <div class="acknowledgments">
209
+ <img src="https://raw.githubusercontent.com/justinpinkney/clip2latent/main/images/headline-large.jpeg"></img>
210
+ <br>
211
+ <h2 style="font-size:1.5em">clip2latent: Text driven sampling of a pre-trained StyleGAN using denoising diffusion and CLIP</h2>
212
+ <p>Justin N. M. Pinkney and Chuan Li @ <a href="https://lambdalabs.com/">Lambda Inc.</a>
213
+ <br>
214
+ <br>
215
+ <em>Abstract:</em>
216
+ We introduce a new method to efficiently create text-to-image models from a pre-trained CLIP and StyleGAN.
217
+ It enables text driven sampling with an existing generative model without any external data or fine-tuning.
218
+ This is achieved by training a diffusion model conditioned on CLIP embeddings to sample latent vectors of a pre-trained StyleGAN, which we call <em>clip2latent</em>.
219
+ We leverage the alignment between CLIP’s image and text embeddings to avoid the need for any text labelled data for training the conditional diffusion model.
220
+ We demonstrate that clip2latent allows us to generate high-resolution (1024x1024 pixels) images based on text prompts with fast sampling, high image quality, and low training compute and data requirements.
221
+ We also show that the use of the well studied StyleGAN architecture, without further fine-tuning, allows us to directly apply existing methods to control and modify the generated images adding a further layer of control to our text-to-image pipeline.
222
+ </p>
223
+ <br>
224
+ <p>Trained using <a href="https://lambdalabs.com/service/gpu-cloud">Lambda GPU Cloud</a></p>
225
+ </div>
226
+ """
227
+ )
228
+
229
+ block.queue()
230
+ block.launch()
231
+
232
+
233
+ if __name__ == "__main__":
234
+ main()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch
3
+ torchvision
4
+ wandb==0.12.16
5
+ ninja==1.10.2.3
6
+ dalle2-pytorch==0.2.38
7
+ hydra-core==1.1.2
8
+ typer==0.4.1
9
+ joblib==1.1.0
10
+ webdataset==0.2.5
11
+ gradio==3.4
12
+ protobuf==3.20.1
13
+ scipy==1.9.1
14
+ git+https://github.com/justinpinkney/clip2latent.git