AioMedica2 / models /GraphTransformer.py
chris1nexus
First commit
54660f7
raw
history blame contribute delete
No virus
4.37 kB
import sys
import os
import torch
import random
import numpy as np
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from .ViT import *
from .gcn import GCNBlock
from torch_geometric.nn import GCNConv, DenseGraphConv, dense_mincut_pool
from torch.nn import Linear
class Classifier(nn.Module):
def __init__(self, n_class):
super(Classifier, self).__init__()
self.n_class = n_class
self.embed_dim = 64
self.num_layers = 3
self.node_cluster_num = 100
self.transformer = VisionTransformer(num_classes=n_class, embed_dim=self.embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.criterion = nn.CrossEntropyLoss()
self.bn = 1
self.add_self = 1
self.normalize_embedding = 1
self.conv1 = GCNBlock(512,self.embed_dim,self.bn,self.add_self,self.normalize_embedding,0.,0) # 64->128
self.pool1 = Linear(self.embed_dim, self.node_cluster_num) # 100-> 20
def forward(self,node_feat,labels,adj,mask,is_print=False, graphcam_flag=False, to_file=True):
# node_feat, labels = self.PrepareFeatureLabel(batch_graph)
cls_loss=node_feat.new_zeros(self.num_layers)
rank_loss=node_feat.new_zeros(self.num_layers-1)
X=node_feat
p_t=[]
pred_logits=0
visualize_tools=[]
if labels is not None:
visualize_tools1=[labels.cpu()]
embeds=0
concats=[]
layer_acc=[]
X=mask.unsqueeze(2)*X
X = self.conv1(X, adj, mask)
s = self.pool1(X)
graphcam_tensors = {}
if graphcam_flag:
s_matrix = torch.argmax(s[0], dim=1)
if to_file:
from os import path
os.makedirs('graphcam', exist_ok=True)
torch.save(s_matrix, 'graphcam/s_matrix.pt')
torch.save(s[0], 'graphcam/s_matrix_ori.pt')
if path.exists('graphcam/att_1.pt'):
os.remove('graphcam/att_1.pt')
os.remove('graphcam/att_2.pt')
os.remove('graphcam/att_3.pt')
if not to_file:
graphcam_tensors['s_matrix'] = s_matrix
graphcam_tensors['s_matrix_ori'] = s[0]
X, adj, mc1, o1 = dense_mincut_pool(X, adj, s, mask)
b, _, _ = X.shape
cls_token = self.cls_token.repeat(b, 1, 1)
X = torch.cat([cls_token, X], dim=1)
out = self.transformer(X)
loss = None
if labels is not None:
# loss
loss = self.criterion(out, labels)
loss = loss + mc1 + o1
# pred
pred = out.data.max(1)[1]
if graphcam_flag:
#print('GraphCAM enabled')
#print(out.shape)
p = F.softmax(out)
#print(p.shape)
if to_file:
torch.save(p, 'graphcam/prob.pt')
if not to_file:
graphcam_tensors['prob'] = p
index = np.argmax(out.cpu().data.numpy(), axis=-1)
for index_ in range(self.n_class):
one_hot = np.zeros((1, out.size()[-1]), dtype=np.float32)
one_hot[0, index_] = out[0][index_]
one_hot_vector = one_hot
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot.to( 'cuda' if torch.cuda.is_available() else 'cpu') * out) #!!!!!!!!!!!!!!!!!!!!out-->p
self.transformer.zero_grad()
one_hot.backward(retain_graph=True)
kwargs = {"alpha": 1}
cam = self.transformer.relprop(torch.tensor(one_hot_vector).to(X.device), method="transformer_attribution", is_ablation=False,
start_layer=0, **kwargs)
if to_file:
torch.save(cam, 'graphcam/cam_{}.pt'.format(index_))
if not to_file:
graphcam_tensors[f'cam_{index_}'] = cam
if not to_file:
return pred,labels,loss, graphcam_tensors
return pred,labels,loss