File size: 7,055 Bytes
c238491
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from transformers import PreTrainedModel
from OmicsConfig import OmicsConfig
from transformers import PretrainedConfig, PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Batch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch_geometric.utils import negative_sampling
from torch.nn.functional import cosine_similarity
from torch.optim.lr_scheduler import StepLR

from GATv2EncoderModel import GATv2EncoderModel
from GATv2DecoderModel import GATv2DecoderModel
from EdgeWeightPredictorModel import EdgeWeightPredictorModel


class MultiOmicsGraphAttentionAutoencoderModel(PreTrainedModel):
    config_class = OmicsConfig
    base_model_prefix = "graph-attention-autoencoder"

    def __init__(self, config):
        super().__init__(config)
        self.encoder = GATv2EncoderModel(config)
        self.decoder = GATv2DecoderModel(config)
        self.optimizer = AdamW(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=config.learning_rate)
        self.scheduler = StepLR(self.optimizer, step_size=30, gamma=0.7)

    def forward(self, x, edge_index, edge_attr):
        z, attention_weights = self.encoder(x, edge_index, edge_attr)
        x_reconstructed = self.decoder(z)
        return x_reconstructed, attention_weights

    def predict_edge_weights(self, z, edge_index):
        return self.decoder.predict_edge_weights(z, edge_index)

    def train_model(self, data_loader, device):
        self.encoder.to(device)
        self.decoder.to(device)
        self.encoder.train()
        self.decoder.train()
        total_loss = 0
        total_cosine_similarity = 0
        loss_weight_node = 1.0
        loss_weight_edge = 1.0
        loss_weight_edge_attr = 1.0

        for data in data_loader:
            data = data.to(device)
            self.optimizer.zero_grad()
            z, attention_weights = self.encoder(data.x, data.edge_index, data.edge_attr)
            x_reconstructed = self.decoder(z)
            node_loss = graph_reconstruction_loss(x_reconstructed, data.x)
            edge_loss = edge_reconstruction_loss(z, data.edge_index)
            cos_sim = cosine_similarity(x_reconstructed, data.x, dim=-1).mean()
            total_cosine_similarity += cos_sim.item()
            pred_edge_weights = self.decoder.predict_edge_weights(z, data.edge_index)
            edge_weight_loss = edge_weight_reconstruction_loss(pred_edge_weights, data.edge_attr)
            loss = (loss_weight_node * node_loss) + (loss_weight_edge * edge_loss) + (loss_weight_edge_attr * edge_weight_loss)
            print(f"node_loss: {node_loss}, edge_loss: {edge_loss:.4f}, edge_weight_loss: {edge_weight_loss:.4f}, cosine_similarity: {cos_sim:.4f}")
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()

        avg_loss, avg_cosine_similarity = total_loss / len(data_loader), total_cosine_similarity / len(data_loader)
        return avg_loss, avg_cosine_similarity

    def fit(self, train_loader, validation_loader, epochs, device):
        train_losses = []
        val_losses = []

        for epoch in range(1, epochs + 1):
            train_loss, train_cosine_similarity = self.train_model(train_loader, device)
            torch.cuda.empty_cache()
            val_loss, val_cosine_similarity = self.validate(validation_loader, device)
            print(f"Epoch: {epoch}, Train Loss: {train_loss:.4f}, Train Cosine Similarity: {train_cosine_similarity:.4f}, Validation Loss: {val_loss:.4f}, Validation Cosine Similarity: {val_cosine_similarity:.4f}")
            self.scheduler.step()

        return train_losses, val_losses

    def validate(self, validation_loader, device):
        self.encoder.to(device)
        self.decoder.to(device)
        self.encoder.eval()
        self.decoder.eval()
        total_loss = 0
        total_cosine_similarity = 0

        with torch.no_grad():
            for data in validation_loader:
                data = data.to(device)
                z, attention_weights = self.encoder(data.x, data.edge_index, data.edge_attr)
                x_reconstructed = self.decoder(z)
                node_loss = graph_reconstruction_loss(x_reconstructed, data.x)
                edge_loss = edge_reconstruction_loss(z, data.edge_index)
                cos_sim = cosine_similarity(x_reconstructed, data.x, dim=-1).mean()
                total_cosine_similarity += cos_sim.item()
                loss = node_loss + edge_loss
                total_loss += loss.item()

        avg_loss = total_loss / len(validation_loader)
        avg_cosine_similarity = total_cosine_similarity / len(validation_loader)
        return avg_loss, avg_cosine_similarity

    def evaluate(self, test_loader, device):
        self.encoder.to(device)
        self.decoder.to(device)
        self.encoder.eval()
        self.decoder.eval()
        total_loss = 0
        total_accuracy = 0

        with torch.no_grad():
            for data in test_loader:
                data = data.to(device)
                z, attention_weights = self.encoder(data.x, data.edge_index, data.edge_attr)
                x_reconstructed = self.decoder(z)
                node_loss = graph_reconstruction_loss(x_reconstructed, data.x)
                edge_loss = edge_reconstruction_loss(z, data.edge_index)
                cos_sim = cosine_similarity(x_reconstructed, data.x, dim=-1).mean()
                total_cosine_similarity += cos_sim.item()
                loss = node_loss + edge_loss
                total_loss += loss.item()
                
        avg_loss = total_loss / len(validation_loader)
        avg_cosine_similarity = total_cosine_similarity / len(validation_loader)
        return avg_loss, avg_cosine_similarity

# Define a collate function for the DataLoader
def collate_graph_data(batch):
    return Batch.from_data_list(batch)

# Define a function to create a DataLoader
def create_data_loader(train_data, batch_size=1, shuffle=True):
    graph_data = list(train_data.values())
    return DataLoader(graph_data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_graph_data)

# Define functions for the losses
def graph_reconstruction_loss(pred_features, true_features):
    return F.mse_loss(pred_features, true_features)

def edge_reconstruction_loss(z, pos_edge_index, neg_edge_index=None):
    pos_logits = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=-1)
    pos_loss = F.binary_cross_entropy_with_logits(pos_logits, torch.ones_like(pos_logits))
    if neg_edge_index is None:
        neg_edge_index = negative_sampling(pos_edge_index, z.size(0))
    neg_logits = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1)
    neg_loss = F.binary_cross_entropy_with_logits(neg_logits, torch.zeros_like(neg_logits))
    return pos_loss + neg_loss

def edge_weight_reconstruction_loss(pred_weights, true_weights):
    pred_weights = pred_weights.squeeze(-1)
    return F.mse_loss(pred_weights, true_weights)