import streamlit as st import os import pathlib import beir from beir import util from beir.datasets.data_loader import GenericDataLoader import pytrec_eval import pandas as pd from collections import defaultdict import json import copy import ir_datasets from constants import BEIR, IR_DATASETS, LOCAL_DATASETS @st.cache_data def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]): if corpus_file is None: return None did2text = {} id_key = "_id" with corpus_file as f: for idx, line in enumerate(f): uses_bytes = not (type(line) == str) if uses_bytes: if idx == 0 and "doc_id" in line.decode("utf-8"): continue inst = json.loads(line.decode("utf-8")) else: if idx == 0 and "doc_id" in line: continue inst = json.loads(line) all_text = " ".join([inst[col] for col in columns_to_combine if col in inst]) if id_key not in inst: id_key = "doc_id" did2text[inst[id_key]] = { "text": all_text, "title": inst["title"] if "title" in inst else "", } return did2text @st.cache_data def load_local_queries(queries_file): if queries_file is None: return None qid2text = {} id_key = "_id" with queries_file as f: for idx, line in enumerate(f): uses_bytes = not (type(line) == str) if uses_bytes: if idx == 0 and "query_id" in line.decode("utf-8"): continue inst = json.loads(line.decode("utf-8")) else: if idx == 0 and "query_id" in line: continue inst = json.loads(line) if id_key not in inst: id_key = "query_id" qid2text[inst[id_key]] = inst["text"] return qid2text @st.cache_data def load_local_qrels(qrels_file): if qrels_file is None: return None qid2did2label = defaultdict(dict) with qrels_file as f: for idx, line in enumerate(f): uses_bytes = not (type(line) == str) if uses_bytes: if idx == 0 and "qid" in line.decode("utf-8") or "query-id" in line.decode("utf-8"): continue cur_line = line.decode("utf-8") else: if idx == 0 and "qid" in line or "query-id" in line: continue cur_line = line try: qid, _, doc_id, label = cur_line.split() except: qid, doc_id, label = cur_line.split() qid2did2label[str(qid)][str(doc_id)] = int(label) return qid2did2label @st.cache_data def load_run(f_run): run = pytrec_eval.parse_run(copy.deepcopy(f_run)) # convert bytes to strings for keys new_run = defaultdict(dict) for key, sub_dict in run.items(): new_run[key.decode("utf-8")] = {k.decode("utf-8"): v for k, v in sub_dict.items()} run_pandas = pd.read_csv(f_run, header=None, index_col=None, sep="\t") run_pandas.columns = ["qid", "generic", "doc_id", "rank", "score", "model"] run_pandas.doc_id = run_pandas.doc_id.astype(str) run_pandas.qid = run_pandas.qid.astype(str) run_pandas["rank"] = run_pandas["rank"].astype(int) run_pandas.score = run_pandas.score.astype(float) # if run_1_alt is not None: # run_1_alt, run_1_alt_sub = load_jsonl(run_1_alt) return new_run, run_pandas @st.cache_data def load_jsonl(f): did2text = defaultdict(list) sub_did2text = {} for idx, line in enumerate(f): inst = json.loads(line) if "question" in inst: docid = inst["metadata"][0]["passage_id"] if "doc_id" not in inst else inst["doc_id"] did2text[docid].append(inst["question"]) elif "text" in inst: docid = inst["doc_id"] if "doc_id" in inst else inst["did"] did2text[docid].append(inst["text"]) sub_did2text[inst["did"]] = inst["text"] elif "query" in inst: docid = inst["doc_id"] if "doc_id" in inst else inst["did"] did2text[docid].append(inst["query"]) else: breakpoint() raise NotImplementedError("Need to handle this case") return did2text, sub_did2text @st.cache_data def get_beir(dataset: str): url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset) out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets") data_path = util.download_and_unzip(url, out_dir) return GenericDataLoader(data_folder=data_path).load(split="test") @st.cache_data def get_ir_datasets(dataset_name: str): dataset = ir_datasets.load(dataset_name) queries = {} for qid, query in dataset.queries_iter(): queries[qid] = query # corpus = {} # for doc in dataset.docs_iter(): # return corpus, queries, qrels return dataset.doc_store(), queries, dataset.qrels_dict() @st.cache_data def get_dataset(dataset_name: str): if dataset_name == "": return {}, {}, {} if dataset_name in BEIR: return get_beir(dataset_name) elif dataset_name in IR_DATASETS: return get_ir_datasets(dataset_name) elif dataset_name in LOCAL_DATASETS: base_path = f"local_datasets/{dataset_name}" corpus_file = open(f"{base_path}/corpus.jsonl", "r") queries_file = open(f"{base_path}/queries.jsonl", "r") qrels_file = open(f"{base_path}/qrels/test.tsv", "r") return load_local_corpus(corpus_file), load_local_queries(queries_file), load_local_qrels(qrels_file) else: raise NotImplementedError("Dataset not implemented")