NetsPresso_QA / pyserini /query_iterator.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import json
from abc import ABC, abstractmethod
from enum import Enum, unique
from pathlib import Path
from pyserini.search import get_topics, get_topics_with_reader
from pyserini.util import download_url, get_cache_home
from pyserini.external_query_info import KILT_QUERY_INFO
from urllib.error import HTTPError, URLError
@unique
class TopicsFormat(Enum):
DEFAULT = 'default'
KILT = 'kilt'
class QueryIterator(ABC):
PREDEFINED_ORDER = {'msmarco-doc-dev',
'msmarco-doc-test',
'msmarco-passage-dev-subset',
'msmarco-passage-test-subset'}
def __init__(self, topics: dict, order: list = None):
self.order = order if order else sorted(topics.keys())
self.topics = topics
@abstractmethod
def get_query(self, id_):
raise NotImplementedError()
@classmethod
@abstractmethod
def from_topics(cls, topics_path: str):
raise NotImplementedError()
def __iter__(self):
for id_ in self.order:
yield id_, self.get_query(id_)
def __len__(self):
return len(self.topics.keys())
@staticmethod
def get_predefined_order(topics_path: str):
order = None
normalized_path = Path(topics_path).stem # get filename w/o extension
normalized_path = normalized_path.replace('_', '-')
if normalized_path in QueryIterator.PREDEFINED_ORDER:
print(f'Using pre-defined topic order for {normalized_path}')
# Lazy import:
from pyserini.query_iterator_order_info import QUERY_IDS
order = QUERY_IDS[topics_path]
return order
class DefaultQueryIterator(QueryIterator):
def get_query(self, id_):
return self.topics[id_].get('title')
@classmethod
def from_topics(cls, topics_path: str):
if os.path.exists(topics_path):
if topics_path.endswith('.json'):
with open(topics_path, 'r') as f:
topics = json.load(f)
elif "beir" in topics_path:
topics = get_topics_with_reader('io.anserini.search.topicreader.TsvStringTopicReader', topics_path)
elif topics_path.endswith('.tsv') or topics_path.endswith('.tsv.gz'):
try:
topics = get_topics_with_reader('io.anserini.search.topicreader.TsvIntTopicReader', topics_path)
except ValueError as e:
topics = get_topics_with_reader('io.anserini.search.topicreader.TsvStringTopicReader', topics_path)
elif topics_path.endswith('.trec'):
topics = get_topics_with_reader('io.anserini.search.topicreader.TrecTopicReader', topics_path)
elif 'cacm' in topics_path:
topics = get_topics_with_reader('io.anserini.search.topicreader.CacmTopicReader', topics_path)
else:
raise NotImplementedError(f"Not sure how to parse {topics_path}. Please specify the file extension.")
else:
topics = get_topics(topics_path)
if not topics:
raise FileNotFoundError(f'Topic {topics_path} Not Found')
order = QueryIterator.get_predefined_order(topics_path)
return cls(topics, order)
class KiltQueryIterator(QueryIterator):
ENT_START_TOKEN = "[START_ENT]"
ENT_END_TOKEN = "[END_ENT]"
def get_query(self, id_):
datapoint = self.topics[id_]
query = (
datapoint["input"]
.replace(KiltQueryIterator.ENT_START_TOKEN, "")
.replace(KiltQueryIterator.ENT_END_TOKEN, "")
.strip()
)
return query
@classmethod
def from_topics(cls, topics_path: str):
topics = {}
order = []
if not os.path.exists(topics_path):
# Download if necessary:
topics_path = cls.download_kilt_topics(topics_path)
with open(topics_path, 'r') as f:
for line in f:
datapoint = json.loads(line)
topics[datapoint["id"]] = datapoint
order.append(datapoint["id"])
return cls(topics, order)
@classmethod
def download_kilt_topics(cls, task: str, force=False):
if task not in KILT_QUERY_INFO:
raise ValueError(f'Unrecognized query name {task}')
task = KILT_QUERY_INFO[task]
md5 = task['md5']
save_dir = os.path.join(get_cache_home(), 'queries')
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for url in task['urls']:
try:
return download_url(url, save_dir, force=force, md5=md5)
except (HTTPError, URLError) as e:
print(f'Unable to download encoded query at {url}, trying next URL...')
raise ValueError(f'Unable to download encoded query at any known URLs.')
def get_query_iterator(topics_path: str, topics_format: TopicsFormat):
mapping = {
TopicsFormat.DEFAULT: DefaultQueryIterator,
TopicsFormat.KILT: KiltQueryIterator,
}
return mapping[topics_format].from_topics(topics_path)