Orion Weller commited on
Commit
a09b56d
1 Parent(s): 56649db

saliency maps

Browse files
Files changed (5) hide show
  1. .gitignore +3 -1
  2. analysis.py +93 -1
  3. app.py +88 -11
  4. dataset_loading.py +11 -2
  5. requirements.txt +3 -1
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  datasets/
2
  __pycache__/
3
- env/
 
 
 
1
  datasets/
2
  __pycache__/
3
+ env/
4
+ .ipynb_checkpoints/
5
+ *.ipynb
analysis.py CHANGED
@@ -1,8 +1,21 @@
1
  import pandas as pd
2
  import numpy as np
 
 
 
 
 
3
  import plotly.express as px
4
  import plotly.figure_factory as ff
5
 
 
 
 
 
 
 
 
 
6
 
7
  def results_to_df(results: dict, metric_name: str):
8
  metric_scores = []
@@ -38,4 +51,83 @@ def create_boxplot_diff(results1, results2, metric_name):
38
 
39
  x_axis = f"Difference in {metric_name} from 1 to 2"
40
  fig = px.histogram(pd.DataFrame({x_axis: diff}), x=x_axis, marginal="box")
41
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import numpy as np
3
+ import os
4
+ import torch
5
+ from transformers import pipeline
6
+ import streamlit as st
7
+
8
  import plotly.express as px
9
  import plotly.figure_factory as ff
10
 
11
+ from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization
12
+ from captum.attr import visualization as viz
13
+ from captum import attr
14
+ from captum.attr._utils.visualization import format_word_importances, format_special_tokens, _get_color
15
+
16
+
17
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
18
+
19
 
20
  def results_to_df(results: dict, metric_name: str):
21
  metric_scores = []
 
51
 
52
  x_axis = f"Difference in {metric_name} from 1 to 2"
53
  fig = px.histogram(pd.DataFrame({x_axis: diff}), x=x_axis, marginal="box")
54
+ return fig
55
+
56
+
57
+ def summarize_attributions(attributions):
58
+ attributions = attributions.sum(dim=-1).squeeze(0)
59
+ attributions = attributions / torch.norm(attributions)
60
+ return attributions
61
+
62
+
63
+ def get_words(words, importances):
64
+ words_colored = []
65
+ for word, importance in zip(words, importances[: len(words)]):
66
+ word = format_special_tokens(word)
67
+ color = _get_color(importance)
68
+ unwrapped_tag = '<span style="background-color: {color}; opacity:1.0; line-height:1.75">{word}</span>'.format(
69
+ color=color, word=word
70
+ )
71
+ words_colored.append(unwrapped_tag)
72
+ return words_colored
73
+
74
+ @st.cache_resource
75
+ def get_model(model_name: str):
76
+ if model_name == "MonoT5":
77
+ pipe = pipeline('text2text-generation',
78
+ model='castorini/monot5-small-msmarco-10k',
79
+ tokenizer='castorini/monot5-small-msmarco-10k',
80
+ device='cpu')
81
+ def formatter(query, doc):
82
+ return f"Query: {query} Document: {doc} Relevant:"
83
+
84
+ return pipe, formatter
85
+
86
+ def prep_func(pipe, formatter):
87
+ # variables that only need to be run once
88
+ decoder_input_ids = pipe.tokenizer(["<pad>"], return_tensors="pt", add_special_tokens=False, truncation=True).input_ids.to('cpu')
89
+ decoder_embedding_layer = pipe.model.base_model.decoder.embed_tokens
90
+ decoder_inputs_emb = decoder_embedding_layer(decoder_input_ids)
91
+
92
+ token_false_id = pipe.tokenizer.get_vocab()['▁false']
93
+ token_true_id = pipe.tokenizer.get_vocab()["▁true"]
94
+
95
+ # this function needs to be run for each combination
96
+ @st.cache_data
97
+ def get_saliency(query, doc):
98
+ input_ids = pipe.tokenizer(
99
+ [formatter(query, doc)],
100
+ padding=False,
101
+ truncation=True,
102
+ return_tensors="pt",
103
+ max_length=pipe.tokenizer.model_max_length,
104
+ )["input_ids"].to('cpu')
105
+
106
+ embedding_layer = pipe.model.base_model.encoder.embed_tokens
107
+ inputs_emb = embedding_layer(input_ids)
108
+
109
+ def forward_from_embeddings(inputs_embeds, decoder_inputs_embeds):
110
+ logits = pipe.model.forward(inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds)['logits'][:, -1, :]
111
+ batch_scores = logits[:, [token_false_id, token_true_id]]
112
+ batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
113
+ scores = batch_scores[:, 1].exp() # relevant token
114
+ return scores
115
+
116
+ lig = attr.Saliency(forward_from_embeddings)
117
+ attributions_ig, delta = lig.attribute(
118
+ inputs=(inputs_emb, decoder_inputs_emb)
119
+ )
120
+ attributions_normed = summarize_attributions(attributions_ig)
121
+ return "\n".join(get_words(pipe.tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).tolist()), attributions_normed))
122
+
123
+ return get_saliency
124
+
125
+
126
+ if __name__ == "__main__":
127
+ query = "how to add dll to visual studio?"
128
+ doc = "StackOverflow In the days of 16-bit Windows, a WPARAM was a 16-bit word, while LPARAM was a 32-bit long. These distinctions went away in Win32; they both became 32-bit values. ... WPARAM is defined as UINT_PTR , which in 64-bit Windows is an unsigned, 64-bit value."
129
+ model, formatter = get_model("MonoT5")
130
+ get_saliency = prep_func(model, formatter)
131
+ print(get_saliency(query, doc))
132
+
133
+
app.py CHANGED
@@ -13,9 +13,10 @@ import plotly.express as px
13
 
