mboss's picture
Initial commit
d945eeb
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from jaxtyping import Float
from torch import Tensor
from sf3d.models.tokenizers.dinov2 import Dinov2Model
from sf3d.models.transformers.attention import Modulation
from sf3d.models.utils import BaseModule
class DINOV2SingleImageTokenizer(BaseModule):
@dataclass
class Config(BaseModule.Config):
pretrained_model_name_or_path: str = "facebook/dinov2-large"
width: int = 512
height: int = 512
modulation_cond_dim: int = 768
cfg: Config
def configure(self) -> None:
self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
for p in self.model.parameters():
p.requires_grad_(False)
self.model.eval()
self.model.set_gradient_checkpointing(False)
# add modulation
modulations = []
for layer in self.model.encoder.layer:
norm1_modulation = Modulation(
self.model.config.hidden_size,
self.cfg.modulation_cond_dim,
zero_init=True,
single_layer=True,
)
norm2_modulation = Modulation(
self.model.config.hidden_size,
self.cfg.modulation_cond_dim,
zero_init=True,
single_layer=True,
)
layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
modulations += [norm1_modulation, norm2_modulation]
self.modulations = nn.ModuleList(modulations)
self.register_buffer(
"image_mean",
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
persistent=False,
)
self.register_buffer(
"image_std",
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
persistent=False,
)
def forward(
self,
images: Float[Tensor, "B *N C H W"],
modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
**kwargs,
) -> Float[Tensor, "B *N Ct Nt"]:
model = self.model
packed = False
if images.ndim == 4:
packed = True
images = images.unsqueeze(1)
if modulation_cond is not None:
assert modulation_cond.ndim == 2
modulation_cond = modulation_cond.unsqueeze(1)
batch_size, n_input_views = images.shape[:2]
images = (images - self.image_mean) / self.image_std
out = model(
rearrange(images, "B N C H W -> (B N) C H W"),
modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
if modulation_cond is not None
else None,
)
local_features = out.last_hidden_state
local_features = local_features.permute(0, 2, 1)
local_features = rearrange(
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
)
if packed:
local_features = local_features.squeeze(1)
return local_features
def detokenize(self, *args, **kwargs):
raise NotImplementedError