vic / src /retrieval.py
altndrr's picture
Add first version
a3ee979
raw
history blame
1.31 kB
from pathlib import Path
import pyarrow as pa
import numpy as np
class ArrowMetadataProvider:
"""The arrow metadata provider provides metadata from contiguous ids using arrow.
Code taken from:
https://github.dev/rom1504/clip-retrieval
"""
def __init__(self, arrow_folder):
arrow_files = [str(a) for a in sorted(Path(arrow_folder).glob("**/*")) if a.is_file()]
self.table = pa.concat_tables(
[
pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
for arrow_file in arrow_files
]
)
def get(self, ids, cols=None):
"""implement the get method from the arrow metadata provide, get metadata from ids"""
if cols is None:
cols = self.table.schema.names
else:
cols = list(set(self.table.schema.names) & set(cols))
t = pa.concat_tables([self.table[i:(i + 1)] for i in ids])
return t.select(cols).to_pandas().to_dict("records")
def meta_to_dict(meta):
"""Convert a metadata list to a dictionary."""
output = {}
for k, v in meta.items():
if isinstance(v, bytes):
v = v.decode()
elif type(v).__module__ == np.__name__:
v = v.item()
output[k] = v
return output