14
  from constants import ALL_DATASETS, ALL_METRICS
15
  from dataset_loading import get_dataset, load_run, load_local_qrels, load_local_corpus, load_local_queries
16
- from analysis import create_boxplot_1df, create_boxplot_2df, create_boxplot_diff
17
 
18
 
 
19
  st.set_page_config(layout="wide")
20
 
21
 
@@ -41,6 +42,7 @@ def check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus)
41
  return True
42
  return False
43
 
 
44
  def validate(config_option, file_loaded):
45
  if config_option != "None" and file_loaded is None:
46
  st.error("Please upload a file for " + config_option)
@@ -90,6 +92,14 @@ with st.sidebar:
90
  incorrect_only = st.checkbox("Show only incorrect instances", value=False)
91
  one_better_than_two = st.checkbox("Show only instances where run 1 is better than run 2", value=False)
92
  two_better_than_one = st.checkbox("Show only instances where run 2 is better than run 1", value=False)
 
 
 
 
 
 
 
 
93
  advanced_options1 = st.checkbox("Show advanced options for Run 1", value=False)
94
  doc_expansion1 = doc_expansion2 = None
95
  query_expansion1 = query_expansion2 = None
@@ -307,9 +317,16 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
307
  if doc_expansion1 is not None and run1_uses_doc_expansion != "None" and not show_orig_rel:
308
  alt_text = doc_expansion1[docid]["text"]
309
  text = combine(text, alt_text, run1_uses_doc_expansion)
310
- st.text_area(f"{docid}:", text)
311
 
312
-
 
 
 
 
 
 
 
 
313
 
314
  pred_doc = run1_pandas[run1_pandas.doc_id.isin(relevant_docs)]
315
  rank_pred = pred_doc[pred_doc.qid == str(inst_num)]["rank"].tolist()
@@ -320,6 +337,7 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
320
  ranking_str = "--"
321
  rank_col.metric(f"Rank of Relevant Doc(s)", ranking_str)
322
 
 
323
  st.divider()
324
 
325
  # top ranked
