MobileSAM / mobile_sam_encoder_onnx /onnx_image_encoder.py
Acly's picture
Export script for image encoder
110a69d
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple, List
import mobile_sam
from mobile_sam.modeling import Sam
from mobile_sam.utils.amg import calculate_stability_score
class ImageEncoderOnnxModel(nn.Module):
"""
This model should not be called directly, but is used in ONNX export.
It combines the image encoder of Sam, with some functions modified to enable
model tracing. Also supports extra options controlling what information. See
the ONNX export script for details.
"""
def __init__(
self,
model: Sam,
use_preprocess: bool,
pixel_mean: List[float] = [123.675, 116.28, 103.53],
pixel_std: List[float] = [58.395, 57.12, 57.375],
):
super().__init__()
self.use_preprocess = use_preprocess
self.pixel_mean = torch.tensor(pixel_mean, dtype=torch.float)
self.pixel_std = torch.tensor(pixel_std, dtype=torch.float)
self.image_encoder = model.image_encoder
@torch.no_grad()
def forward(self, input_image: torch.Tensor):
if self.use_preprocess:
input_image = self.preprocess(input_image)
image_embeddings = self.image_encoder(input_image)
return image_embeddings
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# permute channels
x = torch.permute(x, (2, 0, 1))
# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
# expand channels
x = torch.unsqueeze(x, 0)
return x