mridulk commited on
Commit
e01ed7b
1 Parent(s): 6e3f743

added butterflies app

Browse files
Files changed (1) hide show
  1. butterflies_app.py +330 -0
butterflies_app.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+
4
+
5
+ import argparse, os, sys, glob
6
+ import torch
7
+ import pickle
8
+ import numpy as np
9
+ from omegaconf import OmegaConf
10
+ from PIL import Image
11
+ from tqdm import tqdm, trange
12
+ from einops import rearrange
13
+ from torchvision.utils import make_grid
14
+
15
+ from ldm.util import instantiate_from_config
16
+ from ldm.models.diffusion.ddim import DDIMSampler
17
+ from ldm.models.diffusion.plms import PLMSSampler
18
+
19
+
20
+ def load_model_from_config(config, ckpt, verbose=False):
21
+ print(f"Loading model from {ckpt}")
22
+ # pl_sd = torch.load(ckpt, map_location="cpu")
23
+ pl_sd = torch.load(ckpt)#, map_location="cpu")
24
+ sd = pl_sd["state_dict"]
25
+ model = instantiate_from_config(config.model)
26
+ m, u = model.load_state_dict(sd, strict=False)
27
+ if len(m) > 0 and verbose:
28
+ print("missing keys:")
29
+ print(m)
30
+ if len(u) > 0 and verbose:
31
+ print("unexpected keys:")
32
+ print(u)
33
+
34
+ model.cuda()
35
+ model.eval()
36
+ return model
37
+
38
+
39
+ def masking_embed(embedding, levels=1):
40
+ """
41
+ size of embedding - nx1xd, n: number of samples, d - 512
42
+ replacing the last 128*levels from the embedding
43
+ """
44
+ replace_size = 128*levels
45
+ random_noise = torch.randn(embedding.shape[0], embedding.shape[1], replace_size)
46
+ embedding[:, :, -replace_size:] = random_noise
47
+ return embedding
48
+
49
+
50
+ # LOAD MODEL GLOBALLY
51
+ ckpt_path = '/globalscratch/mridul/ldm/butterflies/model_runs/2024-06-18T21-37-12_HLE_lr1e-6_custom_NEW/checkpoints/epoch=000233.ckpt'
52
+ config_path = '/globalscratch/mridul/ldm/butterflies/model_runs/2024-06-18T21-37-12_HLE_lr1e-6_custom_NEW/configs/2024-06-18T21-37-12-project.yaml'
53
+ config = OmegaConf.load(config_path) # TODO: Optionally download from same location as ckpt and chnage this logic
54
+ model = load_model_from_config(config, ckpt_path) # TODO: check path
55
+
56
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
57
+ model = model.to(device)
58
+
59
+ class_to_node = '/projects/ml4science/mridul/data/cambridge_butterfly/level_encodings/butterflies_hle_4levels_custom_NEW.pkl'
60
+ with open(class_to_node, 'rb') as pickle_file:
61
+ class_to_node_dict = pickle.load(pickle_file)
62
+
63
+ class_to_node_dict = {key.lower(): value for key, value in class_to_node_dict.items()}
64
+ species_name_to_class = {'_'.join(x.split('_')[2:]):x for x in class_to_node_dict.keys()}
65
+
66
+ species_names = list(species_name_to_class.keys())
67
+
68
+ def generate_image(fish_name, masking_level_input,
69
+ swap_fish_name, swap_level_input):
70
+
71
+ # fish_name = fish_name.lower()
72
+
73
+
74
+ # label_to_class_mapping = {0: 'Alosa-chrysochloris', 1: 'Carassius-auratus', 2: 'Cyprinus-carpio', 3: 'Esox-americanus',
75
+ # 4: 'Gambusia-affinis', 5: 'Lepisosteus-osseus', 6: 'Lepisosteus-platostomus', 7: 'Lepomis-auritus', 8: 'Lepomis-cyanellus',
76
+ # 9: 'Lepomis-gibbosus', 10: 'Lepomis-gulosus', 11: 'Lepomis-humilis', 12: 'Lepomis-macrochirus', 13: 'Lepomis-megalotis',
77
+ # 14: 'Lepomis-microlophus', 15: 'Morone-chrysops', 16: 'Morone-mississippiensis', 17: 'Notropis-atherinoides',
78
+ # 18: 'Notropis-blennius', 19: 'Notropis-boops', 20: 'Notropis-buccatus', 21: 'Notropis-buchanani', 22: 'Notropis-dorsalis',
79
+ # 23: 'Notropis-hudsonius', 24: 'Notropis-leuciodus', 25: 'Notropis-nubilus', 26: 'Notropis-percobromus',
80
+ # 27: 'Notropis-stramineus', 28: 'Notropis-telescopus', 29: 'Notropis-texanus', 30: 'Notropis-volucellus',
81
+ # 31: 'Notropis-wickliffi', 32: 'Noturus-exilis', 33: 'Noturus-flavus', 34: 'Noturus-gyrinus', 35: 'Noturus-miurus',
82
+ # 36: 'Noturus-nocturnus', 37: 'Phenacobius-mirabilis'}
83
+
84
+ # def get_label_from_class(class_name):
85
+ # for key, value in label_to_class_mapping.items():
86
+ # if value == class_name:
87
+ # return key
88
+
89
+
90
+ if opt.plms:
91
+ sampler = PLMSSampler(model)
92
+ else:
93
+ sampler = DDIMSampler(model)
94
+
95
+
96
+ prompt = class_to_node_dict[species_name_to_class[fish_name]]
97
+
98
+ ### Trait Swapping
99
+ if swap_fish_name!='None':
100
+ # swap_fish_name = swap_fish_name.lower()
101
+ swap_level = int(swap_level_input.split(" ")[-1]) - 1
102
+ swap_fish = class_to_node_dict[species_name_to_class[swap_fish_name]]
103
+
104
+ swap_fish_split = swap_fish[0].split(',')
105
+ fish_name_split = prompt[0].split(',')
106
+ fish_name_split[swap_level] = swap_fish_split[swap_level]
107
+
108
+ prompt = [','.join(fish_name_split)]
109
+
110
+ all_samples=list()
111
+ with torch.no_grad():
112
+ with model.ema_scope():
113
+ uc = None
114
+ for n in trange(opt.n_iter, desc="Sampling"):
115
+
116
+ all_prompts = opt.n_samples * (prompt)
117
+ all_prompts = [tuple(all_prompts)]
118
+ c = model.get_learned_conditioning({'class_to_node': all_prompts})
119
+ if masking_level_input != "None":
120
+ masked_level = int(masking_level_input.split(" ")[-1])
121
+ masked_level = 4-masked_level
122
+ c = masking_embed(c, levels=masked_level)
123
+ shape = [3, 64, 64]
124
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
125
+ conditioning=c,
126
+ batch_size=opt.n_samples,
127
+ shape=shape,
128
+ verbose=False,
129
+ unconditional_guidance_scale=opt.scale,
130
+ unconditional_conditioning=uc,
131
+ eta=opt.ddim_eta)
132
+
133
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
134
+ x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
135
+
136
+ all_samples.append(x_samples_ddim)
137
+
138
+ ###### to make grid
139
+ # additionally, save as grid
140
+ grid = torch.stack(all_samples, 0)
141
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
142
+ grid = make_grid(grid, nrow=opt.n_samples)
143
+
144
+ # to image
145
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
146
+ final_image = Image.fromarray(grid.astype(np.uint8))
147
+ # final_image.save(os.path.join(sample_path, f'{class_name.replace(" ", "-")}.png'))
148
+
149
+ return final_image
150
+
151
+
152
+ if __name__ == "__main__":
153
+ parser = argparse.ArgumentParser()
154
+
155
+ # parser.add_argument(
156
+ # "--prompt",
157
+ # type=str,
158
+ # nargs="?",
159
+ # default="a painting of a virus monster playing guitar",
160
+ # help="the prompt to render"
161
+ # )
162
+
163
+ # parser.add_argument(
164
+ # "--outdir",
165
+ # type=str,
166
+ # nargs="?",
167
+ # help="dir to write results to",
168
+ # default="outputs/txt2img-samples"
169
+ # )
170
+ parser.add_argument(
171
+ "--ddim_steps",
172
+ type=int,
173
+ default=200,
174
+ help="number of ddim sampling steps",
175
+ )
176
+
177
+ parser.add_argument(
178
+ "--plms",
179
+ action='store_true',
180
+ help="use plms sampling",
181
+ )
182
+
183
+ parser.add_argument(
184
+ "--ddim_eta",
185
+ type=float,
186
+ default=1.0,
187
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
188
+ )
189
+ parser.add_argument(
190
+ "--n_iter",
191
+ type=int,
192
+ default=1,
193
+ help="sample this often",
194
+ )
195
+
196
+ # parser.add_argument(
197
+ # "--H",
198
+ # type=int,
199
+ # default=256,
200
+ # help="image height, in pixel space",
201
+ # )
202
+
203
+ # parser.add_argument(
204
+ # "--W",
205
+ # type=int,
206
+ # default=256,
207
+ # help="image width, in pixel space",
208
+ # )
209
+
210
+ parser.add_argument(
211
+ "--n_samples",
212
+ type=int,
213
+ default=3,
214
+ help="how many samples to produce for the given prompt",
215
+ )
216
+
217
+ # parser.add_argument(
218
+ # "--output_dir_name",
219
+ # type=str,
220
+ # default='default_file',
221
+ # help="name of folder",
222
+ # )
223
+
224
+ # parser.add_argument(
225
+ # "--postfix",
226
+ # type=str,
227
+ # default='',
228
+ # help="name of folder",
229
+ # )
230
+
231
+ parser.add_argument(
232
+ "--scale",
233
+ type=float,
234
+ # default=5.0,
235
+ default=1.0,
236
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
237
+ )
238
+ opt = parser.parse_args()
239
+
240
+ title = "🎞️ Phylo Diffusion - Generating Butterfly Images Tool"
241
+ description = "Write the Species name to generate an image for.\n For Trait Masking: Specify the Level information as well"
242
+
243
+
244
+ def load_example(prompt, level, option, components):
245
+ components['prompt_input'].value = prompt
246
+ components['masking_level_input'].value = level
247
+ # components['option'].value = option
248
+
249
+ def setup_interface():
250
+ with gr.Blocks() as demo:
251
+
252
+ gr.Markdown("# Phylo Diffusion - Generating Butterfly Images Tool")
253
+ gr.Markdown("### Write the Species name to generate a butterfly image")
254
+ gr.Markdown("### 1. Trait Masking: Specify the Level information as well")
255
+ gr.Markdown("### 2. Trait Swapping: Specify the species name to swap trait with at also at what level")
256
+
257
+ with gr.Row():
258
+ with gr.Column():
259
+ gr.Markdown("## Generate Images Based on Prompts")
260
+ gr.Markdown("Select a species to generate an image:")
261
+ # prompt_input = gr.Textbox(label="Species Name")
262
+ prompt_input = gr.Dropdown(label="Select Butterfly", choices=species_names, value="None")
263
+ gr.Markdown("Trait Masking")
264
+ with gr.Row():
265
+ masking_level_input = gr.Dropdown(label="Select Ancestral Level", choices=["None", "Level 3", "Level 2"], value="None")
266
+ # masking_node_input = gr.Dropdown(label="Select Internal", choices=["0", "1", "2", "3", "4", "5", "6", "7", "8"], value="0")
267
+
268
+ gr.Markdown("Trait Swapping")
269
+ with gr.Row():
270
+ swap_fish_name = gr.Dropdown(label="Select species Name to swap trait with:", choices=species_names, value="None")
271
+ swap_level_input = gr.Dropdown(label="Level of swapping", choices=["Level 3", "Level 2"], value="Level 3")
272
+ submit_button = gr.Button("Generate")
273
+ gr.Markdown("## Phylogeny Tree")
274
+ architecture_image = "phylogeny_tree.jpg" # Update this with the actual path
275
+ gr.Image(value=architecture_image, label="Phylogeny Tree")
276
+
277
+ with gr.Column():
278
+
279
+ gr.Markdown("## Generated Image")
280
+ output_image = gr.Image(label="Generated Image", width=768, height=256)
281
+
282
+
283
+ # # Place to put example buttons
284
+ # gr.Markdown("## Select an example:")
285
+ # examples = [
286
+ # ("Gambusia Affinis", "None", "", "Level 3"),
287
+ # ("Lepomis Auritus", "None", "", "Level 3"),
288
+ # ("Lepomis Auritus", "Level 3", "", "Level 3"),
289
+ # ("Noturus nocturnus", "None", "Notropis dorsalis", "Level 2")]
290
+
291
+ # for text, level, swap_text, swap_level in examples:
292
+ # if level == "None" and swap_text == "":
293
+ # button = gr.Button(f"Species: {text}")
294
+ # elif level != "None":
295
+ # button = gr.Button(f"Species: {text} | Masking: {level}")
296
+ # elif swap_text != "":
297
+ # button = gr.Button(f"Species: {text} | Swapping with {swap_text} at {swap_level} ")
298
+ # button.click(
299
+ # fn=lambda text=text, level=level, swap_text=swap_text, swap_level=swap_level: (text, level, swap_text, swap_level),
300
+ # inputs=[],
301
+ # outputs=[prompt_input, masking_level_input, swap_fish_name, swap_level_input]
302
+ # )
303
+
304
+
305
+ # Display an image of the architecture
306
+
307
+
308
+ submit_button.click(
309
+ fn=generate_image,
310
+ inputs=[prompt_input, masking_level_input,
311
+ swap_fish_name, swap_level_input],
312
+ outputs=output_image
313
+ )
314
+
315
+ return demo
316
+
317
+ # # Launch the interface
318
+ # iface = setup_interface()
319
+
320
+ # iface = gr.Interface(
321
+ # fn=generate_image,
322
+ # inputs=gr.Textbox(label="Prompt"),
323
+ # outputs=[
324
+ # gr.Image(label="Generated Image"),
325
+ # ]
326
+ # )
327
+
328
+ iface = setup_interface()
329
+
330
+ iface.launch(share=True)