SixOpen's picture
Update app.py
a2466a4 verified
raw
history blame contribute delete
No virus
11.5 kB
import gradio as gr
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import plotly.graph_objects as go
from typing import List, Tuple
import spaces
model_name_original = "microsoft/Phi-3-mini-128k-instruct"
model_name_abliterated = "failspy/Phi-3-mini-128k-instruct-abliterated-v3"
tokenizer = AutoTokenizer.from_pretrained(model_name_original)
model_original = AutoModelForCausalLM.from_pretrained(model_name_original)
model_abliterated = AutoModelForCausalLM.from_pretrained(model_name_abliterated)
def get_neuron_acts(model, input_ids, layers, neuron_indices):
with torch.no_grad():
outputs = model(input_ids, output_hidden_states=True, output_attentions=True)
hidden_states = [outputs.hidden_states[layer] for layer in layers]
attentions = [outputs.attentions[layer] for layer in layers]
acts = [hidden_state[:, :, neuron_indices].cpu().numpy() for hidden_state in hidden_states]
return acts, attentions
def logit_differences(model_original, model_abliterated, input_ids):
with torch.no_grad():
original_logits = model_original(input_ids).logits
abliterated_logits = model_abliterated(input_ids).logits
return original_logits[:,-1,:] - abliterated_logits[:,-1,:], original_logits[:,-1,:], abliterated_logits[:,-1,:]
def generate_response(model, input_ids, max_length=200, temperature=0.7, top_k=50, top_p=0.95):
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_length=max_length,
num_return_sequences=1,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(output_ids[0], skip_special_tokens=True), output_ids[0]
def visualize_attention_patterns(attentions, tokens, head_indices, title, max_width=700): #removed most of this for now as it was leading to ZeroGPU timeouts
num_layers = attentions.shape[0]
num_heads = attentions.shape[1]
valid_head_indices = [idx for idx in head_indices if 0 <= idx < num_layers * num_heads]
if not valid_head_indices:
return "No valid attention heads. Provide indices of relevant heads."
attention_scores = attentions.mean(dim=0)[valid_head_indices].detach().cpu().numpy()
attention_scores = attention_scores[:, :len(tokens), :len(tokens)]
fig = go.Figure(data=[
go.Heatmap(
z=attention_score,
x=tokens,
y=[f"Head L{head_index // num_heads}H{head_index % num_heads}" for head_index in valid_head_indices],
colorscale='Blues'
)
for attention_score in attention_scores
])
fig.update_layout(
title=title,
xaxis_title="Tokens",
yaxis_title="Attention Heads",
width=max_width,
height=max(300, 50*len(valid_head_indices))
)
return fig
def calculate_color(val, max_val, min_val):
normalized_val = float((val - min_val) / (max_val - min_val))
normalized_val = np.clip(normalized_val, 0, 1)
print(f"calculate_color: val={val}, max_val={max_val}, min_val={min_val}, normalized_val={normalized_val}")
return f"rgb(255, {int((1 - normalized_val) * 255)}, {int((1 - normalized_val) * 255)})"
def visualize_activations(model_name, acts, output_ids, layers, neuron_indices, max_val=None, min_val=None):
tokens = tokenizer.convert_ids_to_tokens(output_ids, skip_special_tokens=True)
activation_data = []
for layer, acts_layer in zip(layers, acts):
act_max = acts_layer.max()
act_min = acts_layer.min()
if max_val is None:
max_val = act_max
if min_val is None:
min_val = act_min
if max_val == min_val:
max_val = min_val + 1e-6
print(f"visualize_activations: layer={layer}, act_max={act_max}, act_min={act_min}, max_val={max_val}, min_val={min_val}")
mean_acts = np.mean(acts_layer, axis=1) # Average across neuron indices
print(f"visualize_activations: mean_acts.shape={mean_acts.shape}")
activation_data.append(mean_acts[0, :len(tokens)]) #get the values for the first (and only) batch
fig_acts = go.Figure(data=go.Heatmap(
z=activation_data,
x=tokens[:len(activation_data[0])],
y=[f"Layer {layer}" for layer in layers],
colorscale='RdBu',
zmid=0,
zmin=min_val,
zmax=max_val,
colorbar=dict(title="Activation")
))
fig_acts.update_layout(
title=f"{model_name} Neuron Activations",
xaxis_title="Tokens",
yaxis_title="Layers",
width=800,
height=200 + len(layers) * 20
)
return fig_acts
def patch_representation(model, input_ids, layer, position, representation):
def hook(module, input, output):
output[:, position, :] = representation
return output
handle = model.model.layers[layer].mlp.register_forward_hook(hook) #during the forward pass, hook is called with i/o of the MLP at the given layer
patched_outputs = model(input_ids)
handle.remove()
return patched_outputs.logits[:, -1, :] #returns logits of the patched output at the last position
@spaces.GPU(duration=120)
def compare_models(text, layers, neuron_indices, top_k, max_length, att_heads, temperature, top_k_sampling, top_p_sampling):
neuron_indices = [int(idx) for idx in neuron_indices.split(',')]
layers = [int(layer) for layer in layers.split(',')]
att_heads = [int(head) for head in att_heads.split(',')]
top_k = int(top_k)
max_length = int(max_length)
input_ids = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
original_response, original_output_ids = generate_response(model_original, input_ids, max_length, temperature, top_k_sampling, top_p_sampling)
abliterated_response, abliterated_output_ids = generate_response(model_abliterated, input_ids, max_length, temperature, top_k_sampling, top_p_sampling)
logit_diffs, original_logits, abliterated_logits = logit_differences(model_original, model_abliterated, input_ids)
acts_original, attentions_original = get_neuron_acts(model_original, original_output_ids.unsqueeze(0), layers, neuron_indices)
acts_abliterated, attentions_abliterated = get_neuron_acts(model_abliterated, abliterated_output_ids.unsqueeze(0), layers, neuron_indices)
fig_acts_original = visualize_activations("Original Model", acts_original, original_output_ids, layers, neuron_indices)
fig_acts_abliterated = visualize_activations("Abliterated Model", acts_abliterated, abliterated_output_ids, layers, neuron_indices)
fig_logits = go.Figure()
fig_logits.add_trace(go.Scatter(y=logit_diffs.squeeze().cpu().numpy(), mode='lines+markers', name="Logit Differences"))
fig_logits.add_trace(go.Scatter(y=original_logits.squeeze().cpu().numpy(), mode='lines+markers', name="Original Logits", visible='legendonly'))
fig_logits.add_trace(go.Scatter(y=abliterated_logits.squeeze().cpu().numpy(), mode='lines+markers', name="Abliterated Logits", visible='legendonly'))
fig_logits.update_layout(title="Logit Analysis", xaxis_title="Token Index", yaxis_title="Logit Value")
top_indices = logit_diffs.topk(top_k).indices.squeeze().tolist()
top_tokens = [tokenizer.decode(idx) for idx in top_indices]
top_values_orig = original_logits.squeeze()[top_indices].tolist()
top_values_ablit = abliterated_logits.squeeze()[top_indices].tolist()
top_diffs = logit_diffs.squeeze()[top_indices].tolist()
token_diffs = [f'<b>{token}</b>: Original {orig:.3f} | Abliterated {ablit:.3f} | Diff {diff:.3f}'
for token, orig, ablit, diff in zip(top_tokens, top_values_orig, top_values_ablit, top_diffs)]
attention_html = visualize_attention_patterns(torch.stack(attentions_original), input_tokens, att_heads, "Attention Heads")
layer_logit_diffs = []
with torch.no_grad():
for layer in layers:
orig_outputs = model_original(input_ids, output_hidden_states=True)
ablit_outputs = model_abliterated(input_ids, output_hidden_states=True)
orig_rep = orig_outputs.hidden_states[layer][:, -1, :].detach()
ablit_rep = ablit_outputs.hidden_states[layer][:, -1, :].detach()
patched_orig_logits = patch_representation(model_original, input_ids, layer, -1, ablit_rep)
patched_ablit_logits = patch_representation(model_abliterated, input_ids, layer, -1, orig_rep)
layer_logit_diff = patched_orig_logits.mean().item() - patched_ablit_logits.mean().item()
layer_logit_diffs.append(layer_logit_diff)
fig_per_layer = go.Figure(data=[go.Bar(x=layers, y=layer_logit_diffs)])
fig_per_layer.update_layout(
title="Per-Layer Logit Differences",
xaxis_title="Layer",
yaxis_title="Logit Difference"
)
result = f"""
<div style="font-size:16px">
<h3>Original Model Completion:</h3>
<p>{original_response}</p>
<h3>Abliterated Model Completion:</h3>
<p>{abliterated_response}</p>
<br>
<h3>Logit Differences (Original - Abliterated):</h3>
<p>Mean Logit Difference: {logit_diffs.mean().item():.3f}</p>
<h4>Top {top_k} Tokens:</h4>
<ol>
{'<li>' + '</li><li>'.join(token_diffs) + '</li>'}
</ol>
<i>Hover over the plots for more details.</i>
</div>
"""
return result, fig_logits, attention_html, fig_per_layer, fig_acts_original, fig_acts_abliterated
inputs = [
gr.Textbox(label="Prompt", placeholder="Enter a prompt to test the model's robustness"),
gr.Textbox(label="Layers", value="9,10,11", placeholder="e.g. 9,10,11"),
gr.Textbox(label="Neuron Indices", value="100,200,300,400", placeholder="e.g. 100,200,300,400"),
gr.Slider(minimum=1, maximum=20, step=1, value=10, label="Number of Top Tokens"),
gr.Slider(minimum=50, maximum=500, step=1, value=70, label="Max Response Length"),
gr.Textbox(label="Attention Heads", value="108,120,132", placeholder="e.g. 108,120,132 (Layer 9 Heads 0,1,2)"),
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.7, label="Temperature"),
gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top-k Sampling"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.95, label="Top-p Sampling")
]
outputs = [
gr.HTML(label="Analysis Results"),
gr.Plot(label="Logit Analysis"),
gr.HTML(label="Attention Heads"),
gr.Plot(label="Per-Layer Logit Differences"),
gr.Plot(label="Original Model Activation Heatmap"),
gr.Plot(label="Abliterated Model Activation Heatmap")
]
title = "Phi-3 Analysis"
description = """
Compare the original phi-3 model with its ablated counterpart to scrutinize its inner workings and identify differences- suggestion: try prompts where refusal would be expected (i.e. How do I torrent a movie online?), patterns of letters/characters such as repetitions, or number sequences.
The plots and results will update based on your selection, hover over them for details.
"""
gr.Interface(fn=compare_models, inputs=inputs, outputs=outputs, title=title, description=description, theme="SixOpen/catmocha").launch(debug=True)