HawkClaws's picture
Update app.py
8b5c657 verified
raw
history blame
2.57 kB
import streamlit as st
import torch
from transformers import AutoModelForCausalLM
import difflib
@st.cache_data
def get_model_structure(model_id):
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="cpu",
)
structure = {k: str(v.shape) for k, v in model.state_dict().items()}
return structure
def compare_structures(struct1, struct2):
struct1_lines = [f"{k}: {v}" for k, v in struct1.items()]
struct2_lines = [f"{k}: {v}" for k, v in struct2.items()]
diff = difflib.ndiff(struct1_lines, struct2_lines)
return diff
def display_diff(diff):
left_lines = []
right_lines = []
diff_found = False
for line in diff:
if line.startswith('- '):
left_lines.append(f'<span style="background-color: #ffdddd;">{line[2:]}</span>')
right_lines.append('')
diff_found = True
elif line.startswith('+ '):
right_lines.append(f'<span style="background-color: #ddffdd;">{line[2:]}</span>')
left_lines.append('')
diff_found = True
elif line.startswith(' '):
left_lines.append(line[2:])
right_lines.append(line[2:])
else:
pass
left_html = "<br>".join(left_lines)
right_html = "<br>".join(right_lines)
return left_html, right_html, diff_found
st.title("Model Structure Comparison Tool")
model_id1 = st.text_input("Enter the first HuggingFace Model ID")
model_id2 = st.text_input("Enter the second HuggingFace Model ID")
if model_id1 and model_id2:
struct1 = get_model_structure(model_id1)
struct2 = get_model_structure(model_id2)
diff = compare_structures(struct1, struct2)
left_html, right_html, diff_found = display_diff(diff)
st.write("### Comparison Result")
if not diff_found:
st.success("The model structures are identical.")
else:
col1, col2 = st.columns([1.5, 1.5]) # Adjust the ratio to make columns wider
with col1:
st.write("### Model 1")
st.markdown(left_html, unsafe_allow_html=True)
with col2:
st.write("### Model 2")
st.markdown(right_html, unsafe_allow_html=True)
# Apply custom CSS for wider layout
st.markdown(
"""
<style>
.reportview-container .main .block-container {
max-width: 90%;
padding-left: 5%;
padding-right: 5%;
}
.stMarkdown {
white-space: pre-wrap;
}
</style>
""",
unsafe_allow_html=True
)