justinpinkney commited on
Commit
3b80337
1 Parent(s): 8881aeb

init commit

Browse files
Files changed (2) hide show
  1. app.py +210 -0
  2. requirements.txt +26 -0
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from einops import rearrange
6
+ from torch import autocast
7
+ from contextlib import nullcontext
8
+ import requests
9
+ import functools
10
+
11
+ from ldm.models.diffusion.ddim import DDIMSampler
12
+ from ldm.models.diffusion.plms import PLMSSampler
13
+ from ldm.extras import load_model_from_config, load_training_dir
14
+ import clip
15
+
16
+ from PIL import Image
17
+
18
+ from huggingface_hub import hf_hub_download
19
+ ckpt = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-pruned.ckpt")
20
+ config = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-config.yaml")
21
+
22
+ device = "cuda:0"
23
+ model = load_model_from_config(config, ckpt, device=device, verbose=False)
24
+ model = model.to(device).half()
25
+
26
+ clip_model, preprocess = clip.load("ViT-L/14", device=device)
27
+
28
+ n_inputs = 5
29
+
30
+ @functools.lru_cache()
31
+ def get_url_im(t):
32
+ user_agent = {'User-agent': 'gradio-app'}
33
+ response = requests.get(t, headers=user_agent)
34
+ return Image.open(BytesIO(response.content))
35
+
36
+ @torch.no_grad()
37
+ def get_im_c(im_path, clip_model):
38
+ # im = Image.open(im_path).convert("RGB")
39
+ prompts = preprocess(im_path).to(device).unsqueeze(0)
40
+ return clip_model.encode_image(prompts).float()
41
+
42
+ @torch.no_grad()
43
+ def get_txt_c(txt, clip_model):
44
+ text = clip.tokenize([txt,]).to(device)
45
+ return clip_model.encode_text(text)
46
+
47
+ def get_txt_diff(txt1, txt2, clip_model):
48
+ return get_txt_c(txt1, clip_model) - get_txt_c(txt2, clip_model)
49
+
50
+ def to_im_list(x_samples_ddim):
51
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
52
+ ims = []
53
+ for x_sample in x_samples_ddim:
54
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
55
+ ims.append(Image.fromarray(x_sample.astype(np.uint8)))
56
+ return ims
57
+
58
+ @torch.no_grad()
59
+ def sample(sampler, model, c, uc, scale, start_code, h=512, w=512, precision="autocast",ddim_steps=50):
60
+ ddim_eta=0.0
61
+ precision_scope = autocast if precision=="autocast" else nullcontext
62
+ with precision_scope("cuda"):
63
+ shape = [4, h // 8, w // 8]
64
+ samples_ddim, _ = sampler.sample(S=ddim_steps,
65
+ conditioning=c,
66
+ batch_size=c.shape[0],
67
+ shape=shape,
68
+ verbose=False,
69
+ unconditional_guidance_scale=scale,
70
+ unconditional_conditioning=uc,
71
+ eta=ddim_eta,
72
+ x_T=start_code)
73
+
74
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
75
+ return to_im_list(x_samples_ddim)
76
+
77
+ def run(*args):
78
+
79
+ inps = []
80
+ for i in range(0, len(args)-4, n_inputs):
81
+ inps.append(args[i:i+n_inputs])
82
+
83
+ scale, n_samples, seed, steps = args[-4:]
84
+ h = w = 640
85
+
86
+ sampler = DDIMSampler(model)
87
+ # sampler = PLMSSampler(model)
88
+
89
+ torch.manual_seed(seed)
90
+ start_code = torch.randn(n_samples, 4, h//8, w//8, device=device)
91
+ conds = []
92
+
93
+ for b, t, im, s in zip(*inps):
94
+ if b == "Image":
95
+ this_cond = s*get_im_c(im, clip_model)
96
+ elif b == "Text/URL":
97
+ if t.startswith("http"):
98
+ im = get_url_im(t)
99
+ this_cond = s*get_im_c(im, clip_model)
100
+ else:
101
+ this_cond = s*get_txt_c(t, clip_model)
102
+ else:
103
+ this_cond = torch.zeros((1, 768), device=device)
104
+ conds.append(this_cond)
105
+ conds = torch.cat(conds, dim=0).unsqueeze(0)
106
+ conds = conds.tile(n_samples, 1, 1)
107
+
108
+ ims = sample(sampler, model, conds, 0*conds, scale, start_code, ddim_steps=steps)
109
+ # return make_row(ims)
110
+ return ims
111
+
112
+
113
+ import gradio as gr
114
+ from functools import partial
115
+ from itertools import chain
116
+
117
+ def change_visible(txt1, im1, val):
118
+ outputs = {}
119
+ if val == "Image":
120
+ outputs[im1] = gr.update(visible=True)
121
+ outputs[txt1] = gr.update(visible=False)
122
+ elif val == "Text/URL":
123
+ outputs[im1] = gr.update(visible=False)
124
+ outputs[txt1] = gr.update(visible=True)
125
+ elif val == "Nothing":
126
+ outputs[im1] = gr.update(visible=False)
127
+ outputs[txt1] = gr.update(visible=False)
128
+ return outputs
129
+
130
+
131
+ with gr.Blocks(title="Image Mixer") as demo:
132
+
133
+ gr.Markdown("")
134
+ gr.Markdown(
135
+ """
136
+ # Image Mixer
137
+
138
+ _Created by [Justin Pinkney](https://www.justinpinkney.com) at [Lambda Labs](https://lambdalabs.com/)_
139
+
140
+ ### __Provide one or more images to be mixed together by a fine-tuned Stable Diffusion model.__
141
+
142
+ ![banner-large.jpeg](https://s3.amazonaws.com/moonup/production/uploads/1673968679262-62bd5f951e22ec84279820e8.jpeg)
143
+
144
+ """)
145
+
146
+ btns = []
147
+ txts = []
148
+ ims = []
149
+ strengths = []
150
+
151
+ with gr.Row():
152
+ for i in range(n_inputs):
153
+ with gr.Column():
154
+ btn1 = gr.Radio(
155
+ choices=["Image", "Text/URL", "Nothing"],
156
+ label=f"Input {i} type",
157
+ interactive=True,
158
+ value="Nothing",
159
+ )
160
+ txt1 = gr.Textbox(label="Text or Image URL", visible=False, interactive=True)
161
+ im1 = gr.Image(label="Image", interactive=True, visible=False, type="pil")
162
+ strength = gr.Slider(label="Strength", minimum=0, maximum=5, step=0.05, value=1, interactive=True)
163
+
164
+ fn = partial(change_visible, txt1, im1)
165
+ btn1.change(fn=fn, inputs=[btn1], outputs=[txt1, im1])
166
+
167
+ btns.append(btn1)
168
+ txts.append(txt1)
169
+ ims.append(im1)
170
+ strengths.append(strength)
171
+ with gr.Row():
172
+ cfg_scale = gr.Slider(label="CFG scale", value=3, minimum=1, maximum=10, step=0.5)
173
+ n_samples = gr.Slider(label="Num samples", value=2, minimum=1, maximum=4, step=1)
174
+ seed = gr.Slider(label="Seed", value=0, minimum=0, maximum=10000, step=1)
175
+ steps = gr.Slider(label="Steps", value=30, minimum=10, maximum=100, step=5)
176
+
177
+ with gr.Row():
178
+ submit = gr.Button("Generate")
179
+ output = gr.Gallery().style(grid=[1,2,2,2,4,4], height="640px")
180
+
181
+ inps = list(chain(btns, txts, ims, strengths))
182
+ inps.extend([cfg_scale,n_samples,seed, steps,])
183
+ submit.click(fn=run, inputs=inps, outputs=[output])
184
+
185
+ gr.Markdown(
186
+ """
187
+
188
+ ## Tips
189
+
190
+ - You can provide between 1 and 5 inputs, these can either be an uploaded image a text prompt or a url to an image file.
191
+ - The order of the inputs shouldn't matter, any images will be centre cropped before use.
192
+ - Each input has an individual strength parameter which controls how big an influence it has on the output.
193
+ - Using only text prompts doesn't work well, make sure there is at least one image or URL to an image.
194
+ - The parameters on the bottom row such as cfg scale do the same as for a normal Stable Diffusion model.
195
+ - Balancing the different inputs requires tweaking of the strengths, I suggest getting the right balance for a small number of samples and with few steps until you're
196
+ happy with the result then increase the steps for better quality.
197
+ - Outputs are 640x640 by default.
198
+
199
+ ## How does this work?
200
+
201
+ This model is based on the [Stable Diffusion Image Variations model](https://huggingface.co/lambdalabs/sd-image-variations-diffusers)
202
+ but it has been fined tuned to take multiple CLIP image embeddings. During training, up to 5 random crops were taken from the training images and
203
+ the CLIP image embeddings were computed, these were then concatenated and used as the conditioning for the model. At inference time we can combine the image
204
+ embeddings from multiple images to mix their concepts (and we can also use the text encoder to add text concepts too).
205
+
206
+ The model was trained on a subset of LAION Improved Aesthetics at a resolution of 640x640 and was trained using 8xA100 GPUs on [Lambda GPU Cloud](https://lambdalabs.com/service/gpu-cloud).
207
+
208
+ """)
209
+
210
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch==1.12.1
3
+ torchvision==0.13.1
4
+ albumentations==0.4.3
5
+ opencv-python==4.5.5.64
6
+ pudb==2019.2
7
+ imageio==2.9.0
8
+ imageio-ffmpeg==0.4.2
9
+ pytorch-lightning==1.4.2
10
+ omegaconf==2.1.1
11
+ test-tube>=0.7.5
12
+ streamlit>=0.73.1
13
+ einops==0.3.0
14
+ torch-fidelity==0.3.0
15
+ transformers==4.22.2
16
+ kornia==0.6
17
+ webdataset==0.2.5
18
+ torchmetrics==0.6.0
19
+ fire==0.4.0
20
+ gradio==3.1.4
21
+ diffusers==0.3.0
22
+ datasets[vision]==2.4.0
23
+ -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
24
+ -e git+https://github.com/openai/CLIP.git@main#egg=clip
25
+ -e git+https://github.com/justinpinkney/nomi.git@e9ded23b7e2269cc64d39683e1bf3c0319f552ab#egg=nomi
26
+ -e .