File size: 3,277 Bytes
687ef3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from collections import Counter
from tqdm import tqdm
from time import time
import torch
from torch import nn, Tensor
from S3_intermediateDataset import build_intermediate_dataset_if_not_exists, intermediate_dataset
from S2_TimberDataset import build_dataloader
from S1_CNN_Model import build_model

if __name__ == '__main__':
    img_size = (320,320)
    train_loader, val_loader = build_dataloader(
        # train_ratio= 0.005,
        img_size=img_size,
        batch_size=16,
    )
    
    build_intermediate_dataset_if_not_exists(lambda x:x, "train", train_loader)
    build_intermediate_dataset_if_not_exists(lambda x:x, "val", val_loader)
    
    train_loader = intermediate_dataset("train") 
    val_loader = intermediate_dataset("val") 

    model = build_model(img_size=img_size)
    model.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

    accuracies = []

    n_epoch = 40
    timer = time()
    pbar_0 = tqdm(range(n_epoch), position=0, ncols=100)
    pbar_0.set_description(f"epoch 1/{n_epoch}")
    img = None
    for epoch in pbar_0:
        pbar_1 = tqdm(enumerate(train_loader), total=len(train_loader), position=1, ncols=100, leave=False)
        for i, (images, labels) in pbar_1:
            # # Reshape
            images = images.reshape(images.shape[1:])
            labels = labels.reshape(labels.shape[1:])
            # Forward
            out = model.forward(images)
            loss = criterion.forward(out, labels)

            # Backward
            optimizer.zero_grad()
            loss.backward()
            
            optimizer.step()

            if i % 10 == 0:
                pbar_1.set_description(f"  loss = {loss.item():.4f} ({(time() - timer)*1000:.4f} ms)")
                timer = time()
        
        n_correct = 0
        n_samples = 0
        pbar_0.set_description(f"epoch {epoch+1}/{n_epoch}, Validating . . .")

        with torch.no_grad(): 
            with tqdm(bar_format='{desc}{postfix}', position=1, leave=False) as val_desc:
                tally = Counter()
                for images, labels in tqdm(val_loader, position=2, ncols=100, leave=False):
                    # # Reshape
                    images = images.reshape(images.shape[1:])
                    labels = labels.reshape(labels.shape[1:])

                    x = model.forward(images)

                    _, predictions = torch.max(x,1)
                    tally += Counter(predictions.tolist())
                    n_samples += labels.shape[0]
                    n_correct += (predictions == labels).sum().item()
                    tally_desc = ' '.join([f"{n}:{c}" for n,c in tally.most_common()])[:80] + "..."
                    val_desc.set_description(f"{n_correct}/{n_samples} correct")
                    val_desc.set_postfix_str(tally_desc)
                
                accuracy = f"{n_correct/n_samples * 100:.2f}%"
                pbar_0.set_description(f"epoch {epoch+2}/{n_epoch}, accuracy: {accuracy}")
                
                if len(accuracies) >= 3 and accuracy > max(accuracies):
                    torch.save(model,f"model_{epoch}.pt")

                accuracies.append(accuracy)

    torch.save(model,"model.pt")
    
    print(accuracies)