Spaces:
Sleeping
Sleeping
from collections import Counter | |
from tqdm import tqdm | |
from time import time | |
import torch | |
from torch import nn, Tensor | |
from S3_intermediateDataset import build_intermediate_dataset_if_not_exists, intermediate_dataset | |
from S2_TimberDataset import build_dataloader | |
from S1_CNN_Model import build_model | |
if __name__ == '__main__': | |
img_size = (320,320) | |
train_loader, val_loader = build_dataloader( | |
# train_ratio= 0.005, | |
img_size=img_size, | |
batch_size=16, | |
) | |
build_intermediate_dataset_if_not_exists(lambda x:x, "train", train_loader) | |
build_intermediate_dataset_if_not_exists(lambda x:x, "val", val_loader) | |
train_loader = intermediate_dataset("train") | |
val_loader = intermediate_dataset("val") | |
model = build_model(img_size=img_size) | |
model.train() | |
criterion = nn.CrossEntropyLoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) | |
accuracies = [] | |
n_epoch = 40 | |
timer = time() | |
pbar_0 = tqdm(range(n_epoch), position=0, ncols=100) | |
pbar_0.set_description(f"epoch 1/{n_epoch}") | |
img = None | |
for epoch in pbar_0: | |
pbar_1 = tqdm(enumerate(train_loader), total=len(train_loader), position=1, ncols=100, leave=False) | |
for i, (images, labels) in pbar_1: | |
# # Reshape | |
images = images.reshape(images.shape[1:]) | |
labels = labels.reshape(labels.shape[1:]) | |
# Forward | |
out = model.forward(images) | |
loss = criterion.forward(out, labels) | |
# Backward | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
if i % 10 == 0: | |
pbar_1.set_description(f" loss = {loss.item():.4f} ({(time() - timer)*1000:.4f} ms)") | |
timer = time() | |
n_correct = 0 | |
n_samples = 0 | |
pbar_0.set_description(f"epoch {epoch+1}/{n_epoch}, Validating . . .") | |
with torch.no_grad(): | |
with tqdm(bar_format='{desc}{postfix}', position=1, leave=False) as val_desc: | |
tally = Counter() | |
for images, labels in tqdm(val_loader, position=2, ncols=100, leave=False): | |
# # Reshape | |
images = images.reshape(images.shape[1:]) | |
labels = labels.reshape(labels.shape[1:]) | |
x = model.forward(images) | |
_, predictions = torch.max(x,1) | |
tally += Counter(predictions.tolist()) | |
n_samples += labels.shape[0] | |
n_correct += (predictions == labels).sum().item() | |
tally_desc = ' '.join([f"{n}:{c}" for n,c in tally.most_common()])[:80] + "..." | |
val_desc.set_description(f"{n_correct}/{n_samples} correct") | |
val_desc.set_postfix_str(tally_desc) | |
accuracy = f"{n_correct/n_samples * 100:.2f}%" | |
pbar_0.set_description(f"epoch {epoch+2}/{n_epoch}, accuracy: {accuracy}") | |
if len(accuracies) >= 3 and accuracy > max(accuracies): | |
torch.save(model,f"model_{epoch}.pt") | |
accuracies.append(accuracy) | |
torch.save(model,"model.pt") | |
print(accuracies) |