HawkClaws commited on
Commit
62e947b
1 Parent(s): cd81b99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -3
app.py CHANGED
@@ -2,17 +2,36 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoModelForCausalLM
4
  import difflib
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- @st.cache_data
7
  def get_model_structure(model_id):
 
 
 
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_id,
10
  torch_dtype=torch.bfloat16,
11
  device_map="cpu",
12
  )
13
  structure = {k: str(v.shape) for k, v in model.state_dict().items()}
 
14
  return structure
15
 
 
16
  def compare_structures(struct1, struct2):
17
  struct1_lines = [f"{k}: {v}" for k, v in struct1.items()]
18
  struct2_lines = [f"{k}: {v}" for k, v in struct2.items()]
@@ -53,8 +72,8 @@ st.markdown(
53
  <style>
54
  .reportview-container .main .block-container {
55
  max-width: 100%;
56
- padding-left: 0%;
57
- padding-right: 0%;
58
  }
59
  .stMarkdown {
60
  white-space: pre-wrap;
 
2
  import torch
3
  from transformers import AutoModelForCausalLM
4
  import difflib
5
+ import requests
6
+ import os
7
+ import json
8
+
9
+ FIREBASE_URL = os.getenv("FIREBASE_URL")
10
+
11
+ def fetch_from_firebase(model_id):
12
+ response = requests.get(f"{FIREBASE_URL}/model_structures/{model_id}.json")
13
+ if response.status_code == 200:
14
+ return response.json()
15
+ return None
16
+
17
+ def save_to_firebase(model_id, structure):
18
+ response = requests.put(f"{FIREBASE_URL}/model_structures/{model_id}.json", data=json.dumps(structure))
19
+ return response.status_code == 200
20
 
 
21
  def get_model_structure(model_id):
22
+ structure = fetch_from_firebase(model_id)
23
+ if structure:
24
+ return structure
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
27
  torch_dtype=torch.bfloat16,
28
  device_map="cpu",
29
  )
30
  structure = {k: str(v.shape) for k, v in model.state_dict().items()}
31
+ save_to_firebase(model_id, structure)
32
  return structure
33
 
34
+
35
  def compare_structures(struct1, struct2):
36
  struct1_lines = [f"{k}: {v}" for k, v in struct1.items()]
37
  struct2_lines = [f"{k}: {v}" for k, v in struct2.items()]
 
72
  <style>
73
  .reportview-container .main .block-container {
74
  max-width: 100%;
75
+ padding-left: 10%;
76
+ padding-right: 10%;
77
  }
78
  .stMarkdown {
79
  white-space: pre-wrap;