import gradio as gr import torch import pandas as pd import plotly.graph_objects as go import spaces from plotly.subplots import make_subplots from transformers import AutoModelForCausalLM, AutoTokenizer import time import numpy as np # Load the model and tokenizer model_str = "valcore/Branchy-Phi-2" tokenizer_str = "microsoft/Phi-2" device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True).to(device) tokenizer = AutoTokenizer.from_pretrained(tokenizer_str) # Initialize dataframe for storing token generation data data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth", "Token"]) # Define thresholds for different epsilon values epsilon_thresholds = { 0.4: [1.0307843685150146, 0.8693032264709473, 0.6637287139892578, 0.3111608028411865], 0.5: [1.505380630493164, 1.5712471008300781, 1.1971790790557861, 0.6908178329467773], 0.6: [2.0270779132843018, 1.8969502449035645, 1.4789371490478516, 0.9875392913818359], 0.7: [2.506962537765503, 2.656052589416504, 1.924393653869629, 1.4434680938720703], 0.8: [3.3786778450012207, 2.568857192993164, 2.5665550231933594, 2.006620407104492], 0.9: [3.187114715576172, 3.442272663116455, 2.636230945587158, 2.460529088973999], 1.0: [10.0, 10.0, 10.0, 10.0] # Effectively disable early exits } # Global variable to control generation stop_generation = False def create_plot(): fig = make_subplots(specs=[[{"secondary_y": True}]]) fig.add_trace( go.Scatter( x=data.index, y=data["Time taken (in ms)"], name="Time taken (ms)", text=data["Token"], hovertemplate="Token: %{text}
Time: %{y:.2f} ms", ), secondary_y=False, ) fig.add_trace( go.Scatter( x=data.index, y=data["Early exit depth"], name="Early exit depth", text=data["Token"], hovertemplate="Token: %{text}
Depth: %{y:.2f}", ), secondary_y=True, ) fig.update_layout( title_text="Token Generation Metrics", xaxis_title="Token Index", yaxis_title="Time (ms)", yaxis2_title="Exit Depth", hovermode="closest", ) fig.update_yaxes(range=[0, 1.1], secondary_y=True) return fig def truncate_context(input_ids, max_length=2048): if len(input_ids[0]) > max_length: return input_ids[:, -max_length:] return input_ids @spaces.GPU def generate_response(message, chat_history, epsilon): global data, stop_generation data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth", "Token"]) stop_generation = False # Set model thresholds based on epsilon model.head_thresholds = torch.tensor(epsilon_thresholds[epsilon]) # Format the prompt with chat history formatted_prompt = "" for user_msg, assistant_msg in chat_history: formatted_prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n" formatted_prompt += f"User: {message}\nAssistant:" full_response = "" inputs = tokenizer.encode(formatted_prompt, return_tensors="pt").to(device) while not stop_generation: inputs = truncate_context(inputs) start = time.time() outputs = model(inputs) stop = time.time() next_token_logits = outputs.logits[:, -1, :] next_token_id = torch.argmax(next_token_logits, dim=-1) if next_token_id.item() == tokenizer.eos_token_id: break inputs = torch.cat([inputs, next_token_id.unsqueeze(0)], dim=-1) next_token = tokenizer.decode(next_token_id) full_response += next_token time_taken = (stop - start) * 1000 # Convert to milliseconds branch_locations = model.config.branch_locations early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations) if outputs.head_indices in branch_locations else 1.0 new_row = pd.DataFrame({ "Time taken (in ms)": [time_taken], "Early exit depth": [early_exit], "Token": [next_token] }) data = pd.concat([data, new_row], ignore_index=True) new_history = chat_history + [(message, full_response)] yield new_history, new_history, gr.update(value=create_plot()) def stop_gen(): global stop_generation stop_generation = True return gr.update(interactive=False) with gr.Blocks() as demo: gr.Markdown("# Multi-Head LLM Demo with Early Exit Capabilities 🤗") gr.Markdown("""This is a demo of a multi-head language model with early exit capabilities. The model is based on the Phi-2 architecture and is available here: https://huggingface.co/valcore/Branchy-Phi-2. The model has four heads, each of which can be exited early based on a threshold. The graph shows the depth of early exit for each token and the time taken to generate each token. Use the slider to adjust the early exit threshold. Lower values allow for more early exits, potentially speeding up generation at the cost of accuracy. """) chatbot = gr.Chatbot() msg = gr.Textbox(label="Message") epsilon = gr.Slider(minimum=0.4, maximum=1.0, value=0.7, step=0.1, label="Epsilon") with gr.Row(): send = gr.Button("Send") stop = gr.Button("Stop Generation") graph = gr.Plot() send.click(generate_response, inputs=[msg, chatbot, epsilon], outputs=[chatbot, chatbot, graph]) msg.submit(generate_response, inputs=[msg, chatbot, epsilon], outputs=[chatbot, chatbot, graph]) stop.click(stop_gen, outputs=[stop]) demo.queue().launch()