Spaces:
Runtime error
Runtime error
File size: 1,926 Bytes
44504f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from models import (
SigmoidNNAutoencoder,
TanhNNAutoencoder,
TanhPNAutoencoder,
ReLUNNAutoencoder,
ReLUPNAutoencoder,
TanhSwishNNAutoencoder,
ReLUSigmoidNRAutoencoder,
ReLUSigmoidRRAutoencoder,
)
from tqdm import tqdm
def get_network(name):
match name:
case "nn_sigmoid":
return SigmoidNNAutoencoder()
case "nn_tanh":
return TanhNNAutoencoder()
case "pn_tanh":
return TanhPNAutoencoder()
case "nn_relu":
return ReLUNNAutoencoder()
case "pn_relu":
return ReLUPNAutoencoder()
case "nn_tanh_swish":
return TanhSwishNNAutoencoder()
case "nr_relu_sigmoid":
return ReLUSigmoidNRAutoencoder()
case "rr_relu_sigmoid":
return ReLUSigmoidRRAutoencoder()
case _:
raise NotImplementedError(
f"Autoencoder of name '{name}' currently is not supported"
)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def epoch(loader, model, device, criterion, opt=None):
losses = AverageMeter()
if opt is None:
model.eval()
else:
model.train()
for inputs, _ in tqdm(loader, leave=False):
inputs = inputs.view(-1, 28 * 28).to(device)
outputs = model(inputs)
loss = criterion(outputs, inputs)
if opt:
opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()
model.clamp()
losses.update(loss.item(), inputs.size(0))
return losses.avg
|