from cortex_DIM.nn_modules.mi_networks import MIFCNet, MI1x1ConvNet from torch import optim from torch.autograd import Variable import json import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class GlobalDiscriminator(nn.Module): def __init__(self, args, input_dim): super().__init__() self.l0 = nn.Linear(32, 32) self.l1 = nn.Linear(32, 32) self.l2 = nn.Linear(512, 1) def forward(self, y, M, data): adj = Variable(data['adj'].float(), requires_grad=False).cuda() # h0 = Variable(data['feats'].float()).cuda() batch_num_nodes = data['num_nodes'].int().numpy() M, _ = self.encoder(M, adj, batch_num_nodes) # h = F.relu(self.c0(M)) # h = self.c1(h) # h = h.view(y.shape[0], -1) h = torch.cat((y, M), dim=1) h = F.relu(self.l0(h)) h = F.relu(self.l1(h)) return self.l2(h) class PriorDiscriminator(nn.Module): def __init__(self, input_dim): super().__init__() self.l0 = nn.Linear(input_dim, input_dim) self.l1 = nn.Linear(input_dim, input_dim) self.l2 = nn.Linear(input_dim, 1) def forward(self, x): h = F.relu(self.l0(x)) h = F.relu(self.l1(h)) return torch.sigmoid(self.l2(h)) class FF(nn.Module): def __init__(self, input_dim): super().__init__() # self.c0 = nn.Conv1d(input_dim, 512, kernel_size=1) # self.c1 = nn.Conv1d(512, 512, kernel_size=1) # self.c2 = nn.Conv1d(512, 1, kernel_size=1) self.block = nn.Sequential( nn.Linear(input_dim, input_dim), nn.ReLU(), nn.Linear(input_dim, input_dim), nn.ReLU(), nn.Linear(input_dim, input_dim), nn.ReLU() ) self.linear_shortcut = nn.Linear(input_dim, input_dim) # self.c0 = nn.Conv1d(input_dim, 512, kernel_size=1, stride=1, padding=0) # self.c1 = nn.Conv1d(512, 512, kernel_size=1, stride=1, padding=0) # self.c2 = nn.Conv1d(512, 1, kernel_size=1, stride=1, padding=0) def forward(self, x): return self.block(x) + self.linear_shortcut(x)