NetsPresso_QA / pyserini /output_writer.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw history blame
No virus
3.83 kB
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
from abc import ABC, abstractmethod
from enum import Enum, unique
from typing import List
from pyserini.search import JLuceneSearcherResult
@unique
class OutputFormat(Enum):
TREC = 'trec'
MSMARCO = "msmarco"
KILT = 'kilt'
class OutputWriter(ABC):
def __init__(self, file_path: str, mode: str = 'w',
max_hits: int = 1000, tag: str = None, topics: dict = None,
use_max_passage: bool = False, max_passage_delimiter: str = None, max_passage_hits: int = 100):
self.file_path = file_path
self.mode = mode
self.tag = tag
self.topics = topics
self.use_max_passage = use_max_passage
self.max_passage_delimiter = max_passage_delimiter if use_max_passage else None
self.max_hits = max_passage_hits if use_max_passage else max_hits
self._file = None
def __enter__(self):
dirname = os.path.dirname(self.file_path)
if dirname:
os.makedirs(dirname, exist_ok=True)
self._file = open(self.file_path, self.mode)
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self._file.close()
def hits_iterator(self, hits: List[JLuceneSearcherResult]):
unique_docs = set()
rank = 1
for hit in hits:
if self.use_max_passage and self.max_passage_delimiter:
docid = hit.docid.split(self.max_passage_delimiter)[0]
else:
docid = hit.docid.strip()
if self.use_max_passage:
if docid in unique_docs:
continue
unique_docs.add(docid)
yield docid, rank, hit.score, hit
rank = rank + 1
if rank > self.max_hits:
break
@abstractmethod
def write(self, topic: str, hits: List[JLuceneSearcherResult]):
raise NotImplementedError()
class TrecWriter(OutputWriter):
def write(self, topic: str, hits: List[JLuceneSearcherResult]):
for docid, rank, score, _ in self.hits_iterator(hits):
self._file.write(f'{topic} Q0 {docid} {rank} {score:.6f} {self.tag}\n')
class MsMarcoWriter(OutputWriter):
def write(self, topic: str, hits: List[JLuceneSearcherResult]):
for docid, rank, score, _ in self.hits_iterator(hits):
self._file.write(f'{topic}\t{docid}\t{rank}\n')
class KiltWriter(OutputWriter):
def write(self, topic: str, hits: List[JLuceneSearcherResult]):
datapoint = self.topics[topic]
provenance = []
for docid, rank, score, _ in self.hits_iterator(hits):
provenance.append({"wikipedia_id": docid})
datapoint["output"] = [{"provenance": provenance}]
json.dump(datapoint, self._file)
self._file.write('\n')
def get_output_writer(file_path: str, output_format: OutputFormat, *args, **kwargs) -> OutputWriter:
mapping = {
OutputFormat.TREC: TrecWriter,
OutputFormat.MSMARCO: MsMarcoWriter,
OutputFormat.KILT: KiltWriter,
}
return mapping[output_format](file_path, *args, **kwargs)
def tie_breaker(hits):
return sorted(hits, key=lambda x: (-x.score, x.docid))