test-analysis / dataset_loading.py
Orion Weller
updates, charts, ir_datasetes
68ecf38
raw
history blame
No virus
7.52 kB
import streamlit as st
import os
import pathlib
import beir
from beir import util
from beir.datasets.data_loader import GenericDataLoader
import pytrec_eval
import pandas as pd
from collections import defaultdict
import json
import copy
import ir_datasets
from constants import BEIR, IR_DATASETS, LOCAL_DATASETS
@st.cache_data
def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]):
if corpus_file is None:
return None
did2text = {}
id_key = "_id"
with corpus_file as f:
for idx, line in enumerate(f):
uses_bytes = not (type(line) == str)
if uses_bytes:
if idx == 0 and "doc_id" in line.decode("utf-8"):
continue
inst = json.loads(line.decode("utf-8"))
else:
if idx == 0 and "doc_id" in line:
continue
inst = json.loads(line)
all_text = " ".join([inst[col] for col in columns_to_combine if col in inst])
if id_key not in inst:
id_key = "doc_id"
did2text[inst[id_key]] = {
"text": all_text,
"title": inst["title"] if "title" in inst else "",
}
return did2text
@st.cache_data
def load_local_queries(queries_file):
if queries_file is None:
return None
qid2text = {}
id_key = "_id"
with queries_file as f:
for idx, line in enumerate(f):
uses_bytes = not (type(line) == str)
if uses_bytes:
if idx == 0 and "query_id" in line.decode("utf-8"):
continue
inst = json.loads(line.decode("utf-8"))
else:
if idx == 0 and "query_id" in line:
continue
inst = json.loads(line)
if id_key not in inst:
id_key = "query_id"
qid2text[inst[id_key]] = inst["text"]
return qid2text
@st.cache_data
def load_local_qrels(qrels_file):
if qrels_file is None:
return None
qid2did2label = defaultdict(dict)
with qrels_file as f:
for idx, line in enumerate(f):
uses_bytes = not (type(line) == str)
if uses_bytes:
if idx == 0 and "qid" in line.decode("utf-8") or "query-id" in line.decode("utf-8"):
continue
cur_line = line.decode("utf-8")
else:
if idx == 0 and "qid" in line or "query-id" in line:
continue
cur_line = line
try:
qid, _, doc_id, label = cur_line.split()
except:
qid, doc_id, label = cur_line.split()
qid2did2label[str(qid)][str(doc_id)] = int(label)
return qid2did2label
@st.cache_data
def load_run(f_run):
run = pytrec_eval.parse_run(copy.deepcopy(f_run))
# convert bytes to strings for keys
new_run = defaultdict(dict)
for key, sub_dict in run.items():
new_run[key.decode("utf-8")] = {k.decode("utf-8"): v for k, v in sub_dict.items()}
run_pandas = pd.read_csv(f_run, header=None, index_col=None, sep="\t")
run_pandas.columns = ["qid", "generic", "doc_id", "rank", "score", "model"]
run_pandas.doc_id = run_pandas.doc_id.astype(str)
run_pandas.qid = run_pandas.qid.astype(str)
run_pandas["rank"] = run_pandas["rank"].astype(int)
run_pandas.score = run_pandas.score.astype(float)
all_groups = []
for qid, sub_df in run_pandas.groupby("qid"):
sub_df.sort_values(["score", "doc_id"], ascending=[False, False])
sub_df["rank"] = list(range(1, len(sub_df) + 1))
all_groups.append(sub_df)
run_pandas = pd.concat(all_groups)
return new_run, run_pandas
@st.cache_data
def load_jsonl(f):
did2text = defaultdict(list)
sub_did2text = {}
for idx, line in enumerate(f):
inst = json.loads(line)
if "question" in inst:
docid = inst["metadata"][0]["passage_id"] if "doc_id" not in inst else inst["doc_id"]
did2text[docid].append(inst["question"])
elif "text" in inst:
docid = inst["doc_id"] if "doc_id" in inst else inst["did"]
did2text[docid].append(inst["text"])
sub_did2text[inst["did"]] = inst["text"]
elif "query" in inst:
docid = inst["doc_id"] if "doc_id" in inst else inst["did"]
did2text[docid].append(inst["query"])
else:
breakpoint()
raise NotImplementedError("Need to handle this case")
return did2text, sub_did2text
@st.cache_data(persist="disk")
def get_beir(dataset: str):
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
data_path = util.download_and_unzip(url, out_dir)
return GenericDataLoader(data_folder=data_path).load(split="test")
@st.cache_data(persist="disk")
def get_ir_datasets(dataset_name: str, input_fields_doc: str = None, input_fields_query: str = None):
dataset = ir_datasets.load(dataset_name)
queries = {}
for qid, query in dataset.queries_iter():
if input_fields_query is None:
if type(query) == str:
queries[qid] = query
else:
# get all fields that exist in query
all_fields = {field: getattr(query, field) for field in query._fields}
# put all fields into a single string
queries[qid] = " ".join([str(v) for v in all_fields.values()])
else:
all_fields = {field: getattr(query, field) for field in input_fields_query}
queries[qid] = " ".join([str(v) for v in all_fields.values()])
corpus = {}
for doc in dataset.docs_iter():
if input_fields_doc is None:
if type(doc) == str:
corpus[doc.doc_id] = {"text": doc}
else: # get all fields that exist in query
all_fields = {field: getattr(doc, field) for field in doc._fields}
corpus[doc.doc_id] = {"text": " ".join([str(v) for v in all_fields.values()])}
else:
all_fields = {field: getattr(doc, field) for field in input_fields_doc}
corpus[doc.doc_id] = {"text": " ".join([str(v) for v in all_fields.values()])}
# return corpus, queries, qrels
return corpus, queries, dataset.qrels_dict()
@st.cache_data(persist="disk")
def get_dataset(dataset_name: str, input_fields_doc, input_fields_query):
if type(input_fields_doc) == str:
input_fields_doc = input_fields_doc.strip().split(",")
if type(input_fields_query) == str:
input_fields_query = input_fields_query.strip().split(",")
if dataset_name == "":
return {}, {}, {}
if dataset_name in BEIR:
return get_beir(dataset_name)
elif dataset_name in IR_DATASETS:
return get_ir_datasets(dataset_name, input_fields_doc, input_fields_query)
elif dataset_name in LOCAL_DATASETS:
base_path = f"local_datasets/{dataset_name}"
corpus_file = open(f"{base_path}/corpus.jsonl", "r")
queries_file = open(f"{base_path}/queries.jsonl", "r")
qrels_file = open(f"{base_path}/qrels/test.tsv", "r")
return load_local_corpus(corpus_file), load_local_queries(queries_file), load_local_qrels(qrels_file)
else:
raise NotImplementedError("Dataset not implemented")