Spaces:
Runtime error
Runtime error
# | |
# Pyserini: Reproducible IR research with sparse and dense representations | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import hashlib | |
import os | |
import re | |
import shutil | |
import tarfile | |
import logging | |
from urllib.error import HTTPError, URLError | |
from urllib.request import urlretrieve | |
import pandas as pd | |
from tqdm import tqdm | |
from pyserini.encoded_query_info import QUERY_INFO | |
from pyserini.encoded_corpus_info import CORPUS_INFO | |
from pyserini.evaluate_script_info import EVALUATION_INFO | |
from pyserini.prebuilt_index_info import TF_INDEX_INFO, FAISS_INDEX_INFO, IMPACT_INDEX_INFO | |
logger = logging.getLogger(__name__) | |
# https://gist.github.com/leimao/37ff6e990b3226c2c9670a2cd1e4a6f5 | |
class TqdmUpTo(tqdm): | |
def update_to(self, b=1, bsize=1, tsize=None): | |
""" | |
b : int, optional | |
Number of blocks transferred so far [default: 1]. | |
bsize : int, optional | |
Size of each block (in tqdm units) [default: 1]. | |
tsize : int, optional | |
Total size (in tqdm units). If [default: None] remains unchanged. | |
""" | |
if tsize is not None: | |
self.total = tsize | |
self.update(b * bsize - self.n) # will also set self.n = b * bsize | |
# For large files, we need to compute MD5 block by block. See: | |
# https://stackoverflow.com/questions/1131220/get-md5-hash-of-big-files-in-python | |
def compute_md5(file, block_size=2**20): | |
m = hashlib.md5() | |
with open(file, 'rb') as f: | |
while True: | |
buf = f.read(block_size) | |
if not buf: | |
break | |
m.update(buf) | |
return m.hexdigest() | |
def download_url(url, save_dir, local_filename=None, md5=None, force=False, verbose=True): | |
# If caller does not specify local filename, figure it out from the download URL: | |
if not local_filename: | |
filename = url.split('/')[-1] | |
filename = re.sub('\\?dl=1$', '', filename) # Remove the Dropbox 'force download' parameter | |
else: | |
# Otherwise, use the specified local_filename: | |
filename = local_filename | |
destination_path = os.path.join(save_dir, filename) | |
if verbose: | |
print(f'Downloading {url} to {destination_path}...') | |
# Check to see if file already exists, if so, simply return (quietly) unless force=True, in which case we remove | |
# destination file and download fresh copy. | |
if os.path.exists(destination_path): | |
if verbose: | |
print(f'{destination_path} already exists!') | |
if not force: | |
if verbose: | |
print(f'Skipping download.') | |
return destination_path | |
if verbose: | |
print(f'force=True, removing {destination_path}; fetching fresh copy...') | |
os.remove(destination_path) | |
with TqdmUpTo(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=filename) as t: | |
urlretrieve(url, filename=destination_path, reporthook=t.update_to) | |
if md5: | |
md5_computed = compute_md5(destination_path) | |
assert md5_computed == md5, f'{destination_path} does not match checksum! Expecting {md5} got {md5_computed}.' | |
return destination_path | |
def get_cache_home(): | |
custom_dir = os.environ.get("PYSERINI_CACHE") | |
if custom_dir is not None and custom_dir != '': | |
return custom_dir | |
return os.path.expanduser(os.path.join(f'~{os.path.sep}.cache', "pyserini")) | |
def download_and_unpack_index(url, index_directory='indexes', local_filename=False, | |
force=False, verbose=True, prebuilt=False, md5=None): | |
# If caller does not specify local filename, figure it out from the download URL: | |
if not local_filename: | |
index_name = url.split('/')[-1] | |
else: | |
# Otherwise, use the specified local_filename: | |
index_name = local_filename | |
# Remove the suffix: | |
index_name = re.sub('''.tar.gz.*$''', '', index_name) | |
if prebuilt: | |
index_directory = os.path.join(get_cache_home(), index_directory) | |
index_path = os.path.join(index_directory, f'{index_name}.{md5}') | |
if not os.path.exists(index_directory): | |
os.makedirs(index_directory) | |
local_tarball = os.path.join(index_directory, f'{index_name}.tar.gz') | |
# If there's a local tarball, it's likely corrupted, because we remove the local tarball on success (below). | |
# So, we want to remove. | |
if os.path.exists(local_tarball): | |
os.remove(local_tarball) | |
else: | |
local_tarball = os.path.join(index_directory, f'{index_name}.tar.gz') | |
index_path = os.path.join(index_directory, f'{index_name}') | |
# Check to see if index already exists, if so, simply return (quietly) unless force=True, in which case we remove | |
# index and download fresh copy. | |
if os.path.exists(index_path): | |
if not force: | |
if verbose: | |
print(f'{index_path} already exists, skipping download.') | |
return index_path | |
if verbose: | |
print(f'{index_path} already exists, but force=True, removing {index_path} and fetching fresh copy...') | |
shutil.rmtree(index_path) | |
print(f'Downloading index at {url}...') | |
download_url(url, index_directory, local_filename=local_filename, verbose=False, md5=md5) | |
if verbose: | |
print(f'Extracting {local_tarball} into {index_path}...') | |
try: | |
tarball = tarfile.open(local_tarball) | |
except: | |
local_tarball = os.path.join(index_directory, f'{index_name}') | |
tarball = tarfile.open(local_tarball) | |
dirs_in_tarball = [member.name for member in tarball if member.isdir()] | |
assert len(dirs_in_tarball), f"Detect multiple members ({', '.join(dirs_in_tarball)}) under the tarball {local_tarball}." | |
tarball.extractall(index_directory) | |
tarball.close() | |
os.remove(local_tarball) | |
if prebuilt: | |
dir_in_tarball = dirs_in_tarball[0] | |
if dir_in_tarball != index_name: | |
logger.info(f"Renaming {index_directory}/{dir_in_tarball} into {index_directory}/{index_name}.") | |
index_name = dir_in_tarball | |
os.rename(os.path.join(index_directory, f'{index_name}'), index_path) | |
return index_path | |
def check_downloaded(index_name): | |
if index_name in TF_INDEX_INFO: | |
target_index = TF_INDEX_INFO[index_name] | |
elif index_name in IMPACT_INDEX_INFO: | |
target_index = IMPACT_INDEX_INFO[index_name] | |
else: | |
target_index = FAISS_INDEX_INFO[index_name] | |
index_url = target_index['urls'][0] | |
index_md5 = target_index['md5'] | |
index_name = index_url.split('/')[-1] | |
index_name = re.sub('''.tar.gz.*$''', '', index_name) | |
index_directory = os.path.join(get_cache_home(), 'indexes') | |
index_path = os.path.join(index_directory, f'{index_name}.{index_md5}') | |
return os.path.exists(index_path) | |
def get_sparse_indexes_info(): | |
df = pd.DataFrame.from_dict({**TF_INDEX_INFO, **IMPACT_INDEX_INFO}) | |
for index in df.keys(): | |
df[index]['downloaded'] = check_downloaded(index) | |
with pd.option_context('display.max_rows', None, 'display.max_columns', | |
None, 'display.max_colwidth', -1, 'display.colheader_justify', 'left'): | |
print(df) | |
def get_impact_indexes_info(): | |
df = pd.DataFrame.from_dict(IMPACT_INDEX_INFO) | |
for index in df.keys(): | |
df[index]['downloaded'] = check_downloaded(index) | |
with pd.option_context('display.max_rows', None, 'display.max_columns', | |
None, 'display.max_colwidth', -1, 'display.colheader_justify', 'left'): | |
print(df) | |
def get_dense_indexes_info(): | |
df = pd.DataFrame.from_dict(FAISS_INDEX_INFO) | |
for index in df.keys(): | |
df[index]['downloaded'] = check_downloaded(index) | |
with pd.option_context('display.max_rows', None, 'display.max_columns', | |
None, 'display.max_colwidth', -1, 'display.colheader_justify', 'left'): | |
print(df) | |
def download_prebuilt_index(index_name, force=False, verbose=True, mirror=None): | |
if index_name not in TF_INDEX_INFO and index_name not in FAISS_INDEX_INFO and index_name not in IMPACT_INDEX_INFO: | |
raise ValueError(f'Unrecognized index name {index_name}') | |
if index_name in TF_INDEX_INFO: | |
target_index = TF_INDEX_INFO[index_name] | |
elif index_name in IMPACT_INDEX_INFO: | |
target_index = IMPACT_INDEX_INFO[index_name] | |
else: | |
target_index = FAISS_INDEX_INFO[index_name] | |
index_md5 = target_index['md5'] | |
for url in target_index['urls']: | |
local_filename = target_index['filename'] if 'filename' in target_index else None | |
try: | |
return download_and_unpack_index(url, local_filename=local_filename, | |
prebuilt=True, md5=index_md5, verbose=verbose) | |
except (HTTPError, URLError) as e: | |
print(f'Unable to download pre-built index at {url}, trying next URL...') | |
raise ValueError(f'Unable to download pre-built index at any known URLs.') | |
def download_encoded_queries(query_name, force=False, verbose=True, mirror=None): | |
if query_name not in QUERY_INFO: | |
raise ValueError(f'Unrecognized query name {query_name}') | |
query_md5 = QUERY_INFO[query_name]['md5'] | |
for url in QUERY_INFO[query_name]['urls']: | |
try: | |
return download_and_unpack_index(url, index_directory='queries', prebuilt=True, md5=query_md5) | |
except (HTTPError, URLError) as e: | |
print(f'Unable to download encoded query at {url}, trying next URL...') | |
raise ValueError(f'Unable to download encoded query at any known URLs.') | |
def download_encoded_corpus(corpus_name, force=False, verbose=True, mirror=None): | |
if corpus_name not in CORPUS_INFO: | |
raise ValueError(f'Unrecognized corpus name {corpus_name}') | |
corpus_md5 = CORPUS_INFO[corpus_name]['md5'] | |
for url in CORPUS_INFO[corpus_name]['urls']: | |
local_filename = CORPUS_INFO[corpus_name]['filename'] if 'filename' in CORPUS_INFO[corpus_name] else None | |
try: | |
return download_and_unpack_index(url, local_filename=local_filename, index_directory='corpus', prebuilt=True, md5=corpus_md5) | |
except (HTTPError, URLError) as e: | |
print(f'Unable to download encoded corpus at {url}, trying next URL...') | |
raise ValueError(f'Unable to download encoded corpus at any known URLs.') | |
def download_evaluation_script(evaluation_name, force=False, verbose=True, mirror=None): | |
if evaluation_name not in EVALUATION_INFO: | |
raise ValueError(f'Unrecognized evaluation name {evaluation_name}') | |
for url in EVALUATION_INFO[evaluation_name]['urls']: | |
try: | |
save_dir = os.path.join(get_cache_home(), 'eval') | |
if not os.path.exists(save_dir): | |
os.makedirs(save_dir) | |
return download_url(url, save_dir=save_dir) | |
except HTTPError: | |
print(f'Unable to download evaluation script at {url}, trying next URL...') | |
raise ValueError(f'Unable to download evaluation script at any known URLs.') | |
def get_sparse_index(index_name): | |
if index_name not in FAISS_INDEX_INFO: | |
raise ValueError(f'Unrecognized index name {index_name}') | |
return FAISS_INDEX_INFO[index_name]["texts"] | |