| """TIPSv2 DPT dense prediction model for HuggingFace.""" |
|
|
| import importlib |
| import os |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| from huggingface_hub import hf_hub_download |
| from transformers import AutoModel, PreTrainedModel |
|
|
| from .configuration_dpt import TIPSv2DPTConfig |
|
|
| _this_dir = Path(__file__).parent |
| _sibling_cache = {} |
|
|
|
|
| def _load_sibling(name, repo_id=None): |
| if name in _sibling_cache: |
| return _sibling_cache[name] |
| path = _this_dir / f"{name}.py" |
| if not path.exists() and repo_id: |
| path = Path(hf_hub_download(repo_id, f"{name}.py")) |
| spec = importlib.util.spec_from_file_location(name, str(path)) |
| mod = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(mod) |
| _sibling_cache[name] = mod |
| return mod |
|
|
|
|
| @dataclass |
| class TIPSv2DPTOutput: |
| depth: Optional[torch.Tensor] = None |
| normals: Optional[torch.Tensor] = None |
| segmentation: Optional[torch.Tensor] = None |
|
|
|
|
| class TIPSv2DPTModel(PreTrainedModel): |
| """TIPSv2 DPT dense prediction model (depth, normals, segmentation). |
| |
| The backbone is loaded automatically from the base TIPSv2 model repo. |
| |
| Usage:: |
| |
| model = AutoModel.from_pretrained("google/tipsv2-l14-dpt", trust_remote_code=True) |
| model.eval().cuda() |
| |
| outputs = model(pixel_values) |
| outputs.depth # (B, 1, H, W) |
| outputs.normals # (B, 3, H, W) |
| outputs.segmentation # (B, 150, H, W) |
| |
| # Individual tasks |
| depth = model.predict_depth(pixel_values) |
| normals = model.predict_normals(pixel_values) |
| seg = model.predict_segmentation(pixel_values) |
| """ |
|
|
| config_class = TIPSv2DPTConfig |
| _no_split_modules = [] |
| _supports_cache_class = False |
| _tied_weights_keys = [] |
|
|
| @property |
| def all_tied_weights_keys(self): |
| return {} |
|
|
| def __init__(self, config: TIPSv2DPTConfig): |
| super().__init__(config) |
|
|
| repo_id = getattr(config, "_name_or_path", None) |
| dpt_mod = _load_sibling("dpt_head", repo_id) |
|
|
| ppc = tuple(config.post_process_channels) |
|
|
| self.depth_head = dpt_mod.DPTDepthHead( |
| input_embed_dim=config.embed_dim, channels=config.channels, |
| post_process_channels=ppc, readout_type=config.readout_type, |
| num_depth_bins=config.num_depth_bins, |
| min_depth=config.min_depth, max_depth=config.max_depth, |
| ) |
| self.normals_head = dpt_mod.DPTNormalsHead( |
| input_embed_dim=config.embed_dim, channels=config.channels, |
| post_process_channels=ppc, readout_type=config.readout_type, |
| ) |
| self.segmentation_head = dpt_mod.DPTSegmentationHead( |
| input_embed_dim=config.embed_dim, channels=config.channels, |
| post_process_channels=ppc, readout_type=config.readout_type, |
| num_classes=config.num_seg_classes, |
| ) |
| self._backbone = None |
|
|
| def _get_backbone(self): |
| if self._backbone is None: |
| self._backbone = AutoModel.from_pretrained(self.config.backbone_repo, trust_remote_code=True) |
| self._backbone.to(self.device).eval() |
| return self._backbone.vision_encoder |
|
|
| def _extract_intermediate(self, pixel_values): |
| backbone = self._get_backbone() |
| intermediate = backbone.get_intermediate_layers( |
| pixel_values, n=self.config.block_indices, |
| reshape=True, return_class_token=True, norm=True, |
| ) |
| return [(cls_tok, patch_feat) for patch_feat, cls_tok in intermediate] |
|
|
| @torch.no_grad() |
| def predict_depth(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| """Predict depth map. Returns (B, 1, H, W).""" |
| pixel_values = pixel_values.to(self.device) |
| h, w = pixel_values.shape[2:] |
| dpt_inputs = self._extract_intermediate(pixel_values) |
| return self.depth_head(dpt_inputs, image_size=(h, w)) |
|
|
| @torch.no_grad() |
| def predict_normals(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| """Predict surface normals. Returns (B, 3, H, W).""" |
| pixel_values = pixel_values.to(self.device) |
| h, w = pixel_values.shape[2:] |
| dpt_inputs = self._extract_intermediate(pixel_values) |
| return self.normals_head(dpt_inputs, image_size=(h, w)) |
|
|
| @torch.no_grad() |
| def predict_segmentation(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| """Predict semantic segmentation (ADE20K). Returns (B, 150, H, W).""" |
| pixel_values = pixel_values.to(self.device) |
| h, w = pixel_values.shape[2:] |
| dpt_inputs = self._extract_intermediate(pixel_values) |
| return self.segmentation_head(dpt_inputs, image_size=(h, w)) |
|
|
| def forward(self, pixel_values: torch.Tensor) -> TIPSv2DPTOutput: |
| """Run all three tasks. Returns TIPSv2DPTOutput.""" |
| pixel_values = pixel_values.to(self.device) |
| h, w = pixel_values.shape[2:] |
| dpt_inputs = self._extract_intermediate(pixel_values) |
| return TIPSv2DPTOutput( |
| depth=self.depth_head(dpt_inputs, image_size=(h, w)), |
| normals=self.normals_head(dpt_inputs, image_size=(h, w)), |
| segmentation=self.segmentation_head(dpt_inputs, image_size=(h, w)), |
| ) |
|
|