Charles Lin commited on
Commit
8f3eda5
1 Parent(s): e56055d

Add logic for loading models

Browse files
Files changed (3) hide show
  1. app.py +32 -2
  2. config.py +130 -0
  3. requirements.txt +1 -0
app.py CHANGED
@@ -1,7 +1,12 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import time
 
 
4
  import algs
 
 
 
5
 
6
  EDIT_ALGS = [
7
  "MEND: Model editor networks using gradient decomposition",
@@ -9,9 +14,13 @@ EDIT_ALGS = [
9
  "ENN: Editable neural networks",
10
  "KE: KnowledgeEditor",
11
  "FT: Fine-tuning",
12
- "LU: Lookup Cache"
13
  ]
14
 
 
 
 
 
15
  def reset():
16
  st.session_state.edits.drop(st.session_state.edits.index, inplace=True)
17
  st.session_state.model_outputs.drop(st.session_state.edits.index, inplace=True)
@@ -19,7 +28,28 @@ def reset():
19
  selected_alg = st.session_state.alg_selector
20
  selected_alg_idx = EDIT_ALGS.index(selected_alg)
21
 
22
- ############# Need to reset the model here (and maybe show progress spinner?)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def apply_edit():
25
  st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
 
1
  import streamlit as st
2
  import pandas as pd
3
  import time
4
+ import importlib
5
+
6
  import algs
7
+ import config
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+
10
 
11
  EDIT_ALGS = [
12
  "MEND: Model editor networks using gradient decomposition",
 
14
  "ENN: Editable neural networks",
15
  "KE: KnowledgeEditor",
16
  "FT: Fine-tuning",
17
+ "LU: Lookup Cache",
18
  ]
19
 
20
+ tokenizer = None
21
+ model = None
22
+ editable_model = None
23
+
24
  def reset():
25
  st.session_state.edits.drop(st.session_state.edits.index, inplace=True)
26
  st.session_state.model_outputs.drop(st.session_state.edits.index, inplace=True)
 
28
  selected_alg = st.session_state.alg_selector
29
  selected_alg_idx = EDIT_ALGS.index(selected_alg)
30
 
31
+ ############# TODO: show progress spinner
32
+ global tokenizer
33
+ global model
34
+ global editable_model
35
+
36
+ if tokenizer is None:
37
+ tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm-nq")
38
+ if model is None:
39
+ model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-large-ssm-nq").eval()
40
+ del editable_model
41
+
42
+ alg_name = st.session_state.alg_selector
43
+ alg_abbrv = alg_name[:alg_name.index(":")]
44
+ alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
45
+ alg_class = getattr(alg_module, alg_abbrv.upper())
46
+ cfg = getattr(config, f"{alg_abbrv.lower()}_config")
47
+ editable_model = alg_class(
48
+ model,
49
+ cfg,
50
+ lambda: copy.deepcopy(model),
51
+ ).eval()
52
+
53
 
54
  def apply_edit():
55
  st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
config.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ from torch.cuda import is_available as use_cuda
3
+
4
+ model_config = {
5
+ "name": "google/t5-large-ssm-nq",
6
+ "class_name": "AutoModelForSeq2SeqLM",
7
+ "tokenizer_class": "AutoTokenizer",
8
+ "tokenizer_name": "google/t5-large-ssm-nq",
9
+ "inner_params": [
10
+ "encoder.block.22.layer.1.DenseReluDense.wi.weight",
11
+ "encoder.block.22.layer.1.DenseReluDense.wo.weight",
12
+ "encoder.block.23.layer.1.DenseReluDense.wi.weight",
13
+ "encoder.block.23.layer.1.DenseReluDense.wo.weight",
14
+ "decoder.block.22.layer.2.DenseReluDense.wi.weight",
15
+ "decoder.block.22.layer.2.DenseReluDense.wo.weight",
16
+ "decoder.block.23.layer.2.DenseReluDense.wi.weight",
17
+ "decoder.block.23.layer.2.DenseReluDense.wo.weight",
18
+ ],
19
+ "pt": None,
20
+ "small_name": "t5-small",
21
+ }
22
+
23
+ ft_config = OmegaConf.create({
24
+ "device": "cuda" if use_cuda() else "cpu",
25
+ "edit_lr": 5e-6,
26
+ "train_base": False,
27
+ "ft": {
28
+ "verbose": False,
29
+ "max_edit_steps": 100,
30
+ "time_limit": None,
31
+ "locality": {
32
+ "enabled": False,
33
+ "oracle": True,
34
+ "cedit": 1e-2,
35
+ "batch_size": 1,
36
+ },
37
+ "rank": None,
38
+ "opt": "RMSprop",
39
+ "init_std": 0.01,
40
+ },
41
+ "model": model_config,
42
+ })
43
+
44
+ lu_config = OmegaConf.create({
45
+ "device": "cuda" if use_cuda() else "cpu",
46
+ "lu": {
47
+ "threshold": 2.75,
48
+ "onehot_logit": 1,
49
+ },
50
+ "model": model_config,
51
+ })
52
+
53
+ ke_config = OmegaConf.create({
54
+ "device": "cuda" if use_cuda() else "cpu",
55
+ "train_base": False,
56
+ "lr": 1e-5,
57
+ "model": model_config,
58
+ })
59
+
60
+ enn_config = OmegaConf.create({
61
+ "device": "cuda" if use_cuda() else "cpu",
62
+ "lr": 1e-5,
63
+ "edit_lr": 1e-2,
64
+ "lr_lr": 1e-3,
65
+ "train_base": True,
66
+ "grad_clip": 100,
67
+ "dropout": 0,
68
+ "no_grad_layers": None,
69
+ "enn": {
70
+ "first_order": False,
71
+ "n_edit_steps": 1,
72
+ },
73
+ "model": model_config,
74
+ })
75
+
76
+ mend_config = OmegaConf.create({
77
+ "device": "cuda" if use_cuda() else "cpu",
78
+ "lr": 1e-6,
79
+ "edit_lr": 1e-4,
80
+ "lr_lr": 1e-4,
81
+ "train_base": True,
82
+ "grad_clip": 100,
83
+ "dropout": 0,
84
+ "no_grad_layers": None,
85
+ "gtn": {
86
+ "one_sided": False,
87
+ "n_hidden": 1,
88
+ "hidden_dim": None,
89
+ "init": "id",
90
+ "norm": True,
91
+ "combine": True,
92
+ "x_only": False,
93
+ "delta_only": False,
94
+ "act": "relu",
95
+ "rank": 1920,
96
+ "mlp_class": "IDMLP",
97
+ "shared": True,
98
+ "descent": False,
99
+ },
100
+ "model": model_config,
101
+ })
102
+
103
+ serac_config = OmegaConf.create({
104
+ "device": "cuda" if use_cuda() else "cpu",
105
+ "lr": 1e-5,
106
+ "edit_lr": 1e-2,
107
+ "lr_lr": 0,
108
+ "train_base": False,
109
+ "grad_clip": 100,
110
+ "dropout": 0,
111
+ "no_grad_layers": None,
112
+ "rep": {
113
+ "cls_name": "distilbert-base-cased",
114
+ "cls_class": "AutoModel",
115
+ "supervised": "true",
116
+ "cos": False,
117
+ "freeze": None,
118
+ "square": True,
119
+ "bound_embeds": False,
120
+ "use_all_negatives": False,
121
+ "freeze_cntr": False,
122
+ "dist_heads": 1,
123
+ "cross_attend": False,
124
+ "lora": None,
125
+ "soft_weighting": False,
126
+ "checkpoint_grad": False,
127
+ "cache_embeds": True,
128
+ },
129
+ "model": model_config,
130
+ })
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  allennlp
2
  git+https://github.com/eric-mitchell/higher@master # For in-place functional models
 
3
  pandas
4
  streamlit
5
  torch
 
1
  allennlp
2
  git+https://github.com/eric-mitchell/higher@master # For in-place functional models
3
+ omegaconf
4
  pandas
5
  streamlit
6
  torch