|
import sys |
|
import os |
|
import torch |
|
import random |
|
import numpy as np |
|
|
|
from torch.autograd import Variable |
|
from torch.nn.parameter import Parameter |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
|
|
from .ViT import * |
|
from .gcn import GCNBlock |
|
|
|
from torch_geometric.nn import GCNConv, DenseGraphConv, dense_mincut_pool |
|
from torch.nn import Linear |
|
class Classifier(nn.Module): |
|
def __init__(self, n_class): |
|
super(Classifier, self).__init__() |
|
|
|
self.n_class = n_class |
|
self.embed_dim = 64 |
|
self.num_layers = 3 |
|
self.node_cluster_num = 100 |
|
|
|
self.transformer = VisionTransformer(num_classes=n_class, embed_dim=self.embed_dim) |
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) |
|
self.criterion = nn.CrossEntropyLoss() |
|
|
|
self.bn = 1 |
|
self.add_self = 1 |
|
self.normalize_embedding = 1 |
|
self.conv1 = GCNBlock(512,self.embed_dim,self.bn,self.add_self,self.normalize_embedding,0.,0) |
|
self.pool1 = Linear(self.embed_dim, self.node_cluster_num) |
|
|
|
|
|
def forward(self,node_feat,labels,adj,mask,is_print=False, graphcam_flag=False, to_file=True): |
|
|
|
cls_loss=node_feat.new_zeros(self.num_layers) |
|
rank_loss=node_feat.new_zeros(self.num_layers-1) |
|
X=node_feat |
|
p_t=[] |
|
pred_logits=0 |
|
visualize_tools=[] |
|
if labels is not None: |
|
visualize_tools1=[labels.cpu()] |
|
embeds=0 |
|
concats=[] |
|
|
|
layer_acc=[] |
|
|
|
X=mask.unsqueeze(2)*X |
|
X = self.conv1(X, adj, mask) |
|
s = self.pool1(X) |
|
|
|
|
|
graphcam_tensors = {} |
|
|
|
if graphcam_flag: |
|
s_matrix = torch.argmax(s[0], dim=1) |
|
if to_file: |
|
from os import path |
|
os.makedirs('graphcam', exist_ok=True) |
|
torch.save(s_matrix, 'graphcam/s_matrix.pt') |
|
torch.save(s[0], 'graphcam/s_matrix_ori.pt') |
|
|
|
if path.exists('graphcam/att_1.pt'): |
|
os.remove('graphcam/att_1.pt') |
|
os.remove('graphcam/att_2.pt') |
|
os.remove('graphcam/att_3.pt') |
|
|
|
if not to_file: |
|
graphcam_tensors['s_matrix'] = s_matrix |
|
graphcam_tensors['s_matrix_ori'] = s[0] |
|
|
|
|
|
X, adj, mc1, o1 = dense_mincut_pool(X, adj, s, mask) |
|
b, _, _ = X.shape |
|
cls_token = self.cls_token.repeat(b, 1, 1) |
|
X = torch.cat([cls_token, X], dim=1) |
|
|
|
out = self.transformer(X) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
loss = self.criterion(out, labels) |
|
loss = loss + mc1 + o1 |
|
|
|
pred = out.data.max(1)[1] |
|
|
|
if graphcam_flag: |
|
|
|
|
|
p = F.softmax(out) |
|
|
|
if to_file: |
|
torch.save(p, 'graphcam/prob.pt') |
|
if not to_file: |
|
graphcam_tensors['prob'] = p |
|
index = np.argmax(out.cpu().data.numpy(), axis=-1) |
|
|
|
for index_ in range(self.n_class): |
|
one_hot = np.zeros((1, out.size()[-1]), dtype=np.float32) |
|
one_hot[0, index_] = out[0][index_] |
|
one_hot_vector = one_hot |
|
one_hot = torch.from_numpy(one_hot).requires_grad_(True) |
|
one_hot = torch.sum(one_hot.to( 'cuda' if torch.cuda.is_available() else 'cpu') * out) |
|
self.transformer.zero_grad() |
|
one_hot.backward(retain_graph=True) |
|
|
|
kwargs = {"alpha": 1} |
|
cam = self.transformer.relprop(torch.tensor(one_hot_vector).to(X.device), method="transformer_attribution", is_ablation=False, |
|
start_layer=0, **kwargs) |
|
if to_file: |
|
torch.save(cam, 'graphcam/cam_{}.pt'.format(index_)) |
|
if not to_file: |
|
graphcam_tensors[f'cam_{index_}'] = cam |
|
|
|
if not to_file: |
|
return pred,labels,loss, graphcam_tensors |
|
return pred,labels,loss |
|
|