Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import gradio as gr | |
| from gradio.themes import Size, GoogleFont | |
| import sys | |
| import pandas as pd | |
| import webbrowser | |
| from marqo import Client | |
| from PIL import Image | |
| import urllib.request | |
| from PIL import Image | |
| import requests | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| from datetime import datetime | |
| import time | |
| import webbrowser | |
| from transformers import CLIPProcessor, CLIPModel | |
| # model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip") | |
| # processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip") | |
| static_dir = Path('./static') | |
| static_dir.mkdir(parents=True, exist_ok=True) | |
| # client = Client("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882") | |
| # client = Client() | |
| # index_name = "new_look_expanded_dresses" | |
| # device = "cpu" | |
| class Client_Settings(): | |
| def __init__(self): | |
| self.client = Client() | |
| self.index_name = "new_look_expanded_dresses" | |
| self.device = "cpu" | |
| def conn_to_local(self): | |
| self.client = Client() | |
| def conn_to_server(self, url): | |
| self.client = Client(url) | |
| def set_index_name(self, new_index_name): | |
| self.index_name = new_index_name | |
| def set_device(self, new_device): | |
| self.device = new_device | |
| client_obj = Client_Settings() | |
| # client_obj.conn_to_local() | |
| client_obj.conn_to_server("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882") | |
| client_obj.set_index_name("new_look_expanded_dresses") | |
| client_obj.set_device("cuda") | |
| # Create custom Color objects for our primary, secondary, and neutral colors | |
| primary_color = gr.themes.colors.slate | |
| secondary_color = gr.themes.colors.rose | |
| neutral_color = gr.themes.colors.stone # Assuming black for text | |
| # Set the sizes | |
| spacing_size = gr.themes.sizes.spacing_md | |
| radius_size = gr.themes.sizes.radius_md | |
| text_size = gr.themes.sizes.text_md | |
| # Set the fonts | |
| font = GoogleFont("Source Sans Pro") | |
| font_mono = GoogleFont("IBM Plex Mono") | |
| # Create the theme | |
| theme = gr.themes.Base( | |
| primary_hue=primary_color, | |
| secondary_hue=secondary_color, | |
| neutral_hue=neutral_color, | |
| spacing_size=spacing_size, | |
| radius_size=radius_size, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono | |
| ) | |
| def filter_by_column(dataset, search_term, column_name) -> pd.DataFrame: | |
| return dataset[dataset[column_name].str.contains(search_term)] | |
| def dedup_by(dataset, column_name) -> pd.DataFrame: | |
| return dataset.drop_duplicates(subset=[column_name]) | |
| def drop_secondary_images(dataset) -> pd.DataFrame: | |
| dataset.image = dataset.primary_image | |
| return dataset.drop_duplicates(subset=['primary_image']) | |
| def dataset_to_gallery(dataset: pd.DataFrame, _score=None) -> list: | |
| # convert to list of tuples | |
| new_df = dataset[['_id', 'image', 'name', 'colour_code']].copy() | |
| if type(_score) != type(pd.Series()): | |
| new_df['name_code_combined'] = new_df['name'] + '@@' + new_df['colour_code'].astype(str) + '@@' + new_df['image'].astype(str) + '@@' + new_df['_id'].astype(str) | |
| else: | |
| new_df['name_code_combined'] = (_score).map('{:,.4f}'.format).astype(str) + '@@' + new_df['name'] + '@@' + new_df['colour_code'].astype(str) + '@@' + new_df['image'].astype(str) + '@@' + new_df['_id'].astype(str) | |
| final_df = new_df[['image', 'name_code_combined']] | |
| items = final_df.to_records(index=False).tolist() | |
| return items | |
| def get_items_from_dataset(start_index=0, end_index=50, dataset=pd.read_json('{}')) -> pd.DataFrame: | |
| df = dataset.sort_values(by=['best_seller_score'], ascending=False) | |
| return df[start_index:end_index] | |
| # def return_page(page, dataset: pd.DataFrame): | |
| # start_index = page * result_per_page | |
| # end_index = (page + 1) * result_per_page | |
| # df = get_items_from_dataset(start_index, end_index, dataset) | |
| # return dataset_to_gallery(dedup_by(df, 'colour_code')) | |
| def start_page(num_results=50): | |
| result = client_obj.client.index(client_obj.index_name).search("Dress", score_modifiers = { | |
| "add_to_score": [{"field_name": "best_seller_score","weight": 5}], | |
| }, searchable_attributes=['image'], device=client_obj.device, limit=num_results) | |
| imgs = [r for r in result["hits"]] | |
| return return_results_page(imgs) | |
| def return_results_page(results_list: list): | |
| df = pd.DataFrame(results_list) | |
| df_unique = drop_secondary_images(df) | |
| return dataset_to_gallery(df_unique, df_unique['_score']) | |
| def return_item(combined) -> list: | |
| colour_code = combined.split("@@")[2] | |
| result = client_obj.client.index(client_obj.index_name).search("", filter_string = "colour_code:" + str(colour_code), searchable_attributes=['image'], device=client_obj.device) | |
| imgs = [r for r in result["hits"]] | |
| df = pd.DataFrame(imgs) | |
| return dataset_to_gallery(df), imgs[0]["description_total"], imgs[0]["url"] | |
| def return_specific_item(combined) -> list: | |
| _id = combined.split("@@")[3] | |
| result = client_obj.client.index(client_obj.index_name).search("", filter_string = "_id:" + str(_id), searchable_attributes=['image'], device=client_obj.device) | |
| imgs = [r for r in result["hits"]] | |
| print(imgs) | |
| df = pd.DataFrame(imgs) | |
| return dataset_to_gallery(df)[0][0] | |
| ### Load local | |
| def load_image(image_input): | |
| image_input.save("../../../Documents/images/img_path.jpg") | |
| os.system('docker cp "../../../Documents/images/img_path.jpg" marqo:"/images/images/"') | |
| def search_images(query, best_seller_score_weight): | |
| result = client_obj.client.index(client_obj.index_name).search(query, score_modifiers = { | |
| "add_to_score": [{"field_name": "best_seller_score","weight": best_seller_score_weight/1000}], | |
| }, searchable_attributes=['image'], device=client_obj.device, limit=40) | |
| imgs = [r for r in result["hits"]] | |
| return imgs | |
| # def get_labels_probs(labels, image): | |
| # inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) | |
| # outputs = model(**inputs) | |
| # logits_per_image = outputs.logits_per_image # this is the image-text similarity score | |
| # probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities | |
| # return probs.tolist()[0] | |
| def get_bar_plot(labels, probs): | |
| fig, ax = plt.subplots() | |
| bar_container = ax.bar(labels, probs) | |
| ax.set(ylabel='frequency', title='Labels probabilities\n', ylim=(0, 1)) | |
| ax.bar_label(bar_container, fmt='{:,.4f}') | |
| return fig | |
| css = """ | |
| .gradio-container {background-color: beige} | |
| button.gallery-item {background-color: grey} | |
| """ | |
| # .label {background-color: grey; width: 80px} | |
| # h1 {background-color: grey; width: 180px} | |
| with gr.Blocks(theme=theme, title="New Look", css=css) as demo: | |
| gr.Markdown( | |
| """ | |
| <div style="vertical-align: middle"> | |
| <div style="float: left"> | |
| <img src="https://1000logos.net/wp-content/uploads/2021/05/New-Look-logo.png" alt="" | |
| width="250" height="250"> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Tab(label="Search for images"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| text_input = gr.Text(label="Search with text:") | |
| text_relevance = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1) | |
| text_input_1 = gr.Text(label="Search with text:", visible=False) | |
| text_relevance_1 = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1, visible=False) | |
| more_text_search = gr.Button(value="More text fields") | |
| text_expanded = gr.State(value=False) | |
| with gr.Column(scale=3): | |
| best_seller_score_weight = gr.Slider(label = "Best seller relevance", minimum=-1, maximum=1, value=0, step=0.01) | |
| search_button = gr.Button(value="Search") | |
| with gr.Column(scale=2): | |
| image_input = gr.Image(type="pil", label="Search with image") | |
| image_path = gr.State(visible=False) | |
| image_relevance = gr.Slider(label="Image search relevance", minimum = -5, maximum = 5, value = 1, step = 1) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| images_gallery = gr.Gallery(value=start_page(), columns=4, | |
| allow_preview=False, show_label=False, object_fit="contain") | |
| with gr.Column(): | |
| detail_gallery = gr.Gallery(value=[], columns=2, allow_preview=False, show_label=False, rows=1, | |
| height="400",object_fit="contain") | |
| image_description = gr.Text(label="Description") | |
| product_link = gr.State() | |
| page = gr.HTML() | |
| def on_new_text_box(more_text_search): # SelectData is a subclass of EventData | |
| if more_text_search == "More text fields": | |
| return gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(value="Hide extra text box") | |
| else: | |
| return gr.update(value="", visible=False, interactive=False), gr.update(visible=False, interactive=False), gr.update(value="More text fields") | |
| def on_focus(evt: gr.SelectData): # SelectData is a subclass of EventData | |
| item = return_item(evt.value) | |
| return item[0], item[1], item[2], gr.update(value="<a href= " + item[2] + " target='_blank'> Go to product page </a>") | |
| def on_new_image_to_search(images, evt: gr.SelectData): # SelectData is a subclass of EventData | |
| return return_specific_item(evt.value) | |
| # def on_go_to_product_page(product_link): | |
| # # try: | |
| # return '<button onclick="window.location.href='+ product_link +';"> Click Here </button>' | |
| more_text_search.click(on_new_text_box, more_text_search, [text_input_1, text_relevance_1, more_text_search]) | |
| images_gallery.select(on_focus, None, [detail_gallery, image_description, product_link, page]) | |
| detail_gallery.select(on_new_image_to_search, detail_gallery, image_input) | |
| # button_go_to_page.click(on_go_to_product_page, product_link, page) | |
| # with gr.Tab(label="Search for images"): | |
| # labels_input = gr.Text(label="List of labels") | |
| # gr.Examples( | |
| # ["shirt, dress, shoe", | |
| # "short_sleeve, long_sleeve, three_quarter_sleeve, sleeveless, bell_sleeve"], | |
| # labels_input) | |
| # with gr.Row(): | |
| # image_labels_input = gr.Image(type="pil", label="Image to compute") | |
| # bar_plot = gr.Plot() | |
| # with gr.Row(): | |
| # gr.Examples( | |
| # ["https://media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400", | |
| # "https://media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400"], | |
| # image_labels_input) | |
| # gr.Markdown() | |
| # compute_button = gr.Button(value="Compute") | |
| # response_labels = gr.Text() | |
| with gr.Tab(label="Choose dataset"): | |
| gr.Markdown("# Choose Dataset") | |
| with gr.Row(): | |
| list_datasets = gr.Dropdown(["New Look Dresses", "New Look All"], label="Available datasets", value="New Look Dresses") | |
| gr.Markdown() | |
| gr.Markdown() | |
| with gr.Row(): | |
| select_dataset_button = gr.Button("Select") | |
| gr.Markdown() | |
| gr.Markdown() | |
| def on_change_dataset(choice): | |
| index_name = "" | |
| if choice == "New Look Dresses": | |
| index_name = "new_look_expanded_dresses" | |
| elif choice == "New Look All": | |
| index_name = "new_look_expanded_all" | |
| print("Dataset selected: " + index_name) | |
| client_obj.set_index_name(index_name) | |
| time.sleep(0.5) | |
| return choice | |
| select_dataset_button.click(on_change_dataset, list_datasets, list_datasets) | |
| def load(image_input): | |
| if image_input != None: | |
| file_name = f"image_to_search.jpg" | |
| # file_path = static_dir / file_name | |
| file_path = "static/" + file_name | |
| print(file_path) | |
| image_input.save(file_path) | |
| return "https://minderalabs-newlook.hf.space/file=" + file_path | |
| else: | |
| return "" | |
| def search(text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight): | |
| # all_queries = [text_input, text_input_1, image_input] | |
| all_queries = [text_input, text_input_1, image_path] | |
| print(all_queries) | |
| all_queries_relevance = [text_relevance, text_relevance_1, image_relevance] | |
| print(all_queries_relevance) | |
| query_is_none = [True if (query == None or query == "") else False for query in all_queries] | |
| print(query_is_none) | |
| if sum([1 if query == False else 0 for query in query_is_none]) == 0: | |
| empty_response = [None] * 5 | |
| empty_response.append("") | |
| return [] | |
| elif sum([1 if query == False else 0 for query in query_is_none]) == 1: | |
| for i in range(3): | |
| if query_is_none[i] == False: | |
| ### Code to run locally | |
| # if i == 2: | |
| # load_image(image_input) | |
| # query = "/images/images/img_path.jpg" | |
| # break | |
| ### | |
| query = all_queries[i] | |
| break | |
| else: | |
| query = dict() | |
| for i in range(3): | |
| if query_is_none[i] == False: | |
| ### Code to run locally | |
| # if i == 2: | |
| # load_image(image_input) | |
| # query["/images/images/img_path.jpg"] = image_relevance | |
| # continue | |
| ### | |
| query[all_queries[i]] = all_queries_relevance[i] | |
| # if text_input == "" and image_input == None: | |
| # empty_response = [None] * 5 | |
| # empty_response.append("") | |
| # return empty_response | |
| # if text_input == "": | |
| # load_image(image_input) | |
| # query = "/images/images/img_path.jpg" | |
| # # query = image_path | |
| # elif image_input == None: | |
| # query = text_input | |
| # else: | |
| # query = dict() | |
| # load_image(image_input) | |
| # query["/images/images/img_path.jpg"] = image_relevance | |
| # # query[image_path] = image_relevance | |
| # query[text_input] = text_relevance | |
| list_image_results = [] | |
| response = search_images(query, best_seller_score_weight) | |
| # for i in range(len(response)): | |
| # urllib.request.urlretrieve(response[i]["primary_image"], "img_res_path_" + str(i) + ".jpg") | |
| # list_image_results.append(Image.open(r"img_res_path_" + str(i) + r".jpg")) | |
| return return_results_page(response) | |
| # def get_labels(labels_input, image_labels_input): | |
| # labels_probs = get_labels_probs(labels_input.split(","), image_labels_input) | |
| # bar_plot = get_bar_plot(labels_input.split(","), labels_probs) | |
| # return bar_plot, labels_probs | |
| # search_button.click( | |
| # search, [text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight], images_gallery | |
| # ) | |
| search_button.click( | |
| load, image_input, image_path | |
| ).then( | |
| search, [text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight], [images_gallery] | |
| ) | |
| # compute_button.click( | |
| # get_labels, [labels_input, image_labels_input], [bar_plot, response_labels] | |
| # ) | |
| demo.queue() | |
| demo.launch() |