Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from torch.nn import init | |
import torch.nn.functional as F | |
import math | |
import numpy as np | |
torch.set_printoptions(precision=2,threshold=float('inf')) | |
class AGCNBlock(nn.Module): | |
def __init__(self,input_dim,hidden_dim,gcn_layer=2,dropout=0.0,relu=0): | |
super(AGCNBlock,self).__init__() | |
if dropout > 0.001: | |
self.dropout_layer = nn.Dropout(p=dropout) | |
self.sort = 'sort' | |
self.model='agcn' | |
self.gcns=nn.ModuleList() | |
self.bn = 0 | |
self.add_self = 1 | |
self.normalize_embedding = 1 | |
self.gcns.append(GCNBlock(input_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,relu)) | |
self.pool = 'mean' | |
self.tau = 1. | |
self.lamda = 1. | |
for i in range(gcn_layer-1): | |
if i==gcn_layer-2 and (not 1): | |
self.gcns.append(GCNBlock(hidden_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,0)) | |
else: | |
self.gcns.append(GCNBlock(hidden_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,relu)) | |
if self.model=='diffpool': | |
self.pool_gcns=nn.ModuleList() | |
tmp=input_dim | |
self.diffpool_k=200 | |
for i in range(3): | |
self.pool_gcns.append(GCNBlock(tmp,200,0,0,0,dropout,relu)) | |
tmp=200 | |
self.w_a=nn.Parameter(torch.zeros(1,hidden_dim,1)) | |
self.w_b=nn.Parameter(torch.zeros(1,hidden_dim,1)) | |
torch.nn.init.normal_(self.w_a) | |
torch.nn.init.uniform_(self.w_b,-1,1) | |
self.pass_dim=hidden_dim | |
if self.pool=='mean': | |
self.pool=self.mean_pool | |
elif self.pool=='max': | |
self.pool=self.max_pool | |
elif self.pool=='sum': | |
self.pool=self.sum_pool | |
self.softmax='global' | |
if self.softmax=='gcn': | |
self.att_gcn=GCNBlock(2,1,0,0,dropout,relu) | |
self.khop=1 | |
self.adj_norm='none' | |
self.filt_percent=0.25 #default 0.5 | |
self.eps=1e-10 | |
self.tau_config=1 | |
if 1==-1.: | |
self.tau=nn.Parameter(torch.tensor(1),requires_grad=False) | |
elif 1==-2.: | |
self.tau_fc=nn.Linear(hidden_dim,1) | |
torch.nn.init.constant_(self.tau_fc.bias,1) | |
torch.nn.init.xavier_normal_(self.tau_fc.weight.t()) | |
else: | |
self.tau=nn.Parameter(torch.tensor(self.tau)) | |
self.lamda1=nn.Parameter(torch.tensor(self.lamda)) | |
self.lamda2=nn.Parameter(torch.tensor(self.lamda)) | |
self.att_norm=0 | |
self.dnorm=0 | |
self.dnorm_coe=1 | |
self.att_out=0 | |
self.single_att=0 | |
def forward(self,X,adj,mask,is_print=False): | |
''' | |
input: | |
X: node input features , [batch,node_num,input_dim],dtype=float | |
adj: adj matrix, [batch,node_num,node_num], dtype=float | |
mask: mask for nodes, [batch,node_num] | |
outputs: | |
out:unormalized classification prob, [batch,hidden_dim] | |
H: batch of node hidden features, [batch,node_num,pass_dim] | |
new_adj: pooled new adj matrix, [batch, k_max, k_max] | |
new_mask: [batch, k_max] | |
''' | |
hidden=X | |
#adj = adj.float() | |
# print('input size:') | |
# print(hidden.shape) | |
is_print1=is_print2=is_print | |
if adj.shape[-1]>100: | |
is_print1=False | |
for gcn in self.gcns: | |
hidden=gcn(hidden,adj,mask) | |
# print('gcn:') | |
# print(hidden.shape) | |
# print('mask:') | |
# print(mask.unsqueeze(2).shape) | |
# print(mask.sum(dim=1)) | |
hidden=mask.unsqueeze(2)*hidden | |
# print(hidden[0][0]) | |
# print(hidden[0][-1]) | |
if self.model=='unet': | |
att=torch.matmul(hidden,self.w_a).squeeze() | |
att=att/torch.sqrt((self.w_a.squeeze(2)**2).sum(dim=1,keepdim=True)) | |
elif self.model=='agcn': | |
if self.softmax=='global' or self.softmax=='mix': | |
if False: | |
dgree_w = torch.sum(adj, dim=2) / torch.sum(adj, dim=2).max(1, keepdim=True)[0] | |
att_a=torch.matmul(hidden,self.w_a).squeeze()*dgree_w+(mask-1)*1e10 | |
else: | |
att_a=torch.matmul(hidden,self.w_a).squeeze()+(mask-1)*1e10 | |
# print(att_a[0][:10]) | |
# print(att_a[0][-10:-1]) | |
att_a_1=att_a=torch.nn.functional.softmax(att_a,dim=1) | |
# print(att_a[0][:10]) | |
# print(att_a[0][-10:-1]) | |
if self.dnorm: | |
scale=mask.sum(dim=1,keepdim=True)/self.dnorm_coe | |
att_a=scale*att_a | |
if self.softmax=='neibor' or self.softmax=='mix': | |
att_b=torch.matmul(hidden,self.w_b).squeeze()+(mask-1)*1e10 | |
att_b_max,_=att_b.max(dim=1,keepdim=True) | |
if self.tau_config!=-2: | |
att_b=torch.exp((att_b-att_b_max)*torch.abs(self.tau)) | |
else: | |
att_b=torch.exp((att_b-att_b_max)*torch.abs(self.tau_fc(self.pool(hidden,mask)))) | |
denom=att_b.unsqueeze(2) | |
for _ in range(self.khop): | |
denom=torch.matmul(adj,denom) | |
denom=denom.squeeze()+self.eps | |
att_b=(att_b*torch.diagonal(adj,0,1,2))/denom | |
if self.dnorm: | |
if self.adj_norm=='diag': | |
diag_scale=mask/(torch.diagonal(adj,0,1,2)+self.eps) | |
elif self.adj_norm=='none': | |
diag_scale=adj.sum(dim=1) | |
att_b=att_b*diag_scale | |
att_b=att_b*mask | |
if self.softmax=='global': | |
att=att_a | |
elif self.softmax=='neibor' or self.softmax=='hardnei': | |
att=att_b | |
elif self.softmax=='mix': | |
att=att_a*torch.abs(self.lamda1)+att_b*torch.abs(self.lamda2) | |
# print('att:') | |
# print(att.shape) | |
Z=hidden | |
if self.model=='unet': | |
Z=torch.tanh(att.unsqueeze(2))*Z | |
elif self.model=='agcn': | |
if self.single_att: | |
Z=Z | |
else: | |
Z=att.unsqueeze(2)*Z | |
# print('Z shape') | |
# print(Z.shape) | |
k_max=int(math.ceil(self.filt_percent*adj.shape[-1])) | |
# print('k_max') | |
# print(k_max) | |
if self.model=='diffpool': | |
k_max=min(k_max,self.diffpool_k) | |
k_list=[int(math.ceil(self.filt_percent*x)) for x in mask.sum(dim=1).tolist()] | |
# print('k_list') | |
# print(k_list) | |
if self.model!='diffpool': | |
if self.sort=='sample': | |
att_samp = att * mask | |
att_samp = (att_samp/att_samp.sum(1)).detach().cpu().numpy() | |
top_index = () | |
for i in range(att.size(0)): | |
top_index = (torch.LongTensor(np.random.choice(att_samp.size(1), k_max, att_samp[i])) ,) | |
top_index = torch.stack(top_index,1) | |
elif self.sort=='random_sample': | |
top_index = torch.LongTensor(att.size(0), k_max)*0 | |
for i in range(att.size(0)): | |
top_index[i,0:k_list[i]] = torch.randperm(int(mask[i].sum().item()))[0:k_list[i]] | |
else: #sort | |
_,top_index=torch.topk(att,k_max,dim=1) | |
# print('top_index') | |
# print(top_index) | |
# print(len(top_index[0])) | |
new_mask=X.new_zeros(X.shape[0],k_max) | |
# print('new_mask') | |
# print(new_mask.shape) | |
visualize_tools=None | |
if self.model=='unet': | |
for i,k in enumerate(k_list): | |
for j in range(int(k),k_max): | |
top_index[i][j]=adj.shape[-1]-1 | |
new_mask[i][j]=-1. | |
new_mask=new_mask+1 | |
top_index,_=torch.sort(top_index,dim=1) | |
assign_m=X.new_zeros(X.shape[0],k_max,adj.shape[-1]) | |
for i,x in enumerate(top_index): | |
assign_m[i]=torch.index_select(adj[i],0,x) | |
new_adj=X.new_zeros(X.shape[0],k_max,k_max) | |
H=Z.new_zeros(Z.shape[0],k_max,Z.shape[-1]) | |
for i,x in enumerate(top_index): | |
new_adj[i]=torch.index_select(assign_m[i],1,x) | |
H[i]=torch.index_select(Z[i],0,x) | |
elif self.model=='agcn': | |
assign_m=X.new_zeros(X.shape[0],k_max,adj.shape[-1]) | |
# print('assign_m.shape') | |
# print(assign_m.shape) | |
for i,k in enumerate(k_list): | |
#print('top_index[i][j]') | |
for j in range(int(k)): | |
#print(str(top_index[i][j].item())+' ', end='') | |
assign_m[i][j]=adj[i][top_index[i][j]] | |
#print(assign_m[i][j]) | |
new_mask[i][j]=1. | |
assign_m=assign_m/(assign_m.sum(dim=1,keepdim=True)+self.eps) | |
H=torch.matmul(assign_m,Z) | |
# print('H') | |
# print(H.shape) | |
new_adj=torch.matmul(torch.matmul(assign_m,adj),torch.transpose(assign_m,1,2)) | |
# print(torch.matmul(assign_m,adj).shape) | |
# print('new_adj:') | |
# print(new_adj.shape) | |
elif self.model=='diffpool': | |
hidden1=X | |
for gcn in self.pool_gcns: | |
hidden1=gcn(hidden1,adj,mask) | |
assign_m=X.new_ones(X.shape[0],X.shape[1],k_max)*(-100000000.) | |
for i,x in enumerate(hidden1): | |
k=min(k_list[i],k_max) | |
assign_m[i,:,0:k]=hidden1[i,:,0:k] | |
for j in range(int(k)): | |
new_mask[i][j]=1. | |
assign_m=torch.nn.functional.softmax(assign_m,dim=2)*mask.unsqueeze(2) | |
assign_m_t=torch.transpose(assign_m,1,2) | |
new_adj=torch.matmul(torch.matmul(assign_m_t,adj),assign_m) | |
H=torch.matmul(assign_m_t,Z) | |
# print('pool') | |
if self.att_out and self.model=='agcn': | |
if self.softmax=='global': | |
out=self.pool(att_a_1.unsqueeze(2)*hidden,mask) | |
elif self.softmax=='neibor': | |
att_b_sum=att_b.sum(dim=1,keepdim=True) | |
out=self.pool((att_b/(att_b_sum+self.eps)).unsqueeze(2)*hidden,mask) | |
else: | |
# print('hidden.shape') | |
# print(hidden.shape) | |
out=self.pool(hidden,mask) | |
# print('out shape') | |
# print(out.shape) | |
if self.adj_norm=='tanh' or self.adj_norm=='mix': | |
new_adj=torch.tanh(new_adj) | |
elif self.adj_norm=='diag' or self.adj_norm=='mix': | |
diag_elem=torch.pow(new_adj.sum(dim=2)+self.eps,-0.5) | |
diag=new_adj.new_zeros(new_adj.shape) | |
for i,x in enumerate(diag_elem): | |
diag[i]=torch.diagflat(x) | |
new_adj=torch.matmul(torch.matmul(diag,new_adj),diag) | |
visualize_tools=[] | |
''' | |
if (not self.training) and is_print1: | |
print('**********************************') | |
print('node_feat:',X.type(),X.shape) | |
print(X) | |
if self.model!='diffpool': | |
print('**********************************') | |
print('att:',att.type(),att.shape) | |
print(att) | |
print('**********************************') | |
print('top_index:',top_index.type(),top_index.shape) | |
print(top_index) | |
print('**********************************') | |
print('adj:',adj.type(),adj.shape) | |
print(adj) | |
print('**********************************') | |
print('assign_m:',assign_m.type(),assign_m.shape) | |
print(assign_m) | |
print('**********************************') | |
print('new_adj:',new_adj.type(),new_adj.shape) | |
print(new_adj) | |
print('**********************************') | |
print('new_mask:',new_mask.type(),new_mask.shape) | |
print(new_mask) | |
''' | |
#visualization | |
from os import path | |
if not path.exists('att_1.pt'): | |
torch.save(att[0], 'att_1.pt') | |
torch.save(top_index[0], 'att_ind1.pt') | |
elif not path.exists('att_2.pt'): | |
torch.save(att[0], 'att_2.pt') | |
torch.save(top_index[0], 'att_ind2.pt') | |
else: | |
torch.save(att[0], 'att_3.pt') | |
torch.save(top_index[0], 'att_ind3.pt') | |
if (not self.training) and is_print2: | |
if self.model!='diffpool': | |
visualize_tools.append(att[0]) | |
visualize_tools.append(top_index[0]) | |
visualize_tools.append(new_adj[0]) | |
visualize_tools.append(new_mask.sum()) | |
# print('**********************************') | |
return out,H,new_adj,new_mask,visualize_tools | |
def mean_pool(self,x,mask): | |
return x.sum(dim=1)/(self.eps+mask.sum(dim=1,keepdim=True)) | |
def sum_pool(self,x,mask): | |
return x.sum(dim=1) | |
def max_pool(x,mask): | |
#output: [batch,x.shape[2]] | |
m=(mask-1)*1e10 | |
r,_=(x+m.unsqueeze(2)).max(dim=1) | |
return r | |
# GCN basic operation | |
class GCNBlock(nn.Module): | |
def __init__(self, input_dim, output_dim, bn=0,add_self=0, normalize_embedding=0, | |
dropout=0.0,relu=0, bias=True): | |
super(GCNBlock,self).__init__() | |
self.add_self = add_self | |
self.dropout = dropout | |
self.relu=relu | |
self.bn=bn | |
if dropout > 0.001: | |
self.dropout_layer = nn.Dropout(p=dropout) | |
if self.bn: | |
self.bn_layer = torch.nn.BatchNorm1d(output_dim) | |
self.normalize_embedding = normalize_embedding | |
self.input_dim = input_dim | |
self.output_dim = output_dim | |
self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).to( 'cuda' if torch.cuda.is_available() else 'cpu') ) | |
torch.nn.init.xavier_normal_(self.weight) | |
if bias: | |
self.bias = nn.Parameter(torch.zeros(output_dim).to( 'cuda' if torch.cuda.is_available() else 'cpu') ) | |
else: | |
self.bias = None | |
def forward(self, x, adj, mask): | |
y = torch.matmul(adj, x) | |
if self.add_self: | |
y += x | |
y = torch.matmul(y,self.weight) | |
if self.bias is not None: | |
y = y + self.bias | |
if self.normalize_embedding: | |
y = F.normalize(y, p=2, dim=2) | |
if self.bn: | |
index=mask.sum(dim=1).long().tolist() | |
bn_tensor_bf=mask.new_zeros((sum(index),y.shape[2])) | |
bn_tensor_af=mask.new_zeros(*y.shape) | |
start_index=[] | |
ssum=0 | |
for i in range(x.shape[0]): | |
start_index.append(ssum) | |
ssum+=index[i] | |
start_index.append(ssum) | |
for i in range(x.shape[0]): | |
bn_tensor_bf[start_index[i]:start_index[i+1]]=y[i,0:index[i]] | |
bn_tensor_bf=self.bn_layer(bn_tensor_bf) | |
for i in range(x.shape[0]): | |
bn_tensor_af[i,0:index[i]]=bn_tensor_bf[start_index[i]:start_index[i+1]] | |
y=bn_tensor_af | |
if self.dropout > 0.001: | |
y = self.dropout_layer(y) | |
if self.relu=='relu': | |
y=torch.nn.functional.relu(y) | |
print('hahah') | |
elif self.relu=='lrelu': | |
y=torch.nn.functional.leaky_relu(y,0.1) | |
return y | |
#experimental function, untested | |
class masked_batchnorm(nn.Module): | |
def __init__(self,feat_dim,epsilon=1e-10): | |
super().__init__() | |
self.alpha=nn.Parameter(torch.ones(feat_dim)) | |
self.beta=nn.Parameter(torch.zeros(feat_dim)) | |
self.eps=epsilon | |
def forward(self,x,mask): | |
''' | |
x: node feat, [batch,node_num,feat_dim] | |
mask: [batch,node_num] | |
''' | |
mask1 = mask.unsqueeze(2) | |
mask_sum = mask.sum() | |
mean = x.sum(dim=(0,1),keepdim=True)/(self.eps+mask_sum) | |
temp = (x - mean)**2 | |
temp = temp*mask1 | |
var = temp.sum(dim=(0,1),keepdim=True)/(self.eps+mask_sum) | |
rstd = torch.rsqrt(var+self.eps) | |
x=(x-mean)*rstd | |
return ((x*self.alpha) + self.beta)*mask1 |