Charles Lin commited on
Commit
8335d0c
1 Parent(s): a9853a7

All algs except KE working.

Browse files
Files changed (4) hide show
  1. algs/lu.py +37 -53
  2. app.py +26 -12
  3. config.py +9 -6
  4. utils.py +7 -1
algs/lu.py CHANGED
@@ -15,56 +15,45 @@ class LU(EditableModel):
15
  def __init__(self, model, config, model_constructor, memory=None):
16
  super().__init__(model, config, model_constructor)
17
 
 
 
18
  self.memory = memory
19
 
20
- def forward(self, *inputs, **kwargs):
21
- if "bert" in self.config.model.name.lower():
22
- output, encoder_states = self.model(*inputs, **kwargs, output_hidden_states=True)
23
- else:
24
- model_output = self.model(*inputs, **kwargs, output_hidden_states=True)
25
- encoder_states = _last_encoder_state(model_output)
26
- output = _logits(model_output)
27
-
28
- if self.memory is not None:
29
- for i, encoder_state in enumerate(encoder_states):
30
- if "gpt2" in self.config.model.name.lower():
31
- # NOTE: broken
32
- memory_prefixes, memory_labels = self.memory
33
- prefix_means = encoder_state.cumsum(0).detach() / torch.arange(1, encoder_state.shape[0] + 1, device=encoder_state.device).view(-1, 1)
34
- dist_mat = (prefix_means.unsqueeze(1) - memory_prefixes.unsqueeze(0)).norm(2, dim=-1)
35
 
36
- min_dists, min_idxs = dist_mat.min(-1)
37
- memory_mask = (min_dists < self.config.lu.threshold)
38
- onehot_logits = self.config.lu.onehot_logit * F.one_hot(memory_labels[min_idxs], output.shape[-1]).float()
39
- output[i, memory_mask] = onehot_logits[memory_mask]
40
- elif "bart" in self.config.model.name.lower() or "t5" in self.config.model.name.lower():
41
- avg_encoder_state = encoder_state.detach().mean(0)
42
- memory_keys, memory_labels = self.memory
43
- dists = torch.norm(avg_encoder_state - memory_keys, dim=-1)
44
- closest_dist = dists.min()
45
- closest_idx = dists.argmin()
46
- closest_v = memory_labels[closest_idx]
47
-
48
- if closest_dist < self.config.lu.threshold:
49
- output[i] = torch.zeros((1, kwargs['labels'].shape[1], output.shape[2]), device=output.device)
50
- for j, idx in enumerate(closest_v):
51
- if j >= output.shape[1]:
52
- break
53
- output[i, j, idx] = self.config.lu.onehot_logit
54
- if "t5" not in self.config.model.name.lower():
55
- # T5 does not shift targets in the loss
56
- output[i] = output[i].roll(-1, -2)
57
- else:
58
- avg_encoder_state = encoder_state.detach().mean(0)
59
- memory_keys, memory_labels = self.memory
60
- dists = torch.norm(avg_encoder_state - memory_keys, dim=-1)
61
- closest_dist = dists.min()
62
- closest_idx = dists.argmin()
63
- closest_v = memory_labels[closest_idx]
64
 
65
- if closest_dist < self.config.lu.threshold:
66
- output[i] = self.config.lu.onehot_logit * (2 * closest_v - 1) # Return onehot_logit or -onehot_logit
 
 
 
 
 
 
67
 
 
 
 
 
 
 
68
  return output
69
 
70
  def edit(self, batch, condition=None, detach_history=False):
@@ -77,14 +66,9 @@ class LU(EditableModel):
77
  memory_keys = []
78
  memory_labels = []
79
  for encoder_state, label in zip(encoder_states, batch["labels"]):
80
- if "gpt2" in self.config.model.name.lower():
81
- # NOTE: broken
82
- avg_encoder_states = (encoder_state.cumsum(0).detach() / torch.arange(1, encoder_state.shape[0] + 1, device=encoder_state.device).view(-1, 1))[-10:, :]
83
- memory = (avg_encoder_states, label[-10:])
84
- else:
85
- avg_encoder_state = encoder_state.detach().mean(0)
86
- memory_keys.append(avg_encoder_state)
87
- memory_labels.append(label)
88
 
