Spaces:
Sleeping
Sleeping
Alex Hortua
commited on
Commit
Β·
b87aa54
1
Parent(s):
9a6ea32
Creating a faster version with a different approach (Training with a frozen Backbone of COCO images)
Browse files- .gitignore +6 -2
- datasets/examples.json +52 -0
- datasets/test_images/0aa7d4a4-e675-11eb-98bd-b0c090bd3910.jpg +0 -0
- datasets/test_images/0abfb048-e3a1-11eb-9018-b0c090bd3910.jpg +0 -0
- logs/training_log.json +10 -0
- logs/training_log.txt +31 -0
- src/Attempt1/app.py +59 -0
- src/{dataset.py β Attempt1/dataset.py} +0 -0
- src/{evaluate.py β Attempt1/evaluate.py} +0 -0
- src/{train.py β Attempt1/train.py} +0 -0
- src/{utils.py β Attempt1/utils.py} +0 -0
- src/app.py +80 -38
- src/new_trainer.py +136 -0
- src/transformdata.py +70 -0
.gitignore
CHANGED
@@ -2,7 +2,11 @@
|
|
2 |
/Annotations
|
3 |
.qodo
|
4 |
venv/
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
/datasets/annotations/*
|
6 |
/datasets/images/*
|
7 |
-
src/__pycache__/
|
8 |
-
models/lego
|
|
|
2 |
/Annotations
|
3 |
.qodo
|
4 |
venv/
|
5 |
+
src/__pycache__/
|
6 |
+
models/lego
|
7 |
+
models/records/*
|
8 |
+
models/records
|
9 |
+
|
10 |
+
# Datasets
|
11 |
/datasets/annotations/*
|
12 |
/datasets/images/*
|
|
|
|
datasets/examples.json
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"examples": [
|
3 |
+
[
|
4 |
+
"datasets/test_images/0abd88dc-e306-11eb-b5b0-b0c090bd3910.jpg",
|
5 |
+
"[[373, 523, 438, 599], [78, 444, 278, 563], [124, 0, 221, 63], [471, 156, 535, 213]]"
|
6 |
+
],
|
7 |
+
[
|
8 |
+
"datasets/test_images/0abe1e54-e691-11eb-8391-b0c090bd3910.jpg",
|
9 |
+
"[[29, 82, 270, 299], [23, 0, 78, 80]]"
|
10 |
+
],
|
11 |
+
[
|
12 |
+
"datasets/test_images/0abd3f80-daff-11eb-8755-3497f683a169.jpg",
|
13 |
+
"[[0, 77, 172, 317], [523, 202, 582, 277], [410, 112, 600, 544], [136, 39, 187, 110]]"
|
14 |
+
],
|
15 |
+
[
|
16 |
+
"datasets/test_images/0abf3f3e-e4c8-11eb-8f0d-b0c090bd3910.jpg",
|
17 |
+
"[[207, 289, 238, 333], [338, 94, 599, 496], [0, 84, 136, 407], [73, 383, 141, 419]]"
|
18 |
+
],
|
19 |
+
[
|
20 |
+
"datasets/test_images/0abdc2ee-d9cc-11eb-8cf3-3497f683a169.jpg",
|
21 |
+
"[[225, 437, 386, 600], [305, 0, 369, 100], [353, 346, 453, 445], [113, 28, 234, 134]]"
|
22 |
+
],
|
23 |
+
[
|
24 |
+
"datasets/test_images/000abf76-e67a-11eb-b56d-b0c090bd3910.jpg",
|
25 |
+
"[[44, 167, 185, 300], [97, 0, 262, 167]]"
|
26 |
+
],
|
27 |
+
[
|
28 |
+
"datasets/test_images/0abccbc6-e661-11eb-9915-b0c090bd3910.jpg",
|
29 |
+
"[[173, 124, 242, 194], [0, 3, 300, 296]]"
|
30 |
+
],
|
31 |
+
[
|
32 |
+
"datasets/test_images/0abf4764-e480-11eb-b391-b0c090bd3910.jpg",
|
33 |
+
"[[418, 87, 599, 306], [0, 339, 154, 490], [230, 114, 353, 234], [173, 118, 275, 227]]"
|
34 |
+
],
|
35 |
+
[
|
36 |
+
"datasets/test_images/0a82869a-e3c3-11eb-9d75-b0c090bd3910.jpg",
|
37 |
+
"[[387, 0, 536, 119], [378, 509, 463, 600], [94, 193, 288, 368], [74, 301, 237, 486]]"
|
38 |
+
],
|
39 |
+
[
|
40 |
+
"datasets/test_images/0abd5fec-e5c6-11eb-9ac6-b0c090bd3910.jpg",
|
41 |
+
"[[32, 50, 299, 266], [0, 22, 78, 82]]"
|
42 |
+
],
|
43 |
+
[
|
44 |
+
"datasets/test_images/0abfa74a-e2f3-11eb-abe7-b0c090bd3910.jpg",
|
45 |
+
"[[229, 104, 311, 169], [380, 0, 503, 118], [235, 236, 284, 304], [109, 434, 350, 600]]"
|
46 |
+
],
|
47 |
+
[
|
48 |
+
"datasets/test_images/0abfa012-e660-11eb-8710-b0c090bd3910.jpg",
|
49 |
+
"[[0, 43, 299, 254], [50, 58, 215, 238]]"
|
50 |
+
]
|
51 |
+
]
|
52 |
+
}
|
datasets/test_images/0aa7d4a4-e675-11eb-98bd-b0c090bd3910.jpg
DELETED
Binary file (21.7 kB)
|
|
datasets/test_images/0abfb048-e3a1-11eb-9018-b0c090bd3910.jpg
DELETED
Binary file (53 kB)
|
|
logs/training_log.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"loss": [
|
3 |
+
5521765.083993731,
|
4 |
+
813867.9704230821
|
5 |
+
],
|
6 |
+
"mAP": [
|
7 |
+
0.7336118575736044,
|
8 |
+
0.00042658527250095333
|
9 |
+
]
|
10 |
+
}
|
logs/training_log.txt
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Starting Epoch 1/10
|
2 |
+
Starting Epoch 1/10
|
3 |
+
Iteration 1, Loss: 1.19198739528656
|
4 |
+
Iteration 101, Loss: 27.87972640991211
|
5 |
+
Iteration 201, Loss: 7.156171798706055
|
6 |
+
Iteration 301, Loss: 8.546396255493164
|
7 |
+
Iteration 401, Loss: 1.7727022171020508
|
8 |
+
Iteration 501, Loss: 5.378680229187012
|
9 |
+
Iteration 601, Loss: 15.277275085449219
|
10 |
+
Iteration 701, Loss: 4.097675800323486
|
11 |
+
Iteration 801, Loss: 4.272053241729736
|
12 |
+
Iteration 901, Loss: 1.443131446838379
|
13 |
+
Starting Epoch 2/10
|
14 |
+
Iteration 1, Loss: 2.3286213874816895
|
15 |
+
Iteration 101, Loss: 1.8097801208496094
|
16 |
+
Iteration 201, Loss: 1.6668422222137451
|
17 |
+
Iteration 301, Loss: 2.1733906269073486
|
18 |
+
Iteration 401, Loss: 1.8349155187606812
|
19 |
+
Iteration 501, Loss: 0.9883778095245361
|
20 |
+
Iteration 601, Loss: 1.3832241296768188
|
21 |
+
Iteration 701, Loss: 1.6653320789337158
|
22 |
+
Iteration 801, Loss: 1.2079124450683594
|
23 |
+
Iteration 901, Loss: 28428.40625
|
24 |
+
Starting Epoch 3/10
|
25 |
+
Iteration 1, Loss: 1006.0529174804688
|
26 |
+
Iteration 101, Loss: 0.4594302177429199
|
27 |
+
Iteration 201, Loss: 0.4397536516189575
|
28 |
+
Iteration 301, Loss: 0.31954655051231384
|
29 |
+
Iteration 401, Loss: 0.4685922861099243
|
30 |
+
Iteration 501, Loss: 0.31720075011253357
|
31 |
+
Iteration 601, Loss: 0.2652203440666199
|
src/Attempt1/app.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
7 |
+
from torchvision.transforms import functional as F
|
8 |
+
from PIL import Image, ImageDraw
|
9 |
+
|
10 |
+
# Load Trained Model
|
11 |
+
def load_model(model_path):
|
12 |
+
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
|
13 |
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
14 |
+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=2) # Background + 4 LEGO classes
|
15 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
16 |
+
model.eval()
|
17 |
+
return model
|
18 |
+
|
19 |
+
model = load_model("models/lego_fasterrcnn.pth")
|
20 |
+
|
21 |
+
def predict(image):
|
22 |
+
image = Image.fromarray(image).convert("RGB")
|
23 |
+
image_tensor = F.to_tensor(image).unsqueeze(0) # Add batch dimension
|
24 |
+
|
25 |
+
with torch.no_grad():
|
26 |
+
predictions = model(image_tensor)[0]
|
27 |
+
|
28 |
+
boxes = predictions['boxes'].cpu().numpy()
|
29 |
+
labels = predictions['labels'].cpu().numpy()
|
30 |
+
scores = predictions['scores'].cpu().numpy()
|
31 |
+
|
32 |
+
results = []
|
33 |
+
draw = ImageDraw.Draw(image)
|
34 |
+
for box, label, score in zip(boxes, labels, scores):
|
35 |
+
if score > 0.7: # Confidence threshold
|
36 |
+
results.append({
|
37 |
+
"box": box.tolist(),
|
38 |
+
"label": str(label),
|
39 |
+
"score": float(score)
|
40 |
+
})
|
41 |
+
draw.rectangle(box.tolist(), outline="red", width=3)
|
42 |
+
draw.text((box[0], box[1]), f"{label} ({score:.2f})", fill="red")
|
43 |
+
|
44 |
+
return image, results
|
45 |
+
|
46 |
+
def get_examples():
|
47 |
+
return [os.path.join("datasets/test_images", f) for f in os.listdir("datasets/test_images")]
|
48 |
+
|
49 |
+
# Gradio Interface
|
50 |
+
demo = gr.Interface(
|
51 |
+
fn=predict,
|
52 |
+
inputs=gr.Image(type="numpy"),
|
53 |
+
outputs=[gr.Image(type="pil"), gr.JSON()],
|
54 |
+
title="LEGO Detection with Faster R-CNN",
|
55 |
+
description="Upload an image and the model will detect LEGO bricks with bounding boxes.",
|
56 |
+
examples=get_examples()
|
57 |
+
)
|
58 |
+
|
59 |
+
demo.launch()
|
src/{dataset.py β Attempt1/dataset.py}
RENAMED
File without changes
|
src/{evaluate.py β Attempt1/evaluate.py}
RENAMED
File without changes
|
src/{train.py β Attempt1/train.py}
RENAMED
File without changes
|
src/{utils.py β Attempt1/utils.py}
RENAMED
File without changes
|
src/app.py
CHANGED
@@ -1,59 +1,101 @@
|
|
1 |
-
import os
|
2 |
import torch
|
3 |
import torchvision
|
|
|
4 |
import gradio as gr
|
5 |
-
import numpy as np
|
6 |
-
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
7 |
-
from torchvision.transforms import functional as F
|
8 |
from PIL import Image, ImageDraw
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
# Load
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
model.eval()
|
17 |
-
return model
|
18 |
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
def predict(image):
|
22 |
-
|
23 |
-
image_tensor =
|
24 |
|
25 |
with torch.no_grad():
|
26 |
-
predictions = model(image_tensor)
|
27 |
|
28 |
-
boxes = predictions['boxes'].
|
29 |
-
|
30 |
-
scores = predictions['scores'].cpu().numpy()
|
31 |
|
32 |
-
|
33 |
draw = ImageDraw.Draw(image)
|
34 |
-
for box,
|
35 |
if score > 0.5: # Confidence threshold
|
36 |
-
|
37 |
-
|
38 |
-
"label": str(label),
|
39 |
-
"score": float(score)
|
40 |
-
})
|
41 |
-
draw.rectangle(box.tolist(), outline="red", width=3)
|
42 |
-
draw.text((box[0], box[1]), f"{label} ({score:.2f})", fill="red")
|
43 |
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
def get_examples():
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
# Gradio
|
50 |
demo = gr.Interface(
|
51 |
fn=predict,
|
52 |
-
inputs=gr.Image(type="
|
53 |
-
outputs=[gr.Image(type="pil"), gr.JSON()],
|
54 |
-
title="LEGO
|
55 |
-
|
56 |
-
|
57 |
)
|
58 |
|
59 |
-
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torchvision
|
3 |
+
import torchvision.transforms as T
|
4 |
import gradio as gr
|
|
|
|
|
|
|
5 |
from PIL import Image, ImageDraw
|
6 |
+
import torchvision.ops as ops
|
7 |
+
import numpy as np
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
|
11 |
+
|
12 |
+
# LOAD_MODEL_PATH = "models/lego_fasterrcnn.pth"
|
13 |
+
LOAD_MODEL_PATH = "models/faster_rcnn_custom.pth"
|
14 |
|
15 |
+
# Load trained Faster R-CNN model
|
16 |
+
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
|
17 |
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
18 |
+
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes=2)
|
19 |
+
model.load_state_dict(torch.load(LOAD_MODEL_PATH, map_location=torch.device("cpu")))
|
20 |
+
model.eval()
|
|
|
|
|
21 |
|
22 |
+
def compute_iou(box1, box2):
|
23 |
+
x1 = max(box1[0], box2[0])
|
24 |
+
y1 = max(box1[1], box2[1])
|
25 |
+
x2 = min(box1[2], box2[2])
|
26 |
+
y2 = min(box1[3], box2[3])
|
27 |
+
|
28 |
+
intersection = max(0, x2 - x1) * max(0, y2 - y1)
|
29 |
+
area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
30 |
+
area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
31 |
+
|
32 |
+
union = area_box1 + area_box2 - intersection
|
33 |
+
return intersection / union if union > 0 else 0
|
34 |
+
|
35 |
+
def mean_average_precision(predictions, ground_truths, iou_threshold=0.5):
|
36 |
+
iou_scores = []
|
37 |
+
for pred_box in predictions:
|
38 |
+
best_iou = 0
|
39 |
+
for gt_box in ground_truths:
|
40 |
+
iou = compute_iou(pred_box, gt_box)
|
41 |
+
best_iou = max(best_iou, iou)
|
42 |
+
if best_iou >= iou_threshold:
|
43 |
+
iou_scores.append(best_iou)
|
44 |
+
return np.mean(iou_scores) if iou_scores else None
|
45 |
|
46 |
+
def predict(image, ground_truths_json=""):
|
47 |
+
transform = T.Compose([T.ToTensor()])
|
48 |
+
image_tensor = transform(image).unsqueeze(0)
|
49 |
|
50 |
with torch.no_grad():
|
51 |
+
predictions = model(image_tensor)
|
52 |
|
53 |
+
boxes = predictions[0]['boxes'].tolist()
|
54 |
+
scores = predictions[0]['scores'].tolist()
|
|
|
55 |
|
56 |
+
# Draw boxes on image
|
57 |
draw = ImageDraw.Draw(image)
|
58 |
+
for box, score in zip(boxes, scores):
|
59 |
if score > 0.5: # Confidence threshold
|
60 |
+
draw.rectangle(box, outline="red", width=3)
|
61 |
+
draw.text((box[0], box[1]), f"{score:.2f}", fill="red")
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
# Compute mAP if ground truths are provided
|
64 |
+
mAP = None
|
65 |
+
if ground_truths_json:
|
66 |
+
try:
|
67 |
+
ground_truths = json.loads(ground_truths_json)
|
68 |
+
mAP = mean_average_precision(boxes, ground_truths, iou_threshold=0.5)
|
69 |
+
# Draw ground truth boxes in a different color
|
70 |
+
for gt_box in ground_truths:
|
71 |
+
draw.rectangle(gt_box, outline="green", width=3)
|
72 |
+
draw.text((gt_box[0], gt_box[1]), "GT", fill="green")
|
73 |
+
except json.JSONDecodeError:
|
74 |
+
print("β οΈ Invalid ground truth format. Expecting JSON array of bounding boxes.")
|
75 |
+
|
76 |
+
# Filter boxes and scores based on confidence threshold
|
77 |
+
filtered_boxes = [box for box, score in zip(boxes, scores) if score > 0.5]
|
78 |
+
return image, filtered_boxes, mAP
|
79 |
+
|
80 |
|
81 |
def get_examples():
|
82 |
+
# Load examples from JSON file
|
83 |
+
with open("datasets/examples.json", "r") as f:
|
84 |
+
examples_json = json.load(f)
|
85 |
+
examples_with_annotations = examples_json["examples"]
|
86 |
+
|
87 |
+
return examples_with_annotations
|
88 |
|
89 |
+
# Create Gradio interface
|
90 |
demo = gr.Interface(
|
91 |
fn=predict,
|
92 |
+
inputs=[gr.Image(type="pil"), gr.Textbox(placeholder="Enter ground truth bounding boxes as JSON (optional)")],
|
93 |
+
outputs=[gr.Image(type="pil", label="Detected LEGO pieces (Red predictions, green ground truth)"), gr.JSON(label="Predicted bounding boxes"), gr.Textbox(label="Mean Average Precision (mAP @ IoU 0.5)")],
|
94 |
+
title="LEGO Piece Detector",
|
95 |
+
examples=get_examples(),
|
96 |
+
description="Upload an image to detect LEGO pieces using Faster R-CNN. Optionally, enter ground truth bounding boxes to compute mAP. If left empty, mAP will be null."
|
97 |
)
|
98 |
|
99 |
+
# Launch Gradio app
|
100 |
+
if __name__ == "__main__":
|
101 |
+
demo.launch()
|
src/new_trainer.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import torchvision.transforms as T
|
4 |
+
import torch.optim as optim
|
5 |
+
from torch.utils.data import DataLoader, Dataset, Subset
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm # Import tqdm for loading bar
|
10 |
+
|
11 |
+
# Paths (Modify These)
|
12 |
+
DATASET_DIR = "datasets/images" # Folder containing images
|
13 |
+
ANNOTATIONS_FILE = "datasets/annotations.json" # Path to COCO JSON
|
14 |
+
|
15 |
+
# Define Custom COCO Dataset Class (Without pycocotools)
|
16 |
+
class CocoDataset(Dataset):
|
17 |
+
def __init__(self, root, annotation_file, transforms=None):
|
18 |
+
self.root = root
|
19 |
+
with open(annotation_file, 'r') as f:
|
20 |
+
self.coco_data = json.load(f)
|
21 |
+
self.image_data = {img["id"]: img for img in self.coco_data["images"]}
|
22 |
+
self.annotations = self.coco_data["annotations"]
|
23 |
+
self.transforms = transforms
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return len(self.image_data)
|
27 |
+
|
28 |
+
def __getitem__(self, idx):
|
29 |
+
try:
|
30 |
+
image_info = self.image_data[idx]
|
31 |
+
image_path = os.path.join(self.root, image_info["file_name"])
|
32 |
+
image = Image.open(image_path).convert("RGB")
|
33 |
+
img_width, img_height = image.size # Get image dimensions
|
34 |
+
|
35 |
+
# Get Annotations
|
36 |
+
annotations = [ann for ann in self.annotations if ann["image_id"] == image_info["id"]]
|
37 |
+
boxes = []
|
38 |
+
labels = []
|
39 |
+
|
40 |
+
for ann in annotations:
|
41 |
+
xmin, ymin, xmax, ymax = ann["bbox"] # Now using [xmin, ymin, xmax, ymax]
|
42 |
+
xmin = max(0, xmin)
|
43 |
+
ymin = max(0, ymin)
|
44 |
+
xmax = min(img_width, xmax)
|
45 |
+
ymax = min(img_height, ymax)
|
46 |
+
|
47 |
+
if xmax > xmin and ymax > ymin:
|
48 |
+
boxes.append([xmin, ymin, xmax, ymax])
|
49 |
+
labels.append(ann["category_id"])
|
50 |
+
else:
|
51 |
+
print(f"β οΈ Skipping invalid bbox {ann['bbox']} in image {image_info['file_name']} (image_id: {image_info['id']})")
|
52 |
+
|
53 |
+
if len(boxes) == 0:
|
54 |
+
print(f"β οΈ Skipping entire image {image_info['file_name']} because no valid bounding boxes remain.")
|
55 |
+
return None, None
|
56 |
+
|
57 |
+
# Convert to tensors
|
58 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32)
|
59 |
+
labels = torch.as_tensor(labels, dtype=torch.int64)
|
60 |
+
target = {"boxes": boxes, "labels": labels}
|
61 |
+
|
62 |
+
if self.transforms:
|
63 |
+
image = self.transforms(image)
|
64 |
+
|
65 |
+
return image, target
|
66 |
+
except Exception as e:
|
67 |
+
print(f"β οΈ Skipping image {image_info['file_name']} due to error: {e}")
|
68 |
+
return None, None
|
69 |
+
|
70 |
+
# Define Image Transformations
|
71 |
+
transform = T.Compose([T.ToTensor()])
|
72 |
+
|
73 |
+
# Load Dataset
|
74 |
+
full_dataset = CocoDataset(root=DATASET_DIR, annotation_file=ANNOTATIONS_FILE, transforms=transform)
|
75 |
+
subset_size = min(10000, len(full_dataset)) # Limit dataset to 10,000 samples or less
|
76 |
+
subset_indices = list(range(subset_size))
|
77 |
+
dataset = Subset(full_dataset, subset_indices)
|
78 |
+
|
79 |
+
data_loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*[item for item in x if item[0] is not None])))
|
80 |
+
|
81 |
+
# Load Faster R-CNN Model
|
82 |
+
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
|
83 |
+
|
84 |
+
# Freeze Backbone Layers
|
85 |
+
for param in model.backbone.parameters():
|
86 |
+
param.requires_grad = False
|
87 |
+
|
88 |
+
# Modify Classifier Head for Custom Classes
|
89 |
+
num_classes = 2 # One object class + background
|
90 |
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
91 |
+
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
|
92 |
+
|
93 |
+
device = torch.device("cpu")
|
94 |
+
|
95 |
+
# # Check for MPS Availability
|
96 |
+
# if torch.backends.mps.is_available():
|
97 |
+
# print("β
Using MPS (Apple Metal GPU)")
|
98 |
+
# device = torch.device("mps")
|
99 |
+
# else:
|
100 |
+
# print("β οΈ MPS not available, using CPU")
|
101 |
+
# device = torch.device("cpu")
|
102 |
+
|
103 |
+
model.to(device)
|
104 |
+
|
105 |
+
# Training Setup
|
106 |
+
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
107 |
+
num_epochs = 5
|
108 |
+
|
109 |
+
# Training Loop
|
110 |
+
for epoch in range(num_epochs):
|
111 |
+
model.train()
|
112 |
+
epoch_loss = 0
|
113 |
+
|
114 |
+
print(f"Epoch {epoch+1}/{num_epochs}...")
|
115 |
+
|
116 |
+
for images, targets in tqdm(data_loader, desc=f"Training Epoch {epoch+1}"):
|
117 |
+
images = list(img.to(device) for img in images)
|
118 |
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
119 |
+
|
120 |
+
if any(len(t["boxes"]) == 0 for t in targets):
|
121 |
+
print("β οΈ Skipping batch with no valid bounding boxes")
|
122 |
+
continue
|
123 |
+
|
124 |
+
optimizer.zero_grad()
|
125 |
+
loss_dict = model(images, targets)
|
126 |
+
loss = sum(loss for loss in loss_dict.values())
|
127 |
+
loss.backward()
|
128 |
+
optimizer.step()
|
129 |
+
|
130 |
+
epoch_loss += loss.item()
|
131 |
+
|
132 |
+
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
|
133 |
+
|
134 |
+
# Save Trained Model
|
135 |
+
torch.save(model.state_dict(), "faster_rcnn_custom.pth")
|
136 |
+
print("Training Complete! Model saved as 'faster_rcnn_custom.pth'")
|
src/transformdata.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import xml.etree.ElementTree as ET
|
4 |
+
|
5 |
+
# Paths (Modify These)
|
6 |
+
ANNOTATIONS_DIR = "datasets/annotations" # Change to your XML annotations folder
|
7 |
+
OUTPUT_JSON = "datasets/annotations.json" # Where to save the COCO JSON
|
8 |
+
|
9 |
+
# COCO JSON Format
|
10 |
+
coco_data = {
|
11 |
+
"images": [],
|
12 |
+
"annotations": [],
|
13 |
+
"categories": [{"id": 1, "name": "object"}] # Only one class
|
14 |
+
}
|
15 |
+
|
16 |
+
annotation_id = 0
|
17 |
+
|
18 |
+
# Process Each XML File
|
19 |
+
for xml_file in os.listdir(ANNOTATIONS_DIR):
|
20 |
+
if not xml_file.endswith(".xml"):
|
21 |
+
continue
|
22 |
+
|
23 |
+
|
24 |
+
try:
|
25 |
+
tree = ET.parse(os.path.join(ANNOTATIONS_DIR, xml_file))
|
26 |
+
root = tree.getroot()
|
27 |
+
except ET.ParseError:
|
28 |
+
print(f"Skipping file due to parsing error: {xml_file}")
|
29 |
+
continue
|
30 |
+
|
31 |
+
# Extract Image Info
|
32 |
+
filename = root.find("filename").text
|
33 |
+
width = int(root.find("size/width").text)
|
34 |
+
height = int(root.find("size/height").text)
|
35 |
+
image_id = len(coco_data["images"])
|
36 |
+
|
37 |
+
coco_data["images"].append({
|
38 |
+
"id": image_id,
|
39 |
+
"file_name": filename,
|
40 |
+
"width": width,
|
41 |
+
"height": height
|
42 |
+
})
|
43 |
+
|
44 |
+
# Extract Objects
|
45 |
+
for obj in root.findall("object"):
|
46 |
+
bbox = obj.find("bndbox")
|
47 |
+
xmin = int(bbox.find("xmin").text)
|
48 |
+
ymin = int(bbox.find("ymin").text)
|
49 |
+
xmax = int(bbox.find("xmax").text)
|
50 |
+
ymax = int(bbox.find("ymax").text)
|
51 |
+
|
52 |
+
# Convert VOC bbox format (xmin, ymin, xmax, ymax) to COCO format (x, y, width, height)
|
53 |
+
bbox_coco = [xmin, ymin, xmax, ymax]
|
54 |
+
|
55 |
+
# Add Annotation
|
56 |
+
coco_data["annotations"].append({
|
57 |
+
"id": annotation_id,
|
58 |
+
"image_id": image_id,
|
59 |
+
"category_id": 1, # Only one class
|
60 |
+
"bbox": bbox_coco,
|
61 |
+
"area": (bbox_coco[2] - bbox_coco[0]) * (bbox_coco[3] - bbox_coco[1]),
|
62 |
+
"iscrowd": 0
|
63 |
+
})
|
64 |
+
annotation_id += 1
|
65 |
+
|
66 |
+
# Save to JSON File
|
67 |
+
with open(OUTPUT_JSON, "w") as f:
|
68 |
+
json.dump(coco_data, f, indent=4)
|
69 |
+
|
70 |
+
print(f"COCO annotations saved to {OUTPUT_JSON}")
|