Spaces:
Runtime error
Runtime error
| import cv2 | |
| import numpy as np | |
| import torch | |
| from typing import Union | |
| from einops import rearrange, repeat | |
| from PIL import Image | |
| from safetensors.torch import load_file as load_sft | |
| from torch import nn | |
| from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel | |
| from ..utils.process_util import print_load_warning | |
| class ReduxImageEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| redux_path: str, | |
| siglip_path: str = "google/siglip-so400m-patch14-384", | |
| redux_dim: int = 1152, | |
| txt_in_features: int = 4096, | |
| device: Union[str, torch.device, None] = None, | |
| ) -> None: | |
| super().__init__() | |
| self.redux_dim = redux_dim | |
| self.redux_up = nn.Linear(redux_dim, txt_in_features * 3) | |
| self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features) | |
| sd = load_sft(redux_path) | |
| missing, unexpected = self.load_state_dict(sd, strict=False, assign=True) | |
| print_load_warning(missing, unexpected) | |
| self.siglip = SiglipVisionModel.from_pretrained(siglip_path) | |
| self.normalize = SiglipImageProcessor.from_pretrained(siglip_path) | |
| self.to(device) | |
| def __call__(self, x: Image.Image, device: Union[str, torch.device, None] = None, dtype: Union[str, torch.dtype, None] = None) -> torch.Tensor: | |
| if isinstance(device, str): | |
| device = torch.device(device) | |
| if isinstance(dtype, str): | |
| dtype = torch.dtype(dtype) | |
| if device is None: | |
| device = next(self.parameters()).device | |
| if dtype is None: | |
| dtype = next(self.parameters()).dtype | |
| imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True) | |
| _encoded_x = self.siglip(**imgs.to(device=device, dtype=dtype)).last_hidden_state | |
| projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x))) | |
| return projected_x |