|
|
""" |
|
|
Oculus Unified Model |
|
|
|
|
|
HuggingFace-compatible vision-language model with: |
|
|
- Multi-encoder vision (DINOv3 + SigLIP2) |
|
|
- Trained projector for vision-to-language |
|
|
- Optional reasoning with thinking traces |
|
|
- Multiple output modes (Text, Point, Box, Polygon) |
|
|
- Focus/Zoom tool calling for fine-grained perception |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import warnings |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple, List, Dict, Any, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import ( |
|
|
PreTrainedModel, |
|
|
PretrainedConfig, |
|
|
AutoImageProcessor, |
|
|
AutoModel, |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
GenerationConfig, |
|
|
) |
|
|
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast |
|
|
from PIL import Image |
|
|
|
|
|
from .configuration_oculus import OculusConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OculusOutput: |
|
|
"""Base output class for Oculus model.""" |
|
|
text: Optional[str] = None |
|
|
thinking_trace: Optional[str] = None |
|
|
logits: Optional[torch.Tensor] = None |
|
|
hidden_states: Optional[torch.Tensor] = None |
|
|
vision_tokens: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OculusTextOutput(OculusOutput): |
|
|
"""Output for text/caption mode.""" |
|
|
pass |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OculusPointOutput(OculusOutput): |
|
|
"""Output for point detection mode (counting objects).""" |
|
|
points: Optional[List[Tuple[float, float]]] = None |
|
|
labels: Optional[List[str]] = None |
|
|
confidences: Optional[List[float]] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OculusBoxOutput(OculusOutput): |
|
|
"""Output for bounding box detection mode.""" |
|
|
boxes: Optional[List[Tuple[float, float, float, float]]] = None |
|
|
labels: Optional[List[str]] = None |
|
|
confidences: Optional[List[float]] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OculusPolygonOutput(OculusOutput): |
|
|
"""Output for polygon/segmentation mode.""" |
|
|
polygons: Optional[List[List[Tuple[float, float]]]] = None |
|
|
labels: Optional[List[str]] = None |
|
|
mask: Optional[np.ndarray] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OculusVisionEncoder(nn.Module): |
|
|
""" |
|
|
Dual vision encoder combining DINOv3 and SigLIP2. |
|
|
|
|
|
DINOv3: Excellent at semantic understanding, object boundaries |
|
|
SigLIP2: Strong at text/language alignment |
|
|
""" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.dinov3 = None |
|
|
self.dinov3_processor = None |
|
|
self.siglip = None |
|
|
self.siglip_processor = None |
|
|
|
|
|
self._loaded = False |
|
|
|
|
|
def load_encoders(self, device: str = "cpu"): |
|
|
"""Load vision encoders from HuggingFace.""" |
|
|
if self._loaded: |
|
|
return |
|
|
|
|
|
print("[Oculus] Loading vision encoders...") |
|
|
|
|
|
|
|
|
try: |
|
|
self.dinov3_processor = AutoImageProcessor.from_pretrained( |
|
|
self.config.dinov3_model_id |
|
|
) |
|
|
self.dinov3 = AutoModel.from_pretrained( |
|
|
self.config.dinov3_model_id |
|
|
).eval().to(device) |
|
|
print(f" ✓ DINOv3: {self.config.dinov3_model_id}") |
|
|
except Exception as e: |
|
|
warnings.warn(f"Failed to load DINOv3: {e}") |
|
|
self.dinov3_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") |
|
|
self.dinov3 = AutoModel.from_pretrained("facebook/dinov2-base").eval().to(device) |
|
|
print(" ✓ DINOv2-base (fallback)") |
|
|
|
|
|
|
|
|
try: |
|
|
self.siglip_processor = AutoImageProcessor.from_pretrained( |
|
|
self.config.siglip_model_id |
|
|
) |
|
|
self.siglip = AutoModel.from_pretrained( |
|
|
self.config.siglip_model_id |
|
|
).eval().to(device) |
|
|
print(f" ✓ SigLIP: {self.config.siglip_model_id}") |
|
|
except Exception as e: |
|
|
warnings.warn(f"Failed to load SigLIP: {e}") |
|
|
from transformers import SiglipVisionModel |
|
|
self.siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") |
|
|
self.siglip = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224").eval().to(device) |
|
|
print(" ✓ SigLIP-base (fallback)") |
|
|
|
|
|
self._loaded = True |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, image: Union[Image.Image, torch.Tensor, np.ndarray]) -> torch.Tensor: |
|
|
""" |
|
|
Encode image with both vision encoders and fuse features. |
|
|
|
|
|
Returns: |
|
|
Fused vision features [batch, fused_dim] |
|
|
""" |
|
|
if not self._loaded: |
|
|
self.load_encoders() |
|
|
|
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image) |
|
|
elif isinstance(image, torch.Tensor): |
|
|
image = Image.fromarray(image.cpu().numpy().astype(np.uint8)) |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
image = image.convert('RGB') |
|
|
|
|
|
device = next(self.dinov3.parameters()).device |
|
|
|
|
|
|
|
|
d_inputs = self.dinov3_processor(images=image, return_tensors="pt") |
|
|
d_inputs = {k: v.to(device) for k, v in d_inputs.items()} |
|
|
d_out = self.dinov3(**d_inputs) |
|
|
d_pooled = d_out.pooler_output if hasattr(d_out, 'pooler_output') and d_out.pooler_output is not None else d_out.last_hidden_state[:, 0] |
|
|
|
|
|
|
|
|
s_inputs = self.siglip_processor(images=image, return_tensors="pt") |
|
|
s_inputs = {k: v.to(device) for k, v in s_inputs.items()} |
|
|
|
|
|
if hasattr(self.siglip, 'vision_model'): |
|
|
s_hidden = self.siglip.vision_model.embeddings(s_inputs['pixel_values']) |
|
|
s_pooled = s_hidden.mean(dim=1) |
|
|
else: |
|
|
s_out = self.siglip(**s_inputs) |
|
|
s_pooled = s_out.pooler_output if hasattr(s_out, 'pooler_output') else s_out.last_hidden_state[:, 0] |
|
|
|
|
|
|
|
|
fused = torch.cat([d_pooled, s_pooled], dim=-1) |
|
|
|
|
|
return fused |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OculusProjector(nn.Module): |
|
|
""" |
|
|
Projects fused vision features to language model token space. |
|
|
|
|
|
Converts [batch, fused_dim] → [batch, num_tokens, lm_hidden_size] |
|
|
""" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
fused_dim = config.fused_vision_dim |
|
|
hidden_dim = config.projector_hidden_dim |
|
|
num_tokens = config.num_vision_tokens |
|
|
embed_dim = config.lm_hidden_size |
|
|
|
|
|
self.fc1 = nn.Linear(fused_dim, hidden_dim) |
|
|
self.act1 = nn.GELU() |
|
|
self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.act2 = nn.GELU() |
|
|
self.fc3 = nn.Linear(hidden_dim, num_tokens * embed_dim) |
|
|
self.norm = nn.LayerNorm(embed_dim) |
|
|
|
|
|
self.num_tokens = num_tokens |
|
|
self.embed_dim = embed_dim |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Project vision features to token embeddings. |
|
|
|
|
|
Args: |
|
|
x: Vision features [batch, fused_dim] |
|
|
|
|
|
Returns: |
|
|
Vision tokens [batch, num_tokens, embed_dim] |
|
|
""" |
|
|
batch_size = x.shape[0] |
|
|
|
|
|
h = self.fc1(x) |
|
|
h = self.act1(h) |
|
|
h = self.fc2(h) |
|
|
h = self.act2(h) |
|
|
h = self.fc3(h) |
|
|
|
|
|
h = h.reshape(batch_size, self.num_tokens, self.embed_dim) |
|
|
h = self.norm(h) |
|
|
|
|
|
return h |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, path: str, config: OculusConfig): |
|
|
"""Load projector from saved weights.""" |
|
|
projector = cls(config) |
|
|
|
|
|
weights_path = Path(path) / "projector.npz" |
|
|
if weights_path.exists(): |
|
|
import numpy as np |
|
|
weights = np.load(weights_path, allow_pickle=True) |
|
|
|
|
|
state_dict = {} |
|
|
for key in weights.files: |
|
|
layer_dict = weights[key].item() |
|
|
for param_name, param_val in layer_dict.items(): |
|
|
full_key = f"{key}.{param_name}" |
|
|
|
|
|
if hasattr(param_val, 'tolist'): |
|
|
param_val = np.array(param_val.tolist()) |
|
|
state_dict[full_key] = torch.from_numpy(np.array(param_val)) |
|
|
|
|
|
projector.load_state_dict(state_dict, strict=False) |
|
|
print(f" ✓ Loaded projector from {path}") |
|
|
|
|
|
return projector |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OculusDetectionHead(nn.Module): |
|
|
"""Head for bounding box detection.""" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__() |
|
|
hidden_dim = config.lm_hidden_size |
|
|
num_classes = config.num_detection_classes |
|
|
|
|
|
self.cls_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim // 2, num_classes) |
|
|
) |
|
|
|
|
|
self.box_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim // 2, 4) |
|
|
) |
|
|
|
|
|
def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Predict boxes and classes from vision tokens. |
|
|
|
|
|
Returns: |
|
|
cls_logits: [batch, num_tokens, num_classes] |
|
|
box_coords: [batch, num_tokens, 4] |
|
|
""" |
|
|
cls_logits = self.cls_head(vision_tokens) |
|
|
box_coords = self.box_head(vision_tokens).sigmoid() |
|
|
return cls_logits, box_coords |
|
|
|
|
|
|
|
|
class OculusPointHead(nn.Module): |
|
|
"""Head for point detection (object counting).""" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__() |
|
|
hidden_dim = config.lm_hidden_size |
|
|
num_classes = config.num_detection_classes |
|
|
|
|
|
self.point_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim // 2, 2) |
|
|
) |
|
|
|
|
|
self.cls_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim // 2, num_classes) |
|
|
) |
|
|
|
|
|
self.conf_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 4), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim // 4, 1) |
|
|
) |
|
|
|
|
|
def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
points = self.point_head(vision_tokens).sigmoid() |
|
|
cls_logits = self.cls_head(vision_tokens) |
|
|
confidence = self.conf_head(vision_tokens).sigmoid() |
|
|
return points, cls_logits, confidence |
|
|
|
|
|
|
|
|
class OculusSegmentationHead(nn.Module): |
|
|
"""Head for polygon/mask segmentation.""" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__() |
|
|
hidden_dim = config.lm_hidden_size |
|
|
num_classes = config.num_segmentation_classes |
|
|
|
|
|
|
|
|
self.mask_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim, 14 * 14 * num_classes) |
|
|
) |
|
|
|
|
|
self.num_classes = num_classes |
|
|
|
|
|
def forward(self, vision_tokens: torch.Tensor) -> torch.Tensor: |
|
|
batch_size = vision_tokens.shape[0] |
|
|
pooled = vision_tokens.mean(dim=1) |
|
|
mask_logits = self.mask_head(pooled) |
|
|
mask_logits = mask_logits.reshape(batch_size, self.num_classes, 14, 14) |
|
|
return mask_logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OculusForConditionalGeneration(PreTrainedModel): |
|
|
""" |
|
|
Oculus: Unified Vision-Language Model |
|
|
|
|
|
Features: |
|
|
- Multi-encoder vision (DINOv3 + SigLIP2) |
|
|
- Optional reasoning with thinking traces |
|
|
- Multiple output modes: Text, Point, Box, Polygon |
|
|
- Focus/Zoom tool calling for fine-grained perception |
|
|
|
|
|
Usage: |
|
|
```python |
|
|
from oculus_unified_model import OculusForConditionalGeneration |
|
|
|
|
|
model = OculusForConditionalGeneration.from_pretrained("OceanirAI/oculus-0.2") |
|
|
|
|
|
# Caption mode |
|
|
output = model.generate(image, mode="text", prompt="Describe this image") |
|
|
|
|
|
# VQA mode |
|
|
output = model.generate(image, mode="text", prompt="What color is the cat?") |
|
|
|
|
|
# With reasoning |
|
|
output = model.generate(image, mode="text", prompt="Count the people", think=True) |
|
|
|
|
|
# Detection mode |
|
|
output = model.generate(image, mode="box", prompt="Find all cars") |
|
|
|
|
|
# Point mode (counting) |
|
|
output = model.generate(image, mode="point", prompt="Count the birds") |
|
|
|
|
|
# Segmentation mode |
|
|
output = model.generate(image, mode="polygon", prompt="Segment the road") |
|
|
``` |
|
|
""" |
|
|
|
|
|
config_class = OculusConfig |
|
|
base_model_prefix = "oculus" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.vision_encoder = OculusVisionEncoder(config) |
|
|
|
|
|
|
|
|
self.vision_adapter = None |
|
|
self._actual_vision_dim = None |
|
|
|
|
|
|
|
|
self.projector = OculusProjector(config) |
|
|
|
|
|
|
|
|
self.detection_head = OculusDetectionHead(config) |
|
|
self.point_head = OculusPointHead(config) |
|
|
self.segmentation_head = OculusSegmentationHead(config) |
|
|
|
|
|
|
|
|
self.lm_tokenizer = None |
|
|
self.lm_model = None |
|
|
self._lm_loaded = False |
|
|
|
|
|
|
|
|
self.thinking_token = config.thinking_token |
|
|
self.thinking_end_token = config.thinking_end_token |
|
|
self.focus_token = config.focus_token |
|
|
self.focus_end_token = config.focus_end_token |
|
|
|
|
|
def load_language_model(self, device: str = "cpu"): |
|
|
"""Load language model for text generation.""" |
|
|
if self._lm_loaded: |
|
|
return |
|
|
|
|
|
print("[Oculus] Loading language model...") |
|
|
|
|
|
try: |
|
|
|
|
|
from transformers import BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering |
|
|
|
|
|
self.lm_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
|
self.lm_caption_model = BlipForConditionalGeneration.from_pretrained( |
|
|
"Salesforce/blip-image-captioning-base" |
|
|
).to(device) |
|
|
|
|
|
self.lm_vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") |
|
|
self.lm_vqa_model = BlipForQuestionAnswering.from_pretrained( |
|
|
"Salesforce/blip-vqa-base" |
|
|
).to(device) |
|
|
|
|
|
print(" ✓ BLIP (captioning + VQA)") |
|
|
self._lm_loaded = True |
|
|
|
|
|
except Exception as e: |
|
|
warnings.warn(f"Failed to load language model: {e}") |
|
|
|
|
|
def encode_image(self, image: Union[Image.Image, str, np.ndarray]) -> torch.Tensor: |
|
|
""" |
|
|
Encode image to vision tokens. |
|
|
|
|
|
Args: |
|
|
image: PIL Image, file path, or numpy array |
|
|
|
|
|
Returns: |
|
|
Vision tokens [1, num_tokens, embed_dim] |
|
|
""" |
|
|
|
|
|
if isinstance(image, str): |
|
|
image = Image.open(image) |
|
|
|
|
|
|
|
|
vision_features = self.vision_encoder(image) |
|
|
|
|
|
|
|
|
actual_dim = vision_features.shape[-1] |
|
|
expected_dim = self.config.fused_vision_dim |
|
|
|
|
|
if actual_dim != expected_dim: |
|
|
if self.vision_adapter is None or self._actual_vision_dim != actual_dim: |
|
|
|
|
|
print(f" [Adapter] Creating vision adapter: {actual_dim} -> {expected_dim}") |
|
|
self.vision_adapter = nn.Linear(actual_dim, expected_dim) |
|
|
self._actual_vision_dim = actual_dim |
|
|
|
|
|
nn.init.xavier_uniform_(self.vision_adapter.weight) |
|
|
nn.init.zeros_(self.vision_adapter.bias) |
|
|
|
|
|
vision_features = self.vision_adapter(vision_features) |
|
|
|
|
|
|
|
|
vision_tokens = self.projector(vision_features) |
|
|
|
|
|
return vision_tokens |
|
|
|
|
|
def _generate_thinking_trace( |
|
|
self, |
|
|
image: Image.Image, |
|
|
prompt: str, |
|
|
max_tokens: int = 256 |
|
|
) -> str: |
|
|
""" |
|
|
Generate a thinking/reasoning trace before answering. |
|
|
|
|
|
This enables multi-step reasoning for complex tasks. |
|
|
""" |
|
|
thinking_prompt = f"""Let me think about this step by step: |
|
|
1. First, I'll analyze what I see in the image. |
|
|
2. Then, I'll consider the question: "{prompt}" |
|
|
3. Finally, I'll formulate my answer. |
|
|
|
|
|
Observation: """ |
|
|
|
|
|
|
|
|
if self._lm_loaded and hasattr(self, 'lm_caption_model'): |
|
|
inputs = self.lm_processor(image, thinking_prompt, return_tensors="pt") |
|
|
inputs = {k: v.to(self.lm_caption_model.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = self.lm_caption_model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
do_sample=True, |
|
|
temperature=0.7 |
|
|
) |
|
|
thinking = self.lm_processor.decode(out[0], skip_special_tokens=True) |
|
|
else: |
|
|
thinking = "I observe the image and analyze its contents." |
|
|
|
|
|
return thinking |
|
|
|
|
|
def _detect_focus_regions( |
|
|
self, |
|
|
image: Image.Image, |
|
|
prompt: str |
|
|
) -> List[Tuple[int, int, int, int]]: |
|
|
""" |
|
|
Detect regions that need closer inspection (Focus/Zoom system). |
|
|
|
|
|
Returns list of (x1, y1, x2, y2) crop regions. |
|
|
""" |
|
|
|
|
|
|
|
|
w, h = image.size |
|
|
return [(0, 0, w, h)] |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
image: Union[Image.Image, str, np.ndarray], |
|
|
prompt: str = "Describe this image", |
|
|
mode: str = "text", |
|
|
think: bool = False, |
|
|
focus: bool = False, |
|
|
max_new_tokens: Optional[int] = None, |
|
|
temperature: float = 0.7, |
|
|
return_thinking: bool = True, |
|
|
**kwargs |
|
|
) -> Union[OculusTextOutput, OculusPointOutput, OculusBoxOutput, OculusPolygonOutput]: |
|
|
""" |
|
|
Generate output from image. |
|
|
|
|
|
Args: |
|
|
image: Input image (PIL, path, or array) |
|
|
prompt: Text prompt/question |
|
|
mode: Output mode ("text", "point", "box", "polygon") |
|
|
think: Enable reasoning traces |
|
|
focus: Enable zoom/crop for fine-grained perception |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
return_thinking: Include thinking trace in output |
|
|
|
|
|
Returns: |
|
|
Mode-specific output dataclass |
|
|
""" |
|
|
|
|
|
self.vision_encoder.load_encoders() |
|
|
if mode == "text": |
|
|
self.load_language_model() |
|
|
|
|
|
|
|
|
if isinstance(image, str): |
|
|
image = Image.open(image).convert('RGB') |
|
|
elif isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image).convert('RGB') |
|
|
|
|
|
|
|
|
vision_tokens = self.encode_image(image) |
|
|
|
|
|
|
|
|
thinking_trace = None |
|
|
if think and self.config.reasoning_enabled: |
|
|
thinking_trace = self._generate_thinking_trace(image, prompt) |
|
|
|
|
|
|
|
|
if focus and self.config.enable_focus: |
|
|
focus_regions = self._detect_focus_regions(image, prompt) |
|
|
|
|
|
|
|
|
|
|
|
if mode == "text": |
|
|
return self._generate_text(image, prompt, vision_tokens, thinking_trace, max_new_tokens, **kwargs) |
|
|
elif mode == "point": |
|
|
return self._generate_points(vision_tokens, thinking_trace, **kwargs) |
|
|
elif mode == "box": |
|
|
return self._generate_boxes(vision_tokens, thinking_trace, **kwargs) |
|
|
elif mode == "polygon": |
|
|
return self._generate_polygons(vision_tokens, thinking_trace, **kwargs) |
|
|
else: |
|
|
raise ValueError(f"Unknown mode: {mode}") |
|
|
|
|
|
def _generate_text( |
|
|
self, |
|
|
image: Image.Image, |
|
|
prompt: str, |
|
|
vision_tokens: torch.Tensor, |
|
|
thinking_trace: Optional[str], |
|
|
max_new_tokens: Optional[int], |
|
|
**kwargs |
|
|
) -> OculusTextOutput: |
|
|
"""Generate text output (caption or VQA).""" |
|
|
|
|
|
device = vision_tokens.device if vision_tokens.is_cuda else "cpu" |
|
|
max_tokens = max_new_tokens or self.config.max_new_tokens |
|
|
|
|
|
|
|
|
is_question = any(q in prompt.lower() for q in ["what", "where", "who", "how", "why", "is", "are", "does", "do", "can", "?"]) |
|
|
|
|
|
if is_question and hasattr(self, 'lm_vqa_model'): |
|
|
|
|
|
inputs = self.lm_vqa_processor(image, prompt, return_tensors="pt") |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = self.lm_vqa_model.generate(**inputs, max_new_tokens=50) |
|
|
text = self.lm_vqa_processor.decode(out[0], skip_special_tokens=True) |
|
|
else: |
|
|
|
|
|
inputs = self.lm_processor(image, prompt, return_tensors="pt") |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = self.lm_caption_model.generate(**inputs, max_new_tokens=max_tokens) |
|
|
text = self.lm_processor.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
return OculusTextOutput( |
|
|
text=text, |
|
|
thinking_trace=thinking_trace, |
|
|
vision_tokens=vision_tokens |
|
|
) |
|
|
|
|
|
def _generate_points( |
|
|
self, |
|
|
vision_tokens: torch.Tensor, |
|
|
thinking_trace: Optional[str], |
|
|
threshold: float = 0.5, |
|
|
**kwargs |
|
|
) -> OculusPointOutput: |
|
|
"""Generate point detections.""" |
|
|
|
|
|
points, cls_logits, confidence = self.point_head(vision_tokens) |
|
|
|
|
|
|
|
|
mask = confidence.squeeze(-1) > threshold |
|
|
|
|
|
filtered_points = [] |
|
|
filtered_labels = [] |
|
|
filtered_conf = [] |
|
|
|
|
|
for i in range(vision_tokens.shape[0]): |
|
|
token_mask = mask[i] |
|
|
pts = points[i][token_mask].detach().cpu().numpy().tolist() |
|
|
confs = confidence[i][token_mask].squeeze(-1).detach().cpu().numpy().tolist() |
|
|
cls_ids = cls_logits[i][token_mask].argmax(dim=-1).detach().cpu().numpy().tolist() |
|
|
|
|
|
filtered_points.extend([tuple(p) for p in pts]) |
|
|
filtered_conf.extend(confs) |
|
|
filtered_labels.extend([str(c) for c in cls_ids]) |
|
|
|
|
|
return OculusPointOutput( |
|
|
points=filtered_points, |
|
|
labels=filtered_labels, |
|
|
confidences=filtered_conf, |
|
|
thinking_trace=thinking_trace, |
|
|
vision_tokens=vision_tokens |
|
|
) |
|
|
|
|
|
def _generate_boxes( |
|
|
self, |
|
|
vision_tokens: torch.Tensor, |
|
|
thinking_trace: Optional[str], |
|
|
threshold: float = 0.3, |
|
|
**kwargs |
|
|
) -> OculusBoxOutput: |
|
|
"""Generate bounding box detections.""" |
|
|
|
|
|
cls_logits, box_coords = self.detection_head(vision_tokens) |
|
|
|
|
|
|
|
|
confidence = F.softmax(cls_logits, dim=-1).max(dim=-1).values |
|
|
|
|
|
filtered_boxes = [] |
|
|
filtered_labels = [] |
|
|
filtered_conf = [] |
|
|
|
|
|
for i in range(vision_tokens.shape[0]): |
|
|
mask = confidence[i] > threshold |
|
|
boxes = box_coords[i][mask].detach().cpu().numpy() |
|
|
confs = confidence[i][mask].detach().cpu().numpy().tolist() |
|
|
cls_ids = cls_logits[i][mask].argmax(dim=-1).detach().cpu().numpy().tolist() |
|
|
|
|
|
filtered_boxes.extend([tuple(b) for b in boxes]) |
|
|
filtered_conf.extend(confs) |
|
|
filtered_labels.extend([str(c) for c in cls_ids]) |
|
|
|
|
|
return OculusBoxOutput( |
|
|
boxes=filtered_boxes, |
|
|
labels=filtered_labels, |
|
|
confidences=filtered_conf, |
|
|
thinking_trace=thinking_trace, |
|
|
vision_tokens=vision_tokens |
|
|
) |
|
|
|
|
|
def _generate_polygons( |
|
|
self, |
|
|
vision_tokens: torch.Tensor, |
|
|
thinking_trace: Optional[str], |
|
|
**kwargs |
|
|
) -> OculusPolygonOutput: |
|
|
"""Generate polygon/mask segmentation.""" |
|
|
|
|
|
mask_logits = self.segmentation_head(vision_tokens) |
|
|
|
|
|
|
|
|
mask = mask_logits.argmax(dim=1).detach().cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
polygons = [] |
|
|
labels = [] |
|
|
|
|
|
unique_classes = np.unique(mask[0]) |
|
|
for cls_id in unique_classes: |
|
|
if cls_id == 0: |
|
|
continue |
|
|
labels.append(str(cls_id)) |
|
|
|
|
|
polygons.append([(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]) |
|
|
|
|
|
return OculusPolygonOutput( |
|
|
polygons=polygons, |
|
|
labels=labels, |
|
|
mask=mask[0], |
|
|
thinking_trace=thinking_trace, |
|
|
vision_tokens=vision_tokens |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
|
|
""" |
|
|
Load model from pretrained weights. |
|
|
|
|
|
Args: |
|
|
pretrained_model_name_or_path: HuggingFace repo ID or local path |
|
|
""" |
|
|
path = Path(pretrained_model_name_or_path) |
|
|
|
|
|
|
|
|
config_path = path / "config.json" |
|
|
if config_path.exists(): |
|
|
import json |
|
|
with open(config_path) as f: |
|
|
proj_config = json.load(f) |
|
|
|
|
|
|
|
|
config = OculusConfig( |
|
|
dinov3_hidden_size=proj_config.get("fused_dim", 2048) - 768, |
|
|
siglip_hidden_size=768, |
|
|
projector_hidden_dim=proj_config.get("hidden_dim", 2048), |
|
|
num_vision_tokens=proj_config.get("num_tokens", 64), |
|
|
lm_hidden_size=proj_config.get("embed_dim", 1536), |
|
|
) |
|
|
else: |
|
|
config = OculusConfig() |
|
|
|
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
|
|
|
projector_path = path / "projector.npz" |
|
|
if projector_path.exists(): |
|
|
model.projector = OculusProjector.from_pretrained(path, config) |
|
|
|
|
|
|
|
|
heads_path = path / "heads.pth" |
|
|
if heads_path.exists(): |
|
|
heads_state = torch.load(heads_path, map_location="cpu") |
|
|
model.detection_head.load_state_dict(heads_state.get("detection", {}), strict=False) |
|
|
model.point_head.load_state_dict(heads_state.get("point", {}), strict=False) |
|
|
model.segmentation_head.load_state_dict(heads_state.get("segmentation", {}), strict=False) |
|
|
|
|
|
return model |
|
|
|
|
|
def save_pretrained(self, save_directory: str): |
|
|
"""Save model to directory.""" |
|
|
path = Path(save_directory) |
|
|
path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.config.save_pretrained(path) |
|
|
|
|
|
|
|
|
projector_state = self.projector.state_dict() |
|
|
|
|
|
np_weights = {} |
|
|
for k, v in projector_state.items(): |
|
|
parts = k.split(".") |
|
|
layer = parts[0] |
|
|
param = ".".join(parts[1:]) |
|
|
if layer not in np_weights: |
|
|
np_weights[layer] = {} |
|
|
np_weights[layer][param] = v.cpu().numpy() |
|
|
np.savez(path / "projector.npz", **{k: v for k, v in np_weights.items()}) |
|
|
|
|
|
|
|
|
torch.save({ |
|
|
"detection": self.detection_head.state_dict(), |
|
|
"point": self.point_head.state_dict(), |
|
|
"segmentation": self.segmentation_head.state_dict(), |
|
|
}, path / "heads.pth") |
|
|
|
|
|
print(f"✓ Saved model to {path}") |
|
|
|
|
|
|
|
|
|
|
|
OculusForConditionalGeneration.register_for_auto_class("AutoModelForVision2Seq") |
|
|
|