| import torch
|
| import torchvision
|
| import matplotlib.pyplot as plt
|
| from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
| from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
|
| from torch.utils.data import DataLoader
|
| from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
| from torchvision import transforms as T
|
| from scripts.dataset import COCODataset, get_transforms
|
|
|
|
|
| DATASET_ROOT = "dataset/train"
|
| ANNOTATION_FILE = "dataset/annotations/instances_train.json"
|
|
|
|
|
| def get_model(num_classes):
|
| model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1)
|
|
|
|
|
| in_features = model.roi_heads.box_predictor.cls_score.in_features
|
| model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
|
|
|
|
| model.rpn.nms_thresh = 0.3
|
| model.roi_heads.nms_thresh = 0.3
|
|
|
| return model
|
|
|
|
|
| dataset = COCODataset(root=DATASET_ROOT, annotation_file=ANNOTATION_FILE, transforms=get_transforms())
|
| dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda batch: tuple(zip(*batch)))
|
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
| num_classes = 6
|
| model = get_model(num_classes).to(device)
|
|
|
|
|
| params = [p for p in model.parameters() if p.requires_grad]
|
| optimizer = torch.optim.AdamW(params, lr=0.002, weight_decay=0.0001)
|
| lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
|
|
|
|
|
| metric = MeanAveragePrecision()
|
|
|
|
|
| def evaluate_model(model, dataloader, device):
|
| model.eval()
|
| metric.reset()
|
|
|
| with torch.no_grad():
|
| for images, targets in dataloader:
|
| images = [img.to(device) for img in images]
|
| targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
|
|
| outputs = model(images)
|
| preds = [{"boxes": o["boxes"].cpu(), "scores": o["scores"].cpu(), "labels": o["labels"].cpu()} for o in outputs]
|
| gts = [{"boxes": t["boxes"].cpu(), "labels": t["labels"].cpu()} for t in targets]
|
|
|
| metric.update(preds, gts)
|
|
|
| result = metric.compute()
|
| return result["map"].item(), result["map_50"].item(), result["map_75"].item()
|
|
|
|
|
| losses_list = []
|
| mAP_list = []
|
| mAP50_list = []
|
| mAP75_list = []
|
|
|
|
|
| num_epochs = 15
|
| for epoch in range(num_epochs):
|
| model.train()
|
| epoch_loss = 0.0
|
|
|
| for images, targets in dataloader:
|
| images = [image.to(device) for image in images]
|
| targets = [{k: v.to(device) for k, v in target.items()} for target in targets]
|
|
|
| optimizer.zero_grad()
|
| loss_dict = model(images, targets)
|
| losses = sum(loss for loss in loss_dict.values())
|
| epoch_loss += losses.item()
|
|
|
| losses.backward()
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
|
| optimizer.step()
|
|
|
| lr_scheduler.step()
|
| losses_list.append(epoch_loss)
|
|
|
|
|
| mAP, mAP50, mAP75 = evaluate_model(model, dataloader, device)
|
| mAP_list.append(mAP)
|
| mAP50_list.append(mAP50)
|
| mAP75_list.append(mAP75)
|
|
|
| print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
| print(f"Epoch {epoch+1}: mAP: {mAP:.4f}, mAP@50: {mAP50:.4f}, mAP@75: {mAP75:.4f}")
|
|
|
|
|
| torch.save(model.state_dict(), "models/fasterrcnn_model.pth")
|
| print("Training complete! Model saved as 'fasterrcnn_model.pth'.")
|
|
|
|
|
| plt.figure(figsize=(10, 5))
|
| plt.plot(range(1, num_epochs + 1), losses_list, marker='o', linestyle='-')
|
| plt.xlabel("Epochs")
|
| plt.ylabel("Loss")
|
| plt.title("Training Loss over Epochs")
|
| plt.grid()
|
| plt.show()
|
|
|
|
|
| plt.figure(figsize=(10, 5))
|
| plt.plot(range(1, num_epochs + 1), mAP_list, marker='o', linestyle='-', label="mAP")
|
| plt.plot(range(1, num_epochs + 1), mAP50_list, marker='s', linestyle='-', label="mAP@50")
|
| plt.plot(range(1, num_epochs + 1), mAP75_list, marker='d', linestyle='-', label="mAP@75")
|
| plt.xlabel("Epochs")
|
| plt.ylabel("Accuracy")
|
| plt.title("mAP over Epochs")
|
| plt.legend()
|
| plt.grid()
|
| plt.show()
|
|
|