Spaces:
Running
on
Zero
Running
on
Zero
from abc import ABC, abstractmethod | |
from typing import List, Optional, Union | |
import torch | |
from PIL import Image | |
from transformers import BatchEncoding, BatchFeature | |
def get_torch_device(device: str = "auto") -> str: | |
""" | |
Returns the device (string) to be used by PyTorch. | |
`device` arg defaults to "auto" which will use: | |
- "cuda:0" if available | |
- else "mps" if available | |
- else "cpu". | |
""" | |
if device == "auto": | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
elif torch.backends.mps.is_available(): # for Apple Silicon | |
device = "mps" | |
else: | |
device = "cpu" | |
logger.info(f"Using device: {device}") | |
return device | |
class BaseVisualRetrieverProcessor(ABC): | |
""" | |
Base class for visual retriever processors. | |
""" | |
def process_images( | |
self, | |
images: List[Image.Image], | |
) -> Union[BatchFeature, BatchEncoding]: | |
pass | |
def process_queries( | |
self, | |
queries: List[str], | |
max_length: int = 50, | |
suffix: Optional[str] = None, | |
) -> Union[BatchFeature, BatchEncoding]: | |
pass | |
def score( | |
self, | |
qs: List[torch.Tensor], | |
ps: List[torch.Tensor], | |
device: Optional[Union[str, torch.device]] = None, | |
**kwargs, | |
) -> torch.Tensor: | |
pass | |
def score_single_vector( | |
qs: List[torch.Tensor], | |
ps: List[torch.Tensor], | |
device: Optional[Union[str, torch.device]] = None, | |
) -> torch.Tensor: | |
""" | |
Compute the dot product score for the given single-vector query and passage embeddings. | |
""" | |
device = device or get_torch_device("auto") | |
if len(qs) == 0: | |
raise ValueError("No queries provided") | |
if len(ps) == 0: | |
raise ValueError("No passages provided") | |
qs_stacked = torch.stack(qs).to(device) | |
ps_stacked = torch.stack(ps).to(device) | |
scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked) | |
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" | |
scores = scores.to(torch.float32) | |
return scores | |
def score_multi_vector( | |
qs: List[torch.Tensor], | |
ps: List[torch.Tensor], | |
batch_size: int = 128, | |
device: Optional[Union[str, torch.device]] = None, | |
) -> torch.Tensor: | |
""" | |
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. | |
""" | |
device = device or get_torch_device("auto") | |
if len(qs) == 0: | |
raise ValueError("No queries provided") | |
if len(ps) == 0: | |
raise ValueError("No passages provided") | |
scores_list: List[torch.Tensor] = [] | |
for i in range(0, len(qs), batch_size): | |
scores_batch = [] | |
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to( | |
device | |
) | |
for j in range(0, len(ps), batch_size): | |
ps_batch = torch.nn.utils.rnn.pad_sequence( | |
ps[j : j + batch_size], batch_first=True, padding_value=0 | |
).to(device) | |
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2)) | |
scores_batch = torch.cat(scores_batch, dim=1).cpu() | |
scores_list.append(scores_batch) | |
scores = torch.cat(scores_list, dim=0) | |
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" | |
scores = scores.to(torch.float32) | |
return scores |