|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
from typing import Tuple, List |
|
|
|
import mobile_sam |
|
from mobile_sam.modeling import Sam |
|
from mobile_sam.utils.amg import calculate_stability_score |
|
|
|
|
|
class ImageEncoderOnnxModel(nn.Module): |
|
""" |
|
This model should not be called directly, but is used in ONNX export. |
|
It combines the image encoder of Sam, with some functions modified to enable |
|
model tracing. Also supports extra options controlling what information. See |
|
the ONNX export script for details. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: Sam, |
|
use_preprocess: bool, |
|
pixel_mean: List[float] = [123.675, 116.28, 103.53], |
|
pixel_std: List[float] = [58.395, 57.12, 57.375], |
|
): |
|
super().__init__() |
|
self.use_preprocess = use_preprocess |
|
self.pixel_mean = torch.tensor(pixel_mean, dtype=torch.float) |
|
self.pixel_std = torch.tensor(pixel_std, dtype=torch.float) |
|
self.image_encoder = model.image_encoder |
|
|
|
@torch.no_grad() |
|
def forward(self, input_image: torch.Tensor): |
|
if self.use_preprocess: |
|
input_image = self.preprocess(input_image) |
|
image_embeddings = self.image_encoder(input_image) |
|
return image_embeddings |
|
|
|
def preprocess(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x = (x - self.pixel_mean) / self.pixel_std |
|
|
|
|
|
x = torch.permute(x, (2, 0, 1)) |
|
|
|
|
|
h, w = x.shape[-2:] |
|
padh = self.image_encoder.img_size - h |
|
padw = self.image_encoder.img_size - w |
|
x = F.pad(x, (0, padw, 0, padh)) |
|
|
|
|
|
x = torch.unsqueeze(x, 0) |
|
return x |
|
|