Spaces:
Runtime error
Runtime error
import os | |
import time | |
import gradio as gr | |
from gradio.themes import Size, GoogleFont | |
import sys | |
import pandas as pd | |
import webbrowser | |
from marqo import Client | |
from PIL import Image | |
import urllib.request | |
from PIL import Image | |
import requests | |
import matplotlib.pyplot as plt | |
from pathlib import Path | |
from datetime import datetime | |
import time | |
import webbrowser | |
from transformers import CLIPProcessor, CLIPModel | |
# model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip") | |
# processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip") | |
static_dir = Path('./static') | |
static_dir.mkdir(parents=True, exist_ok=True) | |
# client = Client("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882") | |
# client = Client() | |
# index_name = "new_look_expanded_dresses" | |
# device = "cpu" | |
class Client_Settings(): | |
def __init__(self): | |
self.client = Client() | |
self.index_name = "new_look_expanded_dresses" | |
self.device = "cpu" | |
def conn_to_local(self): | |
self.client = Client() | |
def conn_to_server(self, url): | |
self.client = Client(url) | |
def set_index_name(self, new_index_name): | |
self.index_name = new_index_name | |
def set_device(self, new_device): | |
self.device = new_device | |
client_obj = Client_Settings() | |
# client_obj.conn_to_local() | |
client_obj.conn_to_server("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882") | |
client_obj.set_index_name("new_look_expanded_dresses") | |
client_obj.set_device("cuda") | |
# Create custom Color objects for our primary, secondary, and neutral colors | |
primary_color = gr.themes.colors.slate | |
secondary_color = gr.themes.colors.rose | |
neutral_color = gr.themes.colors.stone # Assuming black for text | |
# Set the sizes | |
spacing_size = gr.themes.sizes.spacing_md | |
radius_size = gr.themes.sizes.radius_md | |
text_size = gr.themes.sizes.text_md | |
# Set the fonts | |
font = GoogleFont("Source Sans Pro") | |
font_mono = GoogleFont("IBM Plex Mono") | |
# Create the theme | |
theme = gr.themes.Base( | |
primary_hue=primary_color, | |
secondary_hue=secondary_color, | |
neutral_hue=neutral_color, | |
spacing_size=spacing_size, | |
radius_size=radius_size, | |
text_size=text_size, | |
font=font, | |
font_mono=font_mono | |
) | |
def filter_by_column(dataset, search_term, column_name) -> pd.DataFrame: | |
return dataset[dataset[column_name].str.contains(search_term)] | |
def dedup_by(dataset, column_name) -> pd.DataFrame: | |
return dataset.drop_duplicates(subset=[column_name]) | |
def drop_secondary_images(dataset) -> pd.DataFrame: | |
dataset.image = dataset.primary_image | |
return dataset.drop_duplicates(subset=['primary_image']) | |
def dataset_to_gallery(dataset: pd.DataFrame, _score=None) -> list: | |
# convert to list of tuples | |
new_df = dataset[['_id', 'image', 'name', 'colour_code']].copy() | |
if type(_score) != type(pd.Series()): | |
new_df['name_code_combined'] = new_df['name'] + '@@' + new_df['colour_code'].astype(str) + '@@' + new_df['image'].astype(str) + '@@' + new_df['_id'].astype(str) | |
else: | |
new_df['name_code_combined'] = (_score).map('{:,.4f}'.format).astype(str) + '@@' + new_df['name'] + '@@' + new_df['colour_code'].astype(str) + '@@' + new_df['image'].astype(str) + '@@' + new_df['_id'].astype(str) | |
final_df = new_df[['image', 'name_code_combined']] | |
items = final_df.to_records(index=False).tolist() | |
return items | |
def get_items_from_dataset(start_index=0, end_index=50, dataset=pd.read_json('{}')) -> pd.DataFrame: | |
df = dataset.sort_values(by=['best_seller_score'], ascending=False) | |
return df[start_index:end_index] | |
# def return_page(page, dataset: pd.DataFrame): | |
# start_index = page * result_per_page | |
# end_index = (page + 1) * result_per_page | |
# df = get_items_from_dataset(start_index, end_index, dataset) | |
# return dataset_to_gallery(dedup_by(df, 'colour_code')) | |
def start_page(num_results=50): | |
result = client_obj.client.index(client_obj.index_name).search("Dress", score_modifiers = { | |
"add_to_score": [{"field_name": "best_seller_score","weight": 5}], | |
}, searchable_attributes=['image'], device=client_obj.device, limit=num_results) | |
imgs = [r for r in result["hits"]] | |
return return_results_page(imgs) | |
def return_results_page(results_list: list): | |
df = pd.DataFrame(results_list) | |
df_unique = drop_secondary_images(df) | |
return dataset_to_gallery(df_unique, df_unique['_score']) | |
def return_item(combined) -> list: | |
colour_code = combined.split("@@")[2] | |
result = client_obj.client.index(client_obj.index_name).search("", filter_string = "colour_code:" + str(colour_code), searchable_attributes=['image'], device=client_obj.device) | |
imgs = [r for r in result["hits"]] | |
df = pd.DataFrame(imgs) | |
return dataset_to_gallery(df), imgs[0]["description_total"], imgs[0]["url"] | |
def return_specific_item(combined) -> list: | |
_id = combined.split("@@")[3] | |
result = client_obj.client.index(client_obj.index_name).search("", filter_string = "_id:" + str(_id), searchable_attributes=['image'], device=client_obj.device) | |
imgs = [r for r in result["hits"]] | |
print(imgs) | |
df = pd.DataFrame(imgs) | |
return dataset_to_gallery(df)[0][0] | |
### Load local | |
def load_image(image_input): | |
image_input.save("../../../Documents/images/img_path.jpg") | |
os.system('docker cp "../../../Documents/images/img_path.jpg" marqo:"/images/images/"') | |
def search_images(query, best_seller_score_weight): | |
result = client_obj.client.index(client_obj.index_name).search(query, score_modifiers = { | |
"add_to_score": [{"field_name": "best_seller_score","weight": best_seller_score_weight/1000}], | |
}, searchable_attributes=['image'], device=client_obj.device, limit=40) | |
imgs = [r for r in result["hits"]] | |
return imgs | |
# def get_labels_probs(labels, image): | |
# inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) | |
# outputs = model(**inputs) | |
# logits_per_image = outputs.logits_per_image # this is the image-text similarity score | |
# probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities | |
# return probs.tolist()[0] | |
def get_bar_plot(labels, probs): | |
fig, ax = plt.subplots() | |
bar_container = ax.bar(labels, probs) | |
ax.set(ylabel='frequency', title='Labels probabilities\n', ylim=(0, 1)) | |
ax.bar_label(bar_container, fmt='{:,.4f}') | |
return fig | |
css = """ | |
.gradio-container {background-color: beige} | |
button.gallery-item {background-color: grey} | |
""" | |
# .label {background-color: grey; width: 80px} | |
# h1 {background-color: grey; width: 180px} | |
with gr.Blocks(theme=theme, title="New Look", css=css) as demo: | |
gr.Markdown( | |
""" | |
<div style="vertical-align: middle"> | |
<div style="float: left"> | |
<img src="https://1000logos.net/wp-content/uploads/2021/05/New-Look-logo.png" alt="" | |
width="250" height="250"> | |
</div> | |
</div> | |
""") | |
with gr.Tab(label="Search for images"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
text_input = gr.Text(label="Search with text:") | |
text_relevance = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1) | |
text_input_1 = gr.Text(label="Search with text:", visible=False) | |
text_relevance_1 = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1, visible=False) | |
more_text_search = gr.Button(value="More text fields") | |
text_expanded = gr.State(value=False) | |
with gr.Column(scale=3): | |
best_seller_score_weight = gr.Slider(label = "Best seller relevance", minimum=-1, maximum=1, value=0, step=0.01) | |
search_button = gr.Button(value="Search") | |
with gr.Column(scale=2): | |
image_input = gr.Image(type="pil", label="Search with image") | |
image_path = gr.State(visible=False) | |
image_relevance = gr.Slider(label="Image search relevance", minimum = -5, maximum = 5, value = 1, step = 1) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
images_gallery = gr.Gallery(value=start_page(), columns=4, | |
allow_preview=False, show_label=False, object_fit="contain") | |
with gr.Column(): | |
detail_gallery = gr.Gallery(value=[], columns=2, allow_preview=False, show_label=False, rows=1, | |
height="400",object_fit="contain") | |
image_description = gr.Text(label="Description") | |
product_link = gr.State() | |
page = gr.HTML() | |
def on_new_text_box(more_text_search): # SelectData is a subclass of EventData | |
if more_text_search == "More text fields": | |
return gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(value="Hide extra text box") | |
else: | |
return gr.update(value="", visible=False, interactive=False), gr.update(visible=False, interactive=False), gr.update(value="More text fields") | |
def on_focus(evt: gr.SelectData): # SelectData is a subclass of EventData | |
item = return_item(evt.value) | |
return item[0], item[1], item[2], gr.update(value="<a href= " + item[2] + " target='_blank'> Go to product page </a>") | |
def on_new_image_to_search(images, evt: gr.SelectData): # SelectData is a subclass of EventData | |
return return_specific_item(evt.value) | |
# def on_go_to_product_page(product_link): | |
# # try: | |
# return '<button onclick="window.location.href='+ product_link +';"> Click Here </button>' | |
more_text_search.click(on_new_text_box, more_text_search, [text_input_1, text_relevance_1, more_text_search]) | |
images_gallery.select(on_focus, None, [detail_gallery, image_description, product_link, page]) | |
detail_gallery.select(on_new_image_to_search, detail_gallery, image_input) | |
# button_go_to_page.click(on_go_to_product_page, product_link, page) | |
# with gr.Tab(label="Search for images"): | |
# labels_input = gr.Text(label="List of labels") | |
# gr.Examples( | |
# ["shirt, dress, shoe", | |
# "short_sleeve, long_sleeve, three_quarter_sleeve, sleeveless, bell_sleeve"], | |
# labels_input) | |
# with gr.Row(): | |
# image_labels_input = gr.Image(type="pil", label="Image to compute") | |
# bar_plot = gr.Plot() | |
# with gr.Row(): | |
# gr.Examples( | |
# ["https://media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400", | |
# "https://media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400"], | |
# image_labels_input) | |
# gr.Markdown() | |
# compute_button = gr.Button(value="Compute") | |
# response_labels = gr.Text() | |
with gr.Tab(label="Choose dataset"): | |
gr.Markdown("# Choose Dataset") | |
with gr.Row(): | |
list_datasets = gr.Dropdown(["New Look Dresses", "New Look All"], label="Available datasets", value="New Look Dresses") | |
gr.Markdown() | |
gr.Markdown() | |
with gr.Row(): | |
select_dataset_button = gr.Button("Select") | |
gr.Markdown() | |
gr.Markdown() | |
def on_change_dataset(choice): | |
index_name = "" | |
if choice == "New Look Dresses": | |
index_name = "new_look_expanded_dresses" | |
elif choice == "New Look All": | |
index_name = "new_look_expanded_all" | |
print("Dataset selected: " + index_name) | |
client_obj.set_index_name(index_name) | |
time.sleep(0.5) | |
return choice | |
select_dataset_button.click(on_change_dataset, list_datasets, list_datasets) | |
def load(image_input): | |
if image_input != None: | |
file_name = f"image_to_search.jpg" | |
# file_path = static_dir / file_name | |
file_path = "static/" + file_name | |
print(file_path) | |
image_input.save(file_path) | |
return "https://minderalabs-newlook.hf.space/file=" + file_path | |
else: | |
return "" | |
def search(text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight): | |
# all_queries = [text_input, text_input_1, image_input] | |
all_queries = [text_input, text_input_1, image_path] | |
print(all_queries) | |
all_queries_relevance = [text_relevance, text_relevance_1, image_relevance] | |
print(all_queries_relevance) | |
query_is_none = [True if (query == None or query == "") else False for query in all_queries] | |
print(query_is_none) | |
if sum([1 if query == False else 0 for query in query_is_none]) == 0: | |
empty_response = [None] * 5 | |
empty_response.append("") | |
return [] | |
elif sum([1 if query == False else 0 for query in query_is_none]) == 1: | |
for i in range(3): | |
if query_is_none[i] == False: | |
### Code to run locally | |
# if i == 2: | |
# load_image(image_input) | |
# query = "/images/images/img_path.jpg" | |
# break | |
### | |
query = all_queries[i] | |
break | |
else: | |
query = dict() | |
for i in range(3): | |
if query_is_none[i] == False: | |
### Code to run locally | |
# if i == 2: | |
# load_image(image_input) | |
# query["/images/images/img_path.jpg"] = image_relevance | |
# continue | |
### | |
query[all_queries[i]] = all_queries_relevance[i] | |
# if text_input == "" and image_input == None: | |
# empty_response = [None] * 5 | |
# empty_response.append("") | |
# return empty_response | |
# if text_input == "": | |
# load_image(image_input) | |
# query = "/images/images/img_path.jpg" | |
# # query = image_path | |
# elif image_input == None: | |
# query = text_input | |
# else: | |
# query = dict() | |
# load_image(image_input) | |
# query["/images/images/img_path.jpg"] = image_relevance | |
# # query[image_path] = image_relevance | |
# query[text_input] = text_relevance | |
list_image_results = [] | |
response = search_images(query, best_seller_score_weight) | |
# for i in range(len(response)): | |
# urllib.request.urlretrieve(response[i]["primary_image"], "img_res_path_" + str(i) + ".jpg") | |
# list_image_results.append(Image.open(r"img_res_path_" + str(i) + r".jpg")) | |
return return_results_page(response) | |
# def get_labels(labels_input, image_labels_input): | |
# labels_probs = get_labels_probs(labels_input.split(","), image_labels_input) | |
# bar_plot = get_bar_plot(labels_input.split(","), labels_probs) | |
# return bar_plot, labels_probs | |
# search_button.click( | |
# search, [text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight], images_gallery | |
# ) | |
search_button.click( | |
load, image_input, image_path | |
).then( | |
search, [text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight], [images_gallery] | |
) | |
# compute_button.click( | |
# get_labels, [labels_input, image_labels_input], [bar_plot, response_labels] | |
# ) | |
demo.queue() | |
demo.launch() |