| |
| """ |
| Phase 3a: Vision Encoder Export for ExecuTorch |
| Extracts vision_encoder + vision_projection into a standalone nn.Module |
| with fixed-size input for torch.export compatibility. |
| |
| Fixed resolution: 1120x1540 (snapped to patch_size=14 multiples) |
| -> patch grid: 80 x 110 = 8800 patches |
| -> after PatchMerger (2x2): 40 x 55 = 2200 tokens |
| """ |
|
|
| import os |
| import sys |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| FIXED_H = 1120 |
| FIXED_W = 1540 |
| PATCH_SIZE = 14 |
| SPATIAL_MERGE = 2 |
|
|
| |
| PATCHES_H = FIXED_H // PATCH_SIZE |
| PATCHES_W = FIXED_W // PATCH_SIZE |
| NUM_PATCHES = PATCHES_H * PATCHES_W |
| MERGED_H = PATCHES_H // SPATIAL_MERGE |
| MERGED_W = PATCHES_W // SPATIAL_MERGE |
| NUM_MERGED = MERGED_H * MERGED_W |
|
|
| MODEL_DIR = "./models/LightOnOCR-2-1B" |
|
|
|
|
| class FixedPatchMerger(nn.Module): |
| """ |
| Rewritten PatchMerger that works with fixed single-image input. |
| No Python loops, no dynamic shapes. |
| |
| Original: loops over variable-size images, dynamic unfold |
| This: single fixed-size image, vectorized unfold |
| """ |
|
|
| def __init__(self, hidden_size: int, spatial_merge_size: int = 2): |
| super().__init__() |
| self.spatial_merge_size = spatial_merge_size |
| self.merging_layer = nn.Linear( |
| hidden_size * spatial_merge_size ** 2, hidden_size, bias=False |
| ) |
|
|
| def forward(self, image_features: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| image_features: [num_patches, hidden_size] where num_patches = PATCHES_H * PATCHES_W |
| |
| Returns: |
| [num_merged, hidden_size] where num_merged = MERGED_H * MERGED_W |
| """ |
| d = image_features.shape[-1] |
|
|
| |
| image_grid = image_features.view(PATCHES_H, PATCHES_W, d).permute(2, 0, 1).unsqueeze(0) |
|
|
| |
| |
| |
| grid = F.unfold( |
| image_grid, |
| kernel_size=self.spatial_merge_size, |
| stride=self.spatial_merge_size |
| ) |
|
|
| |
| grid = grid.squeeze(0).t() |
|
|
| |
| return self.merging_layer(grid) |
|
|
|
|
| class FixedMultiModalProjector(nn.Module): |
| """Fixed-size multimodal projector (RMSNorm + PatchMerger + MLP).""" |
|
|
| def __init__(self, vision_hidden_size: int, text_hidden_size: int, |
| spatial_merge_size: int = 2, rms_eps: float = 1e-6): |
| super().__init__() |
| self.norm_weight = nn.Parameter(torch.ones(vision_hidden_size)) |
| self.norm_eps = rms_eps |
| self.patch_merger = FixedPatchMerger(vision_hidden_size, spatial_merge_size) |
| self.linear_1 = nn.Linear(vision_hidden_size, text_hidden_size, bias=False) |
| self.linear_2 = nn.Linear(text_hidden_size, text_hidden_size, bias=False) |
|
|
| def _rms_norm(self, x: torch.Tensor) -> torch.Tensor: |
| """Inline RMSNorm — avoids @use_kernel_forward_from_hub decorator.""" |
| input_dtype = x.dtype |
| x = x.to(torch.float32) |
| variance = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(variance + self.norm_eps) |
| return self.norm_weight * x.to(input_dtype) |
|
|
| def forward(self, image_features: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| image_features: [num_patches, vision_hidden_size] |
| Returns: |
| [num_merged, text_hidden_size] |
| """ |
| image_features = self._rms_norm(image_features) |
| image_features = self.patch_merger(image_features) |
| hidden = self.linear_1(image_features) |
| hidden = F.gelu(hidden) |
| hidden = self.linear_2(hidden) |
| return hidden |
|
|
|
|
| class VisionEncoderFixed(nn.Module): |
| """ |
| Standalone vision encoder for ExecuTorch export. |
| Wraps PixtralVisionModel + MultiModalProjector with fixed-size input. |
| |
| Input: pixel_values [1, 3, 1120, 1540] |
| Output: image_features [1, 2200, 1024] |
| """ |
|
|
| def __init__(self, vision_encoder, projector): |
| super().__init__() |
| |
| self.patch_conv = vision_encoder.patch_conv |
| self.ln_pre_weight = nn.Parameter(vision_encoder.ln_pre.weight.clone()) |
| self.ln_pre_eps = vision_encoder.ln_pre.variance_epsilon |
| self.transformer = vision_encoder.transformer |
| self.rope = vision_encoder.patch_positional_embedding |
|
|
| |
| self.projector = projector |
|
|
| |
| max_width = vision_encoder.config.image_size // PATCH_SIZE |
| self.register_buffer( |
| "position_ids", |
| self._compute_fixed_position_ids(PATCHES_H, PATCHES_W, max_width) |
| ) |
|
|
| @staticmethod |
| def _compute_fixed_position_ids(h: int, w: int, max_width: int) -> torch.Tensor: |
| """Pre-compute position IDs for fixed-size image grid.""" |
| mesh = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") |
| h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) |
| ids = h_grid * max_width + v_grid |
| return ids[:, 0].unsqueeze(0) |
|
|
| def _rms_norm_pre(self, x: torch.Tensor) -> torch.Tensor: |
| """Inline RMSNorm for ln_pre.""" |
| input_dtype = x.dtype |
| x = x.to(torch.float32) |
| variance = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(variance + self.ln_pre_eps) |
| return self.ln_pre_weight * x.to(input_dtype) |
|
|
| def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| pixel_values: [1, 3, 1120, 1540] |
| Returns: |
| image_features: [1, 2200, 1024] |
| """ |
| |
| |
| patch_embeds = self.patch_conv(pixel_values) |
|
|
| |
| |
| patch_embeds = patch_embeds.flatten(2).transpose(1, 2) |
|
|
| |
| patch_embeds = self._rms_norm_pre(patch_embeds) |
|
|
| |
| position_embeddings = self.rope(patch_embeds, self.position_ids) |
|
|
| |
| |
| outputs = self.transformer( |
| patch_embeds, |
| attention_mask=None, |
| position_embeddings=position_embeddings, |
| output_hidden_states=True, |
| output_attentions=False, |
| return_dict=True, |
| ) |
|
|
| |
| |
| hidden_states = outputs.hidden_states[-1].squeeze(0) |
|
|
| |
| image_features = self.projector(hidden_states) |
|
|
| return image_features.unsqueeze(0) |
|
|
|
|
| def load_original_model(): |
| """Load the original model with proper weight remapping.""" |
| from transformers import AutoModelForImageTextToText |
| from safetensors.torch import load_file |
|
|
| print("Loading original model...") |
| model = AutoModelForImageTextToText.from_pretrained( |
| MODEL_DIR, |
| dtype=torch.bfloat16, |
| attn_implementation="sdpa", |
| device_map="cpu", |
| ) |
|
|
| |
| state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors")) |
| remapped = {} |
| for k, v in state_dict.items(): |
| new_k = k.replace("model.vision_encoder.", "model.vision_tower.") |
| new_k = new_k.replace("model.vision_projection.", "model.multi_modal_projector.") |
| remapped[new_k] = v |
| model.load_state_dict(remapped, strict=False) |
|
|
| return model |
|
|
|
|
| def build_vision_module(original_model): |
| """Build the fixed-size vision module from the original model.""" |
| config = original_model.config |
| vision_encoder = original_model.model.vision_tower |
| orig_projector = original_model.model.multi_modal_projector |
|
|
| |
| projector = FixedMultiModalProjector( |
| vision_hidden_size=config.vision_config.hidden_size, |
| text_hidden_size=config.text_config.hidden_size, |
| spatial_merge_size=config.spatial_merge_size, |
| rms_eps=config.text_config.rms_norm_eps, |
| ) |
|
|
| |
| projector.norm_weight.data.copy_(orig_projector.norm.weight.data) |
| projector.patch_merger.merging_layer.weight.data.copy_( |
| orig_projector.patch_merger.merging_layer.weight.data |
| ) |
| projector.linear_1.weight.data.copy_(orig_projector.linear_1.weight.data) |
| projector.linear_2.weight.data.copy_(orig_projector.linear_2.weight.data) |
|
|
| |
| vision_module = VisionEncoderFixed(vision_encoder, projector) |
| vision_module.eval() |
|
|
| return vision_module |
|
|
|
|
| def test_vision_module(vision_module, original_model): |
| """Test that the fixed module produces similar output to the original.""" |
| print("\nTesting vision module output consistency...") |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| vision_module = vision_module.to(device).to(torch.bfloat16) |
|
|
| |
| pixel_values = torch.randn(1, 3, FIXED_H, FIXED_W, dtype=torch.bfloat16, device=device) |
|
|
| with torch.no_grad(): |
| |
| fixed_output = vision_module(pixel_values) |
| print(f" Fixed module output shape: {fixed_output.shape}") |
| print(f" Expected: [1, {NUM_MERGED}, {original_model.config.text_config.hidden_size}]") |
|
|
| |
| original_model = original_model.to(device) |
| image_sizes = torch.tensor([[FIXED_H, FIXED_W]], device=device) |
| orig_features = original_model.model.get_image_features( |
| pixel_values=pixel_values, |
| image_sizes=image_sizes, |
| vision_feature_layer=-1, |
| return_dict=True, |
| ) |
| orig_output = torch.cat(orig_features.pooler_output, dim=0).unsqueeze(0) |
| print(f" Original model output shape: {orig_output.shape}") |
|
|
| |
| if fixed_output.shape == orig_output.shape: |
| diff = (fixed_output - orig_output).abs() |
| print(f" Max absolute difference: {diff.max().item():.6f}") |
| print(f" Mean absolute difference: {diff.mean().item():.6f}") |
| print(f" Cosine similarity: {F.cosine_similarity(fixed_output.flatten(), orig_output.flatten(), dim=0).item():.6f}") |
| else: |
| print(f" Shape mismatch! Fixed: {fixed_output.shape}, Original: {orig_output.shape}") |
|
|
| return fixed_output |
|
|
|
|
| def try_torch_export(vision_module): |
| """Attempt torch.export.export() on the vision module.""" |
| print("\n" + "=" * 60) |
| print("ATTEMPTING torch.export.export()") |
| print("=" * 60) |
|
|
| |
| |
| vision_module = vision_module.to("cpu").to(torch.float32) |
| vision_module.eval() |
|
|
| example_input = torch.randn(1, 3, FIXED_H, FIXED_W, dtype=torch.float32) |
|
|
| try: |
| print(" Running torch.export.export() on CPU/float32...") |
| exported = torch.export.export( |
| vision_module, |
| (example_input,), |
| strict=False, |
| ) |
| print(" SUCCESS! torch.export completed!") |
| return exported |
|
|
| except Exception as e: |
| print(f" FAILED: {type(e).__name__}: {e}") |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
|
|
| def export_to_pte(exported_model, vision_module, example_input): |
| """Convert exported model to .pte using XNNPACK backend.""" |
| print("\n" + "=" * 60) |
| print("EXPORTING TO .pte (XNNPACK)") |
| print("=" * 60) |
|
|
| try: |
| from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig |
| from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
|
|
| if not hasattr(exported_model, 'graph_module'): |
| print(" Cannot export non-torch.export model to .pte directly") |
| return None |
|
|
| print(" Running to_edge_transform_and_lower...") |
| edge = to_edge_transform_and_lower( |
| exported_model, |
| compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| partitioner=[XnnpackPartitioner()], |
| ) |
|
|
| print(" Running to_executorch()...") |
| pte = edge.to_executorch() |
|
|
| output_path = "vision_encoder.pte" |
| with open(output_path, "wb") as f: |
| f.write(pte.buffer) |
|
|
| file_size = os.path.getsize(output_path) / (1024 * 1024) |
| print(f" Saved to {output_path} ({file_size:.1f} MB)") |
| return output_path |
|
|
| except ImportError as e: |
| print(f" ExecuTorch import failed: {e}") |
| print(" Make sure executorch is properly installed") |
| return None |
| except Exception as e: |
| print(f" Export failed: {type(e).__name__}: {e}") |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("Vision Encoder Export for ExecuTorch") |
| print(f"Fixed resolution: {FIXED_H}x{FIXED_W}") |
| print(f"Patches: {PATCHES_H}x{PATCHES_W} = {NUM_PATCHES}") |
| print(f"After merge: {MERGED_H}x{MERGED_W} = {NUM_MERGED}") |
| print("=" * 60) |
|
|
| |
| original_model = load_original_model() |
|
|
| |
| print("\nBuilding fixed-size vision module...") |
| vision_module = build_vision_module(original_model) |
| print(f" Vision module parameters: {sum(p.numel() for p in vision_module.parameters())/1e6:.2f}M") |
|
|
| |
| test_vision_module(vision_module, original_model) |
|
|
| |
| del original_model |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
| |
| exported = try_torch_export(vision_module) |
|
|
| if exported is not None: |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| example_input = torch.randn(1, 3, FIXED_H, FIXED_W, dtype=torch.bfloat16, device=device) |
| export_to_pte(exported, vision_module, example_input) |
|
|
| |
| torch.save(vision_module.state_dict(), "vision_encoder_fixed.pt") |
| print(f"\nSaved fixed vision module state dict to vision_encoder_fixed.pt") |
| print("Export script complete!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|