| from src.model import Classifier |
| from src.dataloader import ImageDataset,collate_fn |
| from torch.utils.data import DataLoader |
| import torch.optim as optim |
| import torch.nn.functional as F |
| from tqdm import tqdm |
| import matplotlib.pyplot as plt |
| import torch |
| import random |
| import numpy as np |
| import torch.nn as nn |
| import time |
|
|
| def seed_worker(worker_id): |
| worker_seed = torch.initial_seed() % 2**32 |
| np.random.seed(worker_seed) |
| random.seed(worker_seed) |
|
|
| class ModelTrainer: |
| def __init__(self,model : Classifier,train_set : ImageDataset,val_set : ImageDataset = None, batch_size=32,lr = 1e-3,device='cpu',return_fig=False, seed=None): |
| g = torch.Generator() |
| if seed is not None: |
| g.manual_seed(seed) |
| |
| self.train_loader = DataLoader( |
| train_set, |
| batch_size, |
| shuffle=True, |
| collate_fn=collate_fn, |
| worker_init_fn=seed_worker, |
| generator=g |
| ) |
| |
| self.device = device |
| |
| if val_set is not None: |
| self.val_loader = DataLoader( |
| val_set, |
| batch_size, |
| shuffle=False, |
| collate_fn=collate_fn, |
| worker_init_fn=seed_worker |
| ) |
| else: |
| self.val_loader = None |
| self.class_names = model.classes |
| self.model = model |
| self.lr = lr |
| self.optim = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) |
| self.optim.zero_grad() |
| self.criterion = nn.CrossEntropyLoss() |
| self.return_fig=return_fig |
| |
| def visualize_batch(self, imgs, preds, labels, class_names=None, max_samples=4): |
| |
| first_image = imgs |
| if isinstance(imgs, list): |
| imgs = np.stack(imgs, axis=0) |
| imgs = torch.from_numpy(imgs).permute(0, 3, 1, 2).float() |
|
|
| imgs_np = imgs.cpu().numpy() |
| preds = preds.cpu().numpy() |
| labels = labels.cpu().numpy() |
|
|
| batch_size = imgs_np.shape[0] |
| indices = random.sample(range(batch_size), min(max_samples, batch_size)) |
| first_image = first_image[indices[0]] |
| fig_pred = plt.figure(figsize=(6 * len(indices), 5)) |
| grid = fig_pred.add_gridspec(1, len(indices)) |
|
|
| for col, idx in enumerate(indices): |
| ax = fig_pred.add_subplot(grid[0, col]) |
| ax.imshow(imgs_np[idx].transpose(1, 2, 0)) |
|
|
| if class_names: |
| title = f"P: {class_names[preds[idx]]} | T: {class_names[labels[idx]]}" |
| else: |
| title = f"P: {preds[idx]} | T: {labels[idx]}" |
|
|
| ax.set_title(title) |
| ax.axis("off") |
|
|
| fig_pred.tight_layout() |
| raw_features = self.model.visualize_feature(first_image,show=False) |
| feature_figs = [] |
|
|
| for f in raw_features: |
| |
| if isinstance(f, plt.Figure): |
| feature_figs.append(f) |
| continue |
|
|
| if hasattr(f, "mode"): |
| f = np.array(f) |
| h, w = f.shape[:2] |
|
|
| dpi = 100 |
| fig_w = max(4, w / dpi) |
| fig_h = max(4, h / dpi) |
| fig = plt.figure(figsize=(fig_w, fig_h), dpi=dpi) |
| ax = fig.add_subplot(111) |
| ax.imshow(f) |
| ax.axis("off") |
| feature_figs.append(fig) |
| |
|
|
| all_figs = [fig_pred] + feature_figs |
| if not self.return_fig: |
| plt.show() |
| plt.close(fig_pred) |
| if self.return_fig: |
| return all_figs |
| else: |
| return None |
|
|
|
|
| def train_one_epoch(self): |
| self.model.train() |
| total_loss = 0 |
| train_pbar = tqdm(self.train_loader, desc="Training",leave=False) |
| correct = 0 |
| total = 0 |
| for imgs, labels in train_pbar: |
| labels = labels.to(self.device) |
|
|
| |
| outputs = self.model(imgs) |
| loss = self.criterion(outputs, labels) |
|
|
| |
| self.optim.zero_grad() |
| loss.backward() |
| self.optim.step() |
| preds = outputs.argmax(dim=1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
| total_loss += loss.item() |
| train_pbar.set_postfix(acc=correct/total,loss=loss.item()) |
|
|
| avg_loss = total_loss / len(self.train_loader) |
| avg_acc = correct / total |
| return avg_loss,avg_acc |
| def train(self, epochs=10, visualize_every=5): |
| train_losses=[] |
| train_accuracies=[] |
| val_losses=[] |
| val_accuracies=[] |
| for epoch in range(1, epochs + 1): |
| train_loss,train_acc = self.train_one_epoch() |
| train_losses.append(train_loss) |
| train_accuracies.append(train_acc) |
| if self.val_loader is not None: |
| val_loss,val_acc,fig=self.validate(epoch, visualize=(epoch % visualize_every == 0 or epoch == 1)) |
| val_losses.append(val_loss) |
| val_accuracies.append(val_acc) |
| print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f} | Val Loss : {val_loss:.4f} | Val Acc : {val_acc:.4f}") |
| yield train_loss,train_acc,val_loss,val_acc,fig |
| else: |
| print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f}") |
| yield train_loss,train_acc,None,None,None |
| yield train_losses,train_accuracies,val_losses,val_accuracies,None |
|
|
| def validate(self,epoch, visualize=False): |
| if self.val_loader is None: |
| return |
|
|
| self.model.eval() |
| total_loss = 0 |
| correct = 0 |
| total = 0 |
|
|
| val_imgs_display = None |
| val_preds_display = None |
| val_labels_display = None |
|
|
| val_pbar = tqdm(self.val_loader, desc="Validation",leave=False) |
| fig = None |
| with torch.no_grad(): |
| for imgs, labels in val_pbar: |
| labels = labels.to(self.device) |
|
|
| outputs = self.model(imgs) |
| loss = self.criterion(outputs, labels) |
| total_loss += loss.item() |
|
|
| preds = outputs.argmax(dim=1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
|
|
| if visualize and val_imgs_display is None: |
| val_imgs_display = imgs |
| val_preds_display = preds |
| val_labels_display = labels |
|
|
| val_pbar.set_postfix(loss=loss.item(), acc=correct / total) |
|
|
| avg_loss = total_loss / len(self.val_loader) |
| acc = correct / total |
|
|
| if visualize and val_imgs_display is not None: |
| fig = self.visualize_batch(val_imgs_display, val_preds_display, val_labels_display, self.class_names) |
|
|
| return avg_loss,acc,fig |