Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import pickle | |
from torch_geometric.nn import GCNConv, LGConv | |
from torch_geometric.utils import degree | |
from torch_geometric.nn.conv import MessagePassing | |
from torch_geometric.data import HeteroData, Data | |
import torch_geometric.transforms as T | |
from torch_geometric.nn import LightGCN | |
import utils | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
data = torch.load("processed_MVL_light.pt") | |
ch = torch.load('./lightGCNModel_num_layers_MovieLens100K_checkpoint.pt') | |
lightGCNModel = LightGCN( | |
num_nodes=data.num_nodes, | |
embedding_dim=64, | |
num_layers=3, | |
).to(device) | |
optimizer = torch.optim.Adam(lightGCNModel.parameters(), lr=0.005) | |
mask_train = data.edge_index[0] < data.edge_index[1] | |
train_edge_label_index = data.edge_index[:, mask_train] | |
lightGCNModel.load_state_dict(ch['model_state_dict']) | |
optimizer.load_state_dict(ch['optimizer_state_dict']) | |
num_items = 1682 | |
num_users = 943 | |
def recommend(user_id): | |
ground_truth_items, recommendations = utils.predict(lightGCNModel, device, data, num_users, num_items, user_id, train_edge_label_index, k=5) | |
return ground_truth_items, recommendations | |
iface = gr.Interface(fn=recommend, inputs="number", outputs=["text", "text"]) | |
iface.launch() |