HawkClaws commited on
Commit
d28e9af
·
verified ·
1 Parent(s): 5654dac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -62
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import difflib
5
  import requests
6
  import os
@@ -8,60 +8,73 @@ import json
8
 
9
  FIREBASE_URL = os.getenv("FIREBASE_URL")
10
 
 
11
  def fetch_from_firebase(model_id):
12
  response = requests.get(f"{FIREBASE_URL}/model_structures/{model_id}.json")
13
  if response.status_code == 200:
14
  return response.json()
15
  return None
16
 
 
17
  def save_to_firebase(model_id, structure):
18
- response = requests.put(f"{FIREBASE_URL}/model_structures/{model_id}.json", data=json.dumps(structure))
 
 
19
  return response.status_code == 200
20
 
21
- def get_model_structure(model_id):
22
- structure = fetch_from_firebase(model_id)
23
- if structure:
24
- return structure
 
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
27
  torch_dtype=torch.bfloat16,
28
  device_map="cpu",
29
  )
30
  structure = {k: str(v.shape) for k, v in model.state_dict().items()}
31
- save_to_firebase(model_id, structure)
32
- return structure
 
33
 
34
- def compare_structures(struct1, struct2):
35
- struct1_lines = [f"{k}: {v}" for k, v in struct1.items()]
36
- struct2_lines = [f"{k}: {v}" for k, v in struct2.items()]
 
37
  diff = difflib.ndiff(struct1_lines, struct2_lines)
38
  return diff
39
 
 
40
  def display_diff(diff):
41
  left_lines = []
42
  right_lines = []
43
  diff_found = False
44
-
45
  for line in diff:
46
- if line.startswith('- '):
47
- left_lines.append(f'<span style="background-color: #ffdddd;">{line[2:]}</span>')
48
- right_lines.append('')
 
 
49
  diff_found = True
50
- elif line.startswith('+ '):
51
- right_lines.append(f'<span style="background-color: #ddffdd;">{line[2:]}</span>')
52
- left_lines.append('')
 
 
53
  diff_found = True
54
- elif line.startswith(' '):
55
  left_lines.append(line[2:])
56
  right_lines.append(line[2:])
57
  else:
58
  pass
59
-
60
  left_html = "<br>".join(left_lines)
61
  right_html = "<br>".join(right_lines)
62
-
63
  return left_html, right_html, diff_found
64
 
 
65
  # Set Streamlit page configuration to wide mode
66
  st.set_page_config(layout="wide")
67
 
@@ -79,50 +92,30 @@ st.markdown(
79
  }
80
  </style>
