Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import pandas as pd | |
import plotly.graph_objects as go | |
import spaces | |
from plotly.subplots import make_subplots | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import time | |
import numpy as np | |
# Load the model and tokenizer | |
model_str = "valcore/Branchy-Phi-2" | |
tokenizer_str = "microsoft/Phi-2" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True).to(device) | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_str) | |
# Initialize dataframe for storing token generation data | |
data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth", "Token"]) | |
# Define thresholds for different epsilon values | |
epsilon_thresholds = { | |
0.4: [1.0307843685150146, 0.8693032264709473, 0.6637287139892578, 0.3111608028411865], | |
0.5: [1.505380630493164, 1.5712471008300781, 1.1971790790557861, 0.6908178329467773], | |
0.6: [2.0270779132843018, 1.8969502449035645, 1.4789371490478516, 0.9875392913818359], | |
0.7: [2.506962537765503, 2.656052589416504, 1.924393653869629, 1.4434680938720703], | |
0.8: [3.3786778450012207, 2.568857192993164, 2.5665550231933594, 2.006620407104492], | |
0.9: [3.187114715576172, 3.442272663116455, 2.636230945587158, 2.460529088973999], | |
1.0: [10.0, 10.0, 10.0, 10.0] # Effectively disable early exits | |
} | |
# Global variable to control generation | |
stop_generation = False | |
def create_plot(): | |
fig = make_subplots(specs=[[{"secondary_y": True}]]) | |
fig.add_trace( | |
go.Scatter( | |
x=data.index, | |
y=data["Time taken (in ms)"], | |
name="Time taken (ms)", | |
text=data["Token"], | |
hovertemplate="<b>Token:</b> %{text}<br><b>Time:</b> %{y:.2f} ms<extra></extra>", | |
), | |
secondary_y=False, | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=data.index, | |
y=data["Early exit depth"], | |
name="Early exit depth", | |
text=data["Token"], | |
hovertemplate="<b>Token:</b> %{text}<br><b>Depth:</b> %{y:.2f}<extra></extra>", | |
), | |
secondary_y=True, | |
) | |
fig.update_layout( | |
title_text="Token Generation Metrics", | |
xaxis_title="Token Index", | |
yaxis_title="Time (ms)", | |
yaxis2_title="Exit Depth", | |
hovermode="closest", | |
) | |
fig.update_yaxes(range=[0, 1.1], secondary_y=True) | |
return fig | |
def truncate_context(input_ids, max_length=2048): | |
if len(input_ids[0]) > max_length: | |
return input_ids[:, -max_length:] | |
return input_ids | |
def generate_response(message, chat_history, epsilon): | |
global data, stop_generation | |
data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth", "Token"]) | |
stop_generation = False | |
# Set model thresholds based on epsilon | |
model.head_thresholds = torch.tensor(epsilon_thresholds[epsilon]) | |
# Format the prompt with chat history | |
formatted_prompt = "" | |
for user_msg, assistant_msg in chat_history: | |
formatted_prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n" | |
formatted_prompt += f"User: {message}\nAssistant:" | |
full_response = "" | |
inputs = tokenizer.encode(formatted_prompt, return_tensors="pt").to(device) | |
while not stop_generation: | |
inputs = truncate_context(inputs) | |
start = time.time() | |
outputs = model(inputs) | |
stop = time.time() | |
next_token_logits = outputs.logits[:, -1, :] | |
next_token_id = torch.argmax(next_token_logits, dim=-1) | |
if next_token_id.item() == tokenizer.eos_token_id: | |
break | |
inputs = torch.cat([inputs, next_token_id.unsqueeze(0)], dim=-1) | |
next_token = tokenizer.decode(next_token_id) | |
full_response += next_token | |
time_taken = (stop - start) * 1000 # Convert to milliseconds | |
branch_locations = model.config.branch_locations | |
early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations) if outputs.head_indices in branch_locations else 1.0 | |
new_row = pd.DataFrame({ | |
"Time taken (in ms)": [time_taken], | |
"Early exit depth": [early_exit], | |
"Token": [next_token] | |
}) | |
data = pd.concat([data, new_row], ignore_index=True) | |
new_history = chat_history + [(message, full_response)] | |
yield new_history, new_history, gr.update(value=create_plot()) | |
def stop_gen(): | |
global stop_generation | |
stop_generation = True | |
return gr.update(interactive=False) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Multi-Head LLM Demo with Early Exit Capabilities 🤗") | |
gr.Markdown("""This is a demo of a multi-head language model with early exit capabilities. | |
The model is based on the Phi-2 architecture and is available here: https://huggingface.co/valcore/Branchy-Phi-2. | |
The model has four heads, each of which can be exited early based on a threshold. The graph shows the depth of early exit for each token and the time taken to generate each token. | |
Use the slider to adjust the early exit threshold. Lower values allow for more early exits, potentially speeding up generation at the cost of accuracy. | |
""") | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(label="Message") | |
epsilon = gr.Slider(minimum=0.4, maximum=1.0, value=0.7, step=0.1, label="Epsilon") | |
with gr.Row(): | |
send = gr.Button("Send") | |
stop = gr.Button("Stop Generation") | |
graph = gr.Plot() | |
send.click(generate_response, inputs=[msg, chatbot, epsilon], outputs=[chatbot, chatbot, graph]) | |
msg.submit(generate_response, inputs=[msg, chatbot, epsilon], outputs=[chatbot, chatbot, graph]) | |
stop.click(stop_gen, outputs=[stop]) | |
demo.queue().launch() |