@@ -336,10 +354,22 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
336
  for d_idx, doc in enumerate(run1_top_n_docs):
337
  alt_text = run1_top_n_docs_alt[d_idx]["text"]
338
  doc_text = combine(doc["text"], alt_text, run1_uses_doc_expansion)
339
- st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}")
 
 
 
 
 
 
340
  else:
341
  for d_idx, doc in enumerate(run1_top_n_docs):
342
- st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}")
 
 
 
 
 
 
343
  st.divider()
344
 
345
  # none checked
@@ -384,20 +414,28 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
384
  combined_text2 = combine(query_text_og, alt_text2, run2_uses_query_expansion)
385
  col_run1.markdown(combined_text1)
386
  col_run2.markdown(combined_text2)
 
 
387
  elif run1_uses_query_expansion != "None":
388
  alt_text = query_expansion1[str(inst_num)]
389
  combined_text1 = combine(query_text_og, alt_text, run1_uses_query_expansion)
390
  col_run1.markdown(combined_text1)
391
  col_run2.markdown(query_text_og)
 
 
392
  elif run2_uses_query_expansion != "None":
393
  alt_text = query_expansion2[str(inst_num)]
394
  combined_text2 = combine(query_text_og, alt_text, run2_uses_query_expansion)
395
  col_run1.markdown(query_text_og)
396
  col_run2.markdown(combined_text2)
 
 
397
  else:
398
  query_text = query_text_og
399
  col_run1.markdown(query_text)
400
  col_run2.markdown(query_text)
 
 
401
 
402
  st.divider()
403
 
@@ -420,13 +458,27 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
420
  if doc_expansion1 is not None and run1_uses_doc_expansion != "None" and not show_orig_rel1:
421
  alt_text = doc_expansion1[docid]["text"]
422
  text = combine(text, alt_text, run1_uses_doc_expansion)
423
- col_run1.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}1")
 
 
 
 
 
 
 
424
 
425
  for (docid, title, text) in doc_texts:
426
  if doc_expansion2 is not None and run2_uses_doc_expansion != "None" and not show_orig_rel2:
427
  alt_text = doc_expansion2[docid]["text"]
428
  text = combine(text, alt_text, run2_uses_doc_expansion)
429
- col_run2.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}2")
 
 
 
 
 
 
 
430
 
431
  # top ranked
432
  # NOTE: BEIR calls trec_eval which ranks by score, then doc_id for ties
@@ -474,10 +526,23 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
474
  for d_idx, doc in enumerate(run1_top_n_docs):
475
  alt_text = run1_top_n_docs_alt[d_idx]["text"]
476
  doc_text = combine(doc["text"], alt_text, run1_uses_doc_expansion)
477
- col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}1")
 
 
 
 
 
 
478
  else:
479
  for d_idx, doc in enumerate(run1_top_n_docs):
480
- col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}1")
 
 
 
 
 
 
 
481
 
482
  if col_run2.checkbox('Show top ranked documents for Run 2', key=f"{inst_index}top-2run"):
483
  col_run2.subheader("Top N Ranked Documents")
@@ -492,10 +557,22 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
492
  for d_idx, doc in enumerate(run2_top_n_docs):
493
  alt_text = run2_top_n_docs_alt[d_idx]["text"]
494
  doc_text = combine(doc["text"], alt_text, run2_uses_doc_expansion)
495
- col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}2")
 
 
 
 
 
 
496
  else:
497
  for d_idx, doc in enumerate(run2_top_n_docs):
498
- col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}2")
 
 
 
 
 
 
499
 
500
  st.divider()
501
 
 
13
 
14
  from constants import ALL_DATASETS, ALL_METRICS
15
  from dataset_loading import get_dataset, load_run, load_local_qrels, load_local_corpus, load_local_queries
16
+ from analysis import create_boxplot_1df, create_boxplot_2df, create_boxplot_diff, get_model, prep_func
17
 
18
 
19
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
20
  st.set_page_config(layout="wide")
21
 
