|
import pandas as pd |
|
from tqdm import tqdm |
|
import numpy as np |
|
import itertools |
|
import requests |
|
import sys |
|
|
|
from pyvis.network import Network |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn import Linear |
|
from arango import ArangoClient |
|
import torch_geometric.transforms as T |
|
from torch_geometric.nn import SAGEConv, to_hetero |
|
from torch_geometric.transforms import RandomLinkSplit, ToUndirected |
|
from sentence_transformers import SentenceTransformer |
|
from torch_geometric.data import HeteroData |
|
import yaml |
|
|
|
|
|
import pickle |
|
|
|
def net_repr_html(self): |
|
nodes, edges, height, width, options = self.get_network_data() |
|
html = self.template.render(height=height, width=width, nodes=nodes, edges=edges, options=options) |
|
return html |
|
|
|
Network._repr_html_ = net_repr_html |
|
|
|
|
|
class GNNEncoder(torch.nn.Module): |
|
def __init__(self, hidden_channels, out_channels): |
|
super().__init__() |
|
|
|
self.conv1 = SAGEConv((-1, -1), hidden_channels) |
|
self.conv2 = SAGEConv((-1, -1), out_channels) |
|
|
|
def forward(self, x, edge_index): |
|
x = self.conv1(x, edge_index).relu() |
|
x = self.conv2(x, edge_index) |
|
return x |
|
|
|
class EdgeDecoder(torch.nn.Module): |
|
def __init__(self, hidden_channels): |
|
super().__init__() |
|
self.lin1 = Linear(2 * hidden_channels, hidden_channels) |
|
self.lin2 = Linear(hidden_channels, 1) |
|
|
|
def forward(self, z_dict, edge_label_index): |
|
row, col = edge_label_index |
|
|
|
z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1) |
|
|
|
z = self.lin1(z).relu() |
|
z = self.lin2(z) |
|
return z.view(-1) |
|
|
|
class Model(torch.nn.Module): |
|
def __init__(self, hidden_channels): |
|
super().__init__() |
|
self.encoder = GNNEncoder(hidden_channels, hidden_channels) |
|
self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum') |
|
self.decoder = EdgeDecoder(hidden_channels) |
|
|
|
def forward(self, x_dict, edge_index_dict, edge_label_index): |
|
|
|
z_dict = self.encoder(x_dict, edge_index_dict) |
|
return self.decoder(z_dict, edge_label_index) |
|
|
|
def load_hetero_data(): |
|
with open('Hgraph.pkl', 'rb') as file: |
|
global data |
|
data = pickle.load(file) |
|
return data |
|
|
|
def load_model(data): |
|
|
|
model = Model(hidden_channels=32) |
|
with torch.no_grad(): |
|
model.encoder(data.x_dict, data.edge_index_dict) |
|
model.load_state_dict(torch.load('model.pt',map_location=torch.device('cpu'))) |
|
model.eval() |
|
return model |
|
|
|
global id_map |
|
with open('id_map.pkl', 'rb') as file: |
|
id_map = pickle.load(file) |
|
|
|
global m_id |
|
with open('m_id.pkl', 'rb') as file: |
|
m_id = pickle.load(file) |
|
|
|
def get_title(idx): |
|
return id_map.loc[id_map['movieId'] == m_id[idx]].index[0] |
|
|
|
def get_recommendation(model,data,user_id): |
|
|
|
total_movies = 9025 |
|
|
|
user_row = torch.tensor([user_id] * total_movies) |
|
all_movie_ids = torch.arange(total_movies) |
|
edge_label_index = torch.stack([user_row, all_movie_ids], dim=0) |
|
pred = model(data.x_dict, data.edge_index_dict,edge_label_index) |
|
pred = pred.clamp(min=0, max=5) |
|
|
|
rec_movie_ids = (pred == 5).nonzero(as_tuple=True) |
|
top_ten_recs = [rec_movies for rec_movies in rec_movie_ids[0].tolist()[:10]] |
|
top_ten_recs = [get_title(movie_idx) for movie_idx in top_ten_recs] |
|
return {'user': user_id, 'rec_movies': top_ten_recs} |
|
|
|
|
|
def make_1_hop_graph(data,user_id): |
|
a = data["user", "rates", "movie"].edge_index |
|
b = data["user", "rates", "movie"].edge_label |
|
idxs = (a[0] == user_id).nonzero(as_tuple=True)[0] |
|
ratings = b[idxs] |
|
movie_idxs = a[1][idxs] |
|
|
|
|
|
n = len(ratings) |
|
net = Network(notebook=True) |
|
for i in range(n): |
|
|
|
Source = user_id |
|
lab = get_title(movie_idxs[i]) |
|
Target = movie_idxs[i] + 671 |
|
weight = ratings[i].item() |
|
|
|
net.add_node(Source, label=str(Source),color='#FF0000') |
|
net.add_node(Target.item(), label=lab) |
|
net.add_edge(Source, Target.item(), title=weight) |
|
|
|
net.save_graph('index.html') |
|
|
|
|
|
|
|
|