89
  memory = (torch.stack(memory_keys), torch.stack(memory_labels))
90
  return LU(self.model.eval(), self.config, self.model_constructor, memory), {}
 
15
  def __init__(self, model, config, model_constructor, memory=None):
16
  super().__init__(model, config, model_constructor)
17
 
18
+ if "t5" not in self.config.model.name.lower():
19
+ raise NotImplementedError
20
  self.memory = memory
21
 
22
+ def lookup_replace(self, output, encoder_states):
23
+ for i, encoder_state in enumerate(encoder_states):
24
+ avg_encoder_state = encoder_state.detach().mean(0)
25
+ memory_keys, memory_labels = self.memory
26
+ dists = torch.norm(avg_encoder_state - memory_keys, dim=-1)
27
+ closest_dist = dists.min()
28
+ closest_idx = dists.argmin()
29
+ closest_v = memory_labels[closest_idx]
 
 
 
 
 
 
 
30
 
31
+ if closest_dist < self.config.lu.threshold:
32
+ output[i] = torch.zeros((1, output.shape[1], output.shape[2]), device=output.device)
33
+ for j, idx in enumerate(closest_v):
34
+ if j >= output.shape[1]:
35
+ break
36
+ output[i, j, idx] = self.config.lu.onehot_logit
37
+ if "t5" not in self.config.model.name.lower():
38
+ # T5 does not shift targets in the loss
39
+ output[i] = output[i].roll(-1, -2)
40
+ return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ def generate(self, *inputs, **kwargs):
43
+ model_output = self.model.generate(*inputs, **kwargs, output_hidden_states=True,
44
+ output_scores=True, return_dict_in_generate=True)
45
+ encoder_states = _last_encoder_state(model_output)
46
+ output = _logits(model_output)
47
+ if self.memory is not None:
48
+ output = self.lookup_replace(output, encoder_states)
49
+ return output.argmax(-1)
50
 
51
+ def forward(self, *inputs, **kwargs):
52
+ model_output = self.model(*inputs, **kwargs, output_hidden_states=True)
53
+ encoder_states = _last_encoder_state(model_output)
54
+ output = _logits(model_output)
55
+ if self.memory is not None:
56
+ output = self.lookup_replace(output, encoder_states)
57
  return output
58
 
59
  def edit(self, batch, condition=None, detach_history=False):
 
66
  memory_keys = []
67
  memory_labels = []
68
  for encoder_state, label in zip(encoder_states, batch["labels"]):
69
+ avg_encoder_state = encoder_state.detach().mean(0)
70
+ memory_keys.append(avg_encoder_state)
71
+ memory_labels.append(label)
 
 
 
 
 
72
 
73
  memory = (torch.stack(memory_keys), torch.stack(memory_labels))
74
  return LU(self.model.eval(), self.config, self.model_constructor, memory), {}
app.py CHANGED
@@ -8,6 +8,7 @@ from torch.cuda import is_available as use_cuda
8
  import algs
9
  import config
10
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
11
 
12
 
