import torch from transformers import AutoTokenizer from lxt.models.llama import LlamaForCausalLM, attnlrp from lxt.utils import clean_tokens import gradio as gr import numpy as np import spaces from scipy.signal import convolve2d from huggingface_hub import login import os from dotenv import load_dotenv from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16, ) load_dotenv() login(os.getenv("HF_TOKEN")) model_id = "meta-llama/Meta-Llama-3-8B-Instruct" print(f"Loading model {model_id}...") tokenizer = AutoTokenizer.from_pretrained(model_id) model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda", use_safetensors=True) # model.gradient_checkpointing_enable() attnlrp.register(model) print(f"Loaded model.") def really_clean_tokens(tokens): tokens = clean_tokens(tokens) cleaned_tokens = [] for token in tokens: token = token.replace("_", " ").replace("▁", " ").replace("", " ").replace("Ċ", " ").replace("Ġ", " ") if token.startswith("<0x") and token.endswith(">"): # Convert hex to character char_code = int(token[3:-1], 16) token = chr(char_code) cleaned_tokens.append(token) return cleaned_tokens @spaces.GPU def generate_and_visualize(prompt, num_tokens=10): input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device) input_embeds = model.get_input_embeddings()(input_ids) input_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(input_ids[0])) generated_tokens_ids = [] all_relevances = [] for _ in range(num_tokens): output_logits = model(inputs_embeds=input_embeds.requires_grad_()).logits max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1) max_logits.backward(max_logits) try: relevance = input_embeds.grad.float().sum(-1).cpu()[0] all_relevances.append(relevance) except: all_relevances.append(0) next_token = max_indices.unsqueeze(0) generated_tokens_ids.append(next_token.item()) input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) input_embeds = model.get_input_embeddings()(input_ids) if next_token.item() == tokenizer.eos_token_id: print("EOS token generated, stopping generation.") break generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids)) return input_tokens, all_relevances, generated_tokens def process_relevances(input_tokens, all_relevances, generated_tokens): attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances]) ### FIND ZONES OF INTEREST threshold_per_token = 0.2 kernel_width = 6 context_width = 20 # Number of tokens to include as context on each side kernel = np.ones((kernel_width, kernel_width)) if len(generated_tokens) < kernel_width: return [(token, None, None) for token in generated_tokens] # Compute the rolling sum using 2D convolution rolled_sum = convolve2d(attention_matrix, kernel, mode='valid') / kernel_width**2 # Find where the rolled sum is greater than the threshold significant_areas = rolled_sum > threshold_per_token print(f"Found {significant_areas.sum()} relevant tokens: lower threshold to find more. Max was {rolled_sum.max()}") print("LENGTHS:", len(input_tokens), significant_areas.shape, len(generated_tokens)) def find_largest_contiguous_patch(array): current_patch_start = None best_width, best_patch_start = None, None current_width = 0 for i in range(len(array)): if array[i]: if current_patch_start is not None and current_patch_start + current_width == i: current_width += 1 else: current_patch_start = i current_width = 1 if current_patch_start and (best_width is None or current_width > best_width): best_patch_start = current_patch_start best_width = current_width else: current_width = 0 return best_width, best_patch_start output_with_notes = [] for row in range(len(generated_tokens)-kernel_width+1): best_width, best_patch_start = find_largest_contiguous_patch(significant_areas[row]) if best_width is not None: output_with_notes.append((generated_tokens[row], (best_width, best_patch_start))) else: output_with_notes.append((generated_tokens[row], None)) output_with_notes += [(el, None) for el in generated_tokens[-kernel_width+1:]] # Fuse the notes for consecutive output tokens if necessary for i in range(len(output_with_notes)): token, coords = output_with_notes[i] if coords is not None: best_width, best_patch_start = coords note_width_generated = kernel_width for next_id in range(i+1, min(i+2*kernel_width, len(output_with_notes))): next_token, next_coords = output_with_notes[next_id] if next_coords is not None: next_width, next_patch_start = next_coords if best_patch_start + best_width >= next_patch_start: # then notes are overlapping: thus we delete the last one and make the first wider if needed output_with_notes[next_id] = (next_token, None) larger_end = max(best_patch_start + best_width, next_patch_start + next_width) best_width = larger_end - best_patch_start note_width_generated = kernel_width + (next_id-i) output_with_notes[i] = (token, (best_width, best_patch_start), note_width_generated) else: output_with_notes[i] = (token, None, None) # Convert to text slices for i, (token, coords, width) in enumerate(output_with_notes): if coords is not None: best_width, best_patch_start = coords significant_start = max(0, best_patch_start) significant_end = best_patch_start + kernel_width + best_width context_start = max(0, significant_start - context_width) context_end = min(len(input_tokens), significant_end + context_width) first_part = "".join(input_tokens[context_start:significant_start]) significant_part = "".join(input_tokens[significant_start:significant_end]) final_part = "".join(input_tokens[significant_end:context_end]) output_with_notes[i] = (token, (first_part, significant_part, final_part), width) return output_with_notes def create_html_with_hover(output_with_notes): html = "
" note_number = 0 i = 0 while i < len(output_with_notes): (token, notes, width) = output_with_notes[i] if notes is None: html += f'{token}' i += 1 else: text = "".join(really_clean_tokens([element[0] for element in output_with_notes[i:i+width]])) print(text) first_part, significant_part, final_part = notes formatted_note = f'{first_part}{significant_part}{final_part}' html += f'{text}[{note_number+1}]' html += f'{formatted_note}' note_number += 1 i += width html += "
" return html @spaces.GPU def on_generate(prompt, num_tokens): input_tokens, all_relevances, generated_tokens = generate_and_visualize(prompt, num_tokens) output_with_notes = process_relevances(input_tokens, all_relevances, generated_tokens) html_output = create_html_with_hover(output_with_notes) return html_output css = """ #output-container { font-size: 18px; line-height: 1.5; position: relative; } .hoverable { color: var(--primary-500); position: relative; display: inline-block; } .hover-note { display: none; position: absolute; padding: 5px; border-radius: 5px; bottom: 100%; left: 0; white-space: normal; background-color: var(--input-background-fill); max-width: 600px; width: 500px; word-wrap: break-word; z-index: 100; } .hoverable:hover .hover-note { display: block; } """ examples = [ """Context: The first recorded efforts to reach Everest's summit were made by British mountaineers. As Nepal did not allow foreigners to enter the country at the time, the British made several attempts on the north ridge route from the Tibetan side. After the first reconnaissance expedition by the British in 1921 reached 7,000 m (22,970 ft) on the North Col, the 1922 expedition pushed the north ridge route up to 8,320 m (27,300 ft), marking the first time a human had climbed above 8,000 m (26,247 ft). The 1924 expedition resulted in one of the greatest mysteries on Everest to this day: George Mallory and Andrew Irvine made a final summit attempt on 8 June but never returned, sparking debate as to whether they were the first to reach the top. Tenzing Norgay and Edmund Hillary made the first documented ascent of Everest in 1953, using the southeast ridge route. Norgay had reached 8,595 m (28,199 ft) the previous year as a member of the 1952 Swiss expedition. The Chinese mountaineering team of Wang Fuzhou, Gonpo, and Qu Yinhua made the first reported ascent of the peak from the north ridge on 25 May 1960. Question: How many meters above 8000 did the 1922 expedition go? Answer:""", """Context: Hurricane Katrina killed hundreds of people as it made landfall on New Orleans in 2005 - many of these deaths could have been avoided if alerts had been given one day earlier. Accurate weather forecasts are really life-saving. 🔥 Now, NASA and IBM just dropped a game-changing new model: the first ever foundation model for weather! This means, it's the first time we have a generalist model not restricted to one task, but able to predict 160 weather variables! Prithvi WxC (Prithvi, "पृथ्वी", is the Sanskrit name for Earth) - is a 2.3 billion parameter model, with an architecture close to previous vision transformers like Hiera. 💡 But it comes with some important tweaks: under the hood, Prithvi WxC uses a clever transformer-based architecture with 25 encoder and 5 decoder blocks. It alternates between "local" and "global" attention to capture both regional and global weather patterns. Question: How many weather variables can Prithvi predict? Answer:""", """Context: Transformers v4.45.0 released: includes a lightning-fast method to build tools! ⚡️ During user research with colleagues @MoritzLaurer and @Jofthomas , we discovered that the class definition currently in used to define a Tool in transformers.agents is a bit tedious to use, because it goes in great detail. ➡️ So I've made an easier way to build tools: just make a function with type hints + a docstring, and add a @tool decorator in front. ✅ Voilà, you're good to go! Question: How can you build tools simply in transformers? Answer:""", ] with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: gr.Markdown("# RAG with source linking using Source attribution with [LXT](https://lxt.readthedocs.io/en/latest/quickstart.html#tinyllama)") input_text = gr.Textbox(label="Enter your prompt:", lines=10, value=examples[0]) num_tokens = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Number of tokens to generate (while no EOS token)") generate_button = gr.Button("Generate") output_html = gr.HTML(label="Generated Output") generate_button.click( on_generate, inputs=[input_text, num_tokens], outputs=[output_html] ) gr.Markdown("Hover over the blue text with superscript numbers to see the important input tokens for that group.") # Add clickable examples gr.Examples( examples=examples, inputs=[input_text], ) if __name__ == "__main__": demo.launch()