import torch import torchvision import matplotlib.pyplot as plt from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights from torch.utils.data import DataLoader from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchvision import transforms as T from scripts.dataset import COCODataset, get_transforms # Paths to dataset DATASET_ROOT = "dataset/train" ANNOTATION_FILE = "dataset/annotations/instances_train.json" # Function to get the model with fixed NMS def get_model(num_classes): model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1) # Replace the classifier with a new one (custom num_classes) in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) # Fix excessive bounding boxes by lowering NMS model.rpn.nms_thresh = 0.3 # Default ~0.7 model.roi_heads.nms_thresh = 0.3 return model # Define dataset and data loader dataset = COCODataset(root=DATASET_ROOT, annotation_file=ANNOTATION_FILE, transforms=get_transforms()) dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda batch: tuple(zip(*batch))) # Device configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define the number of classes (5 classes + 1 background) num_classes = 6 model = get_model(num_classes).to(device) # Optimizer and learning rate scheduler params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.AdamW(params, lr=0.002, weight_decay=0.0001) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) # Initialize Mean Average Precision (mAP) metric metric = MeanAveragePrecision() # Function to evaluate the model def evaluate_model(model, dataloader, device): model.eval() metric.reset() with torch.no_grad(): for images, targets in dataloader: images = [img.to(device) for img in images] targets = [{k: v.to(device) for k, v in t.items()} for t in targets] outputs = model(images) preds = [{"boxes": o["boxes"].cpu(), "scores": o["scores"].cpu(), "labels": o["labels"].cpu()} for o in outputs] gts = [{"boxes": t["boxes"].cpu(), "labels": t["labels"].cpu()} for t in targets] metric.update(preds, gts) result = metric.compute() return result["map"].item(), result["map_50"].item(), result["map_75"].item() # Lists to store training loss and accuracy losses_list = [] mAP_list = [] mAP50_list = [] mAP75_list = [] # Training loop num_epochs = 15 for epoch in range(num_epochs): model.train() epoch_loss = 0.0 for images, targets in dataloader: images = [image.to(device) for image in images] targets = [{k: v.to(device) for k, v in target.items()} for target in targets] optimizer.zero_grad() loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) epoch_loss += losses.item() losses.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0) optimizer.step() lr_scheduler.step() losses_list.append(epoch_loss) # Compute mAP mAP, mAP50, mAP75 = evaluate_model(model, dataloader, device) mAP_list.append(mAP) mAP50_list.append(mAP50) mAP75_list.append(mAP75) print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}") print(f"Epoch {epoch+1}: mAP: {mAP:.4f}, mAP@50: {mAP50:.4f}, mAP@75: {mAP75:.4f}") # Save model torch.save(model.state_dict(), "models/fasterrcnn_model.pth") print("Training complete! Model saved as 'fasterrcnn_model.pth'.") # Plot Loss Graph plt.figure(figsize=(10, 5)) plt.plot(range(1, num_epochs + 1), losses_list, marker='o', linestyle='-') plt.xlabel("Epochs") plt.ylabel("Loss") plt.title("Training Loss over Epochs") plt.grid() plt.show() # Plot mAP Graphs plt.figure(figsize=(10, 5)) plt.plot(range(1, num_epochs + 1), mAP_list, marker='o', linestyle='-', label="mAP") plt.plot(range(1, num_epochs + 1), mAP50_list, marker='s', linestyle='-', label="mAP@50") plt.plot(range(1, num_epochs + 1), mAP75_list, marker='d', linestyle='-', label="mAP@75") plt.xlabel("Epochs") plt.ylabel("Accuracy") plt.title("mAP over Epochs") plt.legend() plt.grid() plt.show()