from transformers import PreTrainedModel, AutoModel import torch import torch.nn as nn import math import torch.nn.functional as F from torch_geometric.nn import GCNConv,GATConv from .config import BERTMultiGATAttentionConfig class MultiHeadGATAttention(nn.Module): def __init__(self, hidden_size, num_heads, dropout=0.03): super(MultiHeadGATAttention, self).__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.query = nn.Linear(hidden_size, hidden_size) self.key = nn.Linear(hidden_size, hidden_size) self.value = nn.Linear(hidden_size, hidden_size) self.out = nn.Linear(hidden_size, hidden_size) self.gat = GATConv(hidden_size, hidden_size, heads=num_heads, concat=False) self.alpha = nn.Parameter(torch.tensor(0.5)) # Learnable weight for combining attention outputs self.layer_norm_q = nn.LayerNorm(hidden_size) self.layer_norm_k = nn.LayerNorm(hidden_size) self.layer_norm_v = nn.LayerNorm(hidden_size) self.layer_norm_out = nn.LayerNorm(hidden_size) self.dropout = nn.Dropout(dropout) def forward(self, query, key, value, edge_index): batch_size = query.size(0) seq_length = query.size(1) query_orig = query query = self.layer_norm_q(self.query(query)) key = self.layer_norm_k(self.key(key)) value = self.layer_norm_v(self.value(value)) query = query.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3) key = key.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3) value = value.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3) attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim) attention_weights = F.softmax(attention_scores, dim=-1) attention_weights = self.dropout(attention_weights) attended_values_std = torch.matmul(attention_weights, value).permute(0, 2, 1, 3).contiguous() attended_values_std = attended_values_std.view(batch_size, seq_length, self.hidden_size) query_gat = query.permute(0, 2, 1, 3).reshape(batch_size * seq_length, self.hidden_size) value_gat = value.permute(0, 2, 1, 3).reshape(batch_size * seq_length, self.hidden_size) attended_values_gat = self.gat(value_gat, edge_index).view(batch_size, seq_length, self.hidden_size) # Weighted combin attended_values = self.alpha * attended_values_std + (1 - self.alpha) * attended_values_gat attended_values = self.layer_norm_out(self.out(attended_values)) attended_values = self.dropout(attended_values) return query_orig + attended_values # Residual connection class GNNPreProcessor(nn.Module): def __init__(self, input_dim, hidden_dim, gat_heads=8): super(GNNPreProcessor, self).__init__() self.gcn = GCNConv(input_dim, hidden_dim) self.gat = GATConv(hidden_dim, hidden_dim, heads=gat_heads, concat=False) self.alpha = nn.Parameter(torch.tensor(0.5)) def forward(self, x, edge_index): batch_size, seq_len, feature_dim = x.size() x = x.view(batch_size * seq_len, feature_dim) edge_index = edge_index.view(2, -1) x_gcn = F.relu(self.gcn(x, edge_index)) x_gat = F.relu(self.gat(x, edge_index)) x = self.alpha * x_gcn + (1 - self.alpha) * x_gat x = x.view(batch_size, seq_len, -1) return x class DEBERTAMultiGATAttentionModel(PreTrainedModel): config_class = BERTMultiGATAttentionConfig def __init__(self, config): super(DEBERTAMultiGATAttentionModel, self).__init__(config) self.config = config self.transformer =AutoModel.from_pretrained(config.transformer_model) self.gnn_preprocessor1 = GNNPreProcessor(config.gnn_input_dim, config.gnn_hidden_dim) self.gnn_preprocessor2 = GNNPreProcessor(config.gnn_input_dim, config.gnn_hidden_dim) self.fc_combine = nn.Linear(config.hidden_size * 2, config.hidden_size) self.layer_norm_combine = nn.LayerNorm(config.hidden_size) self.dropout_combine = nn.Dropout(config.dropout) self.self_attention1 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout) self.self_attention2 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout) self.cross_attention = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout) self.self_attention3 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout) self.self_attention4 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout) self.cross_attention_ = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout) self.self_attention5 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout) self.self_attention6 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout) self.cross_attention__ = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout) self.fc1 = nn.Linear(config.hidden_size * 2, 256) self.fc2 = nn.Linear(config.hidden_size * 2, 256) self.fc3 = nn.Linear(config.hidden_size * 2, 256) self.layer_norm_fc1 = nn.LayerNorm(256) self.layer_norm_fc2 = nn.LayerNorm(256) self.layer_norm_fc3 = nn.LayerNorm(256) self.dropout1 = nn.Dropout(config.dropout) self.dropout2 = nn.Dropout(config.dropout) self.dropout3 = nn.Dropout(config.dropout) self.dropout4 = nn.Dropout(config.dropout) self.fc_proj = nn.Linear(256, 256) self.layer_norm_proj = nn.LayerNorm(256) self.fc_final = nn.Linear(256, 1) def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2, edge_index1, edge_index2): output1_bert = self.transformer(input_ids1, attention_mask1)[0] output2_bert = self.transformer(input_ids2, attention_mask2)[0] edge_index1 = edge_index1.view(2, -1) # Flatten the batch dimension edge_index2 = edge_index2.view(2, -1) # Flatten the batch dimension output1_gnn = self.gnn_preprocessor1(output1_bert, edge_index1) output2_gnn = self.gnn_preprocessor2(output2_bert, edge_index2) combined_output1 = torch.cat([output1_bert, output1_gnn], dim=2) combined_output2 = torch.cat([output2_bert, output2_gnn], dim=2) combined_output1 = self.layer_norm_combine(self.fc_combine(combined_output1)) combined_output2 = self.layer_norm_combine(self.fc_combine(combined_output2)) combined_output1 = self.dropout_combine(F.relu(combined_output1)) combined_output2 = self.dropout_combine(F.relu(combined_output2)) # output1 = self.self_attention1(combined_output1, combined_output1, combined_output1, edge_index1) output2 = self.self_attention2(combined_output2, combined_output2, combined_output2, edge_index2) attended_output = self.cross_attention(output1, output2, output2, edge_index1) combined_output = torch.cat([output1, attended_output], dim=2) combined_output, _ = torch.max(combined_output, dim=1) combined_output = self.layer_norm_fc1(self.fc2(combined_output)) combined_output = self.dropout1(F.relu(combined_output)) combined_output = combined_output.unsqueeze(1) # output1 = self.self_attention3(combined_output1, combined_output1, combined_output1, edge_index1) output2 = self.self_attention4(combined_output2, combined_output2, combined_output2, edge_index2) attended_output = self.cross_attention_(output1, output2, output2, edge_index1) combined_output = torch.cat([output1, attended_output], dim=2) combined_output, _ = torch.max(combined_output, dim=1) combined_output = self.layer_norm_fc2(self.fc2(combined_output)) combined_output = self.dropout2(F.relu(combined_output)) combined_output = combined_output.unsqueeze(1) # output1 = self.self_attention5(combined_output1, combined_output1, combined_output1, edge_index1) output2 = self.self_attention6(combined_output2, combined_output2, combined_output2, edge_index2) attended_output = self.cross_attention__(output1, output2, output2, edge_index1) combined_output = torch.cat([output1, attended_output], dim=2) combined_output, _ = torch.max(combined_output, dim=1) combined_output = self.layer_norm_fc1(self.fc3(combined_output)) combined_output = self.dropout3(F.relu(combined_output)) combined_output = combined_output.unsqueeze(1) hidden_state_proj = self.layer_norm_proj(self.fc_proj(combined_output)) hidden_state_proj = self.dropout4(hidden_state_proj) final = self.fc_final(hidden_state_proj) return torch.sigmoid(final) AutoModel.register(BERTMultiGATAttentionConfig, DEBERTAMultiGATAttentionModel)