22
 
 
42
  return True
43
  return False
44
 
45
+
46
  def validate(config_option, file_loaded):
47
  if config_option != "None" and file_loaded is None:
48
  st.error("Please upload a file for " + config_option)
 
92
  incorrect_only = st.checkbox("Show only incorrect instances", value=False)
93
  one_better_than_two = st.checkbox("Show only instances where run 1 is better than run 2", value=False)
94
  two_better_than_one = st.checkbox("Show only instances where run 2 is better than run 1", value=False)
95
+ use_model_saliency = st.checkbox("Use model saliency (slow!)", value=False)
96
+ if use_model_saliency:
97
+ # choose from a list of models
98
+ model_name = st.selectbox("Choose from a list of models", ["MonoT5"])
99
+ model, formatter = get_model("MonoT5")
100
+ get_saliency = prep_func(model, formatter)
101
+
102
+
103
  advanced_options1 = st.checkbox("Show advanced options for Run 1", value=False)
104
  doc_expansion1 = doc_expansion2 = None
105
  query_expansion1 = query_expansion2 = None
 
317
  if doc_expansion1 is not None and run1_uses_doc_expansion != "None" and not show_orig_rel:
318
  alt_text = doc_expansion1[docid]["text"]
319
  text = combine(text, alt_text, run1_uses_doc_expansion)
 
320
 
321
+ if use_model_saliency:
322
+ if st.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency", value=False):
323
+ st.markdown(get_saliency(query_text, doc_texts),unsafe_allow_html=True)
324
+ else:
325
+ st.text_area(f"{docid}:", text)
326
+
327
+ else:
328
+ st.text_area(f"{docid}:", text)
329
+
330
 
331
  pred_doc = run1_pandas[run1_pandas.doc_id.isin(relevant_docs)]
332
  rank_pred = pred_doc[pred_doc.qid == str(inst_num)]["rank"].tolist()
 
337
  ranking_str = "--"
338
  rank_col.metric(f"Rank of Relevant Doc(s)", ranking_str)
339
 
340
+
341
  st.divider()
342
 
343
  # top ranked
 
354
  for d_idx, doc in enumerate(run1_top_n_docs):
355
  alt_text = run1_top_n_docs_alt[d_idx]["text"]
356
  doc_text = combine(doc["text"], alt_text, run1_uses_doc_expansion)
357
+ if use_model_saliency:
358
+ if st.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency", value=False):
359
+ st.markdown(get_saliency(query_text, doc_text),unsafe_allow_html=True)
360
+ else:
361
+ st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}")
362
+ else:
363
+ st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}")
364
  else:
365
  for d_idx, doc in enumerate(run1_top_n_docs):
366
+ if use_model_saliency:
367
+ if st.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked", value=False):
368
+ st.markdown(get_saliency(query_text, doc),unsafe_allow_html=True)
369
+ else:
370
+ st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}")
371
+ else:
372
+ st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}")
373
  st.divider()
374
 
375
  # none checked
 
414
  combined_text2 = combine(query_text_og, alt_text2, run2_uses_query_expansion)
415
  col_run1.markdown(combined_text1)
416
  col_run2.markdown(combined_text2)
417
+ query_text1 = combined_text1
418
+ query_text2 = combined_text2
419
  elif run1_uses_query_expansion != "None":
420
  alt_text = query_expansion1[str(inst_num)]
421
  combined_text1 = combine(query_text_og, alt_text, run1_uses_query_expansion)
422
  col_run1.markdown(combined_text1)
423
  col_run2.markdown(query_text_og)
424
+ query_text1 = combined_text1
425
+ query_text2 = query_text_og
426
  elif run2_uses_query_expansion != "None":
427
  alt_text = query_expansion2[str(inst_num)]
428
  combined_text2 = combine(query_text_og, alt_text, run2_uses_query_expansion)
429
  col_run1.markdown(query_text_og)
430
  col_run2.markdown(combined_text2)
