Alex Hortua commited on
Commit
b87aa54
Β·
1 Parent(s): 9a6ea32

Creating a faster version with a different approach (Training with a frozen Backbone of COCO images)

Browse files
.gitignore CHANGED
@@ -2,7 +2,11 @@
2
  /Annotations
3
  .qodo
4
  venv/
 
 
 
 
 
 
5
  /datasets/annotations/*
6
  /datasets/images/*
7
- src/__pycache__/
8
- models/lego
 
2
  /Annotations
3
  .qodo
4
  venv/
5
+ src/__pycache__/
6
+ models/lego
7
+ models/records/*
8
+ models/records
9
+
10
+ # Datasets
11
  /datasets/annotations/*
12
  /datasets/images/*
 
 
datasets/examples.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "examples": [
3
+ [
4
+ "datasets/test_images/0abd88dc-e306-11eb-b5b0-b0c090bd3910.jpg",
5
+ "[[373, 523, 438, 599], [78, 444, 278, 563], [124, 0, 221, 63], [471, 156, 535, 213]]"
6
+ ],
7
+ [
8
+ "datasets/test_images/0abe1e54-e691-11eb-8391-b0c090bd3910.jpg",
9
+ "[[29, 82, 270, 299], [23, 0, 78, 80]]"
10
+ ],
11
+ [
12
+ "datasets/test_images/0abd3f80-daff-11eb-8755-3497f683a169.jpg",
13
+ "[[0, 77, 172, 317], [523, 202, 582, 277], [410, 112, 600, 544], [136, 39, 187, 110]]"
14
+ ],
15
+ [
16
+ "datasets/test_images/0abf3f3e-e4c8-11eb-8f0d-b0c090bd3910.jpg",
17
+ "[[207, 289, 238, 333], [338, 94, 599, 496], [0, 84, 136, 407], [73, 383, 141, 419]]"
18
+ ],
19
+ [
20
+ "datasets/test_images/0abdc2ee-d9cc-11eb-8cf3-3497f683a169.jpg",
21
+ "[[225, 437, 386, 600], [305, 0, 369, 100], [353, 346, 453, 445], [113, 28, 234, 134]]"
22
+ ],
23
+ [
24
+ "datasets/test_images/000abf76-e67a-11eb-b56d-b0c090bd3910.jpg",
25
+ "[[44, 167, 185, 300], [97, 0, 262, 167]]"
26
+ ],
27
+ [
28
+ "datasets/test_images/0abccbc6-e661-11eb-9915-b0c090bd3910.jpg",
29
+ "[[173, 124, 242, 194], [0, 3, 300, 296]]"
30
+ ],
31
+ [
32
+ "datasets/test_images/0abf4764-e480-11eb-b391-b0c090bd3910.jpg",
33
+ "[[418, 87, 599, 306], [0, 339, 154, 490], [230, 114, 353, 234], [173, 118, 275, 227]]"
34
+ ],
35
+ [
36
+ "datasets/test_images/0a82869a-e3c3-11eb-9d75-b0c090bd3910.jpg",
37
+ "[[387, 0, 536, 119], [378, 509, 463, 600], [94, 193, 288, 368], [74, 301, 237, 486]]"
38
+ ],
39
+ [
40
+ "datasets/test_images/0abd5fec-e5c6-11eb-9ac6-b0c090bd3910.jpg",
41
+ "[[32, 50, 299, 266], [0, 22, 78, 82]]"
42
+ ],
43
+ [
44
+ "datasets/test_images/0abfa74a-e2f3-11eb-abe7-b0c090bd3910.jpg",
45
+ "[[229, 104, 311, 169], [380, 0, 503, 118], [235, 236, 284, 304], [109, 434, 350, 600]]"
46
+ ],
47
+ [
48
+ "datasets/test_images/0abfa012-e660-11eb-8710-b0c090bd3910.jpg",
49
+ "[[0, 43, 299, 254], [50, 58, 215, 238]]"
50
+ ]
51
+ ]
52
+ }
datasets/test_images/0aa7d4a4-e675-11eb-98bd-b0c090bd3910.jpg DELETED
Binary file (21.7 kB)
 
datasets/test_images/0abfb048-e3a1-11eb-9018-b0c090bd3910.jpg DELETED
Binary file (53 kB)
 
logs/training_log.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "loss": [
3
+ 5521765.083993731,
4
+ 813867.9704230821
5
+ ],
6
+ "mAP": [
7
+ 0.7336118575736044,
8
+ 0.00042658527250095333
9
+ ]
10
+ }
logs/training_log.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Starting Epoch 1/10
2
+ Starting Epoch 1/10
3
+ Iteration 1, Loss: 1.19198739528656
4
+ Iteration 101, Loss: 27.87972640991211
5
+ Iteration 201, Loss: 7.156171798706055
6
+ Iteration 301, Loss: 8.546396255493164
7
+ Iteration 401, Loss: 1.7727022171020508
8
+ Iteration 501, Loss: 5.378680229187012
9
+ Iteration 601, Loss: 15.277275085449219
10
+ Iteration 701, Loss: 4.097675800323486
11
+ Iteration 801, Loss: 4.272053241729736
12
+ Iteration 901, Loss: 1.443131446838379
13
+ Starting Epoch 2/10
14
+ Iteration 1, Loss: 2.3286213874816895
15
+ Iteration 101, Loss: 1.8097801208496094
16
+ Iteration 201, Loss: 1.6668422222137451
17
+ Iteration 301, Loss: 2.1733906269073486
18
+ Iteration 401, Loss: 1.8349155187606812
19
+ Iteration 501, Loss: 0.9883778095245361
20
+ Iteration 601, Loss: 1.3832241296768188
21
+ Iteration 701, Loss: 1.6653320789337158
22
+ Iteration 801, Loss: 1.2079124450683594
23
+ Iteration 901, Loss: 28428.40625
24
+ Starting Epoch 3/10
25
+ Iteration 1, Loss: 1006.0529174804688
26
+ Iteration 101, Loss: 0.4594302177429199
27
+ Iteration 201, Loss: 0.4397536516189575
28
+ Iteration 301, Loss: 0.31954655051231384
29
+ Iteration 401, Loss: 0.4685922861099243
30
+ Iteration 501, Loss: 0.31720075011253357
31
+ Iteration 601, Loss: 0.2652203440666199
src/Attempt1/app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ import gradio as gr
5
+ import numpy as np
6
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
7
+ from torchvision.transforms import functional as F
8
+ from PIL import Image, ImageDraw
9
+
10
+ # Load Trained Model
11
+ def load_model(model_path):
12
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
13
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
14
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=2) # Background + 4 LEGO classes
15
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
16
+ model.eval()
17
+ return model
18
+
19
+ model = load_model("models/lego_fasterrcnn.pth")
20
+
21
+ def predict(image):
22
+ image = Image.fromarray(image).convert("RGB")
23
+ image_tensor = F.to_tensor(image).unsqueeze(0) # Add batch dimension
24
+
25
+ with torch.no_grad():
26
+ predictions = model(image_tensor)[0]
27
+
28
+ boxes = predictions['boxes'].cpu().numpy()
29
+ labels = predictions['labels'].cpu().numpy()
30
+ scores = predictions['scores'].cpu().numpy()
31
+
32
+ results = []
33
+ draw = ImageDraw.Draw(image)
34
+ for box, label, score in zip(boxes, labels, scores):
35
+ if score > 0.7: # Confidence threshold
36
+ results.append({
37
+ "box": box.tolist(),
38
+ "label": str(label),
39
+ "score": float(score)
40
+ })
41
+ draw.rectangle(box.tolist(), outline="red", width=3)
42
+ draw.text((box[0], box[1]), f"{label} ({score:.2f})", fill="red")
43
+
44
+ return image, results
45
+
46
+ def get_examples():
47
+ return [os.path.join("datasets/test_images", f) for f in os.listdir("datasets/test_images")]
48
+
49
+ # Gradio Interface
50
+ demo = gr.Interface(
51
+ fn=predict,
52
+ inputs=gr.Image(type="numpy"),
53
+ outputs=[gr.Image(type="pil"), gr.JSON()],
54
+ title="LEGO Detection with Faster R-CNN",
55
+ description="Upload an image and the model will detect LEGO bricks with bounding boxes.",
56
+ examples=get_examples()
57
+ )
58
+
59
+ demo.launch()
src/{dataset.py β†’ Attempt1/dataset.py} RENAMED
File without changes
src/{evaluate.py β†’ Attempt1/evaluate.py} RENAMED
File without changes
src/{train.py β†’ Attempt1/train.py} RENAMED
File without changes
src/{utils.py β†’ Attempt1/utils.py} RENAMED
File without changes
src/app.py CHANGED
@@ -1,59 +1,101 @@
1
- import os
2
  import torch
3
  import torchvision
 
4
  import gradio as gr
5
- import numpy as np
6
- from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
7
- from torchvision.transforms import functional as F
8
  from PIL import Image, ImageDraw
 
 
 
 
 
 
 
 
9
 
10
- # Load Trained Model
11
- def load_model(model_path):
12
- model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
13
- in_features = model.roi_heads.box_predictor.cls_score.in_features
14
- model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=2) # Background + 4 LEGO classes
15
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
16
- model.eval()
17
- return model
18
 
19
- model = load_model("models/lego_fasterrcnn.pth")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- def predict(image):
22
- image = Image.fromarray(image).convert("RGB")
23
- image_tensor = F.to_tensor(image).unsqueeze(0) # Add batch dimension
24
 
25
  with torch.no_grad():
26
- predictions = model(image_tensor)[0]
27
 
28
- boxes = predictions['boxes'].cpu().numpy()
29
- labels = predictions['labels'].cpu().numpy()
30
- scores = predictions['scores'].cpu().numpy()
31
 
32
- results = []
33
  draw = ImageDraw.Draw(image)
34
- for box, label, score in zip(boxes, labels, scores):
35
  if score > 0.5: # Confidence threshold
36
- results.append({
37
- "box": box.tolist(),
38
- "label": str(label),
39
- "score": float(score)
40
- })
41
- draw.rectangle(box.tolist(), outline="red", width=3)
42
- draw.text((box[0], box[1]), f"{label} ({score:.2f})", fill="red")
43
 
44
- return image, results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def get_examples():
47
- return [os.path.join("datasets/test_images", f) for f in os.listdir("datasets/test_images")]
 
 
 
 
 
48
 
49
- # Gradio Interface
50
  demo = gr.Interface(
51
  fn=predict,
52
- inputs=gr.Image(type="numpy"),
53
- outputs=[gr.Image(type="pil"), gr.JSON()],
54
- title="LEGO Detection with Faster R-CNN",
55
- description="Upload an image and the model will detect LEGO bricks with bounding boxes.",
56
- examples=get_examples()
57
  )
58
 
59
- demo.launch()
 
 
 
 
1
  import torch
2
  import torchvision
3
+ import torchvision.transforms as T
4
  import gradio as gr
 
 
 
5
  from PIL import Image, ImageDraw
6
+ import torchvision.ops as ops
7
+ import numpy as np
8
+ import json
9
+ import os
10
+
11
+
12
+ # LOAD_MODEL_PATH = "models/lego_fasterrcnn.pth"
13
+ LOAD_MODEL_PATH = "models/faster_rcnn_custom.pth"
14
 
15
+ # Load trained Faster R-CNN model
16
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
17
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
18
+ model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes=2)
19
+ model.load_state_dict(torch.load(LOAD_MODEL_PATH, map_location=torch.device("cpu")))
20
+ model.eval()
 
 
21
 
22
+ def compute_iou(box1, box2):
23
+ x1 = max(box1[0], box2[0])
24
+ y1 = max(box1[1], box2[1])
25
+ x2 = min(box1[2], box2[2])
26
+ y2 = min(box1[3], box2[3])
27
+
28
+ intersection = max(0, x2 - x1) * max(0, y2 - y1)
29
+ area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
30
+ area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
31
+
32
+ union = area_box1 + area_box2 - intersection
33
+ return intersection / union if union > 0 else 0
34
+
35
+ def mean_average_precision(predictions, ground_truths, iou_threshold=0.5):
36
+ iou_scores = []
37
+ for pred_box in predictions:
38
+ best_iou = 0
39
+ for gt_box in ground_truths:
40
+ iou = compute_iou(pred_box, gt_box)
41
+ best_iou = max(best_iou, iou)
42
+ if best_iou >= iou_threshold:
43
+ iou_scores.append(best_iou)
44
+ return np.mean(iou_scores) if iou_scores else None
45
 
46
+ def predict(image, ground_truths_json=""):
47
+ transform = T.Compose([T.ToTensor()])
48
+ image_tensor = transform(image).unsqueeze(0)
49
 
50
  with torch.no_grad():
51
+ predictions = model(image_tensor)
52
 
53
+ boxes = predictions[0]['boxes'].tolist()
54
+ scores = predictions[0]['scores'].tolist()
 
55
 
56
+ # Draw boxes on image
57
  draw = ImageDraw.Draw(image)
58
+ for box, score in zip(boxes, scores):
59
  if score > 0.5: # Confidence threshold
60
+ draw.rectangle(box, outline="red", width=3)
61
+ draw.text((box[0], box[1]), f"{score:.2f}", fill="red")
 
 
 
 
 
62
 
63
+ # Compute mAP if ground truths are provided
64
+ mAP = None
65
+ if ground_truths_json:
66
+ try:
67
+ ground_truths = json.loads(ground_truths_json)
68
+ mAP = mean_average_precision(boxes, ground_truths, iou_threshold=0.5)
69
+ # Draw ground truth boxes in a different color
70
+ for gt_box in ground_truths:
71
+ draw.rectangle(gt_box, outline="green", width=3)
72
+ draw.text((gt_box[0], gt_box[1]), "GT", fill="green")
73
+ except json.JSONDecodeError:
74
+ print("⚠️ Invalid ground truth format. Expecting JSON array of bounding boxes.")
75
+
76
+ # Filter boxes and scores based on confidence threshold
77
+ filtered_boxes = [box for box, score in zip(boxes, scores) if score > 0.5]
78
+ return image, filtered_boxes, mAP
79
+
80
 
81
  def get_examples():
82
+ # Load examples from JSON file
83
+ with open("datasets/examples.json", "r") as f:
84
+ examples_json = json.load(f)
85
+ examples_with_annotations = examples_json["examples"]
86
+
87
+ return examples_with_annotations
88
 
89
+ # Create Gradio interface
90
  demo = gr.Interface(
91
  fn=predict,
92
+ inputs=[gr.Image(type="pil"), gr.Textbox(placeholder="Enter ground truth bounding boxes as JSON (optional)")],
93
+ outputs=[gr.Image(type="pil", label="Detected LEGO pieces (Red predictions, green ground truth)"), gr.JSON(label="Predicted bounding boxes"), gr.Textbox(label="Mean Average Precision (mAP @ IoU 0.5)")],
94
+ title="LEGO Piece Detector",
95
+ examples=get_examples(),
96
+ description="Upload an image to detect LEGO pieces using Faster R-CNN. Optionally, enter ground truth bounding boxes to compute mAP. If left empty, mAP will be null."
97
  )
98
 
99
+ # Launch Gradio app
100
+ if __name__ == "__main__":
101
+ demo.launch()
src/new_trainer.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torchvision.transforms as T
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader, Dataset, Subset
6
+ import os
7
+ import json
8
+ from PIL import Image
9
+ from tqdm import tqdm # Import tqdm for loading bar
10
+
11
+ # Paths (Modify These)
12
+ DATASET_DIR = "datasets/images" # Folder containing images
13
+ ANNOTATIONS_FILE = "datasets/annotations.json" # Path to COCO JSON
14
+
15
+ # Define Custom COCO Dataset Class (Without pycocotools)
16
+ class CocoDataset(Dataset):
17
+ def __init__(self, root, annotation_file, transforms=None):
18
+ self.root = root
19
+ with open(annotation_file, 'r') as f:
20
+ self.coco_data = json.load(f)
21
+ self.image_data = {img["id"]: img for img in self.coco_data["images"]}
22
+ self.annotations = self.coco_data["annotations"]
23
+ self.transforms = transforms
24
+
25
+ def __len__(self):
26
+ return len(self.image_data)
27
+
28
+ def __getitem__(self, idx):
29
+ try:
30
+ image_info = self.image_data[idx]
31
+ image_path = os.path.join(self.root, image_info["file_name"])
32
+ image = Image.open(image_path).convert("RGB")
33
+ img_width, img_height = image.size # Get image dimensions
34
+
35
+ # Get Annotations
36
+ annotations = [ann for ann in self.annotations if ann["image_id"] == image_info["id"]]
37
+ boxes = []
38
+ labels = []
39
+
40
+ for ann in annotations:
41
+ xmin, ymin, xmax, ymax = ann["bbox"] # Now using [xmin, ymin, xmax, ymax]
42
+ xmin = max(0, xmin)
43
+ ymin = max(0, ymin)
44
+ xmax = min(img_width, xmax)
45
+ ymax = min(img_height, ymax)
46
+
47
+ if xmax > xmin and ymax > ymin:
48
+ boxes.append([xmin, ymin, xmax, ymax])
49
+ labels.append(ann["category_id"])
50
+ else:
51
+ print(f"⚠️ Skipping invalid bbox {ann['bbox']} in image {image_info['file_name']} (image_id: {image_info['id']})")
52
+
53
+ if len(boxes) == 0:
54
+ print(f"⚠️ Skipping entire image {image_info['file_name']} because no valid bounding boxes remain.")
55
+ return None, None
56
+
57
+ # Convert to tensors
58
+ boxes = torch.as_tensor(boxes, dtype=torch.float32)
59
+ labels = torch.as_tensor(labels, dtype=torch.int64)
60
+ target = {"boxes": boxes, "labels": labels}
61
+
62
+ if self.transforms:
63
+ image = self.transforms(image)
64
+
65
+ return image, target
66
+ except Exception as e:
67
+ print(f"⚠️ Skipping image {image_info['file_name']} due to error: {e}")
68
+ return None, None
69
+
70
+ # Define Image Transformations
71
+ transform = T.Compose([T.ToTensor()])
72
+
73
+ # Load Dataset
74
+ full_dataset = CocoDataset(root=DATASET_DIR, annotation_file=ANNOTATIONS_FILE, transforms=transform)
75
+ subset_size = min(10000, len(full_dataset)) # Limit dataset to 10,000 samples or less
76
+ subset_indices = list(range(subset_size))
77
+ dataset = Subset(full_dataset, subset_indices)
78
+
79
+ data_loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*[item for item in x if item[0] is not None])))
80
+
81
+ # Load Faster R-CNN Model
82
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
83
+
84
+ # Freeze Backbone Layers
85
+ for param in model.backbone.parameters():
86
+ param.requires_grad = False
87
+
88
+ # Modify Classifier Head for Custom Classes
89
+ num_classes = 2 # One object class + background
90
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
91
+ model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
92
+
93
+ device = torch.device("cpu")
94
+
95
+ # # Check for MPS Availability
96
+ # if torch.backends.mps.is_available():
97
+ # print("βœ… Using MPS (Apple Metal GPU)")
98
+ # device = torch.device("mps")
99
+ # else:
100
+ # print("⚠️ MPS not available, using CPU")
101
+ # device = torch.device("cpu")
102
+
103
+ model.to(device)
104
+
105
+ # Training Setup
106
+ optimizer = optim.Adam(model.parameters(), lr=0.0001)
107
+ num_epochs = 5
108
+
109
+ # Training Loop
110
+ for epoch in range(num_epochs):
111
+ model.train()
112
+ epoch_loss = 0
113
+
114
+ print(f"Epoch {epoch+1}/{num_epochs}...")
115
+
116
+ for images, targets in tqdm(data_loader, desc=f"Training Epoch {epoch+1}"):
117
+ images = list(img.to(device) for img in images)
118
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
119
+
120
+ if any(len(t["boxes"]) == 0 for t in targets):
121
+ print("⚠️ Skipping batch with no valid bounding boxes")
122
+ continue
123
+
124
+ optimizer.zero_grad()
125
+ loss_dict = model(images, targets)
126
+ loss = sum(loss for loss in loss_dict.values())
127
+ loss.backward()
128
+ optimizer.step()
129
+
130
+ epoch_loss += loss.item()
131
+
132
+ print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
133
+
134
+ # Save Trained Model
135
+ torch.save(model.state_dict(), "faster_rcnn_custom.pth")
136
+ print("Training Complete! Model saved as 'faster_rcnn_custom.pth'")
src/transformdata.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import xml.etree.ElementTree as ET
4
+
5
+ # Paths (Modify These)
6
+ ANNOTATIONS_DIR = "datasets/annotations" # Change to your XML annotations folder
7
+ OUTPUT_JSON = "datasets/annotations.json" # Where to save the COCO JSON
8
+
9
+ # COCO JSON Format
10
+ coco_data = {
11
+ "images": [],
12
+ "annotations": [],
13
+ "categories": [{"id": 1, "name": "object"}] # Only one class
14
+ }
15
+
16
+ annotation_id = 0
17
+
18
+ # Process Each XML File
19
+ for xml_file in os.listdir(ANNOTATIONS_DIR):
20
+ if not xml_file.endswith(".xml"):
21
+ continue
22
+
23
+
24
+ try:
25
+ tree = ET.parse(os.path.join(ANNOTATIONS_DIR, xml_file))
26
+ root = tree.getroot()
27
+ except ET.ParseError:
28
+ print(f"Skipping file due to parsing error: {xml_file}")
29
+ continue
30
+
31
+ # Extract Image Info
32
+ filename = root.find("filename").text
33
+ width = int(root.find("size/width").text)
34
+ height = int(root.find("size/height").text)
35
+ image_id = len(coco_data["images"])
36
+
37
+ coco_data["images"].append({
38
+ "id": image_id,
39
+ "file_name": filename,
40
+ "width": width,
41
+ "height": height
42
+ })
43
+
44
+ # Extract Objects
45
+ for obj in root.findall("object"):
46
+ bbox = obj.find("bndbox")
47
+ xmin = int(bbox.find("xmin").text)
48
+ ymin = int(bbox.find("ymin").text)
49
+ xmax = int(bbox.find("xmax").text)
50
+ ymax = int(bbox.find("ymax").text)
51
+
52
+ # Convert VOC bbox format (xmin, ymin, xmax, ymax) to COCO format (x, y, width, height)
53
+ bbox_coco = [xmin, ymin, xmax, ymax]
54
+
55
+ # Add Annotation
56
+ coco_data["annotations"].append({
57
+ "id": annotation_id,
58
+ "image_id": image_id,
59
+ "category_id": 1, # Only one class
60
+ "bbox": bbox_coco,
61
+ "area": (bbox_coco[2] - bbox_coco[0]) * (bbox_coco[3] - bbox_coco[1]),
62
+ "iscrowd": 0
63
+ })
64
+ annotation_id += 1
65
+
66
+ # Save to JSON File
67
+ with open(OUTPUT_JSON, "w") as f:
68
+ json.dump(coco_data, f, indent=4)
69
+
70
+ print(f"COCO annotations saved to {OUTPUT_JSON}")