Charles Lin commited on
Commit
bb4bb43
1 Parent(s): 8f3eda5

Use session state instead of global vars, and add progress spinners for loading models

Browse files
Files changed (1) hide show
  1. app.py +14 -28
app.py CHANGED
@@ -17,10 +17,6 @@ EDIT_ALGS = [
17
  "LU: Lookup Cache",
18
  ]
19
 
20
- tokenizer = None
21
- model = None
22
- editable_model = None
23
-
24
  def reset():
25
  st.session_state.edits.drop(st.session_state.edits.index, inplace=True)
26
  st.session_state.model_outputs.drop(st.session_state.edits.index, inplace=True)
@@ -28,28 +24,16 @@ def reset():
28
  selected_alg = st.session_state.alg_selector
29
  selected_alg_idx = EDIT_ALGS.index(selected_alg)
30
 
31
- ############# TODO: show progress spinner
32
- global tokenizer
33
- global model
34
- global editable_model
35
-
36
- if tokenizer is None:
37
- tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm-nq")
38
- if model is None:
39
- model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-large-ssm-nq").eval()
40
- del editable_model
41
-
42
- alg_name = st.session_state.alg_selector
43
- alg_abbrv = alg_name[:alg_name.index(":")]
44
- alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
45
- alg_class = getattr(alg_module, alg_abbrv.upper())
46
- cfg = getattr(config, f"{alg_abbrv.lower()}_config")
47
- editable_model = alg_class(
48
- model,
49
- cfg,
50
- lambda: copy.deepcopy(model),
51
- ).eval()
52
-
53
 
54
  def apply_edit():
55
  st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
@@ -71,8 +55,10 @@ if "init" not in st.session_state:
71
  st.session_state.edits = pd.DataFrame([], columns=["Edit input", "Edit label"])
72
  st.session_state.model_outputs = pd.DataFrame([], columns=["Input", "Output", "N edits", "Alg"])
73
  st.session_state.init = True
74
- st.session_state.model = None ##############
75
-
 
 
76
 
77
  ########################
78
  #### Interface code ####
 
17
  "LU: Lookup Cache",
18
  ]
19
 
 
 
 
 
20
  def reset():
21
  st.session_state.edits.drop(st.session_state.edits.index, inplace=True)
22
  st.session_state.model_outputs.drop(st.session_state.edits.index, inplace=True)
 
24
  selected_alg = st.session_state.alg_selector
25
  selected_alg_idx = EDIT_ALGS.index(selected_alg)
26
 
27
+ with st.spinner('Loading model...'):
28
+ alg_abbrv = selected_alg[:selected_alg.index(":")]
29
+ alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
30
+ alg_class = getattr(alg_module, alg_abbrv.upper())
31
+ cfg = getattr(config, f"{alg_abbrv.lower()}_config")
32
+ st.session_state.editable_model = alg_class(
33
+ st.session_state.model,
34
+ cfg,
35
+ lambda: copy.deepcopy(st.session_state.model),
36
+ ).eval()
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def apply_edit():
39
  st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
 
55
  st.session_state.edits = pd.DataFrame([], columns=["Edit input", "Edit label"])
56
  st.session_state.model_outputs = pd.DataFrame([], columns=["Input", "Output", "N edits", "Alg"])
57
  st.session_state.init = True
58
+ with st.spinner('Loading model...'):
59
+ st.session_state.tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm-nq")
60
+ st.session_state.model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-large-ssm-nq").eval()
61
+ st.session_state.editable_model = None
62
 
63
  ########################
64
  #### Interface code ####