Recommender / app.py
Vermeer's picture
Create app.py
6ebe235 verified
raw
history blame
1.35 kB
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from torch_geometric.nn import GCNConv, LGConv
from torch_geometric.utils import degree
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import HeteroData, Data
import torch_geometric.transforms as T
from torch_geometric.nn import LightGCN
import utils
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = torch.load("processed_MVL_light.pt")
ch = torch.load('./lightGCNModel_num_layers_MovieLens100K_checkpoint.pt')
lightGCNModel = LightGCN(
num_nodes=data.num_nodes,
embedding_dim=64,
num_layers=3,
).to(device)
optimizer = torch.optim.Adam(lightGCNModel.parameters(), lr=0.005)
mask_train = data.edge_index[0] < data.edge_index[1]
train_edge_label_index = data.edge_index[:, mask_train]
lightGCNModel.load_state_dict(ch['model_state_dict'])
optimizer.load_state_dict(ch['optimizer_state_dict'])
num_items = 1682
num_users = 943
def recommend(user_id):
ground_truth_items, recommendations = utils.predict(lightGCNModel, device, data, num_users, num_items, user_id, train_edge_label_index, k=5)
return ground_truth_items, recommendations
iface = gr.Interface(fn=recommend, inputs="number", outputs=["text", "text"])
iface.launch()