|
|
|
|
|
""" |
|
|
Oculus Car Part Detection Demo |
|
|
|
|
|
Demonstrates detection on car images using the extended training model. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import requests |
|
|
from io import BytesIO |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from pathlib import Path |
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
|
|
|
from oculus_unified_model import OculusForConditionalGeneration |
|
|
|
|
|
def visualize_results(image, output, filename="output_car_parts.png"): |
|
|
"""Draw bounding boxes and labels on image.""" |
|
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
|
|
|
try: |
|
|
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
width, height = image.size |
|
|
|
|
|
|
|
|
COCO_CLASSES = [ |
|
|
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', |
|
|
'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', |
|
|
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', |
|
|
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', |
|
|
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', |
|
|
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', |
|
|
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', |
|
|
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', |
|
|
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', |
|
|
'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', |
|
|
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', |
|
|
'toothbrush' |
|
|
] |
|
|
|
|
|
|
|
|
for box, label, conf in zip(output.boxes, output.labels, output.confidences): |
|
|
|
|
|
x1, y1, x2, y2 = box |
|
|
|
|
|
|
|
|
x1 = max(0.0, min(1.0, x1)) |
|
|
y1 = max(0.0, min(1.0, y1)) |
|
|
x2 = max(0.0, min(1.0, x2)) |
|
|
y2 = max(0.0, min(1.0, y2)) |
|
|
|
|
|
|
|
|
if x2 <= x1 or y2 <= y1: |
|
|
continue |
|
|
|
|
|
x1 *= width |
|
|
y1 *= height |
|
|
x2 *= width |
|
|
y2 *= height |
|
|
|
|
|
|
|
|
color = "red" if conf < 0.5 else "green" |
|
|
|
|
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=3) |
|
|
|
|
|
|
|
|
try: |
|
|
class_name = COCO_CLASSES[int(label)] |
|
|
except: |
|
|
class_name = str(label) |
|
|
|
|
|
label_text = f"{class_name} ({conf:.2f})" |
|
|
|
|
|
|
|
|
text_bbox = draw.textbbox((x1, y1), label_text, font=font) |
|
|
draw.rectangle(text_bbox, fill=color) |
|
|
draw.text((x1, y1), label_text, fill="white", font=font) |
|
|
|
|
|
image.save(filename) |
|
|
print(f"Saved visualization to {filename}") |
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser(description="Oculus General Object Detection Demo") |
|
|
parser.add_argument("--image", type=str, help="Path to image file to test") |
|
|
parser.add_argument("--prompt", type=str, default="Detect objects", help="Text prompt for the model") |
|
|
parser.add_argument("--mode", type=str, default="box", choices=["box", "vqa", "caption"], help="Inference mode") |
|
|
parser.add_argument("--threshold", type=float, default=0.2, help="Detection threshold") |
|
|
parser.add_argument("--output", type=str, default="detection_result.png", help="Output filename") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
checkpoint_dir = Path("checkpoints/oculus_detection_v2") |
|
|
model_path = None |
|
|
|
|
|
if checkpoint_dir.exists(): |
|
|
|
|
|
steps = [] |
|
|
for d in checkpoint_dir.iterdir(): |
|
|
if d.is_dir() and d.name.startswith("step_"): |
|
|
try: |
|
|
step = int(d.name.split("_")[1]) |
|
|
steps.append((step, d)) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
if steps: |
|
|
steps.sort(key=lambda x: x[0], reverse=True) |
|
|
model_path = str(steps[0][1]) |
|
|
print(f"✨ Found latest checkpoint: {model_path}") |
|
|
|
|
|
if model_path is None: |
|
|
model_path = str(checkpoint_dir / "final") |
|
|
|
|
|
|
|
|
if not Path(model_path).exists(): |
|
|
model_path = "checkpoints/oculus_detection/final" |
|
|
print(f"⚠️ Extended V2 model not found, falling back to V1: {model_path}") |
|
|
|
|
|
print(f"Loading model from {model_path}...") |
|
|
try: |
|
|
model = OculusForConditionalGeneration.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
heads_path = Path(model_path) / "heads.pth" |
|
|
if heads_path.exists(): |
|
|
heads = torch.load(heads_path, map_location="cpu") |
|
|
model.detection_head.load_state_dict(heads['detection']) |
|
|
print("✓ Loaded detection heads") |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
if args.image: |
|
|
image_path = args.image |
|
|
print(f"\nProcessing Custom Image: {image_path}...") |
|
|
else: |
|
|
|
|
|
|
|
|
image_path = "data/coco/images/000000071345.jpg" |
|
|
print(f"\nProcessing Default Image: {image_path}...") |
|
|
|
|
|
try: |
|
|
if Path(image_path).exists(): |
|
|
image = Image.open(image_path).convert('RGB') |
|
|
else: |
|
|
|
|
|
|
|
|
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/8/8d/President_Barack_Obama.jpg/800px-President_Barack_Obama.jpg" |
|
|
print(f"Image not found, downloading sample {url}...") |
|
|
response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) |
|
|
image = Image.open(BytesIO(response.content)).convert('RGB') |
|
|
|
|
|
|
|
|
if args.mode == "box": |
|
|
print(f"Running detection with prompt: '{args.prompt}'...") |
|
|
output = model.generate( |
|
|
image, |
|
|
mode="box", |
|
|
prompt=args.prompt, |
|
|
threshold=args.threshold |
|
|
) |
|
|
print(f"Found {len(output.boxes)} objects") |
|
|
visualize_results(image, output, args.output) |
|
|
|
|
|
elif args.mode == "caption": |
|
|
print("Generating caption...") |
|
|
output = model.generate(image, mode="text", prompt="A photo of") |
|
|
print(f"\n📝 Caption: {output.text}\n") |
|
|
|
|
|
elif args.mode == "vqa": |
|
|
question = args.prompt if args.prompt != "Detect objects" else "What is in this image?" |
|
|
print(f"Thinking about question: '{question}'...") |
|
|
output = model.generate(image, mode="text", prompt=question) |
|
|
print(f"\n🤔 Answer: {output.text}\n") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing image: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|