import os import time import gradio as gr from gradio.themes import Size, GoogleFont import sys import pandas as pd import webbrowser from marqo import Client from PIL import Image import urllib.request from PIL import Image import requests import matplotlib.pyplot as plt from pathlib import Path from datetime import datetime from transformers import CLIPProcessor, CLIPModel model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip") processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip") static_dir = Path('./static') static_dir.mkdir(parents=True, exist_ok=True) # sys.path.insert(1, 'C:/Users/Alexandre/Documents/University/5_Ano/Estagio/repos_1') # Create custom Color objects for our primary, secondary, and neutral colors primary_color = gr.themes.colors.slate secondary_color = gr.themes.colors.rose neutral_color = gr.themes.colors.stone # Assuming black for text # Set the sizes spacing_size = gr.themes.sizes.spacing_md radius_size = gr.themes.sizes.radius_md text_size = gr.themes.sizes.text_md # Set the fonts font = GoogleFont("Source Sans Pro") font_mono = GoogleFont("IBM Plex Mono") # Create the theme theme = gr.themes.Base( primary_hue=primary_color, secondary_hue=secondary_color, neutral_hue=neutral_color, spacing_size=spacing_size, radius_size=radius_size, text_size=text_size, font=font, font_mono=font_mono ) ### Load local # def load_image(image_input): # image_input.save("../../../Documents/images/img_path.jpg") # os.system('docker cp "../../../Documents/images/img_path.jpg" marqo:"/images/images/"') ### Load AWS def load_image(image_input): image_input.save("img_path.jpg") os.system('docker cp "img_path.jpg" marqo:"/local/"') ### Search local # def search_images(query, best_seller_score_weight): # client = Client() # result = client.index("multimodal").search(query, score_modifiers = { # "add_to_score": [{"field_name": "best_seller_score","weight": best_seller_score_weight/1000}], # }, searchable_attributes=['primary_image'], device="cpu", limit=5) # imgs = [r for r in result["hits"]] # return imgs ### Search AWS def search_images(query, best_seller_score_weight): client = Client("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882") result = client.index("test").search(query, score_modifiers = { "add_to_score": [{"field_name": "best_seller_score","weight": best_seller_score_weight/1000}], }, searchable_attributes=['primary_image'], device="cpu", limit=5) imgs = [r for r in result["hits"]] return imgs def get_labels_probs(labels, image): inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) outputs = model(**inputs) logits_per_image = outputs.logits_per_image # this is the image-text similarity score probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities return probs.tolist()[0] def get_bar_plot(labels, probs): fig, ax = plt.subplots() bar_container = ax.bar(labels, probs) ax.set(ylabel='frequency', title='Labels probabilities\n', ylim=(0, 1)) ax.bar_label(bar_container, fmt='{:,.4f}') return fig def get_image_url_in_state(url): # print("##### URL") # print(url) full_url = "https://" + url return full_url, full_url css = """ .gradio-container {background-color: beige} button.gallery-item {background-color: grey} .label {background-color: grey; width: 80px} h1 {background-color: grey; width: 180px} """ # css = """ # .gradio-container {background-color: beige} # .gallery-item { # """ with gr.Blocks(theme=theme, title="New Look", css=css) as demo: gr.Markdown( """
""") # gr.Markdown( # """ # # Hello World! # Start typing below to see the output. # """, primary_color=gr.themes.colors.stone, secondary_color=gr.themes.colors.stone, neutral_color=gr.themes.colors.stone) with gr.Tab(label="Search for images"): # with gr.TabItem(label="Search for images"): with gr.Row().style(equal_height=False): text_input = gr.Text(label="Search with text:") text_relevance = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1) image_input = gr.Image(type="pil", label="Search with an image") image_path = gr.State(visible=False) image_relevance = gr.Slider(label="Image search relevance", minimum = -5, maximum = 5, value = 1, step = 1) with gr.Row(): gr.Examples(["Green", "Red", "Blue", "Sleeveless", "V-Neck", "Long dress, sleeveless, red"], text_input) gr.Markdown() gr.Examples( [["https://media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400", "https://media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400"], ["https://media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400", "https://media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400"]], # ["media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400", # "media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400"], # image_path, [image_input, image_path], # get_image_url_in_state, # cache_examples=True ) gr.Markdown() # with gr.Row().style(equal_height=False): # gr.Markdown() # image_input = gr.Image(type="pil", label="Search with an image") # image_relevance = gr.Slider(label="Image search relevance", minimum = -5, maximum = 5, value = 1, step = 1) # gr.Markdown(scale=10) # with gr.Row(): # gr.Markdown() # gr.Examples( # ["../../../Documents/images/2272.jpg", # "../../../Documents/images/2697.jpg"], # image_input) # gr.Markdown() # gr.Markdown() with gr.Row(): gr.Markdown() best_seller_score_weight = gr.Slider(label = "Best seller relevance", minimum=-1, maximum=1, value=0, step=0.01) gr.Markdown() # image_input = gr.Image(type="pil", label="Search with an image") # image_relevance = gr.Slider(label="Image search relevance", minimum = -5, maximum = 5, value = 1, step = 1) with gr.Row(): gr.Markdown() search_button = gr.Button(value="Search") gr.Markdown() with gr.Row(): image_res_1 = gr.Image(type="pil") image_res_2 = gr.Image(type="pil") image_res_3 = gr.Image(type="pil") image_res_4 = gr.Image(type="pil") image_res_5 = gr.Image(type="pil") response = gr.Text() with gr.Tab(label="Search for images"): labels_input = gr.Text(label="List of labels") gr.Examples( ["shirt, dress, shoe", "short_sleeve, long_sleeve, three_quarter_sleeve, sleeveless, bell_sleeve"], labels_input) with gr.Row(): image_labels_input = gr.Image(type="pil", label="Image to compute") bar_plot = gr.Plot() with gr.Row(): gr.Examples( ["https://media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400", "https://media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400"], image_labels_input) gr.Markdown() compute_button = gr.Button(value="Compute") response_labels = gr.Text() with gr.Tab(label="Choose dataset"): gr.Markdown("# Choose Dataset") with gr.Row(): gr.Dropdown(["New Look Dresses", "New Look All"], label="Available datasets") gr.Markdown() gr.Markdown() with gr.Row(): gr.Button("Select") gr.Markdown() gr.Markdown() def load(image_input): file_name = f"{datetime.utcnow().strftime('%s')}.jpg" file_path = static_dir / file_name print(file_path) image_input.save(file_path) return file_path def search(text_input, image_input, image_path, text_relevance, image_relevance, best_seller_score_weight): if text_input == "" and image_input == None: empty_response = [None] * 5 empty_response.append("") return empty_response if text_input == "": # load_image(image_input) # query = "/images/images/img_path.jpg" query = image_path elif image_input == None: query = text_input else: query = dict() # load_image(image_input) # query["/images/images/img_path.jpg"] = image_relevance query[image_path] = image_relevance query[text_input] = text_relevance list_image_results = [] response = search_images(query, best_seller_score_weight) for i in range(len(response)): # urllib.request.urlretrieve(response[i]["primary_image"], "../../../Documents/images/img_res_path_" + str(i) + ".jpg") # list_image_results.append(Image.open(r"../../../Documents/images/img_res_path_" + str(i) + r".jpg")) urllib.request.urlretrieve(response[i]["primary_image"], "img_res_path_" + str(i) + ".jpg") list_image_results.append(Image.open(r"img_res_path_" + str(i) + r".jpg")) return list_image_results[0], list_image_results[1], list_image_results[2], list_image_results[3], list_image_results[4], response def get_labels(labels_input, image_labels_input): labels_probs = get_labels_probs(labels_input.split(","), image_labels_input) bar_plot = get_bar_plot(labels_input.split(","), labels_probs) return bar_plot, labels_probs # search_button.click( # search, [text_input, image_input, image_path, text_relevance, image_relevance, best_seller_score_weight], [image_res_1, image_res_2, image_res_3, image_res_4, image_res_5, response] # ) search_button.click( load, image_input, image_path ).then( search, [text_input, image_input, image_path, text_relevance, image_relevance, best_seller_score_weight], [image_res_1, image_res_2, image_res_3, image_res_4, image_res_5, response] ) compute_button.click( get_labels, [labels_input, image_labels_input], [bar_plot, response_labels] ) # image_input.upload( # user, image_input, image_input # ).then( # respond, response, [image_res_1, image_res_2, image_res_3, image_res_4, image_res_5, response] # ) # response = isComplete_state.change( # lambda: gr.update(interactive=False), None, [user_input], queue=False # ).then( # respond_itinerary, [chatbot, isComplete_state, dataCollected_state], [chatbot, map, result_df] # ).then( # lambda: gr.update(visible=True), None, [map], queue=False # ).then( # lambda: gr.update(visible=True), None, [result_df], queue=False # ).then( # lambda: gr.update(visible=False), None, [text_map_before_itinerary], queue=False # ) # response.then( # lambda: gr.update(interactive=True), None, [user_input], queue=False # ) # if map != None: # map.update(visible=True) # result_df.update(visible=True) demo.queue() demo.launch()