NetsPresso_QA / scripts /repro_matrix /generate_html_odqa.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw
history blame
No virus
13.6 kB
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import defaultdict
from string import Template
import argparse
import yaml
# from scripts.repro_matrix.defs_odqa import models
from defs_odqa import models
# global vars
TQA_TOPICS = 'dpr-trivia-test'
NQ_TOPICS = 'nq-test'
PRINT_TQA_TOPICS = 'TriviaQA'
PRINT_NQ_TOPICS = 'Natural Question'
TQA_DKRR_RUN = f'runs/run.odqa.DPR-DKRR.{TQA_TOPICS}.hits-100.txt'
NQ_DKRR_RUN = f'runs/run.odqa.DPR-DKRR.{NQ_TOPICS}.hits-100.txt'
HITS_1K = set(['GarT5-RRF', 'DPR-DKRR'])
def format_run_command(raw):
return raw.replace('--encoded-queries', '\\\n --encoded-queries')\
.replace('--encoder', '\\\n --encoder')\
.replace('--topics', '\\\n --topics')\
.replace('--index', '\\\n --index')\
.replace('--output', '\\\n --output')\
.replace('--batch', '\\\n --batch') \
.replace('--threads', '\\\n --threads')\
.replace('--bm25', '\\\n --bm25')\
.replace('--hits 100', '\\\n --hits 100')
def format_hybrid_search_command(raw):
return raw.replace('--encoder', '\\\n\t--encoder')\
.replace(' dense', ' \\\n dense ')\
.replace(' sparse', ' \\\n sparse')\
.replace(' fusion', ' \\\n fusion')\
.replace(' run ', ' \\\n run\t')\
.replace('--output', '\\\n\t--output')\
.replace('--batch', '\\\n\t--batch') \
.replace('--threads', '\\\n\t--threads')\
.replace('--lang', '\\\n\t--lang')\
.replace('--hits 100', '\\\n\t--hits 100')
def format_convert_command(raw):
return raw.replace('--topics', '\\\n --topics')\
.replace('--index', '\\\n --index')\
.replace('--input', '\\\n --input')\
.replace('--output', '\\\n --output')\
def format_eval_command(raw):
return raw.replace('--retrieval ', '\\\n --retrieval ')\
.replace('--topk', '\\\n --topk')
def read_file(f):
fin = open(f, 'r')
text = fin.read()
fin.close()
return text
def generate_table_rows(table_id):
row_cnt = 1
html_rows = []
for model in models['models']:
if model == "GarT5-RRF":
s = Template(row_template_garrrf)
s = s.substitute(table_cnt=table_id,
row_cnt=row_cnt,
model=model,
TQA_Top20=table[model][TQA_TOPICS]["Top20"],
TQA_Top100=table[model][TQA_TOPICS]["Top100"],
NQ_Top20=table[model][NQ_TOPICS]["Top20"],
NQ_Top100=table[model][NQ_TOPICS]["Top100"],
cmd1=f'{commands[model][TQA_TOPICS][0]}',
cmd2=f'{commands[model][TQA_TOPICS][1]}',
cmd3=f'{commands[model][TQA_TOPICS][2]}',
cmd4=f'{commands[model][NQ_TOPICS][0]}',
cmd5=f'{commands[model][NQ_TOPICS][1]}',
cmd6=f'{commands[model][NQ_TOPICS][2]}',
fusion_cmd1=fusion_cmd_tqa[0],
fusion_cmd2=fusion_cmd_nq[0],
convert_cmd1=f'{convert_commands[model][TQA_TOPICS]}',
convert_cmd2=f'{convert_commands[model][NQ_TOPICS]}',
eval_cmd1=f'{eval_commands[model][TQA_TOPICS]}',
eval_cmd2=f'{eval_commands[model][NQ_TOPICS]}'
)
elif model == "GarT5RRF-DKRR-RRF":
s = Template(row_template_rrf)
s = s.substitute(table_cnt=table_id,
row_cnt=row_cnt,
model=model,
TQA_Top20=table[model][TQA_TOPICS]["Top20"],
TQA_Top100=table[model][TQA_TOPICS]["Top100"],
NQ_Top20=table[model][NQ_TOPICS]["Top20"],
NQ_Top100=table[model][NQ_TOPICS]["Top100"],
fusion_cmd1=fusion_cmd_tqa[1],
fusion_cmd2=fusion_cmd_nq[1],
convert_cmd1=f'{convert_commands[model][TQA_TOPICS]}',
convert_cmd2=f'{convert_commands[model][NQ_TOPICS]}',
eval_cmd1=f'{eval_commands[model][TQA_TOPICS]}',
eval_cmd2=f'{eval_commands[model][NQ_TOPICS]}'
)
else:
s = Template(row_template)
s = s.substitute(table_cnt=table_id,
row_cnt=row_cnt,
model=model,
TQA_Top20=table[model][TQA_TOPICS]["Top20"],
TQA_Top100=table[model][TQA_TOPICS]["Top100"],
NQ_Top20=table[model][NQ_TOPICS]["Top20"],
NQ_Top100=table[model][NQ_TOPICS]["Top100"],
cmd1=commands[model][TQA_TOPICS][0],
cmd2=commands[model][NQ_TOPICS][0],
convert_cmd1=f'{convert_commands[model][TQA_TOPICS]}',
convert_cmd2=f'{convert_commands[model][NQ_TOPICS]}',
eval_cmd1=f'{eval_commands[model][TQA_TOPICS]}',
eval_cmd2=f'{eval_commands[model][NQ_TOPICS]}'
)
html_rows.append(s)
row_cnt += 1
return html_rows
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Generate HTML rendering of regression matrix for MS MARCO corpora.')
args = parser.parse_args()
table = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0.0)))
commands = defaultdict(lambda: defaultdict(lambda: []))
eval_commands = defaultdict(lambda: defaultdict(lambda: ''))
convert_commands = defaultdict(lambda: defaultdict(lambda: ''))
html_template = read_file('scripts/repro_matrix/odqa_html.template')
table_template = read_file('scripts/repro_matrix/odqa_html_table.template')
row_template = read_file('scripts/repro_matrix/odqa_html_table_row.template')
row_template_garrrf = read_file('scripts/repro_matrix/odqa_html_table_row_gar-rrf.template')
row_template_rrf = read_file('scripts/repro_matrix/odqa_html_table_row_rrf.template')
tqa_yaml_path = 'pyserini/resources/triviaqa.yaml'
nq_yaml_path = 'pyserini/resources/naturalquestion.yaml'
garrrf_ls = ['answers','titles','sentences']
prefusion_runfile_tqa = []
prefusion_runfile_nq = []
fusion_cmd_tqa = []
fusion_cmd_nq = []
tqa_fused_run = {}
nq_fused_run = {}
with open(tqa_yaml_path) as f_tqa, open(nq_yaml_path) as f_nq:
tqa_yaml_data = yaml.safe_load(f_tqa)
nq_yaml_data = yaml.safe_load(f_nq)
for condition_tqa, condition_nq in zip(tqa_yaml_data['conditions'], nq_yaml_data['conditions']):
name = condition_tqa['model_name']
cmd_template_tqa = condition_tqa['command']
cmd_template_nq = condition_nq['command']
if 'RRF' in name:
if name == 'GarT5-RRF':
runfile_tqa = [f'runs/run.odqa.{name}.{TQA_TOPICS}.{garrrf_ls[i]}.hits-1000.txt' for i in range(len(cmd_template_tqa))]
runfile_nq = [f'runs/run.odqa.{name}.{NQ_TOPICS}.{garrrf_ls[i]}.hits-1000.txt' for i in range(len(cmd_template_nq))]
tqa_fused_run.update({name: runfile_tqa[0].replace('.answers.hits-1000.txt', '.hits-100.fusion.txt')})
nq_fused_run.update({name: runfile_nq[0].replace('.answers.hits-1000.txt', '.hits-100.fusion.txt')})
jsonfile_tqa = tqa_fused_run[name].replace('.txt', '.json').replace('.hits-1000', '')
jsonfile_nq = nq_fused_run[name].replace('.txt', '.json').replace('.hits-1000', '')
elif name == 'GarT5RRF-DKRR-RRF':
jsonfile_tqa = f'runs/run.odqa.{name}.{TQA_TOPICS}.json'
jsonfile_nq = f'runs/run.odqa.{name}.{TQA_TOPICS}.json'
tqa_fused_run.update({name: jsonfile_tqa.replace('.json','.txt')})
nq_fused_run.update({name: jsonfile_nq.replace('.json','.txt')})
else:
raise NameError('Wrong model name in yaml config')
else:
if 'dpr-topics' in name:
runfile_nq = [f'runs/run.odqa.{name}.dpr-nq-test.hits-100.txt']
else:
runfile_nq = [f'runs/run.odqa.{name}.{NQ_TOPICS}.hits-100.txt']
runfile_tqa = [f'runs/run.odqa.{name}.{TQA_TOPICS}.hits-100.txt']
jsonfile_tqa = runfile_tqa[0].replace('.answers', '').replace('.txt', '.json')
jsonfile_nq = runfile_nq[0].replace('.answers', '').replace('.txt', '.json')
display_runfile_tqa = jsonfile_tqa.replace('.json','.txt')
display_runfile_nq = jsonfile_nq.replace('.json','.txt')
# fusion commands
if "RRF" in name:
if name == "GarT5RRF-DKRR-RRF":
nq_runs = ' \\\n\t '.join([NQ_DKRR_RUN, nq_fused_run['GarT5-RRF']])
tqa_runs = ' \\\n\t '.join([TQA_DKRR_RUN, tqa_fused_run['GarT5-RRF']])
else:
tqa_runs = ' \\\n\t '.join(runfile_tqa)
nq_runs = ' \\\n\t '.join(runfile_nq)
fusion_cmd_tqa.append(f'python -m pyserini.fusion \\\n' + \
f' --runs {tqa_runs} \\\n' + \
f' --output {tqa_fused_run[name]} \\\n'
f' --k 100')
fusion_cmd_nq.append(f'python -m pyserini.fusion \\\n' + \
f' --runs {nq_runs} \\\n' + \
f' --output {nq_fused_run[name]} \\\n' + \
f' --k 100')
if name != "GarT5RRF-DKRR-RRF":
hits = 100 if name not in HITS_1K else 1000
cmd_tqa = [Template(cmd_template_tqa[i]).substitute(
output=runfile_tqa[i]) + f" --hits {hits}" for i in range(len(cmd_template_tqa))]
cmd_nq = [Template(cmd_template_nq[i]).substitute(output=runfile_nq[i]) + f" --hits {hits}" for i in range(len(cmd_template_nq))]
if name == 'DPR-Hybrid':
commands[name][TQA_TOPICS].extend([format_hybrid_search_command(i) for i in cmd_tqa])
commands[name][NQ_TOPICS].extend([format_hybrid_search_command(i) for i in cmd_nq])
else:
commands[name][TQA_TOPICS].extend([format_run_command(i) for i in cmd_tqa])
commands[name][NQ_TOPICS].extend([format_run_command(i) for i in cmd_nq])
# convertion commands:
if 'dpr-topics' in name:
temp_nq_topics = 'dpr-nq-test'
else:
temp_nq_topics = NQ_TOPICS
convert_cmd_tqa = f'python -m pyserini.eval.convert_trec_run_to_dpr_retrieval_run ' + \
f'--topics {TQA_TOPICS} ' + \
f'--index wikipedia-dpr ' +\
f'--input {display_runfile_tqa} ' + \
f'--output {jsonfile_tqa}'
convert_cmd_nq = f'python -m pyserini.eval.convert_trec_run_to_dpr_retrieval_run ' + \
f'--topics {temp_nq_topics} ' + \
f'--index wikipedia-dpr ' +\
f'--input {display_runfile_nq} ' + \
f'--output {jsonfile_nq}'
convert_commands[name][TQA_TOPICS] = format_convert_command(convert_cmd_tqa)
convert_commands[name][NQ_TOPICS] = format_convert_command(convert_cmd_nq)
# eval commands
eval_cmd_tqa = f'python -m pyserini.eval.evaluate_dpr_retrieval ' + \
f'--retrieval {jsonfile_tqa} ' + \
f'--topk 20 100'
eval_cmd_nq = f'python -m pyserini.eval.evaluate_dpr_retrieval ' + \
f'--retrieval {jsonfile_nq} ' + \
f'--topk 20 100'
eval_commands[name][TQA_TOPICS] = format_eval_command(eval_cmd_tqa)
eval_commands[name][NQ_TOPICS] = format_eval_command(eval_cmd_nq)
for expected_tqa, expected_nq in zip(condition_tqa['scores'], condition_nq['scores']):
table[name][TQA_TOPICS].update(expected_tqa)
table[name][NQ_TOPICS].update(expected_nq)
tables_html = []
html_rows = generate_table_rows(1)
all_rows = '\n'.join(html_rows)
tables_html.append(Template(table_template).substitute(desc='Models', rows=all_rows))
print(Template(html_template).substitute(
title=f'Retrieval for Open-Domain QA Datasets', tables=' '.join(tables_html)))