gsarti commited on
Commit
9652fe2
·
1 Parent(s): 0c38d9e

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.pyc
2
+ *.html
3
+ *.json
4
+ .DS_Store
README.md CHANGED
@@ -1,13 +1,21 @@
1
  ---
2
- title: Gemma Lxt Mirage
3
- emoji: 📚
4
- colorFrom: red
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.41.0
8
  app_file: app.py
9
- pinned: false
10
  license: apache-2.0
 
 
 
 
 
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: MIRAGE
3
+ emoji: 🌴
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.42.0
8
  app_file: app.py
9
+ pinned: true
10
  license: apache-2.0
11
+ short_description: Model Internals to generate RAG citations
12
+ tags:
13
+ - answer-attribution
14
+ - interpretability
15
+ - context-usage
16
+ - language-modeling
17
+ - arxiv:2406.13663
18
+ - arxiv:2402.05602
19
  ---
20
 
21
+ Demo for the paper [Model Internals-based Answer Attribution for Trustworthy Retrieval-Augmented Generation](https://arxiv.org/abs/2406.13663) using the AttnLRP method from [AttnLRP: Attention-Aware Layer-Wise Relevance Propagation for Transformers](https://arxiv.org/abs/2402.05602).
app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import bm25s
4
+ import spaces
5
+ import gradio as gr
6
+ import gradio_iframe
7
+ from bm25s.hf import BM25HF
8
+ from rerankers import Reranker
9
+
10
+ from inseq import register_step_function, load_model
11
+ from inseq.attr import StepFunctionArgs
12
+ from inseq.commands.attribute_context import visualize_attribute_context
13
+ from inseq.utils.contrast_utils import _setup_contrast_args
14
+ from lxt.models.llama import LlamaForCausalLM, attnlrp
15
+ from transformers import AutoTokenizer
16
+ from lxt.functional import softmax, add2, mul2
17
+ from inseq.commands.attribute_context.attribute_context import attribute_context_with_model, AttributeContextArgs
18
+
19
+ from style import custom_css
20
+ from citations import pecore_citation, mirage_citation, inseq_citation, lxt_citation
21
+ from examples import examples
22
+
23
+ model_id = "HuggingFaceTB/SmolLM-360M-Instruct"
24
+ ranker = Reranker("answerdotai/answerai-colbert-small-v1", model_type='colbert')
25
+ retriever = BM25HF.load_from_hub("xhluca/bm25s-nq-index", load_corpus=True, mmap=True)
26
+ hf_model = LlamaForCausalLM.from_pretrained(model_id)
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
28
+ attnlrp.register(hf_model)
29
+ model = load_model(hf_model, "saliency", tokenizer=tokenizer)
30
+ # Needed since the <|im_start|> token is also the BOS
31
+ model.bos_token = "<|endoftext|>"
32
+ model.bos_token_id = 0
33
+
34
+
35
+ def lxt_probability_fn(args: StepFunctionArgs):
36
+ logits = args.attribution_model.output2logits(args.forward_output)
37
+ target_ids = args.target_ids.reshape(logits.shape[0], 1).to(logits.device)
38
+ logits = softmax(logits, dim=-1)
39
+ return logits.gather(-1, target_ids).squeeze(-1)
40
+
41
+ def lxt_contrast_prob_fn(
42
+ args: StepFunctionArgs,
43
+ contrast_sources = None,
44
+ contrast_targets = None,
45
+ contrast_targets_alignments: list[list[tuple[int, int]]] | None = None,
46
+ contrast_force_inputs: bool = False,
47
+ skip_special_tokens: bool = False,
48
+ ):
49
+ c_args = _setup_contrast_args(
50
+ args,
51
+ contrast_sources=contrast_sources,
52
+ contrast_targets=contrast_targets,
53
+ contrast_targets_alignments=contrast_targets_alignments,
54
+ contrast_force_inputs=contrast_force_inputs,
55
+ skip_special_tokens=skip_special_tokens,
56
+ )
57
+ return lxt_probability_fn(c_args)
58
+
59
+ def lxt_contrast_prob_diff_fn(
60
+ args: StepFunctionArgs,
61
+ contrast_sources = None,
62
+ contrast_targets = None,
63
+ contrast_targets_alignments: list[list[tuple[int, int]]] | None = None,
64
+ contrast_force_inputs: bool = False,
65
+ skip_special_tokens: bool = False,
66
+ ):
67
+ model_probs = lxt_probability_fn(args)
68
+ contrast_probs = lxt_contrast_prob_fn(
69
+ args=args,
70
+ contrast_sources=contrast_sources,
71
+ contrast_targets=contrast_targets,
72
+ contrast_targets_alignments=contrast_targets_alignments,
73
+ contrast_force_inputs=contrast_force_inputs,
74
+ skip_special_tokens=skip_special_tokens,
75
+ ).to(model_probs.device)
76
+ return add2(model_probs, mul2(contrast_probs, -1))
77
+
78
+
79
+ def set_interactive_settings(rag_setting, retrieve_k, top_k, custom_context):
80
+ if rag_setting in ("Retrieve with BM25", "Rerank with ColBERT"):
81
+ return gr.Slider(interactive=True), gr.Slider(interactive=True), gr.Textbox(placeholder="Context will be retrieved automatically. Change mode to 'Use Custom Context' to specify your own.", interactive=False)
82
+ elif rag_setting == "Use Custom Context":
83
+ return gr.Slider(interactive=False), gr.Slider(interactive=False), gr.Textbox(placeholder="Insert a custom context...", interactive=True)
84
+
85
+ @spaces.GPU()
86
+ def generate(query, max_new_tokens, top_p, temperature, retrieve_k, top_k, rag_setting, custom_context, model_size, progress=gr.Progress()):
87
+ global model, model_id
88
+ if rag_setting == "Use Custom Context":
89
+ docs = custom_context.split("\n\n")
90
+ progress(0.1, desc="Using custom context...")
91
+ else:
92
+ if not query:
93
+ raise gr.Error("Please enter a query.")
94
+ progress(0, desc="Retrieving with BM25...")
95
+ q = bm25s.tokenize(query)
96
+ results = retriever.retrieve(q, k=retrieve_k)
97
+ if rag_setting == "Rerank with ColBERT":
98
+ progress(0.1, desc="Reranking with ColBERT...")
99
+ docs = [x["text"] for x in results.documents[0]]
100
+ out = ranker.rank(query=query, docs=docs)
101
+ docs = [out.results[i].document.text for i in range(top_k)]
102
+ else:
103
+ docs = [results.documents[0][i]["text"] for i in range(top_k)]
104
+ docs = [re.sub(r"\[\d+\]", "", doc) for doc in docs]
105
+ curr_model_id = f"HuggingFaceTB/SmolLM-{model_size}-Instruct"
106
+ if model is None or model.model_name != curr_model_id:
107
+ progress(0.2, desc="Loading model...")
108
+ model_id = curr_model_id
109
+ hf_model = LlamaForCausalLM.from_pretrained(model_id)
110
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
111
+ attnlrp.register(hf_model)
112
+ model = load_model(hf_model, "saliency", tokenizer=tokenizer)
113
+ progress(0.3, desc="Attributing with LXT...")
114
+ lm_rag_prompting_example = AttributeContextArgs(
115
+ model_name_or_path=model_id,
116
+ input_context_text="\n\n".join(docs),
117
+ input_current_text=query,
118
+ output_template="{current}",
119
+ attributed_fn="lxt_contrast_prob_diff",
120
+ input_template="<|im_start|>user\n### Context\n{context}\n\n### Query\n{current}<|im_end|>\n<|im_start|>assistant\n",
121
+ contextless_input_current_text="<|im_start|>user\n### Query\n{current}<|im_end|>\n<|im_start|>assistant\n",
122
+ attribution_method="saliency",
123
+ show_viz=False,
124
+ show_intermediate_outputs=False,
125
+ context_sensitivity_std_threshold=1,
126
+ decoder_input_output_separator=" ",
127
+ special_tokens_to_keep=["<|im_start|>", "<|endoftext|>"],
128
+ generation_kwargs={"max_new_tokens": max_new_tokens, "top_p": top_p, "temperature": temperature},
129
+ attribution_aggregators=["sum"],
130
+ rescale_attributions=True,
131
+ save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"),
132
+ viz_path=os.path.join(os.path.dirname(__file__), "outputs/output.html"),
133
+ )
134
+ out = attribute_context_with_model(lm_rag_prompting_example, model)
135
+ html = visualize_attribute_context(out, show_viz=False, return_html=True)
136
+ return [
137
+ gradio_iframe.iFrame(html, height=500, visible=True),
138
+ gr.DownloadButton(
139
+ label="📂 Download output",
140
+ value=os.path.join(os.path.dirname(__file__), "outputs/output.json"),
141
+ visible=True,
142
+ ),
143
+ gr.DownloadButton(
144
+ label="🔍 Download HTML",
145
+ value=os.path.join(os.path.dirname(__file__), "outputs/output.html"),
146
+ visible=True,
147
+ )
148
+ ]
149
+
150
+
151
+ register_step_function(lxt_contrast_prob_diff_fn, "lxt_contrast_prob_diff", overwrite=True)
152
+
153
+
154
+ with gr.Blocks(css=custom_css) as demo:
155
+ with gr.Row():
156
+ with gr.Column(min_width=500):
157
+ gr.HTML(f'<h1><img src="file/img/mirage_logo_white_contour.png" width=300px /></h1>')
158
+ text = gr.Markdown(
159
+ "This demo showcases an end-to-end usage of model internals for RAG answer attribution with the <a href='https://openreview.net/forum?id=XTHfNGI3zT' target='_blank'>PECoRe</a> framework, as described in our <a href='https://arxiv.org/abs/2406.13663' target='_blank'>MIRAGE</a> paper.<br>"
160
+ "Insert a query to retrieve relevant contexts, generate an answer and attribute its context-sensitive components. An interactive <a href='https://github.com/google-deepmind/treescope' target='_blank'>Treescope</a> visualization will appear in the green square.<br>"
161
+ "📋 <i>Retrieval is performed on <a href='https://huggingface.co/datasets/google-research-datasets/natural_questions' target='_blank'>Natural Questions</a> using <a href='https://github.com/xhluca/bm25s' target='_blank'>BM25S</a>, with optional reranking via <a href='https://huggingface.co/answerdotai/answerai-colbert-small-v1' target='_blank'>ColBERT</a>."
162
+ " <a href='https://huggingface.co/blog/smollm' target='_blank'>SmolLM</a> models are used for generation, while <a href='https://github.com/inseq-team/inseq' target='_blank'>Inseq</a> and <a href='https://github.com/rachtibat/LRP-eXplains-Transformers' target='_blank'>LXT</a> are used for attribution.</i><br>"
163
+ "➡️ <i>For more details, see also our <a href='https://huggingface.co/spaces/gsarti/pecore' target='_blank'>PECoRe Demo</a>",
164
+ )
165
+ with gr.Row():
166
+ with gr.Column():
167
+ query = gr.Textbox(
168
+ placeholder="Insert a query for the language model...",
169
+ label="Model query",
170
+ interactive=True,
171
+ lines=2,
172
+ )
173
+ attribute_input_examples = gr.Examples(
174
+ examples,
175
+ inputs=[query],
176
+ examples_per_page=2,
177
+ )
178
+ with gr.Accordion("⚙️ Parameters", open=False):
179
+ with gr.Row():
180
+ model_size = gr.Radio(
181
+ ["135M", "360M", "1.7B"],
182
+ value="360M",
183
+ label="Model size",
184
+ interactive=True
185
+ )
186
+ with gr.Row():
187
+ rag_setting = gr.Radio(
188
+ ["Retrieve with BM25", "Rerank with ColBERT", "Use Custom Context"],
189
+ value="Rerank with ColBERT",
190
+ label="Mode",
191
+ interactive=True
192
+ )
193
+ with gr.Row():
194
+ retrieve_k = gr.Slider(1, 500, value=100, step=1, label="# Docs to Retrieve", interactive=True)
195
+ top_k = gr.Slider(1, 10, value=3, step=1, label="# Docs in Context", interactive=True)
196
+ custom_context = gr.Textbox(
197
+ placeholder="Context will be retrieved automatically. Change mode to 'Use Custom Context' to specify your own.",
198
+ label="Custom context",
199
+ interactive=False,
200
+ lines=4,
201
+ )
202
+ with gr.Row():
203
+ max_new_tokens = gr.Slider(0, 500, value=50, step=5.0, label="Max new tokens", interactive=True)
204
+ top_p = gr.Slider(0, 1, value=1, step=0.01, label="Top P", interactive=True)
205
+ temperature = gr.Slider(0, 1, value=0, step=0.01, label="Temperature", interactive=True)
206
+ with gr.Accordion("📝 Citation", open=False):
207
+ gr.Markdown("Using PECoRe for model internals-based RAG answer attribution is discussed in:")
208
+ gr.Code(mirage_citation, interactive=False, label="MIRAGE (Qi, Sarti et al., 2024)")
209
+ gr.Markdown("To refer to the original PECoRe paper, cite:")
210
+ gr.Code(pecore_citation, interactive=False, label="PECoRe (Sarti et al., 2024)")
211
+ gr.Markdown("The Inseq implementation used in this work (<a href=\"https://inseq.org/en/latest/main_classes/cli.html#attribute-context\"><code>inseq attribute-context</code></a>, including this demo) can be cited with:")
212
+ gr.Code(inseq_citation, interactive=False, label="Inseq (Sarti et al., 2023)")
213
+ gr.Markdown("The AttnLRP attribution method used in this demo via the LXT library can be cited with:")
214
+ gr.Code(lxt_citation, interactive=False, label="AttnLRP (Achtibat et al., 2024)")
215
+ btn = gr.Button("Submit", variant="primary")
216
+ with gr.Column():
217
+ attribute_context_out = gradio_iframe.iFrame(height=400, visible=True)
218
+ with gr.Row(equal_height=True):
219
+ download_output_file_button = gr.DownloadButton(
220
+ "📂 Download output",
221
+ visible=False,
222
+ )
223
+ download_output_html_button = gr.DownloadButton(
224
+ "🔍 Download HTML",
225
+ visible=False,
226
+ value=os.path.join(
227
+ os.path.dirname(__file__), "outputs/output.html"
228
+ ),
229
+ )
230
+ with gr.Row(elem_classes="footer-container"):
231
+ with gr.Column():
232
+ gr.Markdown("""<div class="footer-custom-block"><b>Powered by</b> <a href='https://github.com/inseq-team/inseq' target='_blank'><img src="file/img/inseq_logo_white_contour.png" width=150px /></a> <a href='https://github.com/rachtibat/LRP-eXplains-Transformers' target='_blank'><img src="file/img/lxt_logo.png" width=150px /></a></div>""")
233
+ with gr.Column():
234
+ with gr.Row(elem_classes="footer-custom-block"):
235
+ with gr.Column(scale=0.30, min_width=150):
236
+ gr.Markdown("""<b>Built by <a href="https://gsarti.com" target="_blank">Gabriele Sarti</a><br> with the support of</b>""")
237
+ with gr.Column(scale=0.30, min_width=120):
238
+ gr.Markdown("""<a href='https://www.rug.nl/research/clcg/research/cl/' target='_blank'><img src="file/img/rug_logo_white_contour.png" width=170px /></a>""")
239
+ with gr.Column(scale=0.30, min_width=120):
240
+ gr.Markdown("""<a href='https://projects.illc.uva.nl/indeep/' target='_blank'><img src="file/img/indeep_logo_white_contour.png" width=100px /></a>""")
241
+
242
+ rag_setting.change(
243
+ fn=set_interactive_settings,
244
+ inputs=[rag_setting, retrieve_k, top_k, custom_context],
245
+ outputs=[retrieve_k, top_k, custom_context],
246
+ )
247
+
248
+ btn.click(
249
+ fn=generate,
250
+ inputs=[
251
+ query,
252
+ max_new_tokens,
253
+ top_p,
254
+ temperature,
255
+ retrieve_k,
256
+ top_k,
257
+ rag_setting,
258
+ custom_context,
259
+ model_size,
260
+ ],
261
+ outputs=[
262
+ attribute_context_out,
263
+ download_output_file_button,
264
+ download_output_html_button,
265
+ ]
266
+ )
267
+
268
+ demo.queue(api_open=False, max_size=20).launch(allowed_paths=["img/", "outputs/"], show_api=False)
citations.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pecore_citation = """@inproceedings{sarti-etal-2023-quantifying,
2
+ title = "Quantifying the Plausibility of Context Reliance in Neural Machine Translation",
3
+ author = "Sarti, Gabriele and
4
+ Chrupa{\l}a, Grzegorz and
5
+ Nissim, Malvina and
6
+ Bisazza, Arianna",
7
+ booktitle = "The Twelfth International Conference on Learning Representations (ICLR 2024)",
8
+ month = may,
9
+ year = "2024",
10
+ address = "Vienna, Austria",
11
+ publisher = "OpenReview",
12
+ url = "https://openreview.net/forum?id=XTHfNGI3zT"
13
+ }"""
14
+
15
+ inseq_citation = """@inproceedings{sarti-etal-2023-inseq,
16
+ title = "Inseq: An Interpretability Toolkit for Sequence Generation Models",
17
+ author = "Sarti, Gabriele and
18
+ Feldhus, Nils and
19
+ Sickert, Ludwig and
20
+ van der Wal, Oskar and
21
+ Nissim, Malvina and
22
+ Bisazza, Arianna",
23
+ booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)",
24
+ month = jul,
25
+ year = "2023",
26
+ address = "Toronto, Canada",
27
+ publisher = "Association for Computational Linguistics",
28
+ url = "https://aclanthology.org/2023.acl-demo.40",
29
+ pages = "421--435",
30
+ }"""
31
+
32
+ mirage_citation = """@article{qi-sarti-etal-2024-mirage,
33
+ title = "Model Internals-based Answer Attribution for Trustworthy Retrieval-Augmented Generation",
34
+ author = "Qi, Jirui and Sarti, Gabriele and Fern{\'a}ndez, Raquel and Bisazza, Arianna",
35
+ journal = "ArXiv",
36
+ month = jun,
37
+ year = "2024",
38
+ volume = {abs/2406.13663},
39
+ url = {https://arxiv.org/abs/2406.13663},
40
+ }"""
41
+
42
+ lxt_citation = """@inproceedings{achtibat-etal-2024-attnlrp,
43
+ title = {{A}ttn{LRP}: Attention-Aware Layer-Wise Relevance Propagation for Transformers},
44
+ author = {Achtibat, Reduan and Hatefi, Sayed Mohammad Vakilzadeh and Dreyer, Maximilian and Jain, Aakriti and Wiegand, Thomas and Lapuschkin, Sebastian and Samek, Wojciech},
45
+ booktitle = {Proceedings of the 41st International Conference on Machine Learning},
46
+ pages = {135--168},
47
+ year = {2024},
48
+ editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
49
+ volume = {235},
50
+ series = {Proceedings of Machine Learning Research},
51
+ month = {21--27 Jul},
52
+ publisher = {PMLR}
53
+ }"""
examples.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ examples = [
2
+ "Who was the greek gooddess of spring growth?",
3
+ "When was the heir to the throne of the United Kingdom born?",
4
+ ]
img/indeep_logo_white_contour.png ADDED
img/inseq_logo_white_contour.png ADDED
img/lxt_logo.png ADDED
img/mirage_logo_white_contour.png ADDED
img/rug_logo_white_contour.png ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ spaces
2
+ git+https://github.com/inseq-team/inseq.git@main
3
+ bm25s
4
+ rerankers[transformers]
5
+ git+https://github.com/rachtibat/LRP-eXplains-Transformers
style.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ custom_css = """
2
+ h1 > img {
3
+ text-align: center;
4
+ display: block;
5
+ margin-bottom: 0;
6
+ font-size: 1.7em;
7
+ }
8
+
9
+ iframe {
10
+ overflow: scroll;
11
+ border: 2px solid green;
12
+ }
13
+
14
+ .summary-label {
15
+ display: inline;
16
+ }
17
+ .prose a:visited {
18
+ color: var(--link-text-color);
19
+ }
20
+ .footer-container {
21
+ align-items: center;
22
+ }
23
+ .footer-custom-block {
24
+ display: flex;
25
+ justify-content: center;
26
+ align-items: center;
27
+ }
28
+ .footer-custom-block b {
29
+ margin-right: 10px;
30
+ }
31
+ .footer-custom-block img {
32
+ margin-right: 15px;
33
+ }
34
+ ol {
35
+ padding-left: 30px;
36
+ }
37
+ """