Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from jaxtyping import Float | |
from torch import Tensor | |
from sf3d.models.tokenizers.dinov2 import Dinov2Model | |
from sf3d.models.transformers.attention import Modulation | |
from sf3d.models.utils import BaseModule | |
class DINOV2SingleImageTokenizer(BaseModule): | |
class Config(BaseModule.Config): | |
pretrained_model_name_or_path: str = "facebook/dinov2-large" | |
width: int = 512 | |
height: int = 512 | |
modulation_cond_dim: int = 768 | |
cfg: Config | |
def configure(self) -> None: | |
self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path) | |
for p in self.model.parameters(): | |
p.requires_grad_(False) | |
self.model.eval() | |
self.model.set_gradient_checkpointing(False) | |
# add modulation | |
modulations = [] | |
for layer in self.model.encoder.layer: | |
norm1_modulation = Modulation( | |
self.model.config.hidden_size, | |
self.cfg.modulation_cond_dim, | |
zero_init=True, | |
single_layer=True, | |
) | |
norm2_modulation = Modulation( | |
self.model.config.hidden_size, | |
self.cfg.modulation_cond_dim, | |
zero_init=True, | |
single_layer=True, | |
) | |
layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation) | |
modulations += [norm1_modulation, norm2_modulation] | |
self.modulations = nn.ModuleList(modulations) | |
self.register_buffer( | |
"image_mean", | |
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1), | |
persistent=False, | |
) | |
self.register_buffer( | |
"image_std", | |
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1), | |
persistent=False, | |
) | |
def forward( | |
self, | |
images: Float[Tensor, "B *N C H W"], | |
modulation_cond: Optional[Float[Tensor, "B *N Cc"]], | |
**kwargs, | |
) -> Float[Tensor, "B *N Ct Nt"]: | |
model = self.model | |
packed = False | |
if images.ndim == 4: | |
packed = True | |
images = images.unsqueeze(1) | |
if modulation_cond is not None: | |
assert modulation_cond.ndim == 2 | |
modulation_cond = modulation_cond.unsqueeze(1) | |
batch_size, n_input_views = images.shape[:2] | |
images = (images - self.image_mean) / self.image_std | |
out = model( | |
rearrange(images, "B N C H W -> (B N) C H W"), | |
modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc") | |
if modulation_cond is not None | |
else None, | |
) | |
local_features = out.last_hidden_state | |
local_features = local_features.permute(0, 2, 1) | |
local_features = rearrange( | |
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size | |
) | |
if packed: | |
local_features = local_features.squeeze(1) | |
return local_features | |
def detokenize(self, *args, **kwargs): | |
raise NotImplementedError | |