HawkClaws commited on
Commit
74debf3
1 Parent(s): 52e0217

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -26
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import AutoModelForCausalLM
4
  import difflib
5
  import requests
6
  import os
@@ -8,23 +8,20 @@ import json
8
 
9
  FIREBASE_URL = os.getenv("FIREBASE_URL")
10
 
11
-
12
- def fetch_from_firebase(model_id):
13
- response = requests.get(f"{FIREBASE_URL}/model_structures/{model_id}.json")
14
  if response.status_code == 200:
15
  return response.json()
16
  return None
17
 
18
-
19
- def save_to_firebase(model_id, structure):
20
  response = requests.put(
21
- f"{FIREBASE_URL}/model_structures/{model_id}.json", data=json.dumps(structure)
22
  )
23
  return response.status_code == 200
24
 
25
-
26
  def get_model_structure(model_id) -> list[str]:
27
- struct_lines = fetch_from_firebase(model_id)
28
  if struct_lines:
29
  return struct_lines
30
  model = AutoModelForCausalLM.from_pretrained(
@@ -34,17 +31,22 @@ def get_model_structure(model_id) -> list[str]:
34
  )
35
  structure = {k: str(v.shape) for k, v in model.state_dict().items()}
36
  struct_lines = [f"{k}: {v}" for k, v in structure.items()]
37
- save_to_firebase(model_id, struct_lines)
38
  return struct_lines
39
 
 
 
 
 
 
 
 
 
40
 
41
  def compare_structures(struct1_lines: list[str], struct2_lines: list[str]):
42
- # struct1_lines = [f"{k}: {v}" for k, v in struct1.items()]
43
- # struct2_lines = [f"{k}: {v}" for k, v in struct2.items()]
44
  diff = difflib.ndiff(struct1_lines, struct2_lines)
45
  return diff
46
 
47
-
48
  def display_diff(diff):
49
  left_lines = []
50
  right_lines = []
@@ -74,7 +76,6 @@ def display_diff(diff):
74
 
75
  return left_html, right_html, diff_found
76
 
77
-
78
  # Set Streamlit page configuration to wide mode
79
  st.set_page_config(layout="wide")
80
 
@@ -99,10 +100,7 @@ st.title("Model Structure Comparison Tool")
99
  model_id1 = st.text_input("Enter the first HuggingFace Model ID")
100
  model_id2 = st.text_input("Enter the second HuggingFace Model ID")
101
 
102
- if "compare_button_clicked" not in st.session_state:
103
- st.session_state.compare_button_clicked = False
104
-
105
- if st.session_state.compare_button_clicked:
106
  with st.spinner('Comparing models and loading tokenizers...'):
107
  if model_id1 and model_id2:
108
  struct1 = get_model_structure(model_id1)
@@ -127,15 +125,11 @@ if st.session_state.compare_button_clicked:
127
 
128
  # Tokenizer verification
129
  try:
130
- tokenizer1 = AutoTokenizer.from_pretrained(model_id1)
131
- tokenizer2 = AutoTokenizer.from_pretrained(model_id2)
132
- st.write(f"**{model_id1} Tokenizer Vocab Size**: {tokenizer1.vocab_size}")
133
- st.write(f"**{model_id2} Tokenizer Vocab Size**: {tokenizer2.vocab_size}")
134
  except Exception as e:
135
  st.error(f"Error loading tokenizers: {e}")
136
  else:
137
  st.error("Please enter both model IDs.")
138
- st.session_state.compare_button_clicked = False
139
- else:
140
- if st.button("Compare Models"):
141
- st.session_state.compare_button_clicked = True
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import difflib
5
  import requests
6
  import os
 
8
 
9
  FIREBASE_URL = os.getenv("FIREBASE_URL")
10
 
11
+ def fetch_from_firebase(model_id, data_type):
12
+ response = requests.get(f"{FIREBASE_URL}/{data_type}/{model_id}.json")
 
13
  if response.status_code == 200:
14
  return response.json()
15
  return None
16
 
17
+ def save_to_firebase(model_id, data, data_type):
 
18
  response = requests.put(
19
+ f"{FIREBASE_URL}/{data_type}/{model_id}.json", data=json.dumps(data)
20
  )
21
  return response.status_code == 200
22
 
 
23
  def get_model_structure(model_id) -> list[str]:
24
+ struct_lines = fetch_from_firebase(model_id, "model_structures")
25
  if struct_lines:
26
  return struct_lines
27
  model = AutoModelForCausalLM.from_pretrained(
 
31
  )
32
  structure = {k: str(v.shape) for k, v in model.state_dict().items()}
33
  struct_lines = [f"{k}: {v}" for k, v in structure.items()]
34
+ save_to_firebase(model_id, struct_lines, "model_structures")
35
  return struct_lines
36
 
37
+ def get_tokenizer_vocab_size(model_id) -> int:
38
+ vocab_size = fetch_from_firebase(model_id, "tokenizer_vocab_sizes")
39
+ if vocab_size:
40
+ return vocab_size
41
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
42
+ vocab_size = tokenizer.vocab_size
43
+ save_to_firebase(model_id, vocab_size, "tokenizer_vocab_sizes")
44
+ return vocab_size
45
 
46
  def compare_structures(struct1_lines: list[str], struct2_lines: list[str]):
 
 
47
  diff = difflib.ndiff(struct1_lines, struct2_lines)
48
  return diff
49
 
 
50
  def display_diff(diff):
51
  left_lines = []
52
  right_lines = []
 
76
 
77
  return left_html, right_html, diff_found
78
 
 
79
  # Set Streamlit page configuration to wide mode
80
  st.set_page_config(layout="wide")
81
 
 
100
  model_id1 = st.text_input("Enter the first HuggingFace Model ID")
101
  model_id2 = st.text_input("Enter the second HuggingFace Model ID")
102
 
103
+ if st.button("Compare Models"):
 
 
 
104
  with st.spinner('Comparing models and loading tokenizers...'):
105
  if model_id1 and model_id2:
106
  struct1 = get_model_structure(model_id1)
 
125
 
126
  # Tokenizer verification
127
  try:
128
+ vocab_size1 = get_tokenizer_vocab_size(model_id1)
129
+ vocab_size2 = get_tokenizer_vocab_size(model_id2)
130
+ st.write(f"**{model_id1} Tokenizer Vocab Size**: {vocab_size1}")
131
+ st.write(f"**{model_id2} Tokenizer Vocab Size**: {vocab_size2}")
132
  except Exception as e:
133
  st.error(f"Error loading tokenizers: {e}")
134
  else:
135
  st.error("Please enter both model IDs.")