81
  """,
82
- unsafe_allow_html=True
83
  )
84
 
85
  st.title("Model Structure Comparison Tool")
86
  model_id1 = st.text_input("Enter the first HuggingFace Model ID")
87
  model_id2 = st.text_input("Enter the second HuggingFace Model ID")
88
 
89
- if "compare_button_clicked" not in st.session_state:
90
- st.session_state.compare_button_clicked = False
91
-
92
- if st.session_state.compare_button_clicked:
93
- with st.spinner('Comparing models and loading tokenizers...'):
94
- if model_id1 and model_id2:
95
- struct1 = get_model_structure(model_id1)
96
- struct2 = get_model_structure(model_id2)
97
-
98
- diff = compare_structures(struct1, struct2)
99
- left_html, right_html, diff_found = display_diff(diff)
100
-
101
- st.write("### Comparison Result")
102
- if not diff_found:
103
- st.success("The model structures are identical.")
104
-
105
- col1, col2 = st.columns([1.5, 1.5]) # Adjust the ratio to make columns wider
106
-
107
- with col1:
108
- st.write("### Model 1")
109
- st.markdown(left_html, unsafe_allow_html=True)
110
-
111
- with col2:
112
- st.write("### Model 2")
113
- st.markdown(right_html, unsafe_allow_html=True)
114
-
115
- # Tokenizer verification
116
- try:
117
- tokenizer1 = AutoTokenizer.from_pretrained(model_id1)
118
- tokenizer2 = AutoTokenizer.from_pretrained(model_id2)
119
- st.write(f"**{model_id1} Tokenizer Vocab Size**: {tokenizer1.vocab_size}")
120
- st.write(f"**{model_id2} Tokenizer Vocab Size**: {tokenizer2.vocab_size}")
121
- except Exception as e:
122
- st.error(f"Error loading tokenizers: {e}")
123
- else:
124
- st.error("Please enter both model IDs.")
125
- st.session_state.compare_button_clicked = False
126
- else:
127
- if st.button("Compare Models"):
128
- st.session_state.compare_button_clicked = True
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import AutoModelForCausalLM
4
  import difflib
5
  import requests
6
  import os
 
8
 
9
  FIREBASE_URL = os.getenv("FIREBASE_URL")
10
 
11
+
12
  def fetch_from_firebase(model_id):
13
  response = requests.get(f"{FIREBASE_URL}/model_structures/{model_id}.json")
14
  if response.status_code == 200:
15
  return response.json()
16
  return None
17
 
18
+
19
  def save_to_firebase(model_id, structure):
20
+ response = requests.put(
21
+ f"{FIREBASE_URL}/model_structures/{model_id}.json", data=json.dumps(structure)
22
+ )
23
  return response.status_code == 200
24
 
25
+
26
+ def get_model_structure(model_id) -> list[str]:
27
+ struct_lines = fetch_from_firebase(model_id)
28
+ if struct_lines:
29
+ return struct_lines
30
  model = AutoModelForCausalLM.from_pretrained(
31
  model_id,
32
  torch_dtype=torch.bfloat16,
33
  device_map="cpu",
34
  )
35
  structure = {k: str(v.shape) for k, v in model.state_dict().items()}
36
+ struct_lines = [f"{k}: {v}" for k, v in structure.items()]
37
+ save_to_firebase(model_id, struct_lines)
38
+ return struct_lines
39
 
40
+
41
+ def compare_structures(struct1_lines: list[str], struct2_lines: list[str]):
42
+ # struct1_lines = [f"{k}: {v}" for k, v in struct1.items()]
43
+ # struct2_lines = [f"{k}: {v}" for k, v in struct2.items()]
44
  diff = difflib.ndiff(struct1_lines, struct2_lines)
45
  return diff
46
 
47
+
48
  def display_diff(diff):
49
  left_lines = []
50
  right_lines = []
51
  diff_found = False
52
+
53
  for line in diff:
54
+ if line.startswith("- "):
55
+ left_lines.append(
56
+ f'<span style="background-color: #ffdddd;">{line[2:]}</span>'
57
+ )
58
+ right_lines.append("")
59
  diff_found = True
60
+ elif line.startswith("+ "):
61
+ right_lines.append(
62
+ f'<span style="background-color: #ddffdd;">{line[2:]}</span>'
63
+ )
64
+ left_lines.append("")
65
  diff_found = True
66
+ elif line.startswith(" "):
67
  left_lines.append(line[2:])
68
  right_lines.append(line[2:])
69
  else:
70
  pass
71
+
72
  left_html = "<br>".join(left_lines)
73
  right_html = "<br>".join(right_lines)
74
+
75
  return left_html, right_html, diff_found
76
 
77
+
78
  # Set Streamlit page configuration to wide mode
79
  st.set_page_config(layout="wide")
80
 
 
92
  }
93
  </style>
94
  """,
95
+ unsafe_allow_html=True,
96
  )
97
 
98
  st.title("Model Structure Comparison Tool")
99
  model_id1 = st.text_input("Enter the first HuggingFace Model ID")
100
  model_id2 = st.text_input("Enter the second HuggingFace Model ID")
101
 
102
+ if model_id1 and model_id2:
103
+ struct1 = get_model_structure(model_id1)
104
+ struct2 = get_model_structure(model_id2)
105
+
106
+ diff = compare_structures(struct1, struct2)
107
+ left_html, right_html, diff_found = display_diff(diff)
108
+
109
+ st.write("### Comparison Result")
110
+ if not diff_found:
111
+ st.success("The model structures are identical.")
112
+
113
+ col1, col2 = st.columns([1.5, 1.5]) # Adjust the ratio to make columns wider
114
+
115
+ with col1:
116
+ st.write("### Model 1")
117
+ st.markdown(left_html, unsafe_allow_html=True)
118
+
119
+ with col2:
120
+ st.write("### Model 2")
121
+ st.markdown(right_html, unsafe_allow_html=True)