File size: 1,345 Bytes
6ebe235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle

from torch_geometric.nn import GCNConv, LGConv
from torch_geometric.utils import degree
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import HeteroData, Data
import  torch_geometric.transforms as T
from torch_geometric.nn import LightGCN
import utils

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = torch.load("processed_MVL_light.pt")
ch = torch.load('./lightGCNModel_num_layers_MovieLens100K_checkpoint.pt')
lightGCNModel = LightGCN(
        num_nodes=data.num_nodes,
        embedding_dim=64,
        num_layers=3,
    ).to(device)
optimizer = torch.optim.Adam(lightGCNModel.parameters(), lr=0.005)
mask_train = data.edge_index[0] < data.edge_index[1]
train_edge_label_index = data.edge_index[:, mask_train]
lightGCNModel.load_state_dict(ch['model_state_dict'])
optimizer.load_state_dict(ch['optimizer_state_dict'])
num_items = 1682
num_users = 943

def recommend(user_id):

    ground_truth_items, recommendations = utils.predict(lightGCNModel, device, data, num_users, num_items, user_id, train_edge_label_index, k=5)
    return ground_truth_items, recommendations

iface = gr.Interface(fn=recommend, inputs="number", outputs=["text", "text"])
iface.launch()