Recommender / utils.py
Vermeer's picture
Upload 7 files
89e6926 verified
raw
history blame
2.53 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from tqdm import tqdm
import pandas as pd
from torch_geometric.datasets import AmazonBook, MovieLens100K, MovieLens1M
from torch_geometric.nn import GCNConv, LGConv
from torch_geometric.utils import degree
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import HeteroData, Data
import torch_geometric.transforms as T
def predict(model, device, data, num_users, num_items, user_id, train_edge_label_index, k=5):
with torch.no_grad():
## ML100k
interaction_dataframe = pd.read_csv('./u1.base', delim_whitespace=True, header=None)
meta_dataframe = pd.read_csv('./u.item', sep='|', encoding='latin-1', header=None)
interaction_dataframe = interaction_dataframe[[0, 1]]
interaction_dataframe.columns = ['reviewerID', 'asin']
meta_dataframe = meta_dataframe[[0, 1]]
meta_dataframe.columns = ['asin', 'title']
out = model.get_embedding(data.edge_index)
user_emb, item_emb = out[:num_users], out[num_users:]
logits = user_emb @ item_emb.t()
logits = torch.nn.Sigmoid()(logits)
logits[train_edge_label_index[0], train_edge_label_index[1]-num_users] = float('-inf')
# Create unique users to find the index of it in embedding table
unique_users = interaction_dataframe['reviewerID'].unique().tolist()
unique_items = interaction_dataframe['asin'].unique().tolist()
random_row = random.randint(0, len(interaction_dataframe))
user_to_rec = interaction_dataframe.iloc[random_row]['reviewerID']
user_to_rec = user_id
#user_to_rec = 923
user_rates = logits[unique_users.index(user_to_rec)]
# print(f"ID of user we want to recommend to: {user_to_rec}")
ground_truth_asins = interaction_dataframe[interaction_dataframe['reviewerID'] == user_to_rec]['asin'].to_list()
ground_truth_items = meta_dataframe[meta_dataframe['asin'].isin(ground_truth_asins)].head(5)
_, top_ratings = torch.topk(user_rates, k)
recommended_items = []
for index in top_ratings:
asin_of_item = unique_items[index]
recommended_item = meta_dataframe[meta_dataframe['asin'] == asin_of_item]['title'].values
recommended_items.append(recommended_item)
return ground_truth_items.to_list(), recommended_items