import gradio as gr import pandas as pd import pickle from sentence_transformers import SentenceTransformer, util import re mdl_name = 'sentence-transformers/all-distilroberta-v1' model = SentenceTransformer(mdl_name) embedding_cache_path = "scotch_embd_distilroberta.pkl" with open(embedding_cache_path, "rb") as fIn: cache_data = pickle.load(fIn) embedding_table = cache_data["embeddings"] reviews = cache_data["data"] reviews['price'] = reviews.price.apply(lambda x: re.findall("\d+", x.replace(",","").replace(".00","").replace("$",""))[0]).astype('int') def user_query_recommend(query, price_rng): # Embed user query embedding = model.encode(query) # Calculate similarity with all reviews sim_scores = util.cos_sim(embedding, embedding_table) #print(sim_scores.shape) # Recommend recommendations = reviews.copy() recommendations['sim'] = sim_scores.T if price_rng == "$0-$70": min_p, max_p = 0, 70 if price_rng == "$70-$150": min_p, max_p = 70, 150 if price_rng == "$150+": min_p, max_p = 150, 10000 op=recommendations\ .groupby("name")\ .sim.nlargest(2)\ .reset_index()\ [["name","sim"]] op = pd.merge(op, recommendations[['name', 'category', 'price', 'description','description_sent','sim']], how="left",on=["name",'sim']) op = op.loc[(op.price >= min_p) & (op.price <= max_p), ['name', 'category', 'price', 'description', 'description_sent','sim']].sort_values('sim',ascending=False)\ .groupby(['name', 'category', 'price', 'description'])\ .agg({"description_sent": lambda x: " ".join(x), "sim":['max']})\ .reset_index()\ .set_axis(['name', 'category', 'price', 'description', 'description_sent','sim'],axis="columns") #op = op.loc[(op.price >= min_p) & (op.price <= max_p), ['name', 'price', 'description_sent']] return op[['name', 'price', 'description_sent']].reset_index(drop=True).head(6) interface = gr.Interface( user_query_recommend, inputs=[gr.inputs.Textbox(lines=5, label = "enter flavour profile"), gr.inputs.Radio(choices = ["$0-$70", "$70-$150", "$150+"], default="$0-$70", type="value", label='Price range')], outputs=gr.outputs.Dataframe(max_rows=3, overflow_row_behaviour="paginate", type="pandas", label="Scotch recommendations"), title = "Scotch Recommendation", description = "Looking for scotch recommendations and have some flavours in mind? \nGet recommendations at a preferred price range using semantic search :) ", examples=[["very sweet with lemons and oranges and marmalades", "$0-$70"], ["smoky peaty and wood fire","$70-$150"], ["salty and spicy with exotic fruits", "$150+"], ["fragrant nose with chocolate, custard, toffee, pudding and caramel", "$70-$150"], ], theme="grass", ) interface.launch( enable_queue=True, #cache_examples=True, )