import gradio as gr import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import pickle from torch_geometric.nn import GCNConv, LGConv from torch_geometric.utils import degree from torch_geometric.nn.conv import MessagePassing from torch_geometric.data import HeteroData, Data import torch_geometric.transforms as T from torch_geometric.nn import LightGCN import utils device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data = torch.load("processed_MVL_light.pt") ch = torch.load('./lightGCNModel_num_layers_MovieLens100K_checkpoint.pt') lightGCNModel = LightGCN( num_nodes=data.num_nodes, embedding_dim=64, num_layers=3, ).to(device) optimizer = torch.optim.Adam(lightGCNModel.parameters(), lr=0.005) mask_train = data.edge_index[0] < data.edge_index[1] train_edge_label_index = data.edge_index[:, mask_train] lightGCNModel.load_state_dict(ch['model_state_dict']) optimizer.load_state_dict(ch['optimizer_state_dict']) num_items = 1682 num_users = 943 def recommend(user_id): ground_truth_items, recommendations = utils.predict(lightGCNModel, device, data, num_users, num_items, user_id, train_edge_label_index, k=5) return ground_truth_items, recommendations iface = gr.Interface(fn=recommend, inputs="number", outputs=["text", "text"]) iface.launch()