samarthagarwal23's picture
Update app.py
fc825b8
raw
history blame
3.07 kB
import gradio as gr
import pandas as pd
import pickle
from sentence_transformers import SentenceTransformer, util
import re
mdl_name = 'sentence-transformers/all-distilroberta-v1'
model = SentenceTransformer(mdl_name)
embedding_cache_path = "scotch_embd_distilroberta.pkl"
with open(embedding_cache_path, "rb") as fIn:
cache_data = pickle.load(fIn)
embedding_table = cache_data["embeddings"]
reviews = cache_data["data"]
reviews['price'] = reviews.price.apply(lambda x: re.findall("\d+", x.replace(",","").replace(".00","").replace("$",""))[0]).astype('int')
def user_query_recommend(query, price_rng):
# Embed user query
embedding = model.encode(query)
# Calculate similarity with all reviews
sim_scores = util.cos_sim(embedding, embedding_table)
#print(sim_scores.shape)
# Recommend
recommendations = reviews.copy()
recommendations['sim'] = sim_scores.T
if price_rng == "$0-$70":
min_p, max_p = 0, 70
if price_rng == "$70-$150":
min_p, max_p = 70, 150
if price_rng == "$150+":
min_p, max_p = 150, 10000
op=recommendations\
.groupby("name")\
.sim.nlargest(2)\
.reset_index()\
[["name","sim"]]
op = pd.merge(op,
recommendations[['name', 'category', 'price', 'description','description_sent','sim']],
how="left",on=["name",'sim'])
op = op.loc[(op.price >= min_p) & (op.price <= max_p),
['name', 'category', 'price', 'description', 'description_sent','sim']].sort_values('sim',ascending=False)\
.groupby(['name', 'category', 'price', 'description'])\
.agg({"description_sent": lambda x: " ".join(x),
"sim":['max']})\
.reset_index()\
.set_axis(['name', 'category', 'price', 'description', 'description_sent','sim'],axis="columns")
#op = op.loc[(op.price >= min_p) & (op.price <= max_p), ['name', 'price', 'description_sent']]
return op[['name', 'price', 'description_sent']].reset_index(drop=True).head(6)
interface = gr.Interface(
user_query_recommend,
inputs=[gr.inputs.Textbox(lines=5, label = "enter flavour profile"),
gr.inputs.Radio(choices = ["$0-$70", "$70-$150", "$150+"], default="$0-$70", type="value", label='Price range')],
outputs=gr.outputs.Dataframe(max_rows=3, overflow_row_behaviour="paginate", type="pandas", label="Scotch recommendations"),
title = "Scotch Recommendation",
description = "Looking for scotch recommendations and have some flavours in mind? \nGet recommendations at a preferred price range using semantic search :) ",
examples=[["very sweet with lemons and oranges and marmalades", "$0-$70"],
["smoky peaty and wood fire","$70-$150"],
["salty and spicy with exotic fruits", "$150+"],
["fragrant nose with chocolate, custard, toffee, pudding and caramel", "$70-$150"],
],
theme="grass",
)
interface.launch(
enable_queue=True,
#cache_examples=True,
)