Spaces:
Runtime error
Runtime error
| # A portable utility module for the demo programs | |
| # %% | |
| import os | |
| import numpy as np | |
| import einops as ein | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| import fast_pytorch_kmeans as fpk | |
| from typing import Literal, Union, List | |
| # %% | |
| # Extract features from a Dino-v2 model | |
| _DINO_V2_MODELS = Literal["dinov2_vits14", "dinov2_vitb14", \ | |
| "dinov2_vitl14", "dinov2_vitg14"] | |
| _DINO_FACETS = Literal["query", "key", "value", "token"] | |
| class DinoV2ExtractFeatures: | |
| """ | |
| Extract features from an intermediate layer in Dino-v2 | |
| """ | |
| def __init__(self, dino_model: _DINO_V2_MODELS, layer: int, | |
| facet: _DINO_FACETS="token", use_cls=False, | |
| norm_descs=True, device: str = "cpu") -> None: | |
| """ | |
| Parameters: | |
| - dino_model: The DINO-v2 model to use | |
| - layer: The layer to extract features from | |
| - facet: "query", "key", or "value" for the attention | |
| facets. "token" for the output of the layer. | |
| - use_cls: If True, the CLS token (first item) is also | |
| included in the returned list of descriptors. | |
| Otherwise, only patch descriptors are used. | |
| - norm_descs: If True, the descriptors are normalized | |
| - device: PyTorch device to use | |
| """ | |
| self.vit_type: str = dino_model | |
| self.dino_model: nn.Module = torch.hub.load( | |
| 'facebookresearch/dinov2', dino_model) | |
| self.device = torch.device(device) | |
| self.dino_model = self.dino_model.eval().to(self.device) | |
| self.layer: int = layer | |
| self.facet = facet | |
| if self.facet == "token": | |
| self.fh_handle = self.dino_model.blocks[self.layer].\ | |
| register_forward_hook( | |
| self._generate_forward_hook()) | |
| else: | |
| self.fh_handle = self.dino_model.blocks[self.layer].\ | |
| attn.qkv.register_forward_hook( | |
| self._generate_forward_hook()) | |
| self.use_cls = use_cls | |
| self.norm_descs = norm_descs | |
| # Hook data | |
| self._hook_out = None | |
| def _generate_forward_hook(self): | |
| def _forward_hook(module, inputs, output): | |
| self._hook_out = output | |
| return _forward_hook | |
| def __call__(self, img: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Parameters: | |
| - img: The input image | |
| """ | |
| with torch.no_grad(): | |
| res = self.dino_model(img) | |
| if self.use_cls: | |
| res = self._hook_out | |
| else: | |
| res = self._hook_out[:, 1:, ...] | |
| if self.facet in ["query", "key", "value"]: | |
| d_len = res.shape[2] // 3 | |
| if self.facet == "query": | |
| res = res[:, :, :d_len] | |
| elif self.facet == "key": | |
| res = res[:, :, d_len:2*d_len] | |
| else: | |
| res = res[:, :, 2*d_len:] | |
| if self.norm_descs: | |
| res = F.normalize(res, dim=-1) | |
| self._hook_out = None # Reset the hook | |
| return res | |
| def __del__(self): | |
| self.fh_handle.remove() | |
| # %% | |
| # VLAD global descriptor implementation | |
| class VLAD: | |
| """ | |
| An implementation of VLAD algorithm given database and query | |
| descriptors. | |
| Constructor arguments: | |
| - num_clusters: Number of cluster centers for VLAD | |
| - desc_dim: Descriptor dimension. If None, then it is | |
| inferred when running `fit` method. | |
| - intra_norm: If True, intra normalization is applied | |
| when constructing VLAD | |
| - norm_descs: If True, the given descriptors are | |
| normalized before training and predicting | |
| VLAD descriptors. Different from the | |
| `intra_norm` argument. | |
| - dist_mode: Distance mode for KMeans clustering for | |
| vocabulary (not residuals). Must be in | |
| {'euclidean', 'cosine'}. | |
| - vlad_mode: Mode for descriptor assignment (to cluster | |
| centers) in VLAD generation. Must be in | |
| {'soft', 'hard'} | |
| - soft_temp: Temperature for softmax (if 'vald_mode' is | |
| 'soft') for assignment | |
| - cache_dir: Directory to cache the VLAD vectors. If | |
| None, then no caching is done. If a str, | |
| then it is assumed as the folder path. Use | |
| absolute paths. | |
| Notes: | |
| - Arandjelovic, Relja, and Andrew Zisserman. "All about VLAD." | |
| Proceedings of the IEEE conference on Computer Vision and | |
| Pattern Recognition. 2013. | |
| """ | |
| def __init__(self, num_clusters: int, | |
| desc_dim: Union[int, None]=None, | |
| intra_norm: bool=True, norm_descs: bool=True, | |
| dist_mode: str="cosine", vlad_mode: str="hard", | |
| soft_temp: float=1.0, | |
| cache_dir: Union[str,None]=None) -> None: | |
| self.num_clusters = num_clusters | |
| self.desc_dim = desc_dim | |
| self.intra_norm = intra_norm | |
| self.norm_descs = norm_descs | |
| self.mode = dist_mode | |
| self.vlad_mode = str(vlad_mode).lower() | |
| assert self.vlad_mode in ['soft', 'hard'] | |
| self.soft_temp = soft_temp | |
| # Set in the training phase | |
| self.c_centers = None | |
| self.kmeans = None | |
| # Set the caching | |
| self.cache_dir = cache_dir | |
| if self.cache_dir is not None: | |
| self.cache_dir = os.path.abspath(os.path.expanduser( | |
| self.cache_dir)) | |
| if not os.path.exists(self.cache_dir): | |
| os.makedirs(self.cache_dir) | |
| print(f"Created cache directory: {self.cache_dir}") | |
| else: | |
| print("Warning: Cache directory already exists: " \ | |
| f"{self.cache_dir}") | |
| else: | |
| print("VLAD caching is disabled.") | |
| def can_use_cache_vlad(self): | |
| """ | |
| Checks if the cache directory is a valid cache directory. | |
| For it to be valid, it must exist and should at least | |
| include the cluster centers file. | |
| Returns: | |
| - True if the cache directory is valid | |
| - False if | |
| - the cache directory doesn't exist | |
| - exists but doesn't contain the cluster centers | |
| - no caching is set in constructor | |
| """ | |
| if self.cache_dir is None: | |
| return False | |
| if not os.path.exists(self.cache_dir): | |
| return False | |
| if os.path.exists(f"{self.cache_dir}/c_centers.pt"): | |
| return True | |
| else: | |
| return False | |
| def can_use_cache_ids(self, | |
| cache_ids: Union[List[str], str, None], | |
| only_residuals: bool=False) -> bool: | |
| """ | |
| Checks if the given cache IDs exist in the cache directory | |
| and returns True if all of them exist. | |
| The cache is stored in the following files: | |
| - c_centers.pt: Cluster centers | |
| - `cache_id`_r.pt: Residuals for VLAD | |
| - `cache_id`_l.pt: Labels for VLAD (hard assignment) | |
| - `cache_id`_s.pt: Soft assignment for VLAD | |
| The function returns False if cache cannot be used or if | |
| any of the cache IDs are not found. If all cache IDs are | |
| found, then True is returned. | |
| This function is mainly for use outside the VLAD class. | |
| """ | |
| if not self.can_use_cache_vlad(): | |
| return False | |
| if cache_ids is None: | |
| return False | |
| if isinstance(cache_ids, str): | |
| cache_ids = [cache_ids] | |
| for cache_id in cache_ids: | |
| if not os.path.exists( | |
| f"{self.cache_dir}/{cache_id}_r.pt"): | |
| return False | |
| if self.vlad_mode == "hard" and not os.path.exists( | |
| f"{self.cache_dir}/{cache_id}_l.pt") and not \ | |
| only_residuals: | |
| return False | |
| if self.vlad_mode == "soft" and not os.path.exists( | |
| f"{self.cache_dir}/{cache_id}_s.pt") and not \ | |
| only_residuals: | |
| return False | |
| return True | |
| # Generate cluster centers | |
| def fit(self, train_descs: Union[np.ndarray, torch.Tensor, None]): | |
| """ | |
| Using the training descriptors, generate the cluster | |
| centers (vocabulary). Function expects all descriptors in | |
| a single list (see `fit_and_generate` for a batch of | |
| images). | |
| If the cache directory is valid, then retrieves cluster | |
| centers from there (the `train_descs` are ignored). | |
| Otherwise, stores the cluster centers in the cache | |
| directory (if using caching). | |
| Parameters: | |
| - train_descs: Training descriptors of shape | |
| [num_train_desc, desc_dim]. If None, then | |
| caching should be valid (else ValueError). | |
| """ | |
| # Clustering to create vocabulary | |
| self.kmeans = fpk.KMeans(self.num_clusters, mode=self.mode) | |
| # Check if cache exists | |
| if self.can_use_cache_vlad(): | |
| print("Using cached cluster centers") | |
| self.c_centers = torch.load( | |
| f"{self.cache_dir}/c_centers.pt") | |
| self.kmeans.centroids = self.c_centers | |
| if self.desc_dim is None: | |
| self.desc_dim = self.c_centers.shape[1] | |
| print(f"Desc dim set to {self.desc_dim}") | |
| else: | |
| if train_descs is None: | |
| raise ValueError("No training descriptors given") | |
| if type(train_descs) == np.ndarray: | |
| train_descs = torch.from_numpy(train_descs).\ | |
| to(torch.float32) | |
| if self.desc_dim is None: | |
| self.desc_dim = train_descs.shape[1] | |
| if self.norm_descs: | |
| train_descs = F.normalize(train_descs) | |
| self.kmeans.fit(train_descs) | |
| self.c_centers = self.kmeans.centroids | |
| if self.cache_dir is not None: | |
| print("Caching cluster centers") | |
| torch.save(self.c_centers, | |
| f"{self.cache_dir}/c_centers.pt") | |
| def fit_and_generate(self, | |
| train_descs: Union[np.ndarray, torch.Tensor]) \ | |
| -> torch.Tensor: | |
| """ | |
| Given a batch of descriptors over images, `fit` the VLAD | |
| and generate the global descriptors for the training | |
| images. Use only when there are a fixed number of | |
| descriptors in each image. | |
| Parameters: | |
| - train_descs: Training image descriptors of shape | |
| [num_imgs, num_descs, desc_dim]. There are | |
| 'num_imgs' images, each image has | |
| 'num_descs' descriptors and each | |
| descriptor is 'desc_dim' dimensional. | |
| Returns: | |
| - train_vlads: The VLAD vectors of all training images. | |
| Shape: [num_imgs, num_clusters*desc_dim] | |
| """ | |
| # Generate vocabulary | |
| all_descs = ein.rearrange(train_descs, "n k d -> (n k) d") | |
| self.fit(all_descs) | |
| # For each image, stack VLAD | |
| return torch.stack([self.generate(tr) for tr in train_descs]) | |
| def generate(self, query_descs: Union[np.ndarray, torch.Tensor], | |
| cache_id: Union[str, None]=None) -> torch.Tensor: | |
| """ | |
| Given the query descriptors, generate a VLAD vector. Call | |
| `fit` before using this method. Use this for only single | |
| images and with descriptors stacked. Use function | |
| `generate_multi` for multiple images. | |
| Parameters: | |
| - query_descs: Query descriptors of shape [n_q, desc_dim] | |
| where 'n_q' is number of 'desc_dim' | |
| dimensional descriptors in a query image. | |
| - cache_id: If not None, then the VLAD vector is | |
| constructed using the residual and labels | |
| from this file. | |
| Returns: | |
| - n_vlas: Normalized VLAD: [num_clusters*desc_dim] | |
| """ | |
| residuals = self.generate_res_vec(query_descs, cache_id) | |
| # Un-normalized VLAD vector: [c*d,] | |
| un_vlad = torch.zeros(self.num_clusters * self.desc_dim) | |
| if self.vlad_mode == 'hard': | |
| # Get labels for assignment of descriptors | |
| if cache_id is not None and self.can_use_cache_vlad() \ | |
| and os.path.isfile( | |
| f"{self.cache_dir}/{cache_id}_l.pt"): | |
| labels = torch.load( | |
| f"{self.cache_dir}/{cache_id}_l.pt") | |
| else: | |
| labels = self.kmeans.predict(query_descs) # [q] | |
| if cache_id is not None and self.can_use_cache_vlad(): | |
| torch.save(labels, | |
| f"{self.cache_dir}/{cache_id}_l.pt") | |
| # Create VLAD from residuals and labels | |
| used_clusters = set(labels.numpy()) | |
| for k in used_clusters: | |
| # Sum of residuals for the descriptors in the cluster | |
| # Shape:[q, c, d] -> [q', d] -> [d] | |
| cd_sum = residuals[labels==k,k].sum(dim=0) | |
| if self.intra_norm: | |
| cd_sum = F.normalize(cd_sum, dim=0) | |
| un_vlad[k*self.desc_dim:(k+1)*self.desc_dim] = cd_sum | |
| else: # Soft cluster assignment | |
| # Cosine similarity: 1 = close, -1 = away | |
| if cache_id is not None and self.can_use_cache_vlad() \ | |
| and os.path.isfile( | |
| f"{self.cache_dir}/{cache_id}_s.pt"): | |
| soft_assign = torch.load( | |
| f"{self.cache_dir}/{cache_id}_s.pt") | |
| else: | |
| cos_sims = F.cosine_similarity( # [q, c] | |
| ein.rearrange(query_descs, "q d -> q 1 d"), | |
| ein.rearrange(self.c_centers, "c d -> 1 c d"), | |
| dim=2) | |
| soft_assign = F.softmax(self.soft_temp*cos_sims, | |
| dim=1) | |
| if cache_id is not None and self.can_use_cache_vlad(): | |
| torch.save(soft_assign, | |
| f"{self.cache_dir}/{cache_id}_s.pt") | |
| # Soft assignment scores (as probabilities): [q, c] | |
| for k in range(0, self.num_clusters): | |
| w = ein.rearrange(soft_assign[:, k], "q -> q 1 1") | |
| # Sum of residuals for all descriptors (for cluster k) | |
| cd_sum = ein.rearrange(w * residuals, | |
| "q c d -> (q c) d").sum(dim=0) # [d] | |
| if self.intra_norm: | |
| cd_sum = F.normalize(cd_sum, dim=0) | |
| un_vlad[k*self.desc_dim:(k+1)*self.desc_dim] = cd_sum | |
| # Normalize the VLAD vector | |
| n_vlad = F.normalize(un_vlad, dim=0) | |
| return n_vlad | |
| def generate_multi(self, | |
| multi_query: Union[np.ndarray, torch.Tensor, list], | |
| cache_ids: Union[List[str], None]=None) \ | |
| -> Union[torch.Tensor, list]: | |
| """ | |
| Given query descriptors from multiple images, generate | |
| the VLAD for them. | |
| Parameters: | |
| - multi_query: Descriptors of shape [n_imgs, n_kpts, d] | |
| There are 'n_imgs' and each image has | |
| 'n_kpts' keypoints, with 'd' dimensional | |
| descriptor each. If a List (can then have | |
| different number of keypoints in each | |
| image), then the result is also a list. | |
| - cache_ids: Cache IDs for the VLAD vectors. If None, | |
| then no caching is done (stored or | |
| retrieved). If a list, then the length | |
| should be 'n_imgs' (one per image). | |
| Returns: | |
| - multi_res: VLAD descriptors for the queries | |
| """ | |
| if cache_ids is None: | |
| cache_ids = [None] * len(multi_query) | |
| res = [self.generate(q, c) \ | |
| for (q, c) in zip(multi_query, cache_ids)] | |
| try: # Most likely pytorch | |
| res = torch.stack(res) | |
| except TypeError: | |
| try: # Otherwise numpy | |
| res = np.stack(res) | |
| except TypeError: | |
| pass # Let it remain as a list | |
| return res | |
| def generate_res_vec(self, | |
| query_descs: Union[np.ndarray, torch.Tensor], | |
| cache_id: Union[str, None]=None) -> torch.Tensor: | |
| """ | |
| Given the query descriptors, generate a VLAD vector. Call | |
| `fit` before using this method. Use this for only single | |
| images and with descriptors stacked. Use function | |
| `generate_multi` for multiple images. | |
| Parameters: | |
| - query_descs: Query descriptors of shape [n_q, desc_dim] | |
| where 'n_q' is number of 'desc_dim' | |
| dimensional descriptors in a query image. | |
| - cache_id: If not None, then the VLAD vector is | |
| constructed using the residual and labels | |
| from this file. | |
| Returns: | |
| - residuals: Residual vector: shape [n_q, n_c, d] | |
| """ | |
| assert self.kmeans is not None | |
| assert self.c_centers is not None | |
| # Compute residuals (all query to cluster): [q, c, d] | |
| if cache_id is not None and self.can_use_cache_vlad() and \ | |
| os.path.isfile(f"{self.cache_dir}/{cache_id}_r.pt"): | |
| residuals = torch.load( | |
| f"{self.cache_dir}/{cache_id}_r.pt") | |
| else: | |
| if type(query_descs) == np.ndarray: | |
| query_descs = torch.from_numpy(query_descs)\ | |
| .to(torch.float32) | |
| if self.norm_descs: | |
| query_descs = F.normalize(query_descs) | |
| residuals = ein.rearrange(query_descs, "q d -> q 1 d") \ | |
| - ein.rearrange(self.c_centers, "c d -> 1 c d") | |
| if cache_id is not None and self.can_use_cache_vlad(): | |
| cid_dir = f"{self.cache_dir}/"\ | |
| f"{os.path.split(cache_id)[0]}" | |
| if not os.path.isdir(cid_dir): | |
| os.makedirs(cid_dir) | |
| print(f"Created directory: {cid_dir}") | |
| torch.save(residuals, | |
| f"{self.cache_dir}/{cache_id}_r.pt") | |
| # print("residuals",residuals.shape) | |
| return residuals | |
| def generate_multi_res_vec(self, | |
| multi_query: Union[np.ndarray, torch.Tensor, list], | |
| cache_ids: Union[List[str], None]=None) \ | |
| -> Union[torch.Tensor, list]: | |
| """ | |
| Given query descriptors from multiple images, generate | |
| the VLAD for them. | |
| Parameters: | |
| - multi_query: Descriptors of shape [n_imgs, n_kpts, d] | |
| There are 'n_imgs' and each image has | |
| 'n_kpts' keypoints, with 'd' dimensional | |
| descriptor each. If a List (can then have | |
| different number of keypoints in each | |
| image), then the result is also a list. | |
| - cache_ids: Cache IDs for the VLAD vectors. If None, | |
| then no caching is done (stored or | |
| retrieved). If a list, then the length | |
| should be 'n_imgs' (one per image). | |
| Returns: | |
| - multi_res: VLAD descriptors for the queries | |
| """ | |
| if cache_ids is None: | |
| cache_ids = [None] * len(multi_query) | |
| res = [self.generate_res_vec(q, c) \ | |
| for (q, c) in zip(multi_query, cache_ids)] | |
| try: # Most likely pytorch | |
| res = torch.stack(res) | |
| except TypeError: | |
| try: # Otherwise numpy | |
| res = np.stack(res) | |
| except TypeError: | |
| pass # Let it remain as a list | |
| return res | |