gaverfraxz's picture
Update app.py
6acbf7b verified
raw
history blame
4.08 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import os
import io
def calculate_weight_diff(base_weight, chat_weight):
return torch.abs(base_weight - chat_weight).mean().item()
def calculate_layer_diffs(base_model, chat_model):
layer_diffs = []
for base_layer, chat_layer in zip(base_model.model.layers, chat_model.model.layers):
layer_diff = {
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
}
layer_diffs.append(layer_diff)
return layer_diffs
def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
num_layers = len(layer_diffs)
num_components = len(layer_diffs[0])
fig, axs = plt.subplots(1, num_components, figsize=(24, 8))
fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
for i, component in enumerate(layer_diffs[0].keys()):
component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
sns.heatmap(component_diffs, annot=True, fmt=".6f", cmap="YlGnBu", ax=axs[i], cbar_kws={"shrink": 0.8})
axs[i].set_title(component)
axs[i].set_xlabel("Layer")
axs[i].set_ylabel("Difference")
axs[i].set_xticks([])
axs[i].set_yticks(range(num_layers))
axs[i].set_yticklabels(range(num_layers))
axs[i].invert_yaxis()
plt.tight_layout()
return fig
def main():
st.set_page_config(
page_title="Model Weight Comparator",
layout="wide",
initial_sidebar_state="expanded"
)
st.title("LLM Weight Comparator")
# Config sidebar for input parameters
with st.sidebar:
st.header("Configuration")
base_model_name = st.text_input(
"Base Model Name",
value="meta-llama/Llama-3.1-8B",
help="Enter the name of the base model"
)
chat_model_name = st.text_input(
"Chat Model Name",
value="meta-llama/Llama-3.1-8B-Instruct",
help="Enter the name of the chat model"
)
if st.button("Compare Models"):
if not base_model_name or not chat_model_name:
st.error("Please enter both model names")
return
try:
st.info("Loading models... This might take some time.")
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16)
st.info("Calculating weight differences...")
layer_diffs = calculate_layer_diffs(base_model, chat_model)
st.info("Generating visualization...")
fig = visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name)
st.pyplot(fig)
# visualization
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
buf.seek(0)
st.download_button(
label="Download Visualization",
data=buf,
file_name="model_comparison.png",
mime="image/png"
)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()