|
import torch |
|
import torch.nn as nn |
|
from DeePixBis.Metrics import test_accuracy, test_loss |
|
|
|
|
|
class Trainer(): |
|
def __init__(self, train_dl, val_dl, model, epochs, opt, loss_fn, device='cpu'): |
|
self.train_dl = train_dl |
|
self.val_dl = val_dl |
|
self.model = model.to(device) |
|
self.epochs = epochs |
|
self.opt = opt |
|
self.loss_fn = loss_fn |
|
self.device = device |
|
|
|
def train_one_epoch(self, num): |
|
print(f'\nEpoch ({num+1}/{self.epochs})') |
|
print('----------------------------------') |
|
|
|
for batch, (img, mask, label) in enumerate(self.train_dl): |
|
img, mask, label = img.to(self.device), mask.to(self.device), label.to(self.device) |
|
net_mask, net_label = self.model(img) |
|
loss = self.loss_fn(net_mask, net_label, mask, label) |
|
|
|
|
|
self.opt.zero_grad() |
|
loss.backward() |
|
self.opt.step() |
|
|
|
if batch % 9 == 0: |
|
print(f'Loss : {loss}') |
|
|
|
|
|
test_acc = test_accuracy(self.model, self.val_dl) |
|
test_los = test_loss(self.model, self.val_dl, self.loss_fn) |
|
|
|
print(f'Test Accuracy : {test_acc} Test Loss : {test_los}') |
|
return test_acc, test_los |
|
|
|
def fit(self): |
|
training_loss = [] |
|
training_acc = [] |
|
self.model.train() |
|
for epoch in range(self.epochs): |
|
train_acc, train_loss = self.train_one_epoch(epoch) |
|
training_acc.append(train_acc) |
|
training_loss.append(train_loss) |
|
|
|
return training_acc, training_loss |
|
|