ReubenSun's picture
1
2ac1c2d
import random
import torch
from torch import nn
import numpy as np
import re
from einops import rearrange
from dataclasses import dataclass
from torchvision import transforms
from diffusers.models.modeling_utils import ModelMixin
from transformers import AutoImageProcessor, AutoModel
from transformers.utils import ModelOutput
from typing import Iterable, Optional, Union, List
import step1x3d_geometry
from step1x3d_geometry.utils.typing import *
from .base import BaseVisualEncoder, ImageType
from .dinov2.modeling_dinov2 import Dinov2Model
from .dinov2.modeling_conditional_dinov2 import ConditionalDinov2Model
from .dinov2_with_registers.modeling_dinov2_with_registers import (
Dinov2WithRegistersModel,
)
class DINOEmbedOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
@step1x3d_geometry.register("dinov2-encoder")
class Dinov2Encoder(BaseVisualEncoder, ModelMixin):
@dataclass
class Config(BaseVisualEncoder.Config):
pretrained_model_name_or_path: Optional[str] = (
None # the pretrained model name or path for condition model
)
pretrained_dino_name_or_path: Optional[str] = (
None # the pretrained model name or path for dino
)
freeze_modulation_dino: bool = False
enable_gradient_checkpointing: bool = False
image_size: int = 224
dino_type: Optional[str] = None
kwargs: Optional[dict] = None
cfg: Config
def configure(self) -> None:
super().configure()
# Load the DINOV2 model and processor
if not self.cfg.encode_camera:
if self.cfg.pretrained_dino_name_or_path is not None:
self.cfg.dino_type = f"facebook/{self.cfg.pretrained_dino_name_or_path.split('facebook--')[-1].split('/')[0]}"
if self.cfg.kwargs is not None:
self.dino_model: Dinov2Model = AutoModel.from_pretrained(
self.cfg.pretrained_dino_name_or_path, **self.cfg.kwargs
)
else:
self.dino_model: Dinov2Model = AutoModel.from_pretrained(
self.cfg.pretrained_dino_name_or_path
)
else:
if (
self.cfg.pretrained_model_name_or_path is None
): # default to load Dinov2-base model
assert (
self.cfg.dino_type is not None
), "The dino_type should be provided"
print(f"Loading Dinov2 model from {self.cfg.dino_type}")
if "reg" in self.cfg.dino_type:
self.dino_model: Dinov2WithRegistersModel = (
Dinov2WithRegistersModel(
config=Dinov2WithRegistersModel.config_class.from_pretrained(
self.cfg.dino_type,
)
)
)
else:
self.dino_model: Dinov2Model = Dinov2Model(
config=Dinov2Model.config_class.from_pretrained(
self.dino_type,
)
)
elif "dinov2base" in self.cfg.pretrained_model_name_or_path:
print("Loading Dinov2 model from facebook/dinov2-base")
self.cfg.dino_type = "facebook/dinov2-base"
self.dino_model: Dinov2Model = Dinov2Model(
config=Dinov2Model.config_class.from_pretrained(
"facebook/dinov2-base",
)
)
elif "dinov2regbase" in self.cfg.pretrained_model_name_or_path:
print(
"Loading Dinov2 model from facebook/dinov2-with-registers-base"
)
self.cfg.dino_type = "facebook/dinov2-with-registers-base"
self.dino_model: Dinov2WithRegistersModel = (
Dinov2WithRegistersModel(
config=Dinov2WithRegistersModel.config_class.from_pretrained(
"facebook/dinov2-with-registers-base",
)
)
)
elif "dinov2reglarge" in self.cfg.pretrained_model_name_or_path:
print(
"Loading Dinov2 model from facebook/dinov2-with-registers-large"
)
self.cfg.dino_type = "facebook/dinov2-with-registers-large"
self.dino_model: Dinov2WithRegistersModel = (
Dinov2WithRegistersModel(
config=Dinov2WithRegistersModel.config_class.from_pretrained(
"facebook/dinov2-with-registers-large",
)
)
)
else:
raise ValueError(
f"Unknown Dinov2 model: {self.cfg.pretrained_model_name_or_path}"
)
else:
# dino
conditional_vit_config = (
ConditionalDinov2Model.config_class.from_pretrained(
self.cfg.pretrained_dino_name_or_path,
)
)
conditional_vit_config.modulation_dim = self.cfg.camera_embeds_dim
self.dino_model: ConditionalDinov2Model = (
ConditionalDinov2Model.from_pretrained(
self.cfg.pretrained_dino_name_or_path, config=conditional_vit_config
)
)
self.image_preprocess_dino = AutoImageProcessor.from_pretrained(
self.cfg.dino_type
if self.cfg.pretrained_dino_name_or_path is None
else self.cfg.pretrained_dino_name_or_path
)
self.transform_dino = transforms.Compose(
[
transforms.Resize(
self.cfg.image_size,
transforms.InterpolationMode.BICUBIC,
antialias=True,
),
transforms.CenterCrop(
self.cfg.image_size
), # crop a (image_size, image_size) square
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
if self.cfg.enable_gradient_checkpointing:
self.dino_model.encoder.gradient_checkpointing = True
if self.cfg.zero_uncond_embeds:
self.empty_image_embeds = torch.zeros(
(
self.cfg.n_views,
(self.cfg.image_size // 14) ** 2 + 1,
self.dino_model.config.hidden_size,
)
).detach()
else:
if self.cfg.encode_camera:
self.empty_image_embeds = self.encode_image_dino(
torch.zeros(
self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3
),
self.cameras[: self.cfg.n_views],
).detach()
else:
self.empty_image_embeds = self.encode_image_dino(
torch.zeros(
self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3
)
).detach()
# freeze the dino model parameters
self.dino_model.eval()
for k, p in self.dino_model.named_parameters():
ks = k.split(".")
if (
"mod_norm1" in ks
or "mod_norm2" in ks
and not self.cfg.freeze_modulation_dino
):
p.requires_grad_(not self.cfg.freeze_modulation_dino)
else:
p.requires_grad_(False)
# load pretrained_model_name_or_path
if self.cfg.pretrained_model_name_or_path is not None:
print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}")
ckpt = torch.load(
self.cfg.pretrained_model_name_or_path, map_location="cpu"
)["state_dict"]
pretrained_model_ckpt = {}
for k, v in ckpt.items():
if k.startswith("visual_condition."):
pretrained_model_ckpt[k.replace("visual_condition.", "")] = v
self.load_state_dict(pretrained_model_ckpt, strict=True)
def encode_image_dino(
self,
images: Iterable[Optional[ImageType]],
cameras: Optional[torch.Tensor] = None,
force_none_camera_embeds: bool = False,
return_dict: bool = False,
**kwargs,
) -> torch.FloatTensor:
camera_embeds = None
if isinstance(images, (np.ndarray, torch.Tensor)): # for training process
assert (
images.min() >= 0.0 and images.max() <= 1.0
), "The pixel values should be in the range of [0, 1]"
if self.cfg.encode_camera:
assert cameras is not None, "The cameras should be provided"
camera_embeds = self.encode_camera(cameras)
pixel_values = self.transform_dino(images.permute(0, 3, 1, 2))
else: # for inference process
if self.cfg.encode_camera:
if cameras is None:
bs = len(images) // self.cfg.n_views
cameras = (
self.cameras[: self.cfg.n_views]
.repeat(bs, 1, 1)
.to(self.dino_model.device)
)
camera_embeds = self.encode_camera(cameras)
pixel_values = self.image_preprocess_dino.preprocess(
images,
return_tensors="pt",
do_rescale=True,
do_resize=True,
size=self.cfg.image_size,
crop_size=self.cfg.image_size,
).pixel_values
if force_none_camera_embeds:
camera_embeds = None
if pixel_values.ndim == 4:
pixel_values = pixel_values.unsqueeze(1)
if camera_embeds is not None:
camera_embeds = camera_embeds.unsqueeze(1)
if self.cfg.encode_camera and camera_embeds is not None:
vision_outputs = self.dino_model(
rearrange(
pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W"
),
condition=rearrange(camera_embeds, "B N C -> (B N) C"),
)
else:
vision_outputs = self.dino_model(
rearrange(
pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W"
),
)
if return_dict:
# dino
dino_embeds_dict = DINOEmbedOutput(
last_hidden_state=vision_outputs.last_hidden_state,
pooler_output=vision_outputs.pooler_output,
)
return dino_embeds_dict
else:
return vision_outputs.last_hidden_state
def encode_image(
self,
images: Iterable[Optional[ImageType]],
cameras: Optional[torch.Tensor] = None,
force_none_camera_embeds: bool = False,
return_dict: bool = False,
**kwargs,
) -> torch.FloatTensor:
dino_embeds = self.encode_image_dino(images, cameras)
if (
self.dino_model.__class__.__name__ == "Dinov2WithRegistersModel"
): # x_norm_clstoken, x_norm_regtokens, x_norm_patchtokens
dino_embeds = torch.cat(
[
dino_embeds[:, :1],
dino_embeds[:, self.dino_model.config.num_register_tokens + 1 :],
],
dim=1,
)
return dino_embeds