AioMedica / models /gcn.py
chris1nexus
First commit
d60982d
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import math
import numpy as np
torch.set_printoptions(precision=2,threshold=float('inf'))
class AGCNBlock(nn.Module):
def __init__(self,input_dim,hidden_dim,gcn_layer=2,dropout=0.0,relu=0):
super(AGCNBlock,self).__init__()
if dropout > 0.001:
self.dropout_layer = nn.Dropout(p=dropout)
self.sort = 'sort'
self.model='agcn'
self.gcns=nn.ModuleList()
self.bn = 0
self.add_self = 1
self.normalize_embedding = 1
self.gcns.append(GCNBlock(input_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,relu))
self.pool = 'mean'
self.tau = 1.
self.lamda = 1.
for i in range(gcn_layer-1):
if i==gcn_layer-2 and (not 1):
self.gcns.append(GCNBlock(hidden_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,0))
else:
self.gcns.append(GCNBlock(hidden_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,relu))
if self.model=='diffpool':
self.pool_gcns=nn.ModuleList()
tmp=input_dim
self.diffpool_k=200
for i in range(3):
self.pool_gcns.append(GCNBlock(tmp,200,0,0,0,dropout,relu))
tmp=200
self.w_a=nn.Parameter(torch.zeros(1,hidden_dim,1))
self.w_b=nn.Parameter(torch.zeros(1,hidden_dim,1))
torch.nn.init.normal_(self.w_a)
torch.nn.init.uniform_(self.w_b,-1,1)
self.pass_dim=hidden_dim
if self.pool=='mean':
self.pool=self.mean_pool
elif self.pool=='max':
self.pool=self.max_pool
elif self.pool=='sum':
self.pool=self.sum_pool
self.softmax='global'
if self.softmax=='gcn':
self.att_gcn=GCNBlock(2,1,0,0,dropout,relu)
self.khop=1
self.adj_norm='none'
self.filt_percent=0.25 #default 0.5
self.eps=1e-10
self.tau_config=1
if 1==-1.:
self.tau=nn.Parameter(torch.tensor(1),requires_grad=False)
elif 1==-2.:
self.tau_fc=nn.Linear(hidden_dim,1)
torch.nn.init.constant_(self.tau_fc.bias,1)
torch.nn.init.xavier_normal_(self.tau_fc.weight.t())
else:
self.tau=nn.Parameter(torch.tensor(self.tau))
self.lamda1=nn.Parameter(torch.tensor(self.lamda))
self.lamda2=nn.Parameter(torch.tensor(self.lamda))
self.att_norm=0
self.dnorm=0
self.dnorm_coe=1
self.att_out=0
self.single_att=0
def forward(self,X,adj,mask,is_print=False):
'''
input:
X: node input features , [batch,node_num,input_dim],dtype=float
adj: adj matrix, [batch,node_num,node_num], dtype=float
mask: mask for nodes, [batch,node_num]
outputs:
out:unormalized classification prob, [batch,hidden_dim]
H: batch of node hidden features, [batch,node_num,pass_dim]
new_adj: pooled new adj matrix, [batch, k_max, k_max]
new_mask: [batch, k_max]
'''
hidden=X
#adj = adj.float()
# print('input size:')
# print(hidden.shape)
is_print1=is_print2=is_print
if adj.shape[-1]>100:
is_print1=False
for gcn in self.gcns:
hidden=gcn(hidden,adj,mask)
# print('gcn:')
# print(hidden.shape)
# print('mask:')
# print(mask.unsqueeze(2).shape)
# print(mask.sum(dim=1))
hidden=mask.unsqueeze(2)*hidden
# print(hidden[0][0])
# print(hidden[0][-1])
if self.model=='unet':
att=torch.matmul(hidden,self.w_a).squeeze()
att=att/torch.sqrt((self.w_a.squeeze(2)**2).sum(dim=1,keepdim=True))
elif self.model=='agcn':
if self.softmax=='global' or self.softmax=='mix':
if False:
dgree_w = torch.sum(adj, dim=2) / torch.sum(adj, dim=2).max(1, keepdim=True)[0]
att_a=torch.matmul(hidden,self.w_a).squeeze()*dgree_w+(mask-1)*1e10
else:
att_a=torch.matmul(hidden,self.w_a).squeeze()+(mask-1)*1e10
# print(att_a[0][:10])
# print(att_a[0][-10:-1])
att_a_1=att_a=torch.nn.functional.softmax(att_a,dim=1)
# print(att_a[0][:10])
# print(att_a[0][-10:-1])
if self.dnorm:
scale=mask.sum(dim=1,keepdim=True)/self.dnorm_coe
att_a=scale*att_a
if self.softmax=='neibor' or self.softmax=='mix':
att_b=torch.matmul(hidden,self.w_b).squeeze()+(mask-1)*1e10
att_b_max,_=att_b.max(dim=1,keepdim=True)
if self.tau_config!=-2:
att_b=torch.exp((att_b-att_b_max)*torch.abs(self.tau))
else:
att_b=torch.exp((att_b-att_b_max)*torch.abs(self.tau_fc(self.pool(hidden,mask))))
denom=att_b.unsqueeze(2)
for _ in range(self.khop):
denom=torch.matmul(adj,denom)
denom=denom.squeeze()+self.eps
att_b=(att_b*torch.diagonal(adj,0,1,2))/denom
if self.dnorm:
if self.adj_norm=='diag':
diag_scale=mask/(torch.diagonal(adj,0,1,2)+self.eps)
elif self.adj_norm=='none':
diag_scale=adj.sum(dim=1)
att_b=att_b*diag_scale
att_b=att_b*mask
if self.softmax=='global':
att=att_a
elif self.softmax=='neibor' or self.softmax=='hardnei':
att=att_b
elif self.softmax=='mix':
att=att_a*torch.abs(self.lamda1)+att_b*torch.abs(self.lamda2)
# print('att:')
# print(att.shape)
Z=hidden
if self.model=='unet':
Z=torch.tanh(att.unsqueeze(2))*Z
elif self.model=='agcn':
if self.single_att:
Z=Z
else:
Z=att.unsqueeze(2)*Z
# print('Z shape')
# print(Z.shape)
k_max=int(math.ceil(self.filt_percent*adj.shape[-1]))
# print('k_max')
# print(k_max)
if self.model=='diffpool':
k_max=min(k_max,self.diffpool_k)
k_list=[int(math.ceil(self.filt_percent*x)) for x in mask.sum(dim=1).tolist()]
# print('k_list')
# print(k_list)
if self.model!='diffpool':
if self.sort=='sample':
att_samp = att * mask
att_samp = (att_samp/att_samp.sum(1)).detach().cpu().numpy()
top_index = ()
for i in range(att.size(0)):
top_index = (torch.LongTensor(np.random.choice(att_samp.size(1), k_max, att_samp[i])) ,)
top_index = torch.stack(top_index,1)
elif self.sort=='random_sample':
top_index = torch.LongTensor(att.size(0), k_max)*0
for i in range(att.size(0)):
top_index[i,0:k_list[i]] = torch.randperm(int(mask[i].sum().item()))[0:k_list[i]]
else: #sort
_,top_index=torch.topk(att,k_max,dim=1)
# print('top_index')
# print(top_index)
# print(len(top_index[0]))
new_mask=X.new_zeros(X.shape[0],k_max)
# print('new_mask')
# print(new_mask.shape)
visualize_tools=None
if self.model=='unet':
for i,k in enumerate(k_list):
for j in range(int(k),k_max):
top_index[i][j]=adj.shape[-1]-1
new_mask[i][j]=-1.
new_mask=new_mask+1
top_index,_=torch.sort(top_index,dim=1)
assign_m=X.new_zeros(X.shape[0],k_max,adj.shape[-1])
for i,x in enumerate(top_index):
assign_m[i]=torch.index_select(adj[i],0,x)
new_adj=X.new_zeros(X.shape[0],k_max,k_max)
H=Z.new_zeros(Z.shape[0],k_max,Z.shape[-1])
for i,x in enumerate(top_index):
new_adj[i]=torch.index_select(assign_m[i],1,x)
H[i]=torch.index_select(Z[i],0,x)
elif self.model=='agcn':
assign_m=X.new_zeros(X.shape[0],k_max,adj.shape[-1])
# print('assign_m.shape')
# print(assign_m.shape)
for i,k in enumerate(k_list):
#print('top_index[i][j]')
for j in range(int(k)):
#print(str(top_index[i][j].item())+' ', end='')
assign_m[i][j]=adj[i][top_index[i][j]]
#print(assign_m[i][j])
new_mask[i][j]=1.
assign_m=assign_m/(assign_m.sum(dim=1,keepdim=True)+self.eps)
H=torch.matmul(assign_m,Z)
# print('H')
# print(H.shape)
new_adj=torch.matmul(torch.matmul(assign_m,adj),torch.transpose(assign_m,1,2))
# print(torch.matmul(assign_m,adj).shape)
# print('new_adj:')
# print(new_adj.shape)
elif self.model=='diffpool':
hidden1=X
for gcn in self.pool_gcns:
hidden1=gcn(hidden1,adj,mask)
assign_m=X.new_ones(X.shape[0],X.shape[1],k_max)*(-100000000.)
for i,x in enumerate(hidden1):
k=min(k_list[i],k_max)
assign_m[i,:,0:k]=hidden1[i,:,0:k]
for j in range(int(k)):
new_mask[i][j]=1.
assign_m=torch.nn.functional.softmax(assign_m,dim=2)*mask.unsqueeze(2)
assign_m_t=torch.transpose(assign_m,1,2)
new_adj=torch.matmul(torch.matmul(assign_m_t,adj),assign_m)
H=torch.matmul(assign_m_t,Z)
# print('pool')
if self.att_out and self.model=='agcn':
if self.softmax=='global':
out=self.pool(att_a_1.unsqueeze(2)*hidden,mask)
elif self.softmax=='neibor':
att_b_sum=att_b.sum(dim=1,keepdim=True)
out=self.pool((att_b/(att_b_sum+self.eps)).unsqueeze(2)*hidden,mask)
else:
# print('hidden.shape')
# print(hidden.shape)
out=self.pool(hidden,mask)
# print('out shape')
# print(out.shape)
if self.adj_norm=='tanh' or self.adj_norm=='mix':
new_adj=torch.tanh(new_adj)
elif self.adj_norm=='diag' or self.adj_norm=='mix':
diag_elem=torch.pow(new_adj.sum(dim=2)+self.eps,-0.5)
diag=new_adj.new_zeros(new_adj.shape)
for i,x in enumerate(diag_elem):
diag[i]=torch.diagflat(x)
new_adj=torch.matmul(torch.matmul(diag,new_adj),diag)
visualize_tools=[]
'''
if (not self.training) and is_print1:
print('**********************************')
print('node_feat:',X.type(),X.shape)
print(X)
if self.model!='diffpool':
print('**********************************')
print('att:',att.type(),att.shape)
print(att)
print('**********************************')
print('top_index:',top_index.type(),top_index.shape)
print(top_index)
print('**********************************')
print('adj:',adj.type(),adj.shape)
print(adj)
print('**********************************')
print('assign_m:',assign_m.type(),assign_m.shape)
print(assign_m)
print('**********************************')
print('new_adj:',new_adj.type(),new_adj.shape)
print(new_adj)
print('**********************************')
print('new_mask:',new_mask.type(),new_mask.shape)
print(new_mask)
'''
#visualization
from os import path
if not path.exists('att_1.pt'):
torch.save(att[0], 'att_1.pt')
torch.save(top_index[0], 'att_ind1.pt')
elif not path.exists('att_2.pt'):
torch.save(att[0], 'att_2.pt')
torch.save(top_index[0], 'att_ind2.pt')
else:
torch.save(att[0], 'att_3.pt')
torch.save(top_index[0], 'att_ind3.pt')
if (not self.training) and is_print2:
if self.model!='diffpool':
visualize_tools.append(att[0])
visualize_tools.append(top_index[0])
visualize_tools.append(new_adj[0])
visualize_tools.append(new_mask.sum())
# print('**********************************')
return out,H,new_adj,new_mask,visualize_tools
def mean_pool(self,x,mask):
return x.sum(dim=1)/(self.eps+mask.sum(dim=1,keepdim=True))
def sum_pool(self,x,mask):
return x.sum(dim=1)
@staticmethod
def max_pool(x,mask):
#output: [batch,x.shape[2]]
m=(mask-1)*1e10
r,_=(x+m.unsqueeze(2)).max(dim=1)
return r
# GCN basic operation
class GCNBlock(nn.Module):
def __init__(self, input_dim, output_dim, bn=0,add_self=0, normalize_embedding=0,
dropout=0.0,relu=0, bias=True):
super(GCNBlock,self).__init__()
self.add_self = add_self
self.dropout = dropout
self.relu=relu
self.bn=bn
if dropout > 0.001:
self.dropout_layer = nn.Dropout(p=dropout)
if self.bn:
self.bn_layer = torch.nn.BatchNorm1d(output_dim)
self.normalize_embedding = normalize_embedding
self.input_dim = input_dim
self.output_dim = output_dim
self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).to( 'cuda' if torch.cuda.is_available() else 'cpu') )
torch.nn.init.xavier_normal_(self.weight)
if bias:
self.bias = nn.Parameter(torch.zeros(output_dim).to( 'cuda' if torch.cuda.is_available() else 'cpu') )
else:
self.bias = None
def forward(self, x, adj, mask):
y = torch.matmul(adj, x)
if self.add_self:
y += x
y = torch.matmul(y,self.weight)
if self.bias is not None:
y = y + self.bias
if self.normalize_embedding:
y = F.normalize(y, p=2, dim=2)
if self.bn:
index=mask.sum(dim=1).long().tolist()
bn_tensor_bf=mask.new_zeros((sum(index),y.shape[2]))
bn_tensor_af=mask.new_zeros(*y.shape)
start_index=[]
ssum=0
for i in range(x.shape[0]):
start_index.append(ssum)
ssum+=index[i]
start_index.append(ssum)
for i in range(x.shape[0]):
bn_tensor_bf[start_index[i]:start_index[i+1]]=y[i,0:index[i]]
bn_tensor_bf=self.bn_layer(bn_tensor_bf)
for i in range(x.shape[0]):
bn_tensor_af[i,0:index[i]]=bn_tensor_bf[start_index[i]:start_index[i+1]]
y=bn_tensor_af
if self.dropout > 0.001:
y = self.dropout_layer(y)
if self.relu=='relu':
y=torch.nn.functional.relu(y)
print('hahah')
elif self.relu=='lrelu':
y=torch.nn.functional.leaky_relu(y,0.1)
return y
#experimental function, untested
class masked_batchnorm(nn.Module):
def __init__(self,feat_dim,epsilon=1e-10):
super().__init__()
self.alpha=nn.Parameter(torch.ones(feat_dim))
self.beta=nn.Parameter(torch.zeros(feat_dim))
self.eps=epsilon
def forward(self,x,mask):
'''
x: node feat, [batch,node_num,feat_dim]
mask: [batch,node_num]
'''
mask1 = mask.unsqueeze(2)
mask_sum = mask.sum()
mean = x.sum(dim=(0,1),keepdim=True)/(self.eps+mask_sum)
temp = (x - mean)**2
temp = temp*mask1
var = temp.sum(dim=(0,1),keepdim=True)/(self.eps+mask_sum)
rstd = torch.rsqrt(var+self.eps)
x=(x-mean)*rstd
return ((x*self.alpha) + self.beta)*mask1