Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import os | |
import io | |
def calculate_weight_diff(base_weight, chat_weight): | |
return torch.abs(base_weight - chat_weight).mean().item() | |
def calculate_layer_diffs(base_model, chat_model): | |
layer_diffs = [] | |
for base_layer, chat_layer in zip(base_model.model.layers, chat_model.model.layers): | |
layer_diff = { | |
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight), | |
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight), | |
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight), | |
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight), | |
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight), | |
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight) | |
} | |
layer_diffs.append(layer_diff) | |
return layer_diffs | |
def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name): | |
num_layers = len(layer_diffs) | |
num_components = len(layer_diffs[0]) | |
fig, axs = plt.subplots(1, num_components, figsize=(24, 8)) | |
fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16) | |
for i, component in enumerate(layer_diffs[0].keys()): | |
component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs] | |
sns.heatmap(component_diffs, annot=True, fmt=".6f", cmap="YlGnBu", ax=axs[i], cbar_kws={"shrink": 0.8}) | |
axs[i].set_title(component) | |
axs[i].set_xlabel("Layer") | |
axs[i].set_ylabel("Difference") | |
axs[i].set_xticks([]) | |
axs[i].set_yticks(range(num_layers)) | |
axs[i].set_yticklabels(range(num_layers)) | |
axs[i].invert_yaxis() | |
plt.tight_layout() | |
return fig | |
def main(): | |
st.set_page_config( | |
page_title="Model Weight Comparator", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
st.title("LLM Weight Comparator") | |
# Config sidebar for input parameters | |
with st.sidebar: | |
st.header("Configuration") | |
base_model_name = st.text_input( | |
"Base Model Name", | |
value="meta-llama/Llama-3.1-8B", | |
help="Enter the name of the base model" | |
) | |
chat_model_name = st.text_input( | |
"Chat Model Name", | |
value="meta-llama/Llama-3.1-8B-Instruct", | |
help="Enter the name of the chat model" | |
) | |
if st.button("Compare Models"): | |
if not base_model_name or not chat_model_name: | |
st.error("Please enter both model names") | |
return | |
try: | |
st.info("Loading models... This might take some time.") | |
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16) | |
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16) | |
st.info("Calculating weight differences...") | |
layer_diffs = calculate_layer_diffs(base_model, chat_model) | |
st.info("Generating visualization...") | |
fig = visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name) | |
st.pyplot(fig) | |
# visualization | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') | |
buf.seek(0) | |
st.download_button( | |
label="Download Visualization", | |
data=buf, | |
file_name="model_comparison.png", | |
mime="image/png" | |
) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() |