13
  EDIT_ALGS = [
@@ -19,6 +20,26 @@ EDIT_ALGS = [
19
  "LU: Lookup Cache",
20
  ]
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def generate(ids):
23
  output_ids = st.session_state.editable_model.generate(input_ids=ids, max_new_tokens=20, min_length=1,
24
  num_return_sequences=1, num_beams=3)
@@ -30,15 +51,7 @@ def reset():
30
 
31
  selected_alg = st.session_state.alg_selector
32
  alg_abbrv = selected_alg[:selected_alg.index(":")]
33
- alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
34
- alg_class = getattr(alg_module, alg_abbrv.upper())
35
- st.session_state.config = getattr(config, f"{alg_abbrv.lower()}_config")
36
- with st.spinner('Loading model...'):
37
- st.session_state.editable_model = alg_class(
38
- st.session_state.model,
39
- st.session_state.config,
40
- lambda: copy.deepcopy(st.session_state.model),
41
- ).eval()
42
 
43
  def apply_edit():
44
  st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
@@ -67,12 +80,13 @@ if "init" not in st.session_state:
67
  st.session_state.edits = pd.DataFrame([], columns=["Edit input", "Edit label"])
68
  st.session_state.model_outputs = pd.DataFrame([], columns=["Input", "Output", "N edits", "Alg"])
69
  st.session_state.init = True
70
- st.session_state.config = None
71
- st.session_state.device = "cuda" if use_cuda() else "cpu"
72
  with st.spinner('Loading model...'):
73
  st.session_state.tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm-nq")
74
  st.session_state.model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-large-ssm-nq").to(st.session_state.device).eval()
75
- st.session_state.editable_model = None
 
 
76
 
77
  ########################
78
  #### Interface code ####
 
8
  import algs
9
  import config
10
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
11
+ import utils
12
 
13
 
14
  EDIT_ALGS = [
 
20
  "LU: Lookup Cache",
21
  ]
22
 
23
+ def get_alg_class(alg_abbrv):
24
+ alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
25
+ alg_class = getattr(alg_module, alg_abbrv.upper())
26
+ return alg_class
27
+
28
+ def load_editable_model(alg_abbrv):
29
+ alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
30
+ alg_class = getattr(alg_module, alg_abbrv.upper())
31
+ st.session_state.config = getattr(config, f"{alg_abbrv.lower()}_config")
32
+ with st.spinner('Loading model...'):
33
+ st.session_state.editable_model = alg_class(
34
+ st.session_state.model,
35
+ st.session_state.config,
36
+ lambda: copy.deepcopy(st.session_state.model),
37
+ ).eval()
38
+ if "archive" in st.session_state.config:
39
+ archive, st.session_state.config.archive = utils.load_archive(str(st.session_state.config.archive))
40
+ print(f"Loading archive from {st.session_state.config.archive}")
41
+ st.session_state.editable_model.load_state_dict(archive["model"])
42
+
43
  def generate(ids):
44
  output_ids = st.session_state.editable_model.generate(input_ids=ids, max_new_tokens=20, min_length=1,
45
  num_return_sequences=1, num_beams=3)
 
51
 
52
  selected_alg = st.session_state.alg_selector
53
  alg_abbrv = selected_alg[:selected_alg.index(":")]
54
+ load_editable_model(alg_abbrv)
 
 
 
 
 
 
 
 
55
 
56
  def apply_edit():
57
  st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
 
80
  st.session_state.edits = pd.DataFrame([], columns=["Edit input", "Edit label"])
81
  st.session_state.model_outputs = pd.DataFrame([], columns=["Input", "Output", "N edits", "Alg"])
82
  st.session_state.init = True
83
+ st.session_state.device = "cpu" # "cuda" if use_cuda() else "cpu"
 
84
  with st.spinner('Loading model...'):
85
  st.session_state.tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm-nq")
86
  st.session_state.model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-large-ssm-nq").to(st.session_state.device).eval()
87
+ # There is a "Loading model..." spinner in load_editable_model
88
+ alg_abbrv = "MEND" # Default initial alg of dropdown selector
89
+ load_editable_model(alg_abbrv)
90
 
91
  ########################
92
  #### Interface code ####
config.py CHANGED
@@ -21,7 +21,7 @@ model_config = {
21
  }
22
 
23
  ft_config = OmegaConf.create({
24
- "device": "cuda" if use_cuda() else "cpu",
25
  "edit_lr": 5e-6,
26
  "train_base": False,
27
  "grad_clip": 100,
@@ -43,7 +43,7 @@ ft_config = OmegaConf.create({
43
  })
44
 
45
  lu_config = OmegaConf.create({
46
- "device": "cuda" if use_cuda() else "cpu",
47
  "lu": {
48
  "threshold": 2.75,
49
  "onehot_logit": 1,
@@ -52,14 +52,14 @@ lu_config = OmegaConf.create({
52
  })
53
 
54
  ke_config = OmegaConf.create({
55
- "device": "cuda" if use_cuda() else "cpu",
56
  "train_base": False,
57
  "lr": 1e-5,
58
  "model": model_config,
59
  })
60
 
61
  enn_config = OmegaConf.create({
62
- "device": "cuda" if use_cuda() else "cpu",
63
  "lr": 1e-5,
64
  "edit_lr": 1e-2,
65
  "lr_lr": 1e-3,
@@ -72,10 +72,11 @@ enn_config = OmegaConf.create({
72
  "n_edit_steps": 1,
73
  },
74
  "model": model_config,
 
75
  })
76
 
77
  mend_config = OmegaConf.create({
78
- "device": "cuda" if use_cuda() else "cpu",
79
  "lr": 1e-6,
80
  "edit_lr": 1e-4,
81
  "lr_lr": 1e-4,
@@ -99,10 +100,11 @@ mend_config = OmegaConf.create({
99
  "descent": False,
100
  },
101
  "model": model_config,
 
102
  })
103
 
104
  serac_config = OmegaConf.create({
105
- "device": "cuda" if use_cuda() else "cpu",
106
  "lr": 1e-5,
107
  "edit_lr": 1e-2,
108
  "lr_lr": 0,
@@ -128,4 +130,5 @@ serac_config = OmegaConf.create({
128
  "cache_embeds": True,
129
  },
130
  "model": model_config,
 
131
  })
 
21
  }
22
 
23
  ft_config = OmegaConf.create({
24
+ "device": "cpu",
25
  "edit_lr": 5e-6,
26
  "train_base": False,
27
  "grad_clip": 100,
 
43
  })
44
 
45
  lu_config = OmegaConf.create({
46
+ "device": "cpu",
47
  "lu": {
48
  "threshold": 2.75,
49
  "onehot_logit": 1,
 
52
  })
53
 
54
  ke_config = OmegaConf.create({
55
+ "device": "cpu",
56
  "train_base": False,
57
  "lr": 1e-5,
58
  "model": model_config,
59
  })
60
 
61
  enn_config = OmegaConf.create({
62
+ "device": "cpu",
63
  "lr": 1e-5,
64
  "edit_lr": 1e-2,
65
  "lr_lr": 1e-3,
 
72
  "n_edit_steps": 1,
73
  },
74
  "model": model_config,
75
+ "archive": 8684705655, # "/iris/u/clin/code/efk/outputs/2022-02-09_05-48-20_8684705655/models/t5-large-ssm-nq.2022-02-09_05-48-20_8684705655",
76
  })
77
 
78
  mend_config = OmegaConf.create({
79
+ "device": "cpu",
80
  "lr": 1e-6,
81
  "edit_lr": 1e-4,
82
  "lr_lr": 1e-4,
 
100
  "descent": False,
101
  },
102
  "model": model_config,
103
+ "archive": 5940349945, # "/iris/u/clin/code/efk/outputs/2022-02-09_11-47-28_5940349945/models/t5-large-ssm-nq.2022-02-09_11-47-28_5940349945",
104
  })
105
 
106
  serac_config = OmegaConf.create({
107
+ "device": "cpu", # "device": "cuda" if use_cuda() else "cpu",
108
  "lr": 1e-5,
109
  "edit_lr": 1e-2,
110
  "lr_lr": 0,
 
130
  "cache_embeds": True,
131
  },
132
  "model": model_config,
133
+ "archive": 4719776130, # "/iris/u/clin/code/efk/outputs/2022-02-09_14-05-56_4719776130/models/t5-large-ssm-nq.2022-02-09_14-05-56_4719776130",
134
  })
utils.py CHANGED
@@ -156,12 +156,18 @@ def safe_backward(loss, parameters, accumulate=1, allow_unused=False, backward=F
156
 
157
 
158
  def _logits(x):
159
- return x if not hasattr(x, "logits") else x.logits
 
 
 
 
160
 
161
 
162
  def _last_encoder_state(x):
163
  if hasattr(x, "encoder_last_hidden_state"):
164
  return x.encoder_last_hidden_state
 
 
165
  else:
166
  return x.hidden_states[-1]
167
 
 
156
 
157
 
158
  def _logits(x):
159
+ if hasattr(x, "logits"):
160
+ return x.logits
161
+ elif hasattr(x, "scores"):
162
+ return torch.cat(x.scores).unsqueeze(0)
163
+ return x
164
 
165
 
166
  def _last_encoder_state(x):
167
  if hasattr(x, "encoder_last_hidden_state"):
168
  return x.encoder_last_hidden_state
169
+ elif hasattr(x, "encoder_hidden_states"):
170
+ return x.encoder_hidden_states[-1]
171
  else:
172
  return x.hidden_states[-1]
173