Spaces:
Running
Running
import torch | |
def train( | |
model, | |
device, | |
train_loader, | |
criterion, | |
optimizer, | |
epoch, | |
train_loss, | |
train_acc, | |
mse=None, | |
): | |
model.train() | |
curr_loss = 0 | |
t_pred = 0 | |
for batch_idx, (images, targets) in enumerate(train_loader): | |
images, targets = images.to(device), targets.to(device) | |
optimizer.zero_grad() | |
output = model(images).squeeze() | |
loss = criterion(output, targets) | |
loss.backward() | |
optimizer.step() | |
curr_loss += loss.sum().item() | |
_, preds = torch.max(output, 1) | |
t_pred += torch.sum(preds == targets.data).item() | |
if batch_idx % 10 == 0: | |
print( | |
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( | |
epoch, | |
batch_idx * len(images), | |
len(train_loader.dataset), | |
100.0 * batch_idx / len(train_loader), | |
loss.item(), | |
) | |
) | |
train_loss.append(loss.sum().item() / len(images)) | |
train_acc.append(preds.sum().item() / len(images)) | |
epoch_loss = curr_loss / len(train_loader.dataset) | |
epoch_acc = t_pred / len(train_loader.dataset) | |
train_loss.append(epoch_loss) | |
train_acc.append(epoch_acc) | |
print( | |
"\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( | |
epoch_loss, | |
t_pred, | |
len(train_loader.dataset), | |
100.0 * t_pred / len(train_loader.dataset), | |
) | |
) | |
return train_loss, train_acc, epoch_loss | |
def valid( | |
model, device, test_loader, criterion, epoch, valid_loss, valid_acc, mse=None | |
): | |
model.eval() | |
test_loss = 0 | |
correct = 0 | |
with torch.no_grad(): | |
for batch_idx, (images, targets) in enumerate(test_loader): | |
images, targets = images.to(device), targets.to(device) | |
output = model(images).squeeze() | |
loss = criterion(output, targets) | |
test_loss += loss.sum().item() | |
_, preds = torch.max(output, 1) | |
correct += torch.sum(preds == targets.data) | |
if batch_idx % 10 == 0: | |
print( | |
"Valid Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( | |
epoch, | |
batch_idx * len(images), | |
len(test_loader.dataset), | |
100.0 * batch_idx / len(test_loader), | |
loss.item(), | |
) | |
) | |
valid_loss.append(loss.sum().item() / len(images)) | |
valid_acc.append(preds.sum().item() / len(images)) | |
epoch_loss = test_loss / len(test_loader.dataset) | |
epoch_acc = correct / len(test_loader.dataset) | |
valid_loss.append(epoch_loss) | |
valid_acc.append(epoch_acc.item()) | |
print( | |
"Valid Set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( | |
epoch_loss, | |
correct, | |
len(test_loader.dataset), | |
100.0 * correct / len(test_loader.dataset), | |
) | |
) | |
return valid_loss, valid_acc | |