431
+ query_text1 = query_text_og
432
+ query_text2 = combined_text2
433
  else:
434
  query_text = query_text_og
435
  col_run1.markdown(query_text)
436
  col_run2.markdown(query_text)
437
+ query_text1 = query_text
438
+ query_text2 = query_text
439
 
440
  st.divider()
441
 
 
458
  if doc_expansion1 is not None and run1_uses_doc_expansion != "None" and not show_orig_rel1:
459
  alt_text = doc_expansion1[docid]["text"]
460
  text = combine(text, alt_text, run1_uses_doc_expansion)
461
+
462
+ if use_model_saliency:
463
+ if col_run1.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{docid}relevant", value=False):
464
+ col_run1.markdown(get_saliency(query_text1, text),unsafe_allow_html=True)
465
+ else:
466
+ col_run1.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}1")
467
+ else:
468
+ col_run1.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}1")
469
 
470
  for (docid, title, text) in doc_texts:
471
  if doc_expansion2 is not None and run2_uses_doc_expansion != "None" and not show_orig_rel2:
472
  alt_text = doc_expansion2[docid]["text"]
473
  text = combine(text, alt_text, run2_uses_doc_expansion)
474
+
475
+ if use_model_saliency:
476
+ if col_run2.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{docid}relevant2", value=False):
477
+ col_run2.markdown(get_saliency(query_text2, text),unsafe_allow_html=True)
478
+ else:
479
+ col_run2.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}2")
480
+ else:
481
+ col_run2.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}2")
482
 
483
  # top ranked
484
  # NOTE: BEIR calls trec_eval which ranks by score, then doc_id for ties
 
526
  for d_idx, doc in enumerate(run1_top_n_docs):
527
  alt_text = run1_top_n_docs_alt[d_idx]["text"]
528
  doc_text = combine(doc["text"], alt_text, run1_uses_doc_expansion)
529
+ if use_model_saliency:
530
+ if col_run1.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked1", value=False):
531
+ col_run1.markdown(get_saliency(query_text1, doc_text),unsafe_allow_html=True)
532
+ else:
533
+ col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}1")
534
+ else:
535
+ col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}1")
536
  else:
537
  for d_idx, doc in enumerate(run1_top_n_docs):
538
+ if use_model_saliency:
539
+ if col_run1.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked1", value=False):
540
+ col_run1.markdown(get_saliency(query_text1, doc),unsafe_allow_html=True)
541
+ else:
542
+ col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}1")
543
+ else:
544
+ col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}1")
545
+
546
 
547
  if col_run2.checkbox('Show top ranked documents for Run 2', key=f"{inst_index}top-2run"):
548
  col_run2.subheader("Top N Ranked Documents")
 
557
  for d_idx, doc in enumerate(run2_top_n_docs):
558
  alt_text = run2_top_n_docs_alt[d_idx]["text"]
559
  doc_text = combine(doc["text"], alt_text, run2_uses_doc_expansion)
560
+ if use_model_saliency:
561
+ if col_run2.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked2", value=False):
562
+ col_run2.markdown(get_saliency(query_text2, doc_text),unsafe_allow_html=True)
563
+ else:
564
+ col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}2")
565
+ else:
566
+ col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}2")
567
  else:
568
  for d_idx, doc in enumerate(run2_top_n_docs):
569
+ if use_model_saliency:
570
+ if col_run2.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked2", value=False):
571
+ col_run2.markdown(get_saliency(query_text2, doc),unsafe_allow_html=True)
572
+ else:
573
+ col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}2")
574
+ else:
575
+ col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}2")
576
 
577
  st.divider()
578
 
dataset_loading.py CHANGED
@@ -14,6 +14,8 @@ import ir_datasets
14
 
15
  from constants import BEIR, IR_DATASETS, LOCAL_DATASETS
16
 
 
 
17
  def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]):
18
  if corpus_file is None:
19
  return None
@@ -39,6 +41,8 @@ def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]):
39
  }
