Charles Lin commited on
Commit
9b78f9c
1 Parent(s): bb4bb43

Generation working. Todo: model edits; add model checkpoints. Also, we are only loading an editable model upon switching algs but we should load it when the page loads

Browse files
Files changed (2) hide show
  1. algs/serac.py +4 -2
  2. app.py +15 -4
algs/serac.py CHANGED
@@ -306,13 +306,15 @@ class SERAC(EditableModel):
306
 
307
  def generate(self, *args, **kwargs):
308
  # input_text = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True)
 
 
309
  base_generate_fn = (
310
  self.model.forward if type(self.model) == BertClassifier
311
- else lambda *args, **kwargs: self.model.generate(*args, **kwargs, max_new_tokens=20)
312
  )
313
  cntr_generate_fn = (
314
  self.replacement.forward if type(self.replacement) == BertClassifier
315
- else lambda *args, **kwargs: self.replacement.generate(*args, **kwargs, max_new_tokens=20)
316
  )
317
 
318
  # assert len(args) == 0, "Should only pass named arguments to generate()"
 
306
 
307
  def generate(self, *args, **kwargs):
308
  # input_text = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True)
309
+ if "max_new_tokens" not in kwargs:
310
+ kwargs["max_new_tokens"] = 20
311
  base_generate_fn = (
312
  self.model.forward if type(self.model) == BertClassifier
313
+ else lambda *args, **kwargs: self.model.generate(*args, **kwargs)
314
  )
315
  cntr_generate_fn = (
316
  self.replacement.forward if type(self.replacement) == BertClassifier
317
+ else lambda *args, **kwargs: self.replacement.generate(*args, **kwargs)
318
  )
319
 
320
  # assert len(args) == 0, "Should only pass named arguments to generate()"
app.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  import pandas as pd
3
  import time
4
  import importlib
 
5
 
6
  import algs
7
  import config
@@ -17,6 +18,11 @@ EDIT_ALGS = [
17
  "LU: Lookup Cache",
18
  ]
19
 
 
 
 
 
 
20
  def reset():
21
  st.session_state.edits.drop(st.session_state.edits.index, inplace=True)
22
  st.session_state.model_outputs.drop(st.session_state.edits.index, inplace=True)
@@ -28,10 +34,10 @@ def reset():
28
  alg_abbrv = selected_alg[:selected_alg.index(":")]
29
  alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
30
  alg_class = getattr(alg_module, alg_abbrv.upper())
31
- cfg = getattr(config, f"{alg_abbrv.lower()}_config")
32
  st.session_state.editable_model = alg_class(
33
  st.session_state.model,
34
- cfg,
35
  lambda: copy.deepcopy(st.session_state.model),
36
  ).eval()
37
 
@@ -42,7 +48,10 @@ def apply_edit():
42
 
43
  def sample_model():
44
  input_str = str(test_input)
45
- model_output = "blah blah blah" ############## Actually sample the model
 
 
 
46
  n_edits = len(st.session_state.edits)
47
  alg_name = st.session_state.alg_selector
48
  alg_abbrv = alg_name[:alg_name.index(":")]
@@ -55,9 +64,11 @@ if "init" not in st.session_state:
55
  st.session_state.edits = pd.DataFrame([], columns=["Edit input", "Edit label"])
56
  st.session_state.model_outputs = pd.DataFrame([], columns=["Input", "Output", "N edits", "Alg"])
57
  st.session_state.init = True
 
 
58
  with st.spinner('Loading model...'):
59
  st.session_state.tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm-nq")
60
- st.session_state.model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-large-ssm-nq").eval()
61
  st.session_state.editable_model = None
62
 
63
  ########################
 
2
  import pandas as pd
3
  import time
4
  import importlib
5
+ from torch.cuda import is_available as use_cuda
6
 
7
  import algs
8
  import config
 
18
  "LU: Lookup Cache",
19
  ]
20
 
21
+ def generate(ids):
22
+ output_ids = st.session_state.editable_model.generate(input_ids=ids, max_new_tokens=20, min_length=1,
23
+ num_return_sequences=1, num_beams=3)
24
+ return st.session_state.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
25
+
26
  def reset():
27
  st.session_state.edits.drop(st.session_state.edits.index, inplace=True)
28
  st.session_state.model_outputs.drop(st.session_state.edits.index, inplace=True)
 
34
  alg_abbrv = selected_alg[:selected_alg.index(":")]
35
  alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
36
  alg_class = getattr(alg_module, alg_abbrv.upper())
37
+ st.session_state.config = getattr(config, f"{alg_abbrv.lower()}_config")
38
  st.session_state.editable_model = alg_class(
39
  st.session_state.model,
40
+ st.session_state.config,
41
  lambda: copy.deepcopy(st.session_state.model),
42
  ).eval()
43
 
 
48
 
49
  def sample_model():
50
  input_str = str(test_input)
51
+ with st.spinner('Generating completion...'):
52
+ encoding = st.session_state.tokenizer(input_str, return_tensors="pt")
53
+ ids = encoding["input_ids"].to(st.session_state.device)
54
+ model_output = generate(ids)
55
  n_edits = len(st.session_state.edits)
56
  alg_name = st.session_state.alg_selector
57
  alg_abbrv = alg_name[:alg_name.index(":")]
 
64
  st.session_state.edits = pd.DataFrame([], columns=["Edit input", "Edit label"])
65
  st.session_state.model_outputs = pd.DataFrame([], columns=["Input", "Output", "N edits", "Alg"])
66
  st.session_state.init = True
67
+ st.session_state.config = None
68
+ st.session_state.device = "cuda" if use_cuda() else "cpu"
69
  with st.spinner('Loading model...'):
70
  st.session_state.tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm-nq")
71
+ st.session_state.model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-large-ssm-nq").to(st.session_state.device).eval()
72
  st.session_state.editable_model = None
73
 
74
  ########################