Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import random | |
from tqdm import tqdm | |
import pandas as pd | |
from torch_geometric.datasets import AmazonBook, MovieLens100K, MovieLens1M | |
from torch_geometric.nn import GCNConv, LGConv | |
from torch_geometric.utils import degree | |
from torch_geometric.nn.conv import MessagePassing | |
from torch_geometric.data import HeteroData, Data | |
import torch_geometric.transforms as T | |
def predict(model, device, data, num_users, num_items, user_id, train_edge_label_index, k=5): | |
with torch.no_grad(): | |
## ML100k | |
interaction_dataframe = pd.read_csv('./u1.base', delim_whitespace=True, header=None) | |
meta_dataframe = pd.read_csv('./u.item', sep='|', encoding='latin-1', header=None) | |
interaction_dataframe = interaction_dataframe[[0, 1]] | |
interaction_dataframe.columns = ['reviewerID', 'asin'] | |
meta_dataframe = meta_dataframe[[0, 1]] | |
meta_dataframe.columns = ['asin', 'title'] | |
out = model.get_embedding(data.edge_index) | |
user_emb, item_emb = out[:num_users], out[num_users:] | |
logits = user_emb @ item_emb.t() | |
logits = torch.nn.Sigmoid()(logits) | |
logits[train_edge_label_index[0], train_edge_label_index[1]-num_users] = float('-inf') | |
# Create unique users to find the index of it in embedding table | |
unique_users = interaction_dataframe['reviewerID'].unique().tolist() | |
unique_items = interaction_dataframe['asin'].unique().tolist() | |
random_row = random.randint(0, len(interaction_dataframe)) | |
user_to_rec = interaction_dataframe.iloc[random_row]['reviewerID'] | |
user_to_rec = user_id | |
#user_to_rec = 923 | |
user_rates = logits[unique_users.index(user_to_rec)] | |
# print(f"ID of user we want to recommend to: {user_to_rec}") | |
ground_truth_asins = interaction_dataframe[interaction_dataframe['reviewerID'] == user_to_rec]['asin'].to_list() | |
ground_truth_items = meta_dataframe[meta_dataframe['asin'].isin(ground_truth_asins)].head(5) | |
_, top_ratings = torch.topk(user_rates, k) | |
recommended_items = [] | |
for index in top_ratings: | |
asin_of_item = unique_items[index] | |
recommended_item = meta_dataframe[meta_dataframe['asin'] == asin_of_item]['title'].values | |
recommended_items.append(recommended_item) | |
return ground_truth_items.to_list(), recommended_items | |