import os
import time
import gradio as gr
from gradio.themes import Size, GoogleFont
import sys
import pandas as pd
import webbrowser
from marqo import Client
from PIL import Image
import urllib.request
from PIL import Image
import requests
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip")
static_dir = Path('./static')
static_dir.mkdir(parents=True, exist_ok=True)
# sys.path.insert(1, 'C:/Users/Alexandre/Documents/University/5_Ano/Estagio/repos_1')
# Create custom Color objects for our primary, secondary, and neutral colors
primary_color = gr.themes.colors.slate
secondary_color = gr.themes.colors.rose
neutral_color = gr.themes.colors.stone # Assuming black for text
# Set the sizes
spacing_size = gr.themes.sizes.spacing_md
radius_size = gr.themes.sizes.radius_md
text_size = gr.themes.sizes.text_md
# Set the fonts
font = GoogleFont("Source Sans Pro")
font_mono = GoogleFont("IBM Plex Mono")
# Create the theme
theme = gr.themes.Base(
primary_hue=primary_color,
secondary_hue=secondary_color,
neutral_hue=neutral_color,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono
)
### Load local
# def load_image(image_input):
# image_input.save("../../../Documents/images/img_path.jpg")
# os.system('docker cp "../../../Documents/images/img_path.jpg" marqo:"/images/images/"')
### Load AWS
def load_image(image_input):
image_input.save("img_path.jpg")
os.system('docker cp "img_path.jpg" marqo:"/local/"')
### Search local
# def search_images(query, best_seller_score_weight):
# client = Client()
# result = client.index("multimodal").search(query, score_modifiers = {
# "add_to_score": [{"field_name": "best_seller_score","weight": best_seller_score_weight/1000}],
# }, searchable_attributes=['primary_image'], device="cpu", limit=5)
# imgs = [r for r in result["hits"]]
# return imgs
### Search AWS
def search_images(query, best_seller_score_weight):
client = Client("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882")
result = client.index("test").search(query, score_modifiers = {
"add_to_score": [{"field_name": "best_seller_score","weight": best_seller_score_weight/1000}],
}, searchable_attributes=['primary_image'], device="cpu", limit=5)
imgs = [r for r in result["hits"]]
return imgs
def get_labels_probs(labels, image):
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
return probs.tolist()[0]
def get_bar_plot(labels, probs):
fig, ax = plt.subplots()
bar_container = ax.bar(labels, probs)
ax.set(ylabel='frequency', title='Labels probabilities\n', ylim=(0, 1))
ax.bar_label(bar_container, fmt='{:,.4f}')
return fig
def get_image_url_in_state(url):
# print("##### URL")
# print(url)
full_url = "https://" + url
return full_url, full_url
css = """
.gradio-container {background-color: beige}
button.gallery-item {background-color: grey}
.label {background-color: grey; width: 80px}
h1 {background-color: grey; width: 180px}
"""
# css = """
# .gradio-container {background-color: beige}
# .gallery-item {
# """
with gr.Blocks(theme=theme, title="New Look", css=css) as demo:
gr.Markdown(
"""
""")
# gr.Markdown(
# """
# # Hello World!
# Start typing below to see the output.
# """, primary_color=gr.themes.colors.stone, secondary_color=gr.themes.colors.stone, neutral_color=gr.themes.colors.stone)
with gr.Tab(label="Search for images"):
# with gr.TabItem(label="Search for images"):
with gr.Row().style(equal_height=False):
text_input = gr.Text(label="Search with text:")
text_relevance = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1)
image_input = gr.Image(type="pil", label="Search with an image")
image_path = gr.State(visible=False)
image_relevance = gr.Slider(label="Image search relevance", minimum = -5, maximum = 5, value = 1, step = 1)
with gr.Row():
gr.Examples(["Green", "Red", "Blue", "Sleeveless", "V-Neck", "Long dress, sleeveless, red"], text_input)
gr.Markdown()
gr.Examples(
[["https://media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400",
"https://media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400"],
["https://media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400",
"https://media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400"]],
# ["media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400",
# "media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400"],
# image_path,
[image_input, image_path],
# get_image_url_in_state,
# cache_examples=True
)
gr.Markdown()
# with gr.Row().style(equal_height=False):
# gr.Markdown()
# image_input = gr.Image(type="pil", label="Search with an image")
# image_relevance = gr.Slider(label="Image search relevance", minimum = -5, maximum = 5, value = 1, step = 1)
# gr.Markdown(scale=10)
# with gr.Row():
# gr.Markdown()
# gr.Examples(
# ["../../../Documents/images/2272.jpg",
# "../../../Documents/images/2697.jpg"],
# image_input)
# gr.Markdown()
# gr.Markdown()
with gr.Row():
gr.Markdown()
best_seller_score_weight = gr.Slider(label = "Best seller relevance", minimum=-1, maximum=1, value=0, step=0.01)
gr.Markdown()
# image_input = gr.Image(type="pil", label="Search with an image")
# image_relevance = gr.Slider(label="Image search relevance", minimum = -5, maximum = 5, value = 1, step = 1)
with gr.Row():
gr.Markdown()
search_button = gr.Button(value="Search")
gr.Markdown()
with gr.Row():
image_res_1 = gr.Image(type="pil")
image_res_2 = gr.Image(type="pil")
image_res_3 = gr.Image(type="pil")
image_res_4 = gr.Image(type="pil")
image_res_5 = gr.Image(type="pil")
response = gr.Text()
with gr.Tab(label="Search for images"):
labels_input = gr.Text(label="List of labels")
gr.Examples(
["shirt, dress, shoe",
"short_sleeve, long_sleeve, three_quarter_sleeve, sleeveless, bell_sleeve"],
labels_input)
with gr.Row():
image_labels_input = gr.Image(type="pil", label="Image to compute")
bar_plot = gr.Plot()
with gr.Row():
gr.Examples(
["https://media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400",
"https://media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400"],
image_labels_input)
gr.Markdown()
compute_button = gr.Button(value="Compute")
response_labels = gr.Text()
with gr.Tab(label="Choose dataset"):
gr.Markdown("# Choose Dataset")
with gr.Row():
gr.Dropdown(["New Look Dresses", "New Look All"], label="Available datasets")
gr.Markdown()
gr.Markdown()
with gr.Row():
gr.Button("Select")
gr.Markdown()
gr.Markdown()
def load(image_input):
file_name = f"{datetime.utcnow().strftime('%s')}.jpg"
file_path = static_dir / file_name
print(file_path)
image_input.save(file_path)
return file_path
def search(text_input, image_input, image_path, text_relevance, image_relevance, best_seller_score_weight):
if text_input == "" and image_input == None:
empty_response = [None] * 5
empty_response.append("")
return empty_response
if text_input == "":
# load_image(image_input)
# query = "/images/images/img_path.jpg"
query = image_path
elif image_input == None:
query = text_input
else:
query = dict()
# load_image(image_input)
# query["/images/images/img_path.jpg"] = image_relevance
query[image_path] = image_relevance
query[text_input] = text_relevance
list_image_results = []
response = search_images(query, best_seller_score_weight)
for i in range(len(response)):
# urllib.request.urlretrieve(response[i]["primary_image"], "../../../Documents/images/img_res_path_" + str(i) + ".jpg")
# list_image_results.append(Image.open(r"../../../Documents/images/img_res_path_" + str(i) + r".jpg"))
urllib.request.urlretrieve(response[i]["primary_image"], "img_res_path_" + str(i) + ".jpg")
list_image_results.append(Image.open(r"img_res_path_" + str(i) + r".jpg"))
return list_image_results[0], list_image_results[1], list_image_results[2], list_image_results[3], list_image_results[4], response
def get_labels(labels_input, image_labels_input):
labels_probs = get_labels_probs(labels_input.split(","), image_labels_input)
bar_plot = get_bar_plot(labels_input.split(","), labels_probs)
return bar_plot, labels_probs
# search_button.click(
# search, [text_input, image_input, image_path, text_relevance, image_relevance, best_seller_score_weight], [image_res_1, image_res_2, image_res_3, image_res_4, image_res_5, response]
# )
search_button.click(
load, image_input, image_path
).then(
search, [text_input, image_input, image_path, text_relevance, image_relevance, best_seller_score_weight], [image_res_1, image_res_2, image_res_3, image_res_4, image_res_5, response]
)
compute_button.click(
get_labels, [labels_input, image_labels_input], [bar_plot, response_labels]
)
# image_input.upload(
# user, image_input, image_input
# ).then(
# respond, response, [image_res_1, image_res_2, image_res_3, image_res_4, image_res_5, response]
# )
# response = isComplete_state.change(
# lambda: gr.update(interactive=False), None, [user_input], queue=False
# ).then(
# respond_itinerary, [chatbot, isComplete_state, dataCollected_state], [chatbot, map, result_df]
# ).then(
# lambda: gr.update(visible=True), None, [map], queue=False
# ).then(
# lambda: gr.update(visible=True), None, [result_df], queue=False
# ).then(
# lambda: gr.update(visible=False), None, [text_map_before_itinerary], queue=False
# )
# response.then(
# lambda: gr.update(interactive=True), None, [user_input], queue=False
# )
# if map != None:
# map.update(visible=True)
# result_df.update(visible=True)
demo.queue()
demo.launch()