import gradio as gr
import torch
import pandas as pd
import plotly.graph_objects as go
import spaces
from plotly.subplots import make_subplots
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import numpy as np
# Load the model and tokenizer
model_str = "valcore/Branchy-Phi-2"
tokenizer_str = "microsoft/Phi-2"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)
# Initialize dataframe for storing token generation data
data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth", "Token"])
# Define thresholds for different epsilon values
epsilon_thresholds = {
0.4: [1.0307843685150146, 0.8693032264709473, 0.6637287139892578, 0.3111608028411865],
0.5: [1.505380630493164, 1.5712471008300781, 1.1971790790557861, 0.6908178329467773],
0.6: [2.0270779132843018, 1.8969502449035645, 1.4789371490478516, 0.9875392913818359],
0.7: [2.506962537765503, 2.656052589416504, 1.924393653869629, 1.4434680938720703],
0.8: [3.3786778450012207, 2.568857192993164, 2.5665550231933594, 2.006620407104492],
0.9: [3.187114715576172, 3.442272663116455, 2.636230945587158, 2.460529088973999],
1.0: [10.0, 10.0, 10.0, 10.0] # Effectively disable early exits
}
# Global variable to control generation
stop_generation = False
def create_plot():
fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(
go.Scatter(
x=data.index,
y=data["Time taken (in ms)"],
name="Time taken (ms)",
text=data["Token"],
hovertemplate="Token: %{text}
Time: %{y:.2f} ms",
),
secondary_y=False,
)
fig.add_trace(
go.Scatter(
x=data.index,
y=data["Early exit depth"],
name="Early exit depth",
text=data["Token"],
hovertemplate="Token: %{text}
Depth: %{y:.2f}",
),
secondary_y=True,
)
fig.update_layout(
title_text="Token Generation Metrics",
xaxis_title="Token Index",
yaxis_title="Time (ms)",
yaxis2_title="Exit Depth",
hovermode="closest",
)
fig.update_yaxes(range=[0, 1.1], secondary_y=True)
return fig
def truncate_context(input_ids, max_length=2048):
if len(input_ids[0]) > max_length:
return input_ids[:, -max_length:]
return input_ids
@spaces.GPU
def generate_response(message, chat_history, epsilon):
global data, stop_generation
data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth", "Token"])
stop_generation = False
# Set model thresholds based on epsilon
model.head_thresholds = torch.tensor(epsilon_thresholds[epsilon])
# Format the prompt with chat history
formatted_prompt = ""
for user_msg, assistant_msg in chat_history:
formatted_prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
formatted_prompt += f"User: {message}\nAssistant:"
full_response = ""
inputs = tokenizer.encode(formatted_prompt, return_tensors="pt").to(device)
while not stop_generation:
inputs = truncate_context(inputs)
start = time.time()
outputs = model(inputs)
stop = time.time()
next_token_logits = outputs.logits[:, -1, :]
next_token_id = torch.argmax(next_token_logits, dim=-1)
if next_token_id.item() == tokenizer.eos_token_id:
break
inputs = torch.cat([inputs, next_token_id.unsqueeze(0)], dim=-1)
next_token = tokenizer.decode(next_token_id)
full_response += next_token
time_taken = (stop - start) * 1000 # Convert to milliseconds
branch_locations = model.config.branch_locations
early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations) if outputs.head_indices in branch_locations else 1.0
new_row = pd.DataFrame({
"Time taken (in ms)": [time_taken],
"Early exit depth": [early_exit],
"Token": [next_token]
})
data = pd.concat([data, new_row], ignore_index=True)
new_history = chat_history + [(message, full_response)]
yield new_history, new_history, gr.update(value=create_plot())
def stop_gen():
global stop_generation
stop_generation = True
return gr.update(interactive=False)
with gr.Blocks() as demo:
gr.Markdown("# Multi-Head LLM Demo with Early Exit Capabilities 🤗")
gr.Markdown("""This is a demo of a multi-head language model with early exit capabilities.
The model is based on the Phi-2 architecture and is available here: https://huggingface.co/valcore/Branchy-Phi-2.
The model has four heads, each of which can be exited early based on a threshold. The graph shows the depth of early exit for each token and the time taken to generate each token.
Use the slider to adjust the early exit threshold. Lower values allow for more early exits, potentially speeding up generation at the cost of accuracy.
""")
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Message")
epsilon = gr.Slider(minimum=0.4, maximum=1.0, value=0.7, step=0.1, label="Epsilon")
with gr.Row():
send = gr.Button("Send")
stop = gr.Button("Stop Generation")
graph = gr.Plot()
send.click(generate_response, inputs=[msg, chatbot, epsilon], outputs=[chatbot, chatbot, graph])
msg.submit(generate_response, inputs=[msg, chatbot, epsilon], outputs=[chatbot, chatbot, graph])
stop.click(stop_gen, outputs=[stop])
demo.queue().launch()