add main files
Browse files
api.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
from utils.similarity import get_similar_items
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
clothing_data = pd.read_csv('data/clothing_data_preprocessed.csv')
|
9 |
+
model = SentenceTransformer('model')
|
10 |
+
embeddings = np.load('data/embeddings.npy')
|
11 |
+
|
12 |
+
app = FastAPI()
|
13 |
+
|
14 |
+
class Query(BaseModel):
|
15 |
+
query: str
|
16 |
+
|
17 |
+
@app.post("/predict")
|
18 |
+
def getURL(query: Query):
|
19 |
+
# Get the query from the request payload
|
20 |
+
query_text = query.query
|
21 |
+
# Call your function to retrieve similar item URLs
|
22 |
+
similar_urls = get_similar_items(query_text, embeddings, clothing_data, 5)
|
23 |
+
return {"similar_urls": similar_urls}
|
24 |
+
|
25 |
+
if __name__ == '__main__':
|
26 |
+
import uvicorn
|
27 |
+
uvicorn.run(app, host='0.0.0.0', port=8080)
|
main.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from utils.similarity import get_similar_items
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import markdown
|
7 |
+
import random
|
8 |
+
|
9 |
+
# Create title, description, and article strings
|
10 |
+
title = "Clothing Similarity Search 👕"
|
11 |
+
description = "**Transformer-based search engine** to fetch Amazon URLs for similar clothing items given a text description.\n\n**Data Collection**:\nTo scrape quality clothing data containing proper description and URL for the product, Apify's Amazon Product Scraper was used. The scraped data for various clothing categories was downloaded into a CSV file.\n\n**Data Cleaning**:\nPandas was used to clean and preprocess the text data by removing special characters, lowercasing, and applying text normalization techniques.\n\n**Making Embeddings**:\nSentence-transformers library was used to generate embeddings for the cleaned data using the [all-MiniLM-L6-v2](https://example.com/model-card) model. The embeddings were saved into a .npy file for faster similarity search retrieval.\n\n**Cosine Similarity**:\nCosine similarity was used to find the similarity between the query and the product embeddings.\n"
|
12 |
+
|
13 |
+
model = SentenceTransformer('model')
|
14 |
+
embeddings = np.load('data/embeddings.npy')
|
15 |
+
clothing_data = pd.read_csv('data/clothing_data_preprocessed.csv')
|
16 |
+
|
17 |
+
def getURL(text, top_k):
|
18 |
+
# Call your function to retrieve similar item URLs
|
19 |
+
similar_urls = get_similar_items(text, embeddings, clothing_data, top_k)
|
20 |
+
return similar_urls
|
21 |
+
|
22 |
+
input_text = gr.components.Textbox(lines=1, label="Input Descriptiont")
|
23 |
+
input_top_k = gr.components.Slider(label="Number of Recommendations", minimum=1, maximum=10, step=1, default=5)
|
24 |
+
output_html = gr.outputs.HTML(label="Similar Items")
|
25 |
+
|
26 |
+
def process_text(text, top_k):
|
27 |
+
urls = getURL(text, top_k)
|
28 |
+
random.shuffle(urls) # Shuffle the URLs for variety
|
29 |
+
html_links = "<br>".join([f'<a href="{url}" target="_blank">{url}</a>' for url in urls])
|
30 |
+
return f'<div style="padding: 20px">{html_links}</div>'
|
31 |
+
|
32 |
+
iface = gr.Interface(
|
33 |
+
fn=process_text,
|
34 |
+
inputs=[input_text, input_top_k],
|
35 |
+
outputs=output_html,
|
36 |
+
title=title,
|
37 |
+
description=description,
|
38 |
+
examples=[
|
39 |
+
["casual men's t-shirt", 3],
|
40 |
+
["stylish summer dress", 5],
|
41 |
+
["elegant evening gown", 7],
|
42 |
+
],
|
43 |
+
theme="default",
|
44 |
+
layout="vertical",
|
45 |
+
interpretation="default",
|
46 |
+
allow_flagging="never",
|
47 |
+
)
|
48 |
+
|
49 |
+
iface.launch()
|