File size: 7,233 Bytes
1025b47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""

contains functions for training and testing a pytorch model

"""
import torch 

from tqdm.auto import tqdm 
from typing import Dict, List, Tuple 
# from torch.utils.tensorboard.writer import SummaryWriter

def train_step(model: torch.nn.Module, 

              dataloader: torch.utils.data.DataLoader, 

              loss_fn: torch.nn.Module, 

              optimizer: torch.optim.Optimizer, 

              device: torch.device) -> Tuple[float, float]:
    """Trains a pytorch model for a single epoch 



    turns a target model to training mode then runs through all of the required training steps

    (forward pass, loss calculation, optimizer step).



    Args: 

        model: pytorch model

        dataloader: dataloader insatnce for the model to be trained on 

        loss_fn: pytorch loss function to calculate loss

        optimizer: pytorch optimizer to help minimize the loss function

        device: target device



    returns:

        a tuple of training loss and training accuracy metrics

        in the form (train_loss, train_accuracy)

    """
    # put the model into training mode
    model.train()
    
    # setup train loss and train accuracy 
    train_loss, train_accuracy = 0, 0 

    # loop through data laoder batches
    for batch, (X, y) in enumerate(dataloader):
        # send data to target device 
        X, y = X.to(device), y.to(device)

        # forward pass 
        logits = model(X)

        # calculate loss and accumulate loss 
        loss = loss_fn(logits, y)
        train_loss += loss

        # optimizer zero grad 
        optimizer.zero_grad()

        # loss backward 
        loss.backward()

        # optimizer step 
        optimizer.step()

        # calculate and accumulate accuracy metric across all batches
        preds = torch.softmax(logits, dim=-1).argmax(dim=-1)
        train_accuracy += (preds == y).sum().item()/len(preds)

    # adjust metrics to get average loss and accuracy per batch 
    train_loss /= len(dataloader)
    train_accuracy /= len(dataloader)
    return train_loss, train_accuracy

def test_step(model: torch.nn.Module, 

             dataloader: torch.utils.data.DataLoader, 

             loss_fn: torch.nn.Module, 

             device: torch.device) -> Tuple[float, float]:
    """Tests a pytorch model for a single epoch



    Turns a target model to eval mode and then performs a forward pass on a testing

    dataset. 



    Args: 

        model: pytorch model

        dataloader: dataloader insatnce for the model to be tested on 

        loss_fn: loss function to calculate loss (errors)

        device: target device to compute on 



    returns:

        A tuple of testing loss and testing accuracy metrics.

        In the form (test_loss, test_accuracy)

    """
    # put the model in eval mode
    model.eval()

    # setup test loss and test accuracy 
    test_loss, test_accuracy = 0, 0 

    # turn on inference mode 
    with torch.inference_mode():
        # loop through all batches 
        for X, y in dataloader: 
            # send data to target device
            X, y  = X.to(device), y.to(device)

            # forward pass
            logits = model(X)

            # calculate and accumulate loss
            loss = loss_fn(logits, y)
            test_loss += loss.item()

            # calculate and accumulate accuracy 
            test_preds = torch.softmax(logits, dim=-1).argmax(dim=-1)
            test_accuracy += ((test_preds == y).sum().item()/len(test_preds))
    # adjust metrics to get average loss and accuracy per batch 
    test_loss /= len(dataloader)
    test_accuracy /= len(dataloader)
    return test_loss, test_accuracy

def train(model: torch.nn.Module, 

         train_dataloader: torch.utils.data.DataLoader, 

         test_dataloader: torch.utils.data.DataLoader, 

         optimizer: torch.optim.Optimizer, 

         loss_fn: torch.nn.Module, 

         epochs: int, 

         device: torch.device, 

         writer: torch.utils.tensorboard.writer.SummaryWriter) -> Dict[str, List]:
    """Trains and tests pytorch model



    passes a target model through train_step() and test_step() 

    functions for a number of epochs, training and testing the model in the same epoch loop.



    calculates, prints and stores evaluation metric throughout. 



    Args: 

        model: pytorch model

        train_dataloader: DataLoader instance for the model to be trained on

        test_dataloader: DataLoader instance for the model to be tested on

        optimizer: pytorch optimizer

        loss_fn: pytorch loss function

        epochs: integer indicating how many epochs to train for

        device: target device to compute on 



    returns: 

        A dictionaru of training and testing loss as well as training and testing accuracy 

        metrics. Each metric has a value in a list for each epoch. 



        In the form: {train_loss: [...],

                      train_acc: [...],

                      test_loss: [...],

                      test_acc: [...]}

    """
    # create an empty dictionary 
    results = {
        "train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": []
    }

    # loop through training and testing steps for a number of epochs 
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                          dataloader=train_dataloader, 
                                          loss_fn=loss_fn, 
                                          optimizer=optimizer, 
                                          device=device)
        test_loss, test_acc = test_step(model=model, 
                                       dataloader=test_dataloader, 
                                       loss_fn=loss_fn,
                                       device=device)

        if epoch % 1 == 0:
            print(
                f"Epoch: {epoch+1} | " 
                f"train_loss: {train_loss:.4f} | "
                f"train_acc: {train_acc:.4f} | "
                f"test_loss: {test_loss:.4f} | "
                f"test_acc: {test_acc:.4f}"
            )

        # update results dictionary 
        results["train_loss"].append(train_loss.item())
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

        if writer: 
            # NEW: EXPERIMENT TRACKING 
            # add loss to SummaryWriter
            writer.add_scalars(main_tag="Loss", tag_scalar_dict={"train loss": train_loss, "test loss": test_loss}, global_step=epoch)
            # add accuracy to SummaryWriter 
            writer.add_scalars(main_tag="Accuracy", tag_scalar_dict={"train acc": train_acc, "test acc": test_acc}, global_step=epoch)
            # track the pytorch model architecture 
            writer.add_graph(model=model, input_to_model=torch.randn(size=(32, 3, 224, 224)).to(device))
            writer.close()
    # END SummaryWriter tracking process

    # return the filled results dictionaru 
    return results