recsys-hw / app.py
feeeper's picture
data, train/predict script
33e1108
raw
history blame contribute delete
No virus
1.84 kB
from surprise import SVDpp
from surprise import Dataset
from surprise import Reader
import streamlit as st
import pandas as pd
import pickle
def predict():
with open('./model.pkl', 'rb') as f:
model = pickle.load(f)
titles = pd.read_csv('./amazon-books-titles.csv')
predictions = []
for row in titles.iterrows():
predictions.append({'title': row[1]['title'], 'rating': model.predict(x, row[1]['asin']).est})
sorted_predictions = sorted(predictions, key=lambda p: -p['rating'])
st.write(pd.DataFrame(sorted_predictions))
def train_model():
books = pd.read_csv('./amazon-books.zip')
titles = pd.read_csv('./amazon-books-titles.csv')
current_user_book_ids = titles[titles['title'].isin([t.strip() for t in y])]['asin'].values
current_user_ratings = pd.DataFrame({
'reviewerID': [x] * len(current_user_book_ids),
'asin': current_user_book_ids,
'overall': [5] * len(current_user_book_ids)
})
st.write(current_user_ratings)
books = books.append(current_user_ratings)
data = Dataset.load_from_df(books[['reviewerID', 'asin', 'overall']], Reader(line_format='user item rating', rating_scale=(1, 5)))
trainset = data.build_full_trainset()
best_params = {
'n_epochs': 15,
'lr_all': 0.004760245463611792,
'reg_all': 0.40040712444861504,
'random_state': 42
}
algo = SVDpp(**best_params)
algo.fit(trainset)
with open('./model.pkl', 'wb') as f:
pickle.dump(algo, f)
with open('./titles.txt', 'r', encoding='utf8') as f:
options = f.readlines()
x = st.text_input('uid')
y = st.multiselect('Select book', options)
btn = st.button('Submit', on_click=train_model)
btn2 = st.button('Predict', on_click=predict)