Polo123's picture
Update logic2.py
ce40dcc verified
import pandas as pd
from tqdm import tqdm
import numpy as np
import itertools
import requests
import sys
from pyvis.network import Network
import torch
import torch.nn.functional as F
from torch.nn import Linear
from arango import ArangoClient
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.transforms import RandomLinkSplit, ToUndirected
from sentence_transformers import SentenceTransformer
from torch_geometric.data import HeteroData
import yaml
import pickle
def net_repr_html(self):
nodes, edges, height, width, options = self.get_network_data()
html = self.template.render(height=height, width=width, nodes=nodes, edges=edges, options=options)
return html
Network._repr_html_ = net_repr_html
#----------------------------------------------
# SAGE model
class GNNEncoder(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
# these convolutions have been replicated to match the number of edge types
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
class EdgeDecoder(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
self.lin1 = Linear(2 * hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, 1)
def forward(self, z_dict, edge_label_index):
row, col = edge_label_index
# concat user and movie embeddings
z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
# concatenated embeddings passed to linear layer
z = self.lin1(z).relu()
z = self.lin2(z)
return z.view(-1)
class Model(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
self.encoder = GNNEncoder(hidden_channels, hidden_channels)
self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
self.decoder = EdgeDecoder(hidden_channels)
def forward(self, x_dict, edge_index_dict, edge_label_index):
# z_dict contains dictionary of movie and user embeddings returned from GraphSage
z_dict = self.encoder(x_dict, edge_index_dict)
return self.decoder(z_dict, edge_label_index)
#----------------------------------------------
def load_hetero_data():
with open('Hgraph.pkl', 'rb') as file:
global data
data = pickle.load(file)
return data
def load_model(data):
model = Model(hidden_channels=32)
with torch.no_grad():
model.encoder(data.x_dict, data.edge_index_dict)
model.load_state_dict(torch.load('model.pt',map_location=torch.device('cpu')))
model.eval()
return model
global id_map
with open('id_map.pkl', 'rb') as file:
id_map = pickle.load(file)
global m_id
with open('m_id.pkl', 'rb') as file:
m_id = pickle.load(file)
def get_title(idx):
return id_map.loc[id_map['movieId'] == m_id[idx]].index[0]
def get_recommendation(model,data,user_id):
total_movies = 9025
user_row = torch.tensor([user_id] * total_movies)
all_movie_ids = torch.arange(total_movies)
edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)
pred = model(data.x_dict, data.edge_index_dict,edge_label_index)
pred = pred.clamp(min=0, max=5)
# we will only select movies for the user where the predicting rating is =5
rec_movie_ids = (pred == 5).nonzero(as_tuple=True)
top_ten_recs = [rec_movies for rec_movies in rec_movie_ids[0].tolist()[:10]]
top_ten_recs = [get_title(movie_idx) for movie_idx in top_ten_recs]
return {'user': user_id, 'rec_movies': top_ten_recs}
def make_1_hop_graph(data,user_id):
a = data["user", "rates", "movie"].edge_index
b = data["user", "rates", "movie"].edge_label
idxs = (a[0] == user_id).nonzero(as_tuple=True)[0]
ratings = b[idxs]#.tolist()
movie_idxs = a[1][idxs]#.tolist()
n = len(ratings)
net = Network(notebook=True)
for i in range(n):
#print(i)
Source = user_id
lab = get_title(movie_idxs[i])
Target = movie_idxs[i] + 671 # Addition for sperating movie with user_id
weight = ratings[i].item()
net.add_node(Source, label=str(Source),color='#FF0000')
net.add_node(Target.item(), label=lab)
net.add_edge(Source, Target.item(), title=weight)
net.save_graph('index.html')