svjack commited on
Commit
39f07e6
·
1 Parent(s): 633f17a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +426 -0
app.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ !pip install "deepsparse-nightly==1.6.0.20231007"
3
+ !pip install "deepsparse[image_classification]"
4
+ !pip install opencv-python-headless
5
+ !pip uninstall numpy -y
6
+ !pip install numpy
7
+ !pip install gradio
8
+ !pip install pandas
9
+ '''
10
+
11
+ import os
12
+
13
+ os.system("pip uninstall numpy -y")
14
+ os.system("pip install numpy")
15
+ os.system("pip install pandas")
16
+
17
+ import gradio as gr
18
+ import sys
19
+ from uuid import uuid1
20
+ from PIL import Image
21
+ from zipfile import ZipFile
22
+ import pathlib
23
+ import shutil
24
+ import pandas as pd
25
+ import deepsparse
26
+ import json
27
+ import numpy as np
28
+
29
+ rn50_embedding_pipeline_default = deepsparse.Pipeline.create(
30
+ task="embedding-extraction",
31
+ base_task="image-classification", # tells the pipeline to expect images and normalize input with ImageNet means/stds
32
+ model_path="zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/channel20_pruned75_quant-none-vnni",
33
+ #emb_extraction_layer=-1, # extracts last layer before projection head and softmax
34
+ )
35
+
36
+ rn50_embedding_pipeline_last_1 = deepsparse.Pipeline.create(
37
+ task="embedding-extraction",
38
+ base_task="image-classification", # tells the pipeline to expect images and normalize input with ImageNet means/stds
39
+ model_path="zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/channel20_pruned75_quant-none-vnni",
40
+ emb_extraction_layer=-1, # extracts last layer before projection head and softmax
41
+ )
42
+
43
+ rn50_embedding_pipeline_last_2 = deepsparse.Pipeline.create(
44
+ task="embedding-extraction",
45
+ base_task="image-classification", # tells the pipeline to expect images and normalize input with ImageNet means/stds
46
+ model_path="zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/channel20_pruned75_quant-none-vnni",
47
+ emb_extraction_layer=-2, # extracts last layer before projection head and softmax
48
+ )
49
+
50
+ rn50_embedding_pipeline_last_3 = deepsparse.Pipeline.create(
51
+ task="embedding-extraction",
52
+ base_task="image-classification", # tells the pipeline to expect images and normalize input with ImageNet means/stds
53
+ model_path="zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/channel20_pruned75_quant-none-vnni",
54
+ emb_extraction_layer=-3, # extracts last layer before projection head and softmax
55
+ )
56
+
57
+ rn50_embedding_pipeline_dict = {
58
+ "0": rn50_embedding_pipeline_default,
59
+ "1": rn50_embedding_pipeline_last_1,
60
+ "2": rn50_embedding_pipeline_last_2,
61
+ "3": rn50_embedding_pipeline_last_3
62
+ }
63
+
64
+ def zip_ims(g):
65
+ from uuid import uuid1
66
+ if g is None:
67
+ return None
68
+ l = list(map(lambda x: x["name"], g))
69
+ if not l:
70
+ return None
71
+ zip_file_name ="tmp.zip"
72
+ with ZipFile(zip_file_name ,"w") as zipObj:
73
+ for ele in l:
74
+ zipObj.write(ele, "{}.png".format(uuid1()))
75
+ #zipObj.write(file2.name, "file2")
76
+ return zip_file_name
77
+
78
+ def unzip_ims_func(zip_file_name, choose_model,
79
+ rn50_embedding_pipeline_dict = rn50_embedding_pipeline_dict):
80
+ print("call file")
81
+ if zip_file_name is None:
82
+ return json.dumps({}), None
83
+ print("zip_file_name :")
84
+ print(zip_file_name)
85
+ unzip_path = "img_dir"
86
+ if os.path.exists(unzip_path):
87
+ shutil.rmtree(unzip_path)
88
+ with ZipFile(zip_file_name) as archive:
89
+ archive.extractall(unzip_path)
90
+ im_name_l = pd.Series(
91
+ list(pathlib.Path(unzip_path).rglob("*.png")) + \
92
+ list(pathlib.Path(unzip_path).rglob("*.jpg")) + \
93
+ list(pathlib.Path(unzip_path).rglob("*.jpeg"))
94
+ ).map(str).values.tolist()
95
+ rn50_embedding_pipeline = rn50_embedding_pipeline_dict[choose_model]
96
+ embeddings = rn50_embedding_pipeline(images=im_name_l)
97
+ im_l = pd.Series(im_name_l).map(Image.open).values.tolist()
98
+ if os.path.exists(unzip_path):
99
+ shutil.rmtree(unzip_path)
100
+ im_name_l = pd.Series(im_name_l).map(lambda x: x.split("/")[-1]).values.tolist()
101
+ return json.dumps({
102
+ "names": im_name_l,
103
+ "embs": embeddings.embeddings[0]
104
+ }), im_l
105
+
106
+
107
+ def emb_img_func(im, choose_model,
108
+ rn50_embedding_pipeline_dict = rn50_embedding_pipeline_dict):
109
+ print("call im :")
110
+ if im is None:
111
+ return json.dumps({})
112
+ im_obj = Image.fromarray(im)
113
+ im_name = "{}.png".format(uuid1())
114
+ im_obj.save(im_name)
115
+ rn50_embedding_pipeline = rn50_embedding_pipeline_dict[choose_model]
116
+ embeddings = rn50_embedding_pipeline(images=[im_name])
117
+ os.remove(im_name)
118
+ return json.dumps({
119
+ "names": [im_name],
120
+ "embs": embeddings.embeddings[0]
121
+ })
122
+
123
+ def image_grid(imgs, rows, cols):
124
+ assert len(imgs) <= rows*cols
125
+ w, h = imgs[0].size
126
+ grid = Image.new('RGB', size=(cols*w, rows*h))
127
+ grid_w, grid_h = grid.size
128
+
129
+ for i, img in enumerate(imgs):
130
+ grid.paste(img, box=(i%cols*w, i//cols*h))
131
+ return grid
132
+
133
+ def expand2square(pil_img, background_color):
134
+ width, height = pil_img.size
135
+ if width == height:
136
+ return pil_img
137
+ elif width > height:
138
+ result = Image.new(pil_img.mode, (width, width), background_color)
139
+ result.paste(pil_img, (0, (width - height) // 2))
140
+ return result
141
+ else:
142
+ result = Image.new(pil_img.mode, (height, height), background_color)
143
+ result.paste(pil_img, ((height - width) // 2, 0))
144
+ return result
145
+
146
+ def image_click(images, evt: gr.SelectData,
147
+ choose_model,
148
+ rn50_embedding_pipeline_dict = rn50_embedding_pipeline_dict,
149
+ top_k = 5
150
+ ):
151
+
152
+ images = json.loads(images.model_dump_json())
153
+ images = list(map(lambda x: {"name": x["image"]["path"]}, images))
154
+
155
+ img_selected = images[evt.index]
156
+ pivot_image_path = images[evt.index]['name']
157
+
158
+ im_name_l = list(map(lambda x: x["name"], images))
159
+ rn50_embedding_pipeline = rn50_embedding_pipeline_dict[choose_model]
160
+ embeddings = rn50_embedding_pipeline(images=im_name_l)
161
+ json_text = json.dumps({
162
+ "names": im_name_l,
163
+ "embs": embeddings.embeddings[0]
164
+ })
165
+
166
+ assert type(json_text) == type("")
167
+ assert type(pivot_image_path) in [type(""), type(0)]
168
+ dd_obj = json.loads(json_text)
169
+ names = dd_obj["names"]
170
+ embs = dd_obj["embs"]
171
+
172
+ assert pivot_image_path in names
173
+ corr_df = pd.DataFrame(np.asarray(embs).T).corr()
174
+ corr_df.columns = names
175
+ corr_df.index = names
176
+ arr_l = []
177
+ for i, r in corr_df.iterrows():
178
+ arr_ll = sorted(r.to_dict().items(), key = lambda t2: t2[1], reverse = True)
179
+ arr_l.append(arr_ll)
180
+ top_k = min(len(corr_df), top_k)
181
+ cols = pd.Series(arr_l[names.index(pivot_image_path)]).map(lambda x: x[0]).values.tolist()[:top_k]
182
+ corr_array_df = pd.DataFrame(arr_l).applymap(lambda x: x[0])
183
+ corr_array_df.index = names
184
+ #### corr_array
185
+ corr_array = corr_array_df.loc[cols].iloc[:, :top_k].values
186
+ l_list = pd.Series(corr_array.reshape([-1])).values.tolist()
187
+ l_list = pd.Series(l_list).map(Image.open).map(lambda x: expand2square(x, (0, 0, 0))).values.tolist()
188
+ l_dist_list = []
189
+ for ele in l_list:
190
+ if ele not in l_dist_list:
191
+ l_dist_list.append(ele)
192
+ return l_dist_list, l_list
193
+
194
+ import gradio as gr
195
+ from Lex import *
196
+ '''
197
+ lex = Lexica(query="man woman fire snow").images()
198
+ '''
199
+ from PIL import Image
200
+ import imagehash
201
+ import requests
202
+
203
+ from zipfile import ZipFile
204
+
205
+ from time import sleep
206
+ sleep_time = 0.5
207
+
208
+ hash_func_name = list(filter(lambda x: x.endswith("hash") and
209
+ "hex" not in x ,dir(imagehash)))
210
+ hash_func_name = ['average_hash', 'colorhash', 'dhash', 'phash', 'whash', 'crop_resistant_hash',]
211
+
212
+ def min_dim_to_size(img, size = 512):
213
+ h, w = img.size
214
+ ratio = size / max(h, w)
215
+ h, w = map(lambda x: int(x * ratio), [h, w])
216
+ return ( ratio ,img.resize((h, w)) )
217
+
218
+ #ratio_size = 512
219
+ #ratio, img_rs = min_dim_to_size(img, ratio_size)
220
+
221
+ '''
222
+ def image_click(images, evt: gr.SelectData):
223
+ img_selected = images[evt.index]
224
+ return images[evt.index]['name']
225
+
226
+ def swap_gallery(im, images, func_name):
227
+ #### name data is_file
228
+ #print(images[0].keys())
229
+ if im is None:
230
+ return list(map(lambda x: x["name"], images))
231
+ hash_func = getattr(imagehash, func_name)
232
+
233
+ im_hash = hash_func(Image.fromarray(im))
234
+ t2_list = sorted(images, key = lambda imm:
235
+ hash_func(Image.open(imm["name"])) - im_hash, reverse = False)
236
+ return list(map(lambda x: x["name"], t2_list))
237
+ '''
238
+
239
+
240
+
241
+ def lexica(prompt, limit_size = 128, ratio_size = 256 + 128):
242
+ lex = Lexica(query=prompt).images()
243
+ lex = lex[:limit_size]
244
+ lex = list(map(lambda x: x.replace("full_jpg", "sm2"), lex))
245
+ lex_ = []
246
+ for ele in lex:
247
+ try:
248
+ im = Image.open(
249
+ requests.get(ele, stream = True).raw
250
+ )
251
+ lex_.append(im)
252
+ except:
253
+ print("err")
254
+ sleep(sleep_time)
255
+ assert lex_
256
+ lex = list(map(lambda x: min_dim_to_size(x, ratio_size)[1], lex_))
257
+ return lex
258
+
259
+ def enterpix(prompt, limit_size = 100, ratio_size = 256 + 128, use_key = "bigThumbnailUrl"):
260
+ resp = requests.post(
261
+ url = "https://www.enterpix.app/enterpix/v1/image/prompt-search",
262
+ data= {
263
+ "length": limit_size,
264
+ "platform": "stable-diffusion,midjourney",
265
+ "prompt": prompt,
266
+ "start": 0
267
+ }
268
+ )
269
+ resp = resp.json()
270
+ resp = list(map(lambda x: x[use_key], resp["images"]))
271
+ lex_ = []
272
+ for ele in resp:
273
+ try:
274
+ im = Image.open(
275
+ requests.get(ele, stream = True).raw
276
+ )
277
+ lex_.append(im)
278
+ except:
279
+ print("err")
280
+ sleep(sleep_time)
281
+ assert lex_
282
+ resp = list(map(lambda x: min_dim_to_size(x, ratio_size)[1], lex_))
283
+ return resp
284
+
285
+ #def search(prompt, search_name, im, func_name):
286
+ def search(prompt, search_name,):
287
+ if search_name == "lexica":
288
+ im_l = lexica(prompt)
289
+ else:
290
+ im_l = enterpix(prompt)
291
+ return im_l
292
+ '''
293
+ if im is None:
294
+ return im_l
295
+ hash_func = getattr(imagehash, func_name)
296
+
297
+ im_hash = hash_func(Image.fromarray(im))
298
+ t2_list = sorted(im_l, key = lambda imm:
299
+ hash_func(imm) - im_hash, reverse = False)
300
+ return t2_list
301
+ #return list(map(lambda x: x["name"], t2_list))
302
+ '''
303
+
304
+ def zip_ims(g):
305
+ from uuid import uuid1
306
+ if g is None:
307
+ return None
308
+ l = list(map(lambda x: x["name"], g))
309
+ if not l:
310
+ return None
311
+ zip_file_name ="tmp.zip"
312
+ with ZipFile(zip_file_name ,"w") as zipObj:
313
+ for ele in l:
314
+ zipObj.write(ele, "{}.png".format(uuid1()))
315
+ #zipObj.write(file2.name, "file2")
316
+ return zip_file_name
317
+
318
+ with gr.Blocks(css="custom.css") as demo:
319
+ title = gr.HTML(
320
+ """<h1><img src="https://i.imgur.com/52VJ8vS.png" alt="SD"> StableDiffusion Search by Prompt order by Image</h1>""",
321
+ elem_id="title",
322
+ )
323
+
324
+ with gr.Row():
325
+ with gr.Column():
326
+ with gr.Row():
327
+ search_func_name = gr.Radio(choices=["lexica", "enterpix"],
328
+ value="lexica", label="Search by", elem_id="search_radio")
329
+ with gr.Row():
330
+ #inputs = gr.Textbox(label = 'Enter prompt to search Lexica.art')
331
+ inputs = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=20, min_width = 256,
332
+ placeholder="Enter prompt to search", elem_id="prompt")
333
+ #gr.Slider(label='Number of images ', minimum = 4, maximum = 20, step = 1, value = 4)]
334
+ text_button = gr.Button("Retrieve Images", elem_id="run_button")
335
+ #i = gr.Image(elem_id="result-image", label = "Image upload or selected", height = 768 - 256 - 32)
336
+ with gr.Row():
337
+ title = gr.Markdown(
338
+ value="### Click on a Image in the gallery to select it",
339
+ visible=True,
340
+ elem_id="selected_model",
341
+ )
342
+ choose_model = gr.Radio(choices=["0", "1", "2", "3"],
343
+ value="0", label="Choose embedding layer", elem_id="layer_radio")
344
+ g_outputs = gr.Gallery(lable='Output gallery', elem_id="gallery",).style(grid=5,height=768 + 64 + 32,
345
+ allow_preview=False, label = "retrieve Images")
346
+ with gr.Row():
347
+ with gr.Tab(label = "Download"):
348
+ zip_button = gr.Button("Zip Images to Download", elem_id="zip_button")
349
+ downloads = gr.File(label = "Image zipped", elem_id = "zip_file")
350
+
351
+ with gr.Column():
352
+ sdg_outputs = gr.Gallery(label='Sort Distinct gallery', elem_id="gallery",
353
+ columns=[5],object_fit="contain", height="auto")
354
+ sg_outputs = gr.Gallery(label='Sort gallery', elem_id="gallery",
355
+ columns=[5],object_fit="contain", height="auto")
356
+ #order_func_name = gr.Radio(choices=hash_func_name,
357
+ #value=hash_func_name[0], label="Order by", elem_id="order_radio")
358
+ #gr.Dataframe(label='prompts for corresponding images')]
359
+
360
+
361
+ with gr.Row():
362
+ '''
363
+ gr.Examples(
364
+ [
365
+ ["chinese zodiac signs", "lexica", "images/chinese_zodiac_signs.png", "average_hash"],
366
+ ["trending digital art", "lexica", "images/trending_digital_art.png", "colorhash"],
367
+ ["masterpiece, best quality, 1girl, solo, crop top, denim shorts, choker, (graffiti:1.5), paint splatter, arms behind back, against wall, looking at viewer, armband, thigh strap, paint on body, head tilt, bored, multicolored hair, aqua eyes, headset,", "lexica", "images/yuzu_girl0.png", "average_hash"],
368
+ ["beautiful home", "enterpix", "images/beautiful_home.png", "whash"],
369
+ ["interior design of living room", "enterpix", "images/interior_design_of_living_room.png", "whash"],
370
+ ["1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt",
371
+ "enterpix", "images/waifu_girl0.png", "phash"],
372
+ ],
373
+ inputs = [inputs, search_func_name, i, order_func_name],
374
+ label = "Examples"
375
+ )
376
+ '''
377
+ gr.Examples(
378
+ [
379
+ ["chinese zodiac signs", "lexica", ],
380
+ ["trending digital art", "lexica", ],
381
+ ["masterpiece, best quality, 1girl, solo, crop top, denim shorts, choker, (graffiti:1.5), paint splatter, arms behind back, against wall, looking at viewer, armband, thigh strap, paint on body, head tilt, bored, multicolored hair, aqua eyes, headset,", "lexica",],
382
+ ["beautiful home", "enterpix", ],
383
+ ["interior design of living room", "enterpix", ],
384
+ ["1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt",
385
+ "enterpix", ],
386
+ ],
387
+ inputs = [inputs, search_func_name,],
388
+ label = "Examples"
389
+ )
390
+
391
+
392
+ #outputs.select(image_click, outputs, i, _js="(x) => x.splice(0,x.length)")
393
+ #outputs.select(image_click, outputs, i,)
394
+ '''
395
+ i.change(
396
+ fn=swap_gallery,
397
+ inputs=[i, outputs, order_func_name],
398
+ outputs=outputs,
399
+ queue=False
400
+ )
401
+ order_func_name.change(
402
+ fn=swap_gallery,
403
+ inputs=[i, outputs, order_func_name],
404
+ outputs=outputs,
405
+ queue=False
406
+ )
407
+ '''
408
+
409
+
410
+ g_outputs.select(image_click,
411
+ inputs = [g_outputs, choose_model],
412
+ outputs = [sdg_outputs, sg_outputs],)
413
+
414
+ #### gr.Textbox().submit().success()
415
+
416
+ ### lexica
417
+ #text_button.click(lexica, inputs=inputs, outputs=outputs)
418
+ ### enterpix
419
+ #text_button.click(enterpix, inputs=inputs, outputs=outputs)
420
+ text_button.click(search, inputs=[inputs, search_func_name,], outputs=g_outputs)
421
+
422
+ zip_button.click(
423
+ zip_ims, inputs = sdg_outputs, outputs=downloads
424
+ )
425
+
426
+ demo.launch("0.0.0.0")