40
  return did2text
41
 
 
 
42
  def load_local_queries(queries_file):
43
  if queries_file is None:
44
  return None
@@ -60,6 +64,8 @@ def load_local_queries(queries_file):
60
  qid2text[inst[id_key]] = inst["text"]
61
  return qid2text
62
 
 
 
63
  def load_local_qrels(qrels_file):
64
  if qrels_file is None:
65
  return None
@@ -84,6 +90,7 @@ def load_local_qrels(qrels_file):
84
  return qid2did2label
85
 
86
 
 
87
  def load_run(f_run):
88
  run = pytrec_eval.parse_run(copy.deepcopy(f_run))
89
  # convert bytes to strings for keys
@@ -102,7 +109,7 @@ def load_run(f_run):
102
  return new_run, run_pandas
103
 
104
 
105
-
106
  def load_jsonl(f):
107
  did2text = defaultdict(list)
108
  sub_did2text = {}
@@ -126,7 +133,7 @@ def load_jsonl(f):
126
  return did2text, sub_did2text
127
 
128
 
129
-
130
  def get_beir(dataset: str):
131
  url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
132
  out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
@@ -134,6 +141,7 @@ def get_beir(dataset: str):
134
  return GenericDataLoader(data_folder=data_path).load(split="test")
135
 
136
 
 
137
  def get_ir_datasets(dataset_name: str):
138
  dataset = ir_datasets.load(dataset_name)
139
  queries = {}
@@ -145,6 +153,7 @@ def get_ir_datasets(dataset_name: str):
145
  return dataset.doc_store(), queries, dataset.qrels_dict()
146
 
147
 
 
148
  def get_dataset(dataset_name: str):
149
  if dataset_name == "":
150
  return {}, {}, {}
 
14
 
15
  from constants import BEIR, IR_DATASETS, LOCAL_DATASETS
16
 
17
+
18
+ @st.cache_data
19
  def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]):
20
  if corpus_file is None:
21
  return None
 
41
  }
42
  return did2text
43
 
44
+
45
+ @st.cache_data
46
  def load_local_queries(queries_file):
47
  if queries_file is None:
48
  return None
 
64
  qid2text[inst[id_key]] = inst["text"]
65
  return qid2text
66
 
67
+
68
+ @st.cache_data
69
  def load_local_qrels(qrels_file):
70
  if qrels_file is None:
71
  return None
 
90
  return qid2did2label
91
 
92
 
93
+ @st.cache_data
94
  def load_run(f_run):
95
  run = pytrec_eval.parse_run(copy.deepcopy(f_run))
96
  # convert bytes to strings for keys
 
109
  return new_run, run_pandas
110
 
111
 
112
+ @st.cache_data
113
  def load_jsonl(f):
114
  did2text = defaultdict(list)
115
  sub_did2text = {}
 
133
  return did2text, sub_did2text
134
 
135
 
136
+ @st.cache_data
137
  def get_beir(dataset: str):
138
  url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
139
  out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
 
141
  return GenericDataLoader(data_folder=data_path).load(split="test")
142
 
143
 
144
+ @st.cache_data
145
  def get_ir_datasets(dataset_name: str):
146
  dataset = ir_datasets.load(dataset_name)
147
  queries = {}
 
153
  return dataset.doc_store(), queries, dataset.qrels_dict()
154
 
155
 
156
+ @st.cache_data
157
  def get_dataset(dataset_name: str):
158
  if dataset_name == "":
159
  return {}, {}, {}
requirements.txt CHANGED
@@ -5,4 +5,6 @@ streamlit==1.24.1
5
  ir_datasets==0.5.5
6
  pyserini==0.21.0
7
  torch==2.0.1
8
- plotly==5.15.0
 
 
 
5
  ir_datasets==0.5.5
6
  pyserini==0.21.0
7
  torch==2.0.1
8
+ plotly==5.15.0
9
+ captum==0.6.0
10
+ protobuf==4.21.11