ragebhanukiran's picture
initial commit
b338c44 verified
Raw
History Blame Contribute Delete
4.58 kB
import torch
import torchvision
import matplotlib.pyplot as plt
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
from torch.utils.data import DataLoader
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision import transforms as T
from scripts.dataset import COCODataset, get_transforms
# Paths to dataset
DATASET_ROOT = "dataset/train"
ANNOTATION_FILE = "dataset/annotations/instances_train.json"
# Function to get the model with fixed NMS
def get_model(num_classes):
model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1)
# Replace the classifier with a new one (custom num_classes)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# Fix excessive bounding boxes by lowering NMS
model.rpn.nms_thresh = 0.3 # Default ~0.7
model.roi_heads.nms_thresh = 0.3
return model
# Define dataset and data loader
dataset = COCODataset(root=DATASET_ROOT, annotation_file=ANNOTATION_FILE, transforms=get_transforms())
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda batch: tuple(zip(*batch)))
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the number of classes (5 classes + 1 background)
num_classes = 6
model = get_model(num_classes).to(device)
# Optimizer and learning rate scheduler
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, lr=0.002, weight_decay=0.0001)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
# Initialize Mean Average Precision (mAP) metric
metric = MeanAveragePrecision()
# Function to evaluate the model
def evaluate_model(model, dataloader, device):
model.eval()
metric.reset()
with torch.no_grad():
for images, targets in dataloader:
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
outputs = model(images)
preds = [{"boxes": o["boxes"].cpu(), "scores": o["scores"].cpu(), "labels": o["labels"].cpu()} for o in outputs]
gts = [{"boxes": t["boxes"].cpu(), "labels": t["labels"].cpu()} for t in targets]
metric.update(preds, gts)
result = metric.compute()
return result["map"].item(), result["map_50"].item(), result["map_75"].item()
# Lists to store training loss and accuracy
losses_list = []
mAP_list = []
mAP50_list = []
mAP75_list = []
# Training loop
num_epochs = 15
for epoch in range(num_epochs):
model.train()
epoch_loss = 0.0
for images, targets in dataloader:
images = [image.to(device) for image in images]
targets = [{k: v.to(device) for k, v in target.items()} for target in targets]
optimizer.zero_grad()
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
epoch_loss += losses.item()
losses.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
optimizer.step()
lr_scheduler.step()
losses_list.append(epoch_loss)
# Compute mAP
mAP, mAP50, mAP75 = evaluate_model(model, dataloader, device)
mAP_list.append(mAP)
mAP50_list.append(mAP50)
mAP75_list.append(mAP75)
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
print(f"Epoch {epoch+1}: mAP: {mAP:.4f}, mAP@50: {mAP50:.4f}, mAP@75: {mAP75:.4f}")
# Save model
torch.save(model.state_dict(), "models/fasterrcnn_model.pth")
print("Training complete! Model saved as 'fasterrcnn_model.pth'.")
# Plot Loss Graph
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), losses_list, marker='o', linestyle='-')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training Loss over Epochs")
plt.grid()
plt.show()
# Plot mAP Graphs
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), mAP_list, marker='o', linestyle='-', label="mAP")
plt.plot(range(1, num_epochs + 1), mAP50_list, marker='s', linestyle='-', label="mAP@50")
plt.plot(range(1, num_epochs + 1), mAP75_list, marker='d', linestyle='-', label="mAP@75")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("mAP over Epochs")
plt.legend()
plt.grid()
plt.show()