Oculus / oculus_unified_model /modeling_oculus.py
kobiakor15's picture
Upload oculus_unified_model/modeling_oculus.py with huggingface_hub
4b92f99 verified
"""
Oculus Unified Model
HuggingFace-compatible vision-language model with:
- Multi-encoder vision (DINOv3 + SigLIP2)
- Trained projector for vision-to-language
- Optional reasoning with thinking traces
- Multiple output modes (Text, Point, Box, Polygon)
- Focus/Zoom tool calling for fine-grained perception
"""
import os
import json
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, List, Dict, Any, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PreTrainedModel,
PretrainedConfig,
AutoImageProcessor,
AutoModel,
AutoTokenizer,
AutoModelForCausalLM,
GenerationConfig,
)
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
from PIL import Image
from .configuration_oculus import OculusConfig
# ============================================================================
# Output Data Classes
# ============================================================================
@dataclass
class OculusOutput:
"""Base output class for Oculus model."""
text: Optional[str] = None
thinking_trace: Optional[str] = None
logits: Optional[torch.Tensor] = None
hidden_states: Optional[torch.Tensor] = None
vision_tokens: Optional[torch.Tensor] = None
@dataclass
class OculusTextOutput(OculusOutput):
"""Output for text/caption mode."""
pass
@dataclass
class OculusPointOutput(OculusOutput):
"""Output for point detection mode (counting objects)."""
points: Optional[List[Tuple[float, float]]] = None
labels: Optional[List[str]] = None
confidences: Optional[List[float]] = None
@dataclass
class OculusBoxOutput(OculusOutput):
"""Output for bounding box detection mode."""
boxes: Optional[List[Tuple[float, float, float, float]]] = None # x1, y1, x2, y2
labels: Optional[List[str]] = None
confidences: Optional[List[float]] = None
@dataclass
class OculusPolygonOutput(OculusOutput):
"""Output for polygon/segmentation mode."""
polygons: Optional[List[List[Tuple[float, float]]]] = None
labels: Optional[List[str]] = None
mask: Optional[np.ndarray] = None
# ============================================================================
# Vision Encoder (DINOv3 + SigLIP2)
# ============================================================================
class OculusVisionEncoder(nn.Module):
"""
Dual vision encoder combining DINOv3 and SigLIP2.
DINOv3: Excellent at semantic understanding, object boundaries
SigLIP2: Strong at text/language alignment
"""
def __init__(self, config: OculusConfig):
super().__init__()
self.config = config
# Will be loaded lazily
self.dinov3 = None
self.dinov3_processor = None
self.siglip = None
self.siglip_processor = None
self._loaded = False
def load_encoders(self, device: str = "cpu"):
"""Load vision encoders from HuggingFace."""
if self._loaded:
return
print("[Oculus] Loading vision encoders...")
# DINOv3
try:
self.dinov3_processor = AutoImageProcessor.from_pretrained(
self.config.dinov3_model_id
)
self.dinov3 = AutoModel.from_pretrained(
self.config.dinov3_model_id
).eval().to(device)
print(f" ✓ DINOv3: {self.config.dinov3_model_id}")
except Exception as e:
warnings.warn(f"Failed to load DINOv3: {e}")
self.dinov3_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
self.dinov3 = AutoModel.from_pretrained("facebook/dinov2-base").eval().to(device)
print(" ✓ DINOv2-base (fallback)")
# SigLIP2
try:
self.siglip_processor = AutoImageProcessor.from_pretrained(
self.config.siglip_model_id
)
self.siglip = AutoModel.from_pretrained(
self.config.siglip_model_id
).eval().to(device)
print(f" ✓ SigLIP: {self.config.siglip_model_id}")
except Exception as e:
warnings.warn(f"Failed to load SigLIP: {e}")
from transformers import SiglipVisionModel
self.siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
self.siglip = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224").eval().to(device)
print(" ✓ SigLIP-base (fallback)")
self._loaded = True
@torch.no_grad()
def forward(self, image: Union[Image.Image, torch.Tensor, np.ndarray]) -> torch.Tensor:
"""
Encode image with both vision encoders and fuse features.
Returns:
Fused vision features [batch, fused_dim]
"""
if not self._loaded:
self.load_encoders()
# Handle different input types
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, torch.Tensor):
image = Image.fromarray(image.cpu().numpy().astype(np.uint8))
if isinstance(image, Image.Image):
image = image.convert('RGB')
device = next(self.dinov3.parameters()).device
# DINOv3 encoding
d_inputs = self.dinov3_processor(images=image, return_tensors="pt")
d_inputs = {k: v.to(device) for k, v in d_inputs.items()}
d_out = self.dinov3(**d_inputs)
d_pooled = d_out.pooler_output if hasattr(d_out, 'pooler_output') and d_out.pooler_output is not None else d_out.last_hidden_state[:, 0]
# SigLIP encoding
s_inputs = self.siglip_processor(images=image, return_tensors="pt")
s_inputs = {k: v.to(device) for k, v in s_inputs.items()}
if hasattr(self.siglip, 'vision_model'):
s_hidden = self.siglip.vision_model.embeddings(s_inputs['pixel_values'])
s_pooled = s_hidden.mean(dim=1)
else:
s_out = self.siglip(**s_inputs)
s_pooled = s_out.pooler_output if hasattr(s_out, 'pooler_output') else s_out.last_hidden_state[:, 0]
# Fuse features
fused = torch.cat([d_pooled, s_pooled], dim=-1)
return fused
# ============================================================================
# Vision Projector
# ============================================================================
class OculusProjector(nn.Module):
"""
Projects fused vision features to language model token space.
Converts [batch, fused_dim] → [batch, num_tokens, lm_hidden_size]
"""
def __init__(self, config: OculusConfig):
super().__init__()
self.config = config
fused_dim = config.fused_vision_dim
hidden_dim = config.projector_hidden_dim
num_tokens = config.num_vision_tokens
embed_dim = config.lm_hidden_size
self.fc1 = nn.Linear(fused_dim, hidden_dim)
self.act1 = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.act2 = nn.GELU()
self.fc3 = nn.Linear(hidden_dim, num_tokens * embed_dim)
self.norm = nn.LayerNorm(embed_dim)
self.num_tokens = num_tokens
self.embed_dim = embed_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Project vision features to token embeddings.
Args:
x: Vision features [batch, fused_dim]
Returns:
Vision tokens [batch, num_tokens, embed_dim]
"""
batch_size = x.shape[0]
h = self.fc1(x)
h = self.act1(h)
h = self.fc2(h)
h = self.act2(h)
h = self.fc3(h)
h = h.reshape(batch_size, self.num_tokens, self.embed_dim)
h = self.norm(h)
return h
@classmethod
def from_pretrained(cls, path: str, config: OculusConfig):
"""Load projector from saved weights."""
projector = cls(config)
weights_path = Path(path) / "projector.npz"
if weights_path.exists():
import numpy as np
weights = np.load(weights_path, allow_pickle=True)
state_dict = {}
for key in weights.files:
layer_dict = weights[key].item()
for param_name, param_val in layer_dict.items():
full_key = f"{key}.{param_name}"
# Convert from MLX array if needed
if hasattr(param_val, 'tolist'):
param_val = np.array(param_val.tolist())
state_dict[full_key] = torch.from_numpy(np.array(param_val))
projector.load_state_dict(state_dict, strict=False)
print(f" ✓ Loaded projector from {path}")
return projector
# ============================================================================
# Detection/Segmentation Heads
# ============================================================================
class OculusDetectionHead(nn.Module):
"""Head for bounding box detection."""
def __init__(self, config: OculusConfig):
super().__init__()
hidden_dim = config.lm_hidden_size
num_classes = config.num_detection_classes
self.cls_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, num_classes)
)
self.box_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 4) # x1, y1, x2, y2
)
def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict boxes and classes from vision tokens.
Returns:
cls_logits: [batch, num_tokens, num_classes]
box_coords: [batch, num_tokens, 4]
"""
cls_logits = self.cls_head(vision_tokens)
box_coords = self.box_head(vision_tokens).sigmoid() # Normalize to [0, 1]
return cls_logits, box_coords
class OculusPointHead(nn.Module):
"""Head for point detection (object counting)."""
def __init__(self, config: OculusConfig):
super().__init__()
hidden_dim = config.lm_hidden_size
num_classes = config.num_detection_classes
self.point_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 2) # x, y
)
self.cls_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, num_classes)
)
self.conf_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 4),
nn.GELU(),
nn.Linear(hidden_dim // 4, 1)
)
def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
points = self.point_head(vision_tokens).sigmoid()
cls_logits = self.cls_head(vision_tokens)
confidence = self.conf_head(vision_tokens).sigmoid()
return points, cls_logits, confidence
class OculusSegmentationHead(nn.Module):
"""Head for polygon/mask segmentation."""
def __init__(self, config: OculusConfig):
super().__init__()
hidden_dim = config.lm_hidden_size
num_classes = config.num_segmentation_classes
# Predict mask logits
self.mask_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 14 * 14 * num_classes) # Output spatial mask
)
self.num_classes = num_classes
def forward(self, vision_tokens: torch.Tensor) -> torch.Tensor:
batch_size = vision_tokens.shape[0]
pooled = vision_tokens.mean(dim=1)
mask_logits = self.mask_head(pooled)
mask_logits = mask_logits.reshape(batch_size, self.num_classes, 14, 14)
return mask_logits
# ============================================================================
# Main Model
# ============================================================================
class OculusForConditionalGeneration(PreTrainedModel):
"""
Oculus: Unified Vision-Language Model
Features:
- Multi-encoder vision (DINOv3 + SigLIP2)
- Optional reasoning with thinking traces
- Multiple output modes: Text, Point, Box, Polygon
- Focus/Zoom tool calling for fine-grained perception
Usage:
```python
from oculus_unified_model import OculusForConditionalGeneration
model = OculusForConditionalGeneration.from_pretrained("OceanirAI/oculus-0.2")
# Caption mode
output = model.generate(image, mode="text", prompt="Describe this image")
# VQA mode
output = model.generate(image, mode="text", prompt="What color is the cat?")
# With reasoning
output = model.generate(image, mode="text", prompt="Count the people", think=True)
# Detection mode
output = model.generate(image, mode="box", prompt="Find all cars")
# Point mode (counting)
output = model.generate(image, mode="point", prompt="Count the birds")
# Segmentation mode
output = model.generate(image, mode="polygon", prompt="Segment the road")
```
"""
config_class = OculusConfig
base_model_prefix = "oculus"
def __init__(self, config: OculusConfig):
super().__init__(config)
self.config = config
# Vision encoder
self.vision_encoder = OculusVisionEncoder(config)
# Vision adapter (handles dimension mismatch if needed)
self.vision_adapter = None
self._actual_vision_dim = None
# Projector
self.projector = OculusProjector(config)
# Task-specific heads
self.detection_head = OculusDetectionHead(config)
self.point_head = OculusPointHead(config)
self.segmentation_head = OculusSegmentationHead(config)
# Language model (loaded lazily)
self.lm_tokenizer = None
self.lm_model = None
self._lm_loaded = False
# Special tokens for reasoning
self.thinking_token = config.thinking_token
self.thinking_end_token = config.thinking_end_token
self.focus_token = config.focus_token
self.focus_end_token = config.focus_end_token
def load_language_model(self, device: str = "cpu"):
"""Load language model for text generation."""
if self._lm_loaded:
return
print("[Oculus] Loading language model...")
try:
# Try BLIP for now (works well for captioning/VQA)
from transformers import BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
self.lm_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
self.lm_caption_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base"
).to(device)
self.lm_vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
self.lm_vqa_model = BlipForQuestionAnswering.from_pretrained(
"Salesforce/blip-vqa-base"
).to(device)
print(" ✓ BLIP (captioning + VQA)")
self._lm_loaded = True
except Exception as e:
warnings.warn(f"Failed to load language model: {e}")
def encode_image(self, image: Union[Image.Image, str, np.ndarray]) -> torch.Tensor:
"""
Encode image to vision tokens.
Args:
image: PIL Image, file path, or numpy array
Returns:
Vision tokens [1, num_tokens, embed_dim]
"""
# Load image if path
if isinstance(image, str):
image = Image.open(image)
# Encode with vision encoders
vision_features = self.vision_encoder(image)
# Check if we need an adapter for dimension mismatch
actual_dim = vision_features.shape[-1]
expected_dim = self.config.fused_vision_dim
if actual_dim != expected_dim:
if self.vision_adapter is None or self._actual_vision_dim != actual_dim:
# Create adapter layer
print(f" [Adapter] Creating vision adapter: {actual_dim} -> {expected_dim}")
self.vision_adapter = nn.Linear(actual_dim, expected_dim)
self._actual_vision_dim = actual_dim
# Initialize with small weights
nn.init.xavier_uniform_(self.vision_adapter.weight)
nn.init.zeros_(self.vision_adapter.bias)
vision_features = self.vision_adapter(vision_features)
# Project to language space
vision_tokens = self.projector(vision_features)
return vision_tokens
def _generate_thinking_trace(
self,
image: Image.Image,
prompt: str,
max_tokens: int = 256
) -> str:
"""
Generate a thinking/reasoning trace before answering.
This enables multi-step reasoning for complex tasks.
"""
thinking_prompt = f"""Let me think about this step by step:
1. First, I'll analyze what I see in the image.
2. Then, I'll consider the question: "{prompt}"
3. Finally, I'll formulate my answer.
Observation: """
# Generate reasoning (simplified for now)
if self._lm_loaded and hasattr(self, 'lm_caption_model'):
inputs = self.lm_processor(image, thinking_prompt, return_tensors="pt")
inputs = {k: v.to(self.lm_caption_model.device) for k, v in inputs.items()}
with torch.no_grad():
out = self.lm_caption_model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.7
)
thinking = self.lm_processor.decode(out[0], skip_special_tokens=True)
else:
thinking = "I observe the image and analyze its contents."
return thinking
def _detect_focus_regions(
self,
image: Image.Image,
prompt: str
) -> List[Tuple[int, int, int, int]]:
"""
Detect regions that need closer inspection (Focus/Zoom system).
Returns list of (x1, y1, x2, y2) crop regions.
"""
# Simplified: return full image as single region
# In full implementation, would use attention maps to find regions of interest
w, h = image.size
return [(0, 0, w, h)]
def generate(
self,
image: Union[Image.Image, str, np.ndarray],
prompt: str = "Describe this image",
mode: str = "text",
think: bool = False,
focus: bool = False,
max_new_tokens: Optional[int] = None,
temperature: float = 0.7,
return_thinking: bool = True,
**kwargs
) -> Union[OculusTextOutput, OculusPointOutput, OculusBoxOutput, OculusPolygonOutput]:
"""
Generate output from image.
Args:
image: Input image (PIL, path, or array)
prompt: Text prompt/question
mode: Output mode ("text", "point", "box", "polygon")
think: Enable reasoning traces
focus: Enable zoom/crop for fine-grained perception
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
return_thinking: Include thinking trace in output
Returns:
Mode-specific output dataclass
"""
# Load models if needed
self.vision_encoder.load_encoders()
if mode == "text":
self.load_language_model()
# Load image
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
# Encode image
vision_tokens = self.encode_image(image)
# Generate thinking trace if enabled
thinking_trace = None
if think and self.config.reasoning_enabled:
thinking_trace = self._generate_thinking_trace(image, prompt)
# Focus system: zoom/crop if needed
if focus and self.config.enable_focus:
focus_regions = self._detect_focus_regions(image, prompt)
# Could re-encode cropped regions here
# Mode-specific generation
if mode == "text":
return self._generate_text(image, prompt, vision_tokens, thinking_trace, max_new_tokens, **kwargs)
elif mode == "point":
return self._generate_points(vision_tokens, thinking_trace, **kwargs)
elif mode == "box":
return self._generate_boxes(vision_tokens, thinking_trace, **kwargs)
elif mode == "polygon":
return self._generate_polygons(vision_tokens, thinking_trace, **kwargs)
else:
raise ValueError(f"Unknown mode: {mode}")
def _generate_text(
self,
image: Image.Image,
prompt: str,
vision_tokens: torch.Tensor,
thinking_trace: Optional[str],
max_new_tokens: Optional[int],
**kwargs
) -> OculusTextOutput:
"""Generate text output (caption or VQA)."""
device = vision_tokens.device if vision_tokens.is_cuda else "cpu"
max_tokens = max_new_tokens or self.config.max_new_tokens
# Determine if this is a question
is_question = any(q in prompt.lower() for q in ["what", "where", "who", "how", "why", "is", "are", "does", "do", "can", "?"])
if is_question and hasattr(self, 'lm_vqa_model'):
# VQA mode
inputs = self.lm_vqa_processor(image, prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
out = self.lm_vqa_model.generate(**inputs, max_new_tokens=50)
text = self.lm_vqa_processor.decode(out[0], skip_special_tokens=True)
else:
# Caption mode
inputs = self.lm_processor(image, prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
out = self.lm_caption_model.generate(**inputs, max_new_tokens=max_tokens)
text = self.lm_processor.decode(out[0], skip_special_tokens=True)
return OculusTextOutput(
text=text,
thinking_trace=thinking_trace,
vision_tokens=vision_tokens
)
def _generate_points(
self,
vision_tokens: torch.Tensor,
thinking_trace: Optional[str],
threshold: float = 0.5,
**kwargs
) -> OculusPointOutput:
"""Generate point detections."""
points, cls_logits, confidence = self.point_head(vision_tokens)
# Filter by confidence
mask = confidence.squeeze(-1) > threshold
filtered_points = []
filtered_labels = []
filtered_conf = []
for i in range(vision_tokens.shape[0]):
token_mask = mask[i]
pts = points[i][token_mask].detach().cpu().numpy().tolist()
confs = confidence[i][token_mask].squeeze(-1).detach().cpu().numpy().tolist()
cls_ids = cls_logits[i][token_mask].argmax(dim=-1).detach().cpu().numpy().tolist()
filtered_points.extend([tuple(p) for p in pts])
filtered_conf.extend(confs)
filtered_labels.extend([str(c) for c in cls_ids])
return OculusPointOutput(
points=filtered_points,
labels=filtered_labels,
confidences=filtered_conf,
thinking_trace=thinking_trace,
vision_tokens=vision_tokens
)
def _generate_boxes(
self,
vision_tokens: torch.Tensor,
thinking_trace: Optional[str],
threshold: float = 0.3,
**kwargs
) -> OculusBoxOutput:
"""Generate bounding box detections."""
cls_logits, box_coords = self.detection_head(vision_tokens)
# Get confidence from class logits
confidence = F.softmax(cls_logits, dim=-1).max(dim=-1).values
filtered_boxes = []
filtered_labels = []
filtered_conf = []
for i in range(vision_tokens.shape[0]):
mask = confidence[i] > threshold
boxes = box_coords[i][mask].detach().cpu().numpy()
confs = confidence[i][mask].detach().cpu().numpy().tolist()
cls_ids = cls_logits[i][mask].argmax(dim=-1).detach().cpu().numpy().tolist()
filtered_boxes.extend([tuple(b) for b in boxes])
filtered_conf.extend(confs)
filtered_labels.extend([str(c) for c in cls_ids])
return OculusBoxOutput(
boxes=filtered_boxes,
labels=filtered_labels,
confidences=filtered_conf,
thinking_trace=thinking_trace,
vision_tokens=vision_tokens
)
def _generate_polygons(
self,
vision_tokens: torch.Tensor,
thinking_trace: Optional[str],
**kwargs
) -> OculusPolygonOutput:
"""Generate polygon/mask segmentation."""
mask_logits = self.segmentation_head(vision_tokens)
# Get predicted mask
mask = mask_logits.argmax(dim=1).detach().cpu().numpy()
# Convert to polygons (simplified)
# In full implementation, would use cv2.findContours
polygons = []
labels = []
unique_classes = np.unique(mask[0])
for cls_id in unique_classes:
if cls_id == 0: # Skip background
continue
labels.append(str(cls_id))
# Placeholder polygon
polygons.append([(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)])
return OculusPolygonOutput(
polygons=polygons,
labels=labels,
mask=mask[0],
thinking_trace=thinking_trace,
vision_tokens=vision_tokens
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
"""
Load model from pretrained weights.
Args:
pretrained_model_name_or_path: HuggingFace repo ID or local path
"""
path = Path(pretrained_model_name_or_path)
# Load config
config_path = path / "config.json"
if config_path.exists():
import json
with open(config_path) as f:
proj_config = json.load(f)
# Create config with correct dimensions from projector
config = OculusConfig(
dinov3_hidden_size=proj_config.get("fused_dim", 2048) - 768, # Infer from fused
siglip_hidden_size=768,
projector_hidden_dim=proj_config.get("hidden_dim", 2048),
num_vision_tokens=proj_config.get("num_tokens", 64),
lm_hidden_size=proj_config.get("embed_dim", 1536),
)
else:
config = OculusConfig()
# Create model
model = cls(config)
# Load projector weights
projector_path = path / "projector.npz"
if projector_path.exists():
model.projector = OculusProjector.from_pretrained(path, config)
# Load detection/segmentation heads if available
heads_path = path / "heads.pth"
if heads_path.exists():
heads_state = torch.load(heads_path, map_location="cpu")
model.detection_head.load_state_dict(heads_state.get("detection", {}), strict=False)
model.point_head.load_state_dict(heads_state.get("point", {}), strict=False)
model.segmentation_head.load_state_dict(heads_state.get("segmentation", {}), strict=False)
return model
def save_pretrained(self, save_directory: str):
"""Save model to directory."""
path = Path(save_directory)
path.mkdir(parents=True, exist_ok=True)
# Save config
self.config.save_pretrained(path)
# Save projector
projector_state = self.projector.state_dict()
# Convert to numpy for MLX compatibility
np_weights = {}
for k, v in projector_state.items():
parts = k.split(".")
layer = parts[0]
param = ".".join(parts[1:])
if layer not in np_weights:
np_weights[layer] = {}
np_weights[layer][param] = v.cpu().numpy()
np.savez(path / "projector.npz", **{k: v for k, v in np_weights.items()})
# Save heads
torch.save({
"detection": self.detection_head.state_dict(),
"point": self.point_head.state_dict(),
"segmentation": self.segmentation_head.state_dict(),
}, path / "heads.pth")
print(f"✓ Saved model to {path}")
# Register for auto-loading
OculusForConditionalGeneration.register_for_auto_class("AutoModelForVision2Seq")