GastonOrphant's picture
Update app.py
e4d42ed
raw
history blame contribute delete
No virus
3.81 kB
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
# Account credentials
storage_account_name = "datawarehousegoogle"
storage_account_access_key = "yZ35ST91n+BnuyU3Bgrt9jnmELS6yl3h04UameA5SJiPoo51htikj2Wl2NdwfaiKyMRJb59vT1ov+AStu1yRsw=="
file_location = "wasbs://datasets@datawarehousegoogle.blob.core.windows.net/"
file_type = "parquet"
spark.conf.set(
"fs.azure.account.key."+storage_account_name+".blob.core.windows.net",
storage_account_access_key)
# Read the data
df = spark.read.format(file_type).option("inferSchema", "true").load(file_location)
# drops rows where main_category is not restaurant
restaurants = df.filter(df.main_category == "food services")
#Import the functions library as F to do a hash of user_id
from pyspark.sql import functions as F
#Hashing these columns so they have integers type
restaurants = restaurants.withColumn("user_id_hash", F.hash(restaurants.user_id))
restaurants = restaurants.withColumn("business_id_hash", F.hash(restaurants.business_id))
#Drop the columns that we don't need for the recommendation system
df_ml = restaurants.drop("latitude", "longitude", "main_category", "date", "resp", "opinion", "platform")
# Loading the trained model
from pyspark.ml.recommendation import ALS, ALSModel
path = "/model/modelALS"
model = ALSModel.load(path)
# RECOMMENDATION SYSTEM
# Function to retrieve the restaurant info
def name_retriever(business_id_hash, restaurants):
return (restaurants.where(restaurants.business_id_hash == business_id_hash).take(1)[0]['local_name'], restaurants.where(restaurants.business_id_hash == business_id_hash).take(1)[0]['latitude'], restaurants.where(restaurants.business_id_hash == business_id_hash).take(1)[0]['longitude'])
from pyspark.sql.functions import rand
# Selecting a Random user for now we are using user_id_hash
#usr_id = df_ml.select('user_id_hash').orderBy(rand()).limit(1).collect()
#my_user = [val.user_id_hash for val in usr_id][0]
#my_user = -709842381
def user_recommendation(my_user):
# Opening the dataframe previusly saved
try:
recommendations = spark.read.table("recommendations")
except:
# make recommendations for all users using the recommendForAllUsers method
# we stablish the number of recommendations to show
num_recs = 5
recommendations = model.recommendForAllUsers(num_recs)
#Saving the dataframe
recommendations.write.format("parquet").saveAsTable("recommendations")
# get recommendations specifically for the user
recs_for_user = recommendations.where(recommendations.user_id_hash == my_user).take(1)
#for ranking, (business_id_hash, rating) in enumerate(recs_for_user[0]['recommendations']):
# local_name, latitude, longitude = name_retriever(business_id_hash, restaurants)
#print(f'Recommendation {ranking+1}: {local_name}. coordenates: {latitude}, {longitude}')
string = ""
for ranking, (business_id_hash, rating) in enumerate(recs_for_user[0]['recommendations']):
local_name, latitude, longitude = name_retriever(business_id_hash, restaurants)
string = string + "Recommendation "+ str(ranking+1) + ": " + str(local_name) + ". Coordenates: " + str(latitude) + ", " + str(longitude) + "\n"
return string
import gradio as gr
title = str("Recommendation System")
with gr.Blocks(title= title) as demo:
text = gr.components.HTML("""
<h1>Welcome to the Recommendation System!</h1>
""")
userId = gr.inputs.Number(label="Enter your ID")
get_recommendation_btn = gr.Button("Recommend me!")
#title = gr.Textbox(label = "Local Name:")
output = gr.Textbox(label="You can go to the following restaurants:")
get_recommendation_btn.click(fn=user_recommendation, inputs=[userId], outputs=[output])#outputs=[title, output])
demo.launch()