vic / src /retrieval.py
altndrr's picture
Load CLIP model from transformers
fcf6714
raw
history blame
1.08 kB
from pathlib import Path
from typing import Optional
import numpy as np
import pyarrow as pa
class ArrowMetadataProvider:
"""The arrow metadata provider provides metadata from contiguous ids using arrow.
Code taken from: https://github.dev/rom1504/clip-retrieval
"""
def __init__(self, arrow_folder: Path):
arrow_files = [str(a) for a in sorted(arrow_folder.glob("**/*")) if a.is_file()]
self.table = pa.concat_tables(
[
pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
for arrow_file in arrow_files
]
)
def get(self, ids: np.ndarray, cols: Optional[list] = None):
"""Implement the get method from the arrow metadata provide, get metadata from ids."""
if cols is None:
cols = self.table.schema.names
else:
cols = list(set(self.table.schema.names) & set(cols))
t = pa.concat_tables([self.table[i:j] for i, j in zip(ids, ids + 1)])
return t.select(cols).to_pandas().to_dict("records")