|
|
|
|
|
import torch |
|
|
import requests |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
from pathlib import Path |
|
|
from typing import Union, List, Dict, Any |
|
|
import sys |
|
|
|
|
|
|
|
|
OCULUS_ROOT = Path(__file__).parent |
|
|
sys.path.insert(0, str(OCULUS_ROOT)) |
|
|
|
|
|
try: |
|
|
from oculus_unified_model import OculusForConditionalGeneration |
|
|
except ImportError: |
|
|
|
|
|
from Oculus.oculus_unified_model import OculusForConditionalGeneration |
|
|
|
|
|
class OculusPredictor: |
|
|
""" |
|
|
Easy-to-use interface for the Oculus Unified Model. |
|
|
Supports Object Detection, VQA, and Captioning. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path: str = None, device: str = "cpu"): |
|
|
self.device = device |
|
|
|
|
|
|
|
|
if model_path is None: |
|
|
base_dir = OCULUS_ROOT / "checkpoints" / "oculus_detection_v2" |
|
|
if (base_dir / "final").exists(): |
|
|
model_path = str(base_dir / "final") |
|
|
else: |
|
|
|
|
|
model_path = str(OCULUS_ROOT / "checkpoints" / "oculus_detection" / "final") |
|
|
|
|
|
print(f"Loading Oculus model from: {model_path}") |
|
|
self.model = OculusForConditionalGeneration.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
heads_path = Path(model_path) / "heads.pth" |
|
|
if heads_path.exists(): |
|
|
heads = torch.load(heads_path, map_location=device) |
|
|
self.model.detection_head.load_state_dict(heads['detection']) |
|
|
print("✓ Detection heads loaded") |
|
|
|
|
|
|
|
|
instruct_path = OCULUS_ROOT / "checkpoints" / "oculus_instruct_v1" / "vqa_model" |
|
|
if instruct_path.exists(): |
|
|
from transformers import BlipForQuestionAnswering |
|
|
self.model.lm_vqa_model = BlipForQuestionAnswering.from_pretrained(instruct_path) |
|
|
print("✓ Instruction-tuned VQA model loaded") |
|
|
|
|
|
print("✓ Model loaded successfully") |
|
|
|
|
|
def load_image(self, image_source: Union[str, Image.Image]) -> Image.Image: |
|
|
"""Load image from path, URL, or PIL object.""" |
|
|
if isinstance(image_source, Image.Image): |
|
|
return image_source.convert("RGB") |
|
|
|
|
|
if image_source.startswith("http"): |
|
|
response = requests.get(image_source, headers={'User-Agent': 'Mozilla/5.0'}) |
|
|
return Image.open(BytesIO(response.content)).convert("RGB") |
|
|
|
|
|
return Image.open(image_source).convert("RGB") |
|
|
|
|
|
def detect(self, image_source: Union[str, Image.Image], prompt: str = "Detect objects", threshold: float = 0.2) -> Dict[str, Any]: |
|
|
""" |
|
|
Run object detection. |
|
|
Returns: {'boxes': [[x1,y1,x2,y2], ...], 'labels': [...], 'confidences': [...]} |
|
|
""" |
|
|
image = self.load_image(image_source) |
|
|
output = self.model.generate(image, mode="box", prompt=prompt, threshold=threshold) |
|
|
|
|
|
|
|
|
return { |
|
|
'boxes': output.boxes, |
|
|
'labels': output.labels, |
|
|
'confidences': output.confidences, |
|
|
'image_size': image.size |
|
|
} |
|
|
|
|
|
def ask(self, image_source: Union[str, Image.Image], question: str) -> str: |
|
|
"""Ask a question about the image (VQA).""" |
|
|
image = self.load_image(image_source) |
|
|
output = self.model.generate(image, mode="text", prompt=question) |
|
|
return output.text |
|
|
|
|
|
def caption(self, image_source: Union[str, Image.Image]) -> str: |
|
|
"""Generate a caption for the image.""" |
|
|
return self.ask(image_source, "A photo of") |
|
|
|