|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import cv2 |
|
import torch |
|
import numpy as np |
|
from ultralytics import YOLO |
|
import wandb |
|
import matplotlib.pyplot as plt |
|
from datetime import datetime |
|
from google.colab import userdata |
|
|
|
|
|
wandb.login(key=userdata.get('WANDB')) |
|
|
|
def setup_wandb(): |
|
wandb.init(project="Object-detection", |
|
name=f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}", |
|
config={ |
|
"model": "yolov8n", |
|
"dataset": "coco128", |
|
"img_size": 640, |
|
"batch_size": 8 |
|
}) |
|
|
|
def load_model(): |
|
model = YOLO("yolov8n.pt") |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
model.to(device) |
|
return model |
|
|
|
def train_model(model): |
|
results = model.train( |
|
data="coco128.yaml", |
|
epochs=20, |
|
imgsz=640, |
|
batch=8, |
|
device='0' if torch.cuda.is_available() else 'cpu', |
|
patience=3, |
|
save=True |
|
) |
|
return model |
|
|
|
def validate_model(model): |
|
metrics = model.val() |
|
wandb.log({ |
|
"val/mAP50": metrics.box.map50, |
|
"val/mAP50-95": metrics.box.map, |
|
"val/precision": metrics.box.mp, |
|
"val/recall": metrics.box.mr |
|
}) |
|
return metrics |
|
|
|
def visualize_results(results, img_path): |
|
img = cv2.imread(img_path) |
|
if img is None: |
|
raise ValueError(f"Failed to load image: {img_path}") |
|
pred_img = results[0].plot() |
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) |
|
ax1.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
|
ax1.axis('off') |
|
ax2.imshow(cv2.cvtColor(pred_img, cv2.COLOR_BGR2RGB)) |
|
ax2.axis('off') |
|
plt.savefig("detection_results.jpg") |
|
plt.close() |
|
return "detection_results.jpg" |
|
|
|
def test_image(model, img_path="test_image.jpg"): |
|
if not os.path.exists(img_path): |
|
raise FileNotFoundError(f"Image not found: {img_path}") |
|
results = model(img_path) |
|
output_path = visualize_results(results, img_path) |
|
wandb.log({ |
|
"test_results": wandb.Image(output_path), |
|
"detections": results[0].boxes.cls.tolist(), |
|
"confidences": results[0].boxes.conf.tolist() |
|
}) |
|
return results |
|
|
|
def webcam_demo(model): |
|
try: |
|
from google.colab.patches import cv2_imshow |
|
cap = cv2.VideoCapture(0) |
|
if not cap.isOpened(): |
|
print("Webcam not available - skipping demo") |
|
return |
|
print("Press 'q' to quit webcam demo") |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
results = model(frame) |
|
annotated = results[0].plot() |
|
cv2_imshow(annotated) |
|
if cv2.waitKey(1) & 0xFF == ord('q'): |
|
break |
|
except Exception as e: |
|
print(f"Webcam error: {e}") |
|
finally: |
|
cap.release() |
|
cv2.destroyAllWindows() |
|
|
|
def export_model(): |
|
trained_weights = "runs/detect/train/weights/best.pt" |
|
model = YOLO(trained_weights) |
|
model.export(format="torchscript") |
|
wandb.save("best.torchscript") |
|
|
|
def main(): |
|
setup_wandb() |
|
model = load_model() |
|
model = train_model(model) |
|
validate_model(model) |
|
test_image(model) |
|
export_model() |
|
wandb.finish() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
from google.colab import files |
|
|
|
files.download("runs/detect/train/weights/best.torchscript") |
|
|
|
|