Timber-identification-CNN / S4_Training.py
Yapp99's picture
Included project files
687ef3d
from collections import Counter
from tqdm import tqdm
from time import time
import torch
from torch import nn, Tensor
from S3_intermediateDataset import build_intermediate_dataset_if_not_exists, intermediate_dataset
from S2_TimberDataset import build_dataloader
from S1_CNN_Model import build_model
if __name__ == '__main__':
img_size = (320,320)
train_loader, val_loader = build_dataloader(
# train_ratio= 0.005,
img_size=img_size,
batch_size=16,
)
build_intermediate_dataset_if_not_exists(lambda x:x, "train", train_loader)
build_intermediate_dataset_if_not_exists(lambda x:x, "val", val_loader)
train_loader = intermediate_dataset("train")
val_loader = intermediate_dataset("val")
model = build_model(img_size=img_size)
model.train()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
accuracies = []
n_epoch = 40
timer = time()
pbar_0 = tqdm(range(n_epoch), position=0, ncols=100)
pbar_0.set_description(f"epoch 1/{n_epoch}")
img = None
for epoch in pbar_0:
pbar_1 = tqdm(enumerate(train_loader), total=len(train_loader), position=1, ncols=100, leave=False)
for i, (images, labels) in pbar_1:
# # Reshape
images = images.reshape(images.shape[1:])
labels = labels.reshape(labels.shape[1:])
# Forward
out = model.forward(images)
loss = criterion.forward(out, labels)
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 10 == 0:
pbar_1.set_description(f" loss = {loss.item():.4f} ({(time() - timer)*1000:.4f} ms)")
timer = time()
n_correct = 0
n_samples = 0
pbar_0.set_description(f"epoch {epoch+1}/{n_epoch}, Validating . . .")
with torch.no_grad():
with tqdm(bar_format='{desc}{postfix}', position=1, leave=False) as val_desc:
tally = Counter()
for images, labels in tqdm(val_loader, position=2, ncols=100, leave=False):
# # Reshape
images = images.reshape(images.shape[1:])
labels = labels.reshape(labels.shape[1:])
x = model.forward(images)
_, predictions = torch.max(x,1)
tally += Counter(predictions.tolist())
n_samples += labels.shape[0]
n_correct += (predictions == labels).sum().item()
tally_desc = ' '.join([f"{n}:{c}" for n,c in tally.most_common()])[:80] + "..."
val_desc.set_description(f"{n_correct}/{n_samples} correct")
val_desc.set_postfix_str(tally_desc)
accuracy = f"{n_correct/n_samples * 100:.2f}%"
pbar_0.set_description(f"epoch {epoch+2}/{n_epoch}, accuracy: {accuracy}")
if len(accuracies) >= 3 and accuracy > max(accuracies):
torch.save(model,f"model_{epoch}.pt")
accuracies.append(accuracy)
torch.save(model,"model.pt")
print(accuracies)