HawkClaws commited on
Commit
2ddd1e5
1 Parent(s): 68c7679

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -15
app.py CHANGED
@@ -1,36 +1,40 @@
1
  import streamlit as st
2
  import torch
3
  from transformers import AutoModelForCausalLM
 
4
 
 
5
  def get_model_structure(model_id):
6
  model = AutoModelForCausalLM.from_pretrained(
7
  model_id,
8
  torch_dtype=torch.bfloat16,
9
  device_map="cpu",
10
  )
11
- structure = {k: v.shape for k, v in model.state_dict().items()}
12
  return structure
13
 
14
  def compare_structures(struct1, struct2):
15
- keys1 = set(struct1.keys())
16
- keys2 = set(struct2.keys())
17
- all_keys = keys1.union(keys2)
18
-
19
- diff = []
20
- for key in all_keys:
21
- shape1 = struct1.get(key)
22
- shape2 = struct2.get(key)
23
- if shape1 != shape2:
24
- diff.append((key, shape1, shape2))
25
  return diff
26
 
27
  def display_diff(diff):
28
  left_lines = []
29
  right_lines = []
30
 
31
- for key, shape1, shape2 in diff:
32
- left_lines.append(f"{key}: {shape1}")
33
- right_lines.append(f"{key}: {shape2}")
 
 
 
 
 
 
 
 
 
34
 
35
  left_html = "<br>".join(left_lines)
36
  right_html = "<br>".join(right_lines)
@@ -49,7 +53,7 @@ if model_id1 and model_id2:
49
  left_html, right_html = display_diff(diff)
50
 
51
  st.write("### Comparison Result")
52
- col1, col2 = st.columns(2)
53
 
54
  with col1:
55
  st.write("### Model 1")
@@ -58,3 +62,20 @@ if model_id1 and model_id2:
58
  with col2:
59
  st.write("### Model 2")
60
  st.markdown(right_html, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
  from transformers import AutoModelForCausalLM
4
+ import difflib
5
 
6
+ @st.cache_data
7
  def get_model_structure(model_id):
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_id,
10
  torch_dtype=torch.bfloat16,
11
  device_map="cpu",
12
  )
13
+ structure = {k: str(v.shape) for k, v in model.state_dict().items()}
14
  return structure
15
 
16
  def compare_structures(struct1, struct2):
17
+ struct1_lines = [f"{k}: {v}" for k, v in struct1.items()]
18
+ struct2_lines = [f"{k}: {v}" for k, v in struct2.items()]
19
+ diff = difflib.ndiff(struct1_lines, struct2_lines)
 
 
 
 
 
 
 
20
  return diff
21
 
22
  def display_diff(diff):
23
  left_lines = []
24
  right_lines = []
25
 
26
+ for line in diff:
27
+ if line.startswith('- '):
28
+ left_lines.append(f'<span style="background-color: #ffdddd;">{line[2:]}</span>')
29
+ right_lines.append('')
30
+ elif line.startswith('+ '):
31
+ right_lines.append(f'<span style="background-color: #ddffdd;">{line[2:]}</span>')
32
+ left_lines.append('')
33
+ elif line.startswith(' '):
34
+ left_lines.append(line[2:])
35
+ right_lines.append(line[2:])
36
+ else:
37
+ pass
38
 
39
  left_html = "<br>".join(left_lines)
40
  right_html = "<br>".join(right_lines)
 
53
  left_html, right_html = display_diff(diff)
54
 
55
  st.write("### Comparison Result")
56
+ col1, col2 = st.columns([1, 1]) # Adjust the ratio to make columns wider
57
 
58
  with col1:
59
  st.write("### Model 1")
 
62
  with col2:
63
  st.write("### Model 2")
64
  st.markdown(right_html, unsafe_allow_html=True)
65
+
66
+ # Apply custom CSS for wider layout
67
+ st.markdown(
68
+ """
69
+ <style>
70
+ .reportview-container .main .block-container {
71
+ max-width: 90%;
72
+ padding-left: 5%;
73
+ padding-right: 5%;
74
+ }
75
+ .stMarkdown {
76
+ white-space: pre-wrap;
77
+ }
78
+ </style>
79
+ """,
80
+ unsafe_allow_html=True
81
+ )