File size: 1,926 Bytes
44504f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from models import (
    SigmoidNNAutoencoder,
    TanhNNAutoencoder,
    TanhPNAutoencoder,
    ReLUNNAutoencoder,
    ReLUPNAutoencoder,
    TanhSwishNNAutoencoder,
    ReLUSigmoidNRAutoencoder,
    ReLUSigmoidRRAutoencoder,
)
from tqdm import tqdm


def get_network(name):
    match name:
        case "nn_sigmoid":
            return SigmoidNNAutoencoder()
        case "nn_tanh":
            return TanhNNAutoencoder()
        case "pn_tanh":
            return TanhPNAutoencoder()
        case "nn_relu":
            return ReLUNNAutoencoder()
        case "pn_relu":
            return ReLUPNAutoencoder()
        case "nn_tanh_swish":
            return TanhSwishNNAutoencoder()
        case "nr_relu_sigmoid":
            return ReLUSigmoidNRAutoencoder()
        case "rr_relu_sigmoid":
            return ReLUSigmoidRRAutoencoder()
        case _:
            raise NotImplementedError(
                f"Autoencoder of name '{name}' currently is not supported"
            )


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def epoch(loader, model, device, criterion, opt=None):
    losses = AverageMeter()

    if opt is None:
        model.eval()
    else:
        model.train()
    for inputs, _ in tqdm(loader, leave=False):
        inputs = inputs.view(-1, 28 * 28).to(device)
        outputs = model(inputs)
        loss = criterion(outputs, inputs)
        if opt:
            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()
            model.clamp()

        losses.update(loss.item(), inputs.size(0))

    return losses.avg