import os
import time
import gradio as gr
from gradio.themes import Size, GoogleFont
import sys
import pandas as pd
import webbrowser
from marqo import Client
from PIL import Image
import urllib.request
from PIL import Image
import requests
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
import time
import webbrowser
from transformers import CLIPProcessor, CLIPModel
# model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
# processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip")
static_dir = Path('./static')
static_dir.mkdir(parents=True, exist_ok=True)
# client = Client("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882")
# client = Client()
# index_name = "new_look_expanded_dresses"
# device = "cpu"
class Client_Settings():
def __init__(self):
self.client = Client()
self.index_name = "new_look_expanded_dresses"
self.device = "cpu"
def conn_to_local(self):
self.client = Client()
def conn_to_server(self, url):
self.client = Client(url)
def set_index_name(self, new_index_name):
self.index_name = new_index_name
def set_device(self, new_device):
self.device = new_device
client_obj = Client_Settings()
# client_obj.conn_to_local()
client_obj.conn_to_server("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882")
client_obj.set_index_name("new_look_expanded_dresses")
client_obj.set_device("cuda")
# Create custom Color objects for our primary, secondary, and neutral colors
primary_color = gr.themes.colors.slate
secondary_color = gr.themes.colors.rose
neutral_color = gr.themes.colors.stone # Assuming black for text
# Set the sizes
spacing_size = gr.themes.sizes.spacing_md
radius_size = gr.themes.sizes.radius_md
text_size = gr.themes.sizes.text_md
# Set the fonts
font = GoogleFont("Source Sans Pro")
font_mono = GoogleFont("IBM Plex Mono")
# Create the theme
theme = gr.themes.Base(
primary_hue=primary_color,
secondary_hue=secondary_color,
neutral_hue=neutral_color,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono
)
def filter_by_column(dataset, search_term, column_name) -> pd.DataFrame:
return dataset[dataset[column_name].str.contains(search_term)]
def dedup_by(dataset, column_name) -> pd.DataFrame:
return dataset.drop_duplicates(subset=[column_name])
def drop_secondary_images(dataset) -> pd.DataFrame:
dataset.image = dataset.primary_image
return dataset.drop_duplicates(subset=['primary_image'])
def dataset_to_gallery(dataset: pd.DataFrame, _score=None) -> list:
# convert to list of tuples
new_df = dataset[['_id', 'image', 'name', 'colour_code']].copy()
if type(_score) != type(pd.Series()):
new_df['name_code_combined'] = new_df['name'] + '@@' + new_df['colour_code'].astype(str) + '@@' + new_df['image'].astype(str) + '@@' + new_df['_id'].astype(str)
else:
new_df['name_code_combined'] = (_score).map('{:,.4f}'.format).astype(str) + '@@' + new_df['name'] + '@@' + new_df['colour_code'].astype(str) + '@@' + new_df['image'].astype(str) + '@@' + new_df['_id'].astype(str)
final_df = new_df[['image', 'name_code_combined']]
items = final_df.to_records(index=False).tolist()
return items
def get_items_from_dataset(start_index=0, end_index=50, dataset=pd.read_json('{}')) -> pd.DataFrame:
df = dataset.sort_values(by=['best_seller_score'], ascending=False)
return df[start_index:end_index]
# def return_page(page, dataset: pd.DataFrame):
# start_index = page * result_per_page
# end_index = (page + 1) * result_per_page
# df = get_items_from_dataset(start_index, end_index, dataset)
# return dataset_to_gallery(dedup_by(df, 'colour_code'))
def start_page(num_results=50):
result = client_obj.client.index(client_obj.index_name).search("Dress", score_modifiers = {
"add_to_score": [{"field_name": "best_seller_score","weight": 5}],
}, searchable_attributes=['image'], device=client_obj.device, limit=num_results)
imgs = [r for r in result["hits"]]
return return_results_page(imgs)
def return_results_page(results_list: list):
df = pd.DataFrame(results_list)
df_unique = drop_secondary_images(df)
return dataset_to_gallery(df_unique, df_unique['_score'])
def return_item(combined) -> list:
colour_code = combined.split("@@")[2]
result = client_obj.client.index(client_obj.index_name).search("", filter_string = "colour_code:" + str(colour_code), searchable_attributes=['image'], device=client_obj.device)
imgs = [r for r in result["hits"]]
df = pd.DataFrame(imgs)
return dataset_to_gallery(df), imgs[0]["description_total"], imgs[0]["url"]
def return_specific_item(combined) -> list:
_id = combined.split("@@")[3]
result = client_obj.client.index(client_obj.index_name).search("", filter_string = "_id:" + str(_id), searchable_attributes=['image'], device=client_obj.device)
imgs = [r for r in result["hits"]]
print(imgs)
df = pd.DataFrame(imgs)
return dataset_to_gallery(df)[0][0]
### Load local
def load_image(image_input):
image_input.save("../../../Documents/images/img_path.jpg")
os.system('docker cp "../../../Documents/images/img_path.jpg" marqo:"/images/images/"')
def search_images(query, best_seller_score_weight):
result = client_obj.client.index(client_obj.index_name).search(query, score_modifiers = {
"add_to_score": [{"field_name": "best_seller_score","weight": best_seller_score_weight/1000}],
}, searchable_attributes=['image'], device=client_obj.device, limit=40)
imgs = [r for r in result["hits"]]
return imgs
# def get_labels_probs(labels, image):
# inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
# outputs = model(**inputs)
# logits_per_image = outputs.logits_per_image # this is the image-text similarity score
# probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
# return probs.tolist()[0]
def get_bar_plot(labels, probs):
fig, ax = plt.subplots()
bar_container = ax.bar(labels, probs)
ax.set(ylabel='frequency', title='Labels probabilities\n', ylim=(0, 1))
ax.bar_label(bar_container, fmt='{:,.4f}')
return fig
css = """
.gradio-container {background-color: beige}
button.gallery-item {background-color: grey}
"""
# .label {background-color: grey; width: 80px}
# h1 {background-color: grey; width: 180px}
with gr.Blocks(theme=theme, title="New Look", css=css) as demo:
gr.Markdown(
"""
""")
with gr.Tab(label="Search for images"):
with gr.Row():
with gr.Column(scale=3):
text_input = gr.Text(label="Search with text:")
text_relevance = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1)
text_input_1 = gr.Text(label="Search with text:", visible=False)
text_relevance_1 = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1, visible=False)
more_text_search = gr.Button(value="More text fields")
text_expanded = gr.State(value=False)
with gr.Column(scale=3):
best_seller_score_weight = gr.Slider(label = "Best seller relevance", minimum=-1, maximum=1, value=0, step=0.01)
search_button = gr.Button(value="Search")
with gr.Column(scale=2):
image_input = gr.Image(type="pil", label="Search with image")
image_path = gr.State(visible=False)
image_relevance = gr.Slider(label="Image search relevance", minimum = -5, maximum = 5, value = 1, step = 1)
with gr.Row():
with gr.Column(scale=3):
images_gallery = gr.Gallery(value=start_page(), columns=4,
allow_preview=False, show_label=False, object_fit="contain")
with gr.Column():
detail_gallery = gr.Gallery(value=[], columns=2, allow_preview=False, show_label=False, rows=1,
height="400",object_fit="contain")
image_description = gr.Text(label="Description")
product_link = gr.State()
page = gr.HTML()
def on_new_text_box(more_text_search): # SelectData is a subclass of EventData
if more_text_search == "More text fields":
return gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(value="Hide extra text box")
else:
return gr.update(value="", visible=False, interactive=False), gr.update(visible=False, interactive=False), gr.update(value="More text fields")
def on_focus(evt: gr.SelectData): # SelectData is a subclass of EventData
item = return_item(evt.value)
return item[0], item[1], item[2], gr.update(value=" Go to product page ")
def on_new_image_to_search(images, evt: gr.SelectData): # SelectData is a subclass of EventData
return return_specific_item(evt.value)
# def on_go_to_product_page(product_link):
# # try:
# return ''
more_text_search.click(on_new_text_box, more_text_search, [text_input_1, text_relevance_1, more_text_search])
images_gallery.select(on_focus, None, [detail_gallery, image_description, product_link, page])
detail_gallery.select(on_new_image_to_search, detail_gallery, image_input)
# button_go_to_page.click(on_go_to_product_page, product_link, page)
# with gr.Tab(label="Search for images"):
# labels_input = gr.Text(label="List of labels")
# gr.Examples(
# ["shirt, dress, shoe",
# "short_sleeve, long_sleeve, three_quarter_sleeve, sleeveless, bell_sleeve"],
# labels_input)
# with gr.Row():
# image_labels_input = gr.Image(type="pil", label="Image to compute")
# bar_plot = gr.Plot()
# with gr.Row():
# gr.Examples(
# ["https://media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400",
# "https://media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400"],
# image_labels_input)
# gr.Markdown()
# compute_button = gr.Button(value="Compute")
# response_labels = gr.Text()
with gr.Tab(label="Choose dataset"):
gr.Markdown("# Choose Dataset")
with gr.Row():
list_datasets = gr.Dropdown(["New Look Dresses", "New Look All"], label="Available datasets", value="New Look Dresses")
gr.Markdown()
gr.Markdown()
with gr.Row():
select_dataset_button = gr.Button("Select")
gr.Markdown()
gr.Markdown()
def on_change_dataset(choice):
index_name = ""
if choice == "New Look Dresses":
index_name = "new_look_expanded_dresses"
elif choice == "New Look All":
index_name = "new_look_expanded_all"
print("Dataset selected: " + index_name)
client_obj.set_index_name(index_name)
time.sleep(0.5)
return choice
select_dataset_button.click(on_change_dataset, list_datasets, list_datasets)
def load(image_input):
if image_input != None:
file_name = f"image_to_search.jpg"
# file_path = static_dir / file_name
file_path = "static/" + file_name
print(file_path)
image_input.save(file_path)
return "https://minderalabs-newlook.hf.space/file=" + file_path
else:
return ""
def search(text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight):
# all_queries = [text_input, text_input_1, image_input]
all_queries = [text_input, text_input_1, image_path]
print(all_queries)
all_queries_relevance = [text_relevance, text_relevance_1, image_relevance]
print(all_queries_relevance)
query_is_none = [True if (query == None or query == "") else False for query in all_queries]
print(query_is_none)
if sum([1 if query == False else 0 for query in query_is_none]) == 0:
empty_response = [None] * 5
empty_response.append("")
return []
elif sum([1 if query == False else 0 for query in query_is_none]) == 1:
for i in range(3):
if query_is_none[i] == False:
### Code to run locally
# if i == 2:
# load_image(image_input)
# query = "/images/images/img_path.jpg"
# break
###
query = all_queries[i]
break
else:
query = dict()
for i in range(3):
if query_is_none[i] == False:
### Code to run locally
# if i == 2:
# load_image(image_input)
# query["/images/images/img_path.jpg"] = image_relevance
# continue
###
query[all_queries[i]] = all_queries_relevance[i]
# if text_input == "" and image_input == None:
# empty_response = [None] * 5
# empty_response.append("")
# return empty_response
# if text_input == "":
# load_image(image_input)
# query = "/images/images/img_path.jpg"
# # query = image_path
# elif image_input == None:
# query = text_input
# else:
# query = dict()
# load_image(image_input)
# query["/images/images/img_path.jpg"] = image_relevance
# # query[image_path] = image_relevance
# query[text_input] = text_relevance
list_image_results = []
response = search_images(query, best_seller_score_weight)
# for i in range(len(response)):
# urllib.request.urlretrieve(response[i]["primary_image"], "img_res_path_" + str(i) + ".jpg")
# list_image_results.append(Image.open(r"img_res_path_" + str(i) + r".jpg"))
return return_results_page(response)
# def get_labels(labels_input, image_labels_input):
# labels_probs = get_labels_probs(labels_input.split(","), image_labels_input)
# bar_plot = get_bar_plot(labels_input.split(","), labels_probs)
# return bar_plot, labels_probs
# search_button.click(
# search, [text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight], images_gallery
# )
search_button.click(
load, image_input, image_path
).then(
search, [text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight], [images_gallery]
)
# compute_button.click(
# get_labels, [labels_input, image_labels_input], [bar_plot, response_labels]
# )
demo.queue()
demo.launch()