Spaces:
Runtime error
Runtime error
''' | |
!pip install "deepsparse-nightly==1.6.0.20231007" | |
!pip install "deepsparse[image_classification]" | |
!pip install opencv-python-headless | |
!pip uninstall numpy -y | |
!pip install numpy | |
!pip install gradio | |
!pip install pandas | |
''' | |
import os | |
os.system("pip uninstall numpy -y") | |
os.system("pip install numpy") | |
os.system("pip install pandas") | |
import gradio as gr | |
import sys | |
from uuid import uuid1 | |
from PIL import Image | |
from zipfile import ZipFile | |
import pathlib | |
import shutil | |
import pandas as pd | |
import deepsparse | |
import json | |
import numpy as np | |
rn50_embedding_pipeline_default = deepsparse.Pipeline.create( | |
task="embedding-extraction", | |
base_task="image-classification", # tells the pipeline to expect images and normalize input with ImageNet means/stds | |
model_path="zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/channel20_pruned75_quant-none-vnni", | |
#emb_extraction_layer=-1, # extracts last layer before projection head and softmax | |
) | |
rn50_embedding_pipeline_last_1 = deepsparse.Pipeline.create( | |
task="embedding-extraction", | |
base_task="image-classification", # tells the pipeline to expect images and normalize input with ImageNet means/stds | |
model_path="zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/channel20_pruned75_quant-none-vnni", | |
emb_extraction_layer=-1, # extracts last layer before projection head and softmax | |
) | |
rn50_embedding_pipeline_last_2 = deepsparse.Pipeline.create( | |
task="embedding-extraction", | |
base_task="image-classification", # tells the pipeline to expect images and normalize input with ImageNet means/stds | |
model_path="zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/channel20_pruned75_quant-none-vnni", | |
emb_extraction_layer=-2, # extracts last layer before projection head and softmax | |
) | |
rn50_embedding_pipeline_last_3 = deepsparse.Pipeline.create( | |
task="embedding-extraction", | |
base_task="image-classification", # tells the pipeline to expect images and normalize input with ImageNet means/stds | |
model_path="zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/channel20_pruned75_quant-none-vnni", | |
emb_extraction_layer=-3, # extracts last layer before projection head and softmax | |
) | |
rn50_embedding_pipeline_dict = { | |
"0": rn50_embedding_pipeline_default, | |
"1": rn50_embedding_pipeline_last_1, | |
"2": rn50_embedding_pipeline_last_2, | |
"3": rn50_embedding_pipeline_last_3 | |
} | |
def zip_ims(g): | |
from uuid import uuid1 | |
if g is None: | |
return None | |
''' | |
print(g[0]) | |
print(g[0][0]) | |
print(g[0][1]) | |
''' | |
g = json.loads(g.model_dump_json()) | |
g = list(map(lambda x: {"name": x["image"]["path"]}, g)) | |
l = list(map(lambda x: x["name"], g)) | |
if not l: | |
return None | |
zip_file_name ="tmp.zip" | |
with ZipFile(zip_file_name ,"w") as zipObj: | |
for ele in l: | |
zipObj.write(ele, "{}.png".format(uuid1())) | |
#zipObj.write(file2.name, "file2") | |
return zip_file_name | |
def unzip_ims_func(zip_file_name, choose_model, | |
rn50_embedding_pipeline_dict = rn50_embedding_pipeline_dict): | |
print("call file") | |
if zip_file_name is None: | |
return json.dumps({}), None | |
print("zip_file_name :") | |
print(zip_file_name) | |
unzip_path = "img_dir" | |
if os.path.exists(unzip_path): | |
shutil.rmtree(unzip_path) | |
with ZipFile(zip_file_name) as archive: | |
archive.extractall(unzip_path) | |
im_name_l = pd.Series( | |
list(pathlib.Path(unzip_path).rglob("*.png")) + \ | |
list(pathlib.Path(unzip_path).rglob("*.jpg")) + \ | |
list(pathlib.Path(unzip_path).rglob("*.jpeg")) | |
).map(str).values.tolist() | |
rn50_embedding_pipeline = rn50_embedding_pipeline_dict[choose_model] | |
embeddings = rn50_embedding_pipeline(images=im_name_l) | |
im_l = pd.Series(im_name_l).map(Image.open).values.tolist() | |
if os.path.exists(unzip_path): | |
shutil.rmtree(unzip_path) | |
im_name_l = pd.Series(im_name_l).map(lambda x: x.split("/")[-1]).values.tolist() | |
return json.dumps({ | |
"names": im_name_l, | |
"embs": embeddings.embeddings[0] | |
}), im_l | |
def emb_img_func(im, choose_model, | |
rn50_embedding_pipeline_dict = rn50_embedding_pipeline_dict): | |
print("call im :") | |
if im is None: | |
return json.dumps({}) | |
im_obj = Image.fromarray(im) | |
im_name = "{}.png".format(uuid1()) | |
im_obj.save(im_name) | |
rn50_embedding_pipeline = rn50_embedding_pipeline_dict[choose_model] | |
embeddings = rn50_embedding_pipeline(images=[im_name]) | |
os.remove(im_name) | |
return json.dumps({ | |
"names": [im_name], | |
"embs": embeddings.embeddings[0] | |
}) | |
def image_grid(imgs, rows, cols): | |
assert len(imgs) <= rows*cols | |
w, h = imgs[0].size | |
grid = Image.new('RGB', size=(cols*w, rows*h)) | |
grid_w, grid_h = grid.size | |
for i, img in enumerate(imgs): | |
grid.paste(img, box=(i%cols*w, i//cols*h)) | |
return grid | |
def expand2square(pil_img, background_color): | |
width, height = pil_img.size | |
if width == height: | |
return pil_img | |
elif width > height: | |
result = Image.new(pil_img.mode, (width, width), background_color) | |
result.paste(pil_img, (0, (width - height) // 2)) | |
return result | |
else: | |
result = Image.new(pil_img.mode, (height, height), background_color) | |
result.paste(pil_img, ((height - width) // 2, 0)) | |
return result | |
def image_click(images, evt: gr.SelectData, | |
choose_model, | |
rn50_embedding_pipeline_dict = rn50_embedding_pipeline_dict, | |
top_k = 5 | |
): | |
images = json.loads(images.model_dump_json()) | |
images = list(map(lambda x: {"name": x["image"]["path"]}, images)) | |
img_selected = images[evt.index] | |
pivot_image_path = images[evt.index]['name'] | |
im_name_l = list(map(lambda x: x["name"], images)) | |
rn50_embedding_pipeline = rn50_embedding_pipeline_dict[choose_model] | |
embeddings = rn50_embedding_pipeline(images=im_name_l) | |
json_text = json.dumps({ | |
"names": im_name_l, | |
"embs": embeddings.embeddings[0] | |
}) | |
assert type(json_text) == type("") | |
assert type(pivot_image_path) in [type(""), type(0)] | |
dd_obj = json.loads(json_text) | |
names = dd_obj["names"] | |
embs = dd_obj["embs"] | |
assert pivot_image_path in names | |
corr_df = pd.DataFrame(np.asarray(embs).T).corr() | |
corr_df.columns = names | |
corr_df.index = names | |
arr_l = [] | |
for i, r in corr_df.iterrows(): | |
arr_ll = sorted(r.to_dict().items(), key = lambda t2: t2[1], reverse = True) | |
arr_l.append(arr_ll) | |
top_k = min(len(corr_df), top_k) | |
cols = pd.Series(arr_l[names.index(pivot_image_path)]).map(lambda x: x[0]).values.tolist()[:top_k] | |
corr_array_df = pd.DataFrame(arr_l).applymap(lambda x: x[0]) | |
corr_array_df.index = names | |
#### corr_array | |
corr_array = corr_array_df.loc[cols].iloc[:, :top_k].values | |
l_list = pd.Series(corr_array.reshape([-1])).values.tolist() | |
l_list = pd.Series(l_list).map(Image.open).map(lambda x: expand2square(x, (0, 0, 0))).values.tolist() | |
l_dist_list = [] | |
for ele in l_list: | |
if ele not in l_dist_list: | |
l_dist_list.append(ele) | |
return l_dist_list, l_list | |
import gradio as gr | |
from Lex import * | |
''' | |
lex = Lexica(query="man woman fire snow").images() | |
''' | |
from PIL import Image | |
import imagehash | |
import requests | |
from zipfile import ZipFile | |
from time import sleep | |
sleep_time = 0.5 | |
hash_func_name = list(filter(lambda x: x.endswith("hash") and | |
"hex" not in x ,dir(imagehash))) | |
hash_func_name = ['average_hash', 'colorhash', 'dhash', 'phash', 'whash', 'crop_resistant_hash',] | |
def min_dim_to_size(img, size = 512): | |
h, w = img.size | |
ratio = size / max(h, w) | |
h, w = map(lambda x: int(x * ratio), [h, w]) | |
return ( ratio ,img.resize((h, w)) ) | |
#ratio_size = 512 | |
#ratio, img_rs = min_dim_to_size(img, ratio_size) | |
''' | |
def image_click(images, evt: gr.SelectData): | |
img_selected = images[evt.index] | |
return images[evt.index]['name'] | |
def swap_gallery(im, images, func_name): | |
#### name data is_file | |
#print(images[0].keys()) | |
if im is None: | |
return list(map(lambda x: x["name"], images)) | |
hash_func = getattr(imagehash, func_name) | |
im_hash = hash_func(Image.fromarray(im)) | |
t2_list = sorted(images, key = lambda imm: | |
hash_func(Image.open(imm["name"])) - im_hash, reverse = False) | |
return list(map(lambda x: x["name"], t2_list)) | |
''' | |
def lexica(prompt, limit_size = 128, ratio_size = 256 + 128): | |
lex = Lexica(query=prompt).images() | |
lex = lex[:limit_size] | |
lex = list(map(lambda x: x.replace("full_jpg", "sm2"), lex)) | |
lex_ = [] | |
for ele in lex: | |
try: | |
im = Image.open( | |
requests.get(ele, stream = True).raw | |
) | |
lex_.append(im) | |
except: | |
print("err") | |
sleep(sleep_time) | |
assert lex_ | |
lex = list(map(lambda x: min_dim_to_size(x, ratio_size)[1], lex_)) | |
return lex | |
def enterpix(prompt, limit_size = 100, ratio_size = 256 + 128, use_key = "bigThumbnailUrl"): | |
resp = requests.post( | |
url = "https://www.enterpix.app/enterpix/v1/image/prompt-search", | |
data= { | |
"length": limit_size, | |
"platform": "stable-diffusion,midjourney", | |
"prompt": prompt, | |
"start": 0 | |
} | |
) | |
resp = resp.json() | |
resp = list(map(lambda x: x[use_key], resp["images"])) | |
lex_ = [] | |
for ele in resp: | |
try: | |
im = Image.open( | |
requests.get(ele, stream = True).raw | |
) | |
lex_.append(im) | |
except: | |
print("err") | |
sleep(sleep_time) | |
assert lex_ | |
resp = list(map(lambda x: min_dim_to_size(x, ratio_size)[1], lex_)) | |
return resp | |
#def search(prompt, search_name, im, func_name): | |
def search(prompt, search_name,): | |
if search_name == "lexica": | |
im_l = lexica(prompt) | |
else: | |
im_l = enterpix(prompt) | |
return im_l | |
''' | |
if im is None: | |
return im_l | |
hash_func = getattr(imagehash, func_name) | |
im_hash = hash_func(Image.fromarray(im)) | |
t2_list = sorted(im_l, key = lambda imm: | |
hash_func(imm) - im_hash, reverse = False) | |
return t2_list | |
#return list(map(lambda x: x["name"], t2_list)) | |
''' | |
''' | |
def zip_ims(g): | |
from uuid import uuid1 | |
if g is None: | |
return None | |
l = list(map(lambda x: x["name"], g)) | |
if not l: | |
return None | |
zip_file_name ="tmp.zip" | |
with ZipFile(zip_file_name ,"w") as zipObj: | |
for ele in l: | |
zipObj.write(ele, "{}.png".format(uuid1())) | |
#zipObj.write(file2.name, "file2") | |
return zip_file_name | |
''' | |
with gr.Blocks(css="custom.css") as demo: | |
title = gr.HTML( | |
"""<h1><img src="https://i.imgur.com/dBs990M.png" alt="SD"> StableDiffusion Search by Prompt order by Image Embedding</h1>""", | |
elem_id="title", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
search_func_name = gr.Radio(choices=["lexica", "enterpix"], | |
value="lexica", label="Search by", elem_id="search_radio") | |
with gr.Row(): | |
#inputs = gr.Textbox(label = 'Enter prompt to search Lexica.art') | |
inputs = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=20, min_width = 256, | |
placeholder="Enter prompt to search", elem_id="prompt") | |
#gr.Slider(label='Number of images ', minimum = 4, maximum = 20, step = 1, value = 4)] | |
text_button = gr.Button("Retrieve Images", elem_id="run_button") | |
#i = gr.Image(elem_id="result-image", label = "Image upload or selected", height = 768 - 256 - 32) | |
with gr.Row(): | |
with gr.Column(): | |
title = gr.Markdown( | |
value="### Click on a Image in the gallery to select it", | |
visible=True, | |
elem_id="selected_model", | |
) | |
choose_model = gr.Radio(choices=["0", "1", "2", "3"], | |
value="0", label="Choose embedding layer", elem_id="layer_radio") | |
with gr.Row(): | |
g_outputs = gr.Gallery(label='Output gallery', elem_id="gallery", | |
columns=[5],object_fit="contain", height="auto") | |
with gr.Column(): | |
sdg_outputs = gr.Gallery(label='Sort Distinct gallery', elem_id="gallery", | |
columns=[5],object_fit="contain", height="auto") | |
sg_outputs = gr.Gallery(label='Sort gallery', elem_id="gallery", | |
columns=[5],object_fit="contain", height="auto") | |
#order_func_name = gr.Radio(choices=hash_func_name, | |
#value=hash_func_name[0], label="Order by", elem_id="order_radio") | |
#gr.Dataframe(label='prompts for corresponding images')] | |
with gr.Row(): | |
with gr.Tab(label = "Download"): | |
zip_button = gr.Button("Zip Images to Download", elem_id="zip_button") | |
downloads = gr.File(label = "Image zipped", elem_id = "zip_file") | |
with gr.Row(): | |
''' | |
gr.Examples( | |
[ | |
["chinese zodiac signs", "lexica", "images/chinese_zodiac_signs.png", "average_hash"], | |
["trending digital art", "lexica", "images/trending_digital_art.png", "colorhash"], | |
["masterpiece, best quality, 1girl, solo, crop top, denim shorts, choker, (graffiti:1.5), paint splatter, arms behind back, against wall, looking at viewer, armband, thigh strap, paint on body, head tilt, bored, multicolored hair, aqua eyes, headset,", "lexica", "images/yuzu_girl0.png", "average_hash"], | |
["beautiful home", "enterpix", "images/beautiful_home.png", "whash"], | |
["interior design of living room", "enterpix", "images/interior_design_of_living_room.png", "whash"], | |
["1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt", | |
"enterpix", "images/waifu_girl0.png", "phash"], | |
], | |
inputs = [inputs, search_func_name, i, order_func_name], | |
label = "Examples" | |
) | |
''' | |
gr.Examples( | |
[ | |
["Chinese ink painting", "lexica", ], | |
["silk road", "lexica", ], | |
["masterpiece, best quality, 1girl, solo, crop top, denim shorts, choker, (graffiti:1.5), paint splatter, arms behind back, against wall, looking at viewer, armband, thigh strap, paint on body, head tilt, bored, multicolored hair, aqua eyes, headset,", "lexica",], | |
["beautiful home", "enterpix", ], | |
["interior design of living room", "enterpix", ], | |
["1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt", | |
"enterpix", ], | |
], | |
inputs = [inputs, search_func_name,], | |
label = "Examples" | |
) | |
#outputs.select(image_click, outputs, i, _js="(x) => x.splice(0,x.length)") | |
#outputs.select(image_click, outputs, i,) | |
''' | |
i.change( | |
fn=swap_gallery, | |
inputs=[i, outputs, order_func_name], | |
outputs=outputs, | |
queue=False | |
) | |
order_func_name.change( | |
fn=swap_gallery, | |
inputs=[i, outputs, order_func_name], | |
outputs=outputs, | |
queue=False | |
) | |
''' | |
g_outputs.select(image_click, | |
inputs = [g_outputs, choose_model], | |
outputs = [sdg_outputs, sg_outputs],) | |
#### gr.Textbox().submit().success() | |
### lexica | |
#text_button.click(lexica, inputs=inputs, outputs=outputs) | |
### enterpix | |
#text_button.click(enterpix, inputs=inputs, outputs=outputs) | |
text_button.click(search, inputs=[inputs, search_func_name,], outputs=g_outputs) | |
zip_button.click( | |
zip_ims, inputs = sdg_outputs, outputs=downloads | |
) | |
demo.launch("0.0.0.0") | |