File size: 5,258 Bytes
1964059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch.nn as nn
import torch
import torch.optim as optim
from torchvision.models import resnet18, resnet50, resnet101, vgg16_bn
import pytorch_lightning as pl


class LightningRegressionModel(pl.LightningModule):
    """Resnet Module using lightning architecture"""
    def __init__(self, learning_rate, weights, num_classes, model_name):
        super().__init__()
        self.save_hyperparameters()

        if model_name == "resnet18" : 
            self.model = resnet18(weights=weights)
            self.model.fc = nn.Linear(in_features=512, out_features=num_classes, bias=True)
            self.model.conv1 = nn.Conv2d(
                1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
            )
        if model_name == "resnet50" : 
            self.model = resnet50(weights=weights)
            self.model.fc = nn.Linear(in_features=2048, out_features=num_classes, bias=True)
            self.model.conv1 = nn.Conv2d(
                1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
            )
        if model_name == "resnet101" : 
            self.model = resnet101(weights=weights)
            self.model.fc = nn.Linear(in_features=2048, out_features=num_classes, bias=True)
            self.model.conv1 = nn.Conv2d(
                1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
            )
        if model_name == "vgg" :
            self.model = vgg16_bn(num_classes=num_classes, weights=weights)
            self.model.features[0]= nn.Conv2d(1,64,kernel_size=(3,3),stride=(1,1),padding=(1,1))
            self.model.features[-1]=nn.AdaptiveMaxPool2d(7*7)
            self.model.classifier[-1]=nn.Linear(in_features = 4096, out_features=1, bias = True)

        self.learning_rate = learning_rate
        self.loss_fn = nn.MSELoss()
        
        self.all_train_loss = []

        self.truth_labels = []
        self.predicted_labels = []

    def forward(self, images):
        images = torch.Tensor(images).float()
        images = torch.reshape(
            images, [images.size()[0], 1, images.size()[1], images.size()[2]]
        )
        output = self.model(images)
        return output

    def training_step(self, batch, batch_idx):
        loss, outputs, labels = self._common_step(batch)
        self.all_train_loss.append(loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, outputs, labels = self._common_step(batch)
        self.predicted_labels.append(outputs)
        self.truth_labels.append(labels.float())
        return loss
    
    def on_validation_epoch_end(self):
        """Save logs of every epochs : couple (truth, predictions) and validation loss"""
        tensorboard = self.logger.experiment

        all_preds = torch.concat(self.predicted_labels)
        all_truths = torch.concat(self.truth_labels)
        all_couple = torch.cat((all_truths, all_preds), dim=1)
        self.logger.experiment.add_embedding(all_couple, tag="couple_label_pred_ep" + str(self.current_epoch))
        
        wind_values = torch.unique(all_truths)
        pred_means = []
        pred_std = []
        pred_n = []
        for value in wind_values:
            # find all the couple (truth, preds) where truth == value and compute the mean of all the prediction for this value
            m = torch.mean((all_couple[torch.where(all_couple[:,0] == value)][:,1].float()))
            std = torch.std((all_couple[torch.where(all_couple[:,0] == value)][:,1].float()))
            n = len(all_couple[torch.where(all_couple[:,0] == value)][:,1].float())
            pred_means.append(m)
            pred_std.append(std)
            pred_n.append(n)

        train_loss = torch.mean(torch.tensor(self.all_train_loss))
        train_loss = torch.sqrt(train_loss.clone().detach())

        validation_loss = self.loss_fn(all_preds, all_truths)
        validation_loss = torch.sqrt(validation_loss.clone().detach())
        if train_loss == train_loss: # Check if train_loss != nan
            tensorboard.add_scalars(f"Loss (RMSE)", {'train':train_loss,'validation':validation_loss}, self.current_epoch)

        self.log("validation_loss", validation_loss)

        self.predicted_labels.clear()  # free memory
        self.truth_labels.clear()
        
        print("train_loss:", train_loss.item(), "validation_loss:",  validation_loss.item())

    def test_step(self, batch, batch_idx):
        loss, outputs, labels = self._common_step(batch)
        return loss

    def _common_step(self, batch):
        images, labels = batch
        labels = torch.reshape(labels, [labels.size()[0],1])
        outputs = self.forward(images)
        loss = self.loss_fn(outputs, labels.float())
        return loss, outputs, labels

    def predict_step(self, batch):
        images, labels = batch
        labels = torch.reshape(labels, [labels.size()[0],1])
        outputs = self.forward(images)
        preds = outputs
        return preds

    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=self.learning_rate)
        # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30, 45], gamma=0.1, last_epoch=-1, verbose=True)
        return [optimizer] #, [scheduler]