File size: 4,186 Bytes
5cc7af1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import pandas as pd
from torch_geometric.data import HeteroData
from torch_geometric.nn import SAGEConv, to_hetero
from torch.nn import Linear

# Load the trained model
class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['products'][col]], dim=-1)
        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)

class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)

# Load data and model
data_path = 'data/product_data_PYG.pt'
model_path = 'models/amazon_best_model.pt'
reviews_path = 'data/organized_reviews.csv'
user_mapping_path = 'data/user_mapping.json'
rev_user_mapping_path = 'data/rev_user_mapping.json'

print("Loading data...")
data = torch.load(data_path, map_location=torch.device('cpu'))
device = 'cpu'
data = data.to(device)

print("Loading model...")
model = Model(hidden_channels=32).to(device)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

print("Loading reviews dataframe...")
reviews_df = pd.read_csv(reviews_path)

print("Loading user mappings...")
user_mapping = pd.read_json(user_mapping_path, typ='series').to_dict()
rev_user_mapping = pd.read_json(rev_user_mapping_path, typ='series').to_dict()

# Function to get the username from user_id
def get_username(user_id):
    if user_id not in reviews_df['user_id'].values:
        raise ValueError(f"User ID {user_id} not found in reviews_df")
    return reviews_df[reviews_df['user_id'] == user_id]['username'].iloc[0]

# Function to get product recommendations
def get_product_recommendations(model, data, user_id, total_products):
    user_idx = user_mapping[user_id]  # Get the embedding index for the user_id
    user_row = torch.tensor([user_idx] * total_products).to(device)
    all_product_ids = torch.arange(total_products).to(device)
    edge_label_index = torch.stack([user_row, all_product_ids], dim=0)
    pred = model(data.x_dict, data.edge_index_dict, edge_label_index).cpu()
    top_five_indices = pred.topk(5).indices.numpy()  # Ensure indices are integers for indexing
    recommendations = []
    for idx in top_five_indices:
        idx = int(idx)  # Convert to integer for indexing
        product_id = reviews_df.iloc[idx]['product_id']
        category = reviews_df.iloc[idx]['category']
        subcategory = reviews_df.iloc[idx]['subcategory']
        recommendations.append((product_id, category, subcategory))
    return recommendations

# Function to get and print recommendations for a given user
def get_recommendations(user_id):
    try:
        user_id = str(user_id)
        username = get_username(user_id)
        recommendations = get_product_recommendations(model, data, user_id, data['products'].x.shape[0])
        return f"Recommendations for {username} (User ID: {user_id}):", recommendations
    except Exception as e:
        return f"Error: {str(e)}", []

if __name__ == "__main__":
    # For testing the recommendation functionality
    user_id = 'A314APAWYQFKBJ'  # Example user ID
    recommendations_title, recommendations = get_recommendations(user_id)
    print(recommendations_title)
    print(recommendations)