| from abc import ABC, abstractmethod |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| from PIL import Image |
| from transformers import BatchEncoding, BatchFeature, ProcessorMixin |
|
|
| from .torch_utils import get_torch_device |
|
|
|
|
| class BaseVisualRetrieverProcessor(ABC, ProcessorMixin): |
| """ |
| Base class for visual retriever processors. |
| """ |
|
|
| @abstractmethod |
| def process_images( |
| self, |
| images: List[Image.Image], |
| ) -> Union[BatchFeature, BatchEncoding]: |
| pass |
|
|
| @abstractmethod |
| def process_queries( |
| self, |
| queries: List[str], |
| max_length: int = 50, |
| suffix: Optional[str] = None, |
| ) -> Union[BatchFeature, BatchEncoding]: |
| pass |
|
|
| @abstractmethod |
| def score( |
| self, |
| qs: List[torch.Tensor], |
| ps: List[torch.Tensor], |
| device: Optional[Union[str, torch.device]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| pass |
|
|
| @staticmethod |
| def score_single_vector( |
| qs: List[torch.Tensor], |
| ps: List[torch.Tensor], |
| device: Optional[Union[str, torch.device]] = None, |
| ) -> torch.Tensor: |
| """ |
| Compute the dot product score for the given single-vector query and passage embeddings. |
| """ |
| device = device or get_torch_device("auto") |
|
|
| if len(qs) == 0: |
| raise ValueError("No queries provided") |
| if len(ps) == 0: |
| raise ValueError("No passages provided") |
|
|
| qs_stacked = torch.stack(qs).to(device) |
| ps_stacked = torch.stack(ps).to(device) |
|
|
| scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked) |
| assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" |
|
|
| scores = scores.to(torch.float32) |
| return scores |
|
|
| @staticmethod |
| def score_multi_vector( |
| qs: Union[torch.Tensor, List[torch.Tensor]], |
| ps: Union[torch.Tensor, List[torch.Tensor]], |
| batch_size: int = 128, |
| device: Optional[Union[str, torch.device]] = None, |
| ) -> torch.Tensor: |
| """ |
| Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector |
| query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the |
| image of a document page. |
| |
| Because the embedding tensors are multi-vector and can thus have different shapes, they |
| should be fed as: |
| (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim) |
| (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually |
| obtained by padding the list of tensors. |
| |
| Args: |
| qs (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings. |
| ps (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings. |
| batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. |
| device (`Union[str, torch.device]`, *optional*): Device to use for computation. If not |
| provided, uses `get_torch_device("auto")`. |
| |
| Returns: |
| `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score |
| tensor is saved on the "cpu" device. |
| """ |
| device = device or get_torch_device("auto") |
|
|
| if len(qs) == 0: |
| raise ValueError("No queries provided") |
| if len(ps) == 0: |
| raise ValueError("No passages provided") |
|
|
| scores_list: List[torch.Tensor] = [] |
|
|
| for i in range(0, len(qs), batch_size): |
| scores_batch = [] |
| |
| qs_batch = qs[i : i + batch_size].to(device) |
| for j in range(0, len(ps), batch_size): |
| ps_batch = torch.nn.utils.rnn.pad_sequence( |
| ps[j : j + batch_size], batch_first=True, padding_value=0 |
| ).to(device) |
| scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2)) |
| scores_batch = torch.cat(scores_batch, dim=1).cpu() |
| scores_list.append(scores_batch) |
|
|
| scores = torch.cat(scores_list, dim=0) |
| assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" |
|
|
| scores = scores.to(torch.float32) |
| return scores |
|
|
| @abstractmethod |
| def get_n_patches( |
| self, |
| image_size: Tuple[int, int], |
| patch_size: int = 14, |
| *args, |
| **kwargs, |
| ) -> Tuple[int, int]: |
| """ |
| Get the number of patches (n_patches_x, n_patches_y) that will be used to process an |
| image of size (height, width) with the given patch size. |
| """ |
| pass |
|
|