Spaces:
Runtime error
Runtime error
# | |
# Pyserini: Reproducible IR research with sparse and dense representations | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import json | |
import os | |
import faiss | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
class DocumentEncoder: | |
def encode(self, texts, **kwargs): | |
pass | |
def _mean_pooling(last_hidden_state, attention_mask): | |
token_embeddings = last_hidden_state | |
input_mask_expanded = ( | |
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
) | |
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) | |
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
return sum_embeddings / sum_mask | |
class QueryEncoder: | |
def encode(self, text, **kwargs): | |
pass | |
class PcaEncoder: | |
def __init__(self, encoder, pca_model_path): | |
self.encoder = encoder | |
self.pca_mat = faiss.read_VectorTransform(pca_model_path) | |
def encode(self, text, **kwargs): | |
if isinstance(text, str): | |
embeddings = self.encoder.encode(text, **kwargs) | |
embeddings = self.pca_mat.apply_py(np.array([embeddings])) | |
embeddings = embeddings[0] | |
else: | |
embeddings = self.encoder.encode(text, **kwargs) | |
embeddings = self.pca_mat.apply_py(embeddings) | |
return embeddings | |
class JsonlCollectionIterator: | |
def __init__(self, collection_path: str, fields=None, delimiter="\n\n"): | |
if fields: | |
self.fields = fields | |
else: | |
self.fields = ["text"] | |
self.delimiter = delimiter | |
self.all_info = self._load(collection_path) | |
self.size = len(self.all_info["id"]) | |
self.batch_size = 1 | |
self.shard_id = 0 | |
self.shard_num = 1 | |
def __call__(self, batch_size=1, shard_id=0, shard_num=1): | |
self.batch_size = batch_size | |
self.shard_id = shard_id | |
self.shard_num = shard_num | |
return self | |
def __iter__(self): | |
total_len = self.size | |
shard_size = int(total_len / self.shard_num) | |
start_idx = self.shard_id * shard_size | |
end_idx = min(start_idx + shard_size, total_len) | |
if self.shard_id == self.shard_num - 1: | |
end_idx = total_len | |
to_yield = {} | |
for idx in tqdm(range(start_idx, end_idx, self.batch_size)): | |
for key in self.all_info: | |
to_yield[key] = self.all_info[key][ | |
idx : min(idx + self.batch_size, end_idx) | |
] | |
yield to_yield | |
def _parse_fields_from_info(self, info): | |
""" | |
:params info: dict, containing all fields as speicifed in self.fields either under | |
the key of the field name or under the key of 'contents'. If under `contents`, this | |
function will parse the input contents into each fields based the self.delimiter | |
return: List, each corresponds to the value of self.fields | |
""" | |
n_fields = len(self.fields) | |
# if all fields are under the key of info, read these rather than 'contents' | |
if all([field in info for field in self.fields]): | |
return [info[field].strip() for field in self.fields] | |
assert "contents" in info, f"contents not found in info: {info}" | |
contents = info["contents"] | |
# whether to remove the final self.delimiter (especially \n) | |
# in CACM, a \n is always there at the end of contents, which we want to remove; | |
# but in SciFact, Fiqa, and more, there are documents that only have title but not text (e.g. "This is title\n") | |
# where the trailing \n indicates empty fields | |
if contents.count(self.delimiter) == n_fields: | |
# the user appends one more delimiter to the end, we remove it | |
if contents.endswith(self.delimiter): | |
# not using .rstrip() as there might be more than one delimiters at the end | |
contents = contents[: -len(self.delimiter)] | |
return [field.strip(" ") for field in contents.split(self.delimiter)] | |
def _load(self, collection_path): | |
filenames = [] | |
if os.path.isfile(collection_path): | |
filenames.append(collection_path) | |
else: | |
for filename in os.listdir(collection_path): | |
filenames.append(os.path.join(collection_path, filename)) | |
all_info = {field: [] for field in self.fields} | |
all_info["id"] = [] | |
for filename in filenames: | |
with open(filename) as f: | |
for line_i, line in tqdm(enumerate(f)): | |
# try: | |
info = json.loads(line) | |
_id = info.get("id", info.get("docid", None)) | |
if _id is None: | |
raise ValueError( | |
f"Cannot find 'id' or 'docid' from {filename}." | |
) | |
all_info["id"].append(str(_id)) | |
fields_info = self._parse_fields_from_info(info) | |
if len(fields_info) != len(self.fields): | |
raise ValueError( | |
f"{len(fields_info)} fields are found at Line#{line_i} in file {filename}." | |
f"{len(self.fields)} fields expected." | |
f"Line content: {info['contents']}" | |
) | |
for i in range(len(fields_info)): | |
all_info[self.fields[i]].append(fields_info[i]) | |
# except: | |
# print(f"skip line with error: {line}") | |
return all_info | |
class RepresentationWriter: | |
def __enter__(self): | |
pass | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
pass | |
def write(self, batch_info, fields=None): | |
pass | |
class JsonlRepresentationWriter(RepresentationWriter): | |
def __init__(self, dir_path): | |
self.dir_path = dir_path | |
self.filename = "embeddings.jsonl" | |
self.file = None | |
def __enter__(self): | |
if not os.path.exists(self.dir_path): | |
os.makedirs(self.dir_path) | |
self.file = open(os.path.join(self.dir_path, self.filename), "w") | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.file.close() | |
def write(self, batch_info, fields=None): | |
for i in range(len(batch_info["id"])): | |
contents = "\n".join([batch_info[key][i] for key in fields]) | |
vector = batch_info["vector"][i] | |
vector = vector.tolist() if isinstance(vector, np.ndarray) else vector | |
self.file.write( | |
json.dumps( | |
{"id": batch_info["id"][i], "contents": contents, "vector": vector} | |
) | |
+ "\n" | |
) | |
class FaissRepresentationWriter(RepresentationWriter): | |
def __init__(self, dir_path, dimension=768): | |
self.dir_path = dir_path | |
self.index_name = "index" | |
self.id_file_name = "docid" | |
self.dimension = dimension | |
self.index = faiss.IndexFlatIP(self.dimension) | |
self.id_file = None | |
def __enter__(self): | |
if not os.path.exists(self.dir_path): | |
os.makedirs(self.dir_path) | |
self.id_file = open(os.path.join(self.dir_path, self.id_file_name), "w") | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.id_file.close() | |
faiss.write_index(self.index, os.path.join(self.dir_path, self.index_name)) | |
def write(self, batch_info, fields=None): | |
for id_ in batch_info["id"]: | |
self.id_file.write(f"{id_}\n") | |
self.index.add(np.ascontiguousarray(batch_info["vector"])) | |