npops commited on
Commit
765e08e
1 Parent(s): b31c45a
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ venv/
2
+ .DS_Store
3
+ /transformers/
4
+ /data/
5
+ /examples/
README.md CHANGED
@@ -1,12 +1,36 @@
1
  ---
2
- title: Stoke
3
- emoji: 😻
4
- colorFrom: blue
5
- colorTo: indigo
6
  sdk: streamlit
7
- sdk_version: 1.32.2
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: STOKE playground demo
3
+ emoji: 🐢
4
+ colorFrom: gray
5
+ colorTo: red
6
  sdk: streamlit
7
+ sdk_version: 1.31.1
8
+ app_file: stoke/playground/app.py
9
  pinned: false
10
  ---
11
 
12
+ # STOKE: A Toolkit for Streaming Token Classification
13
+
14
+ [Huggingface Space](https://huggingface.co/spaces/nicpopovic/stoke)
15
+
16
+ [Related publication](https://arxiv.org/abs/2403.11747)
17
+
18
+ *Note: This code is still being cleaned up currently.*
19
+
20
+ ## Quick start
21
+ You can use pip to install the required dependency (including the transformers fork)
22
+ ```
23
+ python3 -m venv venv
24
+ source venv/bin/activate
25
+ pip install -r requirements.txt
26
+ streamlit run stoke/playground/app.py
27
+ ```
28
+
29
+ This will launch the playground, shown below:
30
+
31
+ ![](stoke/docs/images/playground.png)
32
+
33
+ ## Get custom transformers fork
34
+ ```
35
+ git clone -b STOKE https://github.com/nicpopovic/transformers.git
36
+ ```
example_generate.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stoke.src.data.util import GenerationConfig, split_data, conll_prompts
2
+ from stoke.src.data.generation import DataGenerator, FlairNERModel
3
+
4
+ # generation parameters
5
+ generation_kwargs = {"max_new_tokens": 100, "repetition_penalty": 1.2}
6
+
7
+ # Creating TrainConfig object with default values
8
+ config = GenerationConfig(language_model="gpt2", output_path="data/", dataset_name="test", cuda=False, generation_kwargs=generation_kwargs)
9
+
10
+ # create annotation model
11
+ reference_model = FlairNERModel(config.language_model, "flair/ner-english-ontonotes-large")
12
+
13
+ # create DataGenerator
14
+ generator = DataGenerator(config, reference_model)
15
+
16
+ # run generator
17
+ generated_texts = generator.generate_text(conll_prompts()[:10], generation_kwargs)
18
+
19
+ # annotate text with reference model
20
+ annotated_texts = generator.annotate_text(generated_texts)
21
+
22
+ # save data in correct format
23
+ generator.save_data(annotated_texts)
24
+
25
+ # split dataset
26
+ split_data(config.path_data)
example_train.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # imports
2
+ from stoke.src.trainer.util import TrainConfig
3
+ from stoke.src.trainer.trainer import Trainer
4
+ from stoke.src.selection.simple import create_config_for_path
5
+
6
+ # create TrainConfig object with default values
7
+ config = TrainConfig('data/gpt2/test', n_steps_per_epoch=10, n_epochs=10)
8
+
9
+ # create Trainer
10
+ trainer = Trainer(config)
11
+
12
+ # run training
13
+ trainer.train()
14
+
15
+ # create basic config for playground
16
+ create_config_for_path(config.path, "basic")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/nicpopovic/transformers.git@STOKE
2
+ streamlit
3
+ torch
4
+ matplotlib
5
+ flair
6
+ nltk
7
+ datasets
8
+ torcheval
stoke/__init__.py ADDED
File without changes
stoke/docs/images/playground.png ADDED
stoke/src/classifier/__init__.py ADDED
File without changes
stoke/src/classifier/probes.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class MLP(torch.nn.Module):
5
+ def __init__(self, input_dim, output_dim, hidden_dim=1024, cuda=False):
6
+ super(MLP, self).__init__()
7
+ self.fc1 = torch.nn.Linear(input_dim, hidden_dim) # Input layer to hidden layer
8
+ self.fc3 = torch.nn.Linear(hidden_dim, output_dim) # Hidden layer to output layer
9
+ if cuda:
10
+ self.device = "cuda"
11
+ else:
12
+ self.device = "cpu"
13
+ self.to(self.device)
14
+
15
+ def forward(self, x):
16
+ x = torch.flatten(x, start_dim=1)
17
+ x = torch.relu(self.fc1(x))
18
+ x = self.fc3(x)
19
+ return x
20
+
21
+
22
+ class MLPProbe(torch.nn.Module):
23
+ def __init__(self, input_dim, output_dim, hidden_dim=1024, cuda=False):
24
+ super(MLPProbe, self).__init__()
25
+ self.fc1 = torch.nn.Linear(input_dim, hidden_dim) # Input layer to hidden layer
26
+ self.fc3 = torch.nn.Linear(hidden_dim, output_dim) # Hidden layer to output layer
27
+ if cuda:
28
+ self.device = "cuda"
29
+ else:
30
+ self.device = "cpu"
31
+ self.to(self.device)
32
+
33
+ def forward(self, x):
34
+
35
+ x = torch.relu(self.fc1(x))
36
+ x = self.fc3(x)
37
+
38
+ return x
stoke/src/data/__init__.py ADDED
File without changes
stoke/src/data/generation.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from tqdm import tqdm
3
+ import json
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from flair.models import SequenceTagger
6
+ from flair.data import Sentence
7
+
8
+
9
+
10
+ class AnnotationModel:
11
+ def __init__(self, model_id_for_tokenizer):
12
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id_for_tokenizer, use_fast=True)
13
+ self.pipe = pipeline("token-classification", model="FacebookAI/xlm-roberta-large-finetuned-conll03-english", aggregation_strategy="simple")
14
+
15
+ def annotate_text(self, text):
16
+ iob_tags = ['O'] * len(text)
17
+ mentions = []
18
+ text_str = self.tokenizer.decode(text)
19
+ ner_tags = self.pipe(text_str)
20
+
21
+ offsets = []
22
+ offset = 0
23
+ for i, token_id in enumerate(text):
24
+ offsets.append(offset)
25
+ offset = len(self.tokenizer.decode(text[:i+1]))
26
+ offsets.append(offset)
27
+
28
+ for tag in ner_tags:
29
+ try:
30
+ start = self.get_token_for_char(tag["start"], offsets)
31
+ end = self.get_token_for_char(tag["end"]-1, offsets)
32
+ mentions.append([start, end])
33
+ for i in range(start, end+1):
34
+ #iob_tags[i] = "I-" + tag["entity_group"]
35
+ iob_tags[i] = tag["entity_group"]
36
+ #iob_tags[start] = "B-" + tag["entity_group"]
37
+ iob_tags[start] = tag["entity_group"]
38
+ except Exception as e:
39
+ print(e)
40
+ pass
41
+
42
+ return {"tokens": text, "ner_tags": iob_tags, "mentions": mentions}
43
+
44
+ def get_token_for_char(self, i, offsets):
45
+ for off in range(len(offsets)):
46
+ if i < offsets[off]:
47
+ return off - 1
48
+ return len(offsets) - 1
49
+
50
+ class FlairNERModel:
51
+ def __init__(self, model_id_for_tokenizer, flair_model_name):
52
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id_for_tokenizer, use_fast=True)
53
+ self.tagger = SequenceTagger.load(flair_model_name)
54
+ self.name = flair_model_name
55
+
56
+ def annotate_text(self, text):
57
+ iob_tags = ['O'] * len(text)
58
+ mentions = []
59
+ text_str = self.tokenizer.decode(text)
60
+ sentence = Sentence(text_str)
61
+
62
+ # Predict NER tags
63
+ self.tagger.predict(sentence)
64
+
65
+ ner_tags = sentence.get_spans('ner')
66
+
67
+ offsets = []
68
+ offset = 0
69
+ for i, token_id in enumerate(text):
70
+ offsets.append(offset)
71
+ offset = len(self.tokenizer.decode(text[:i+1]))
72
+ offsets.append(offset)
73
+
74
+ for tag in ner_tags:
75
+ try:
76
+ start = self.get_token_for_char(tag.start_position, offsets)
77
+ end = self.get_token_for_char(tag.end_position-1, offsets)
78
+ mentions.append([start, end])
79
+ for i in range(start, end+1):
80
+ #iob_tags[i] = "I-"+ tag.get_labels('ner')[0].to_dict()['value']
81
+ iob_tags[i] = tag.get_labels('ner')[0].to_dict()['value']
82
+ #iob_tags[start] = "B-"+ tag.get_labels('ner')[0].to_dict()['value']
83
+ iob_tags[start] = tag.get_labels('ner')[0].to_dict()['value']
84
+ #print(tag, self.tokenizer.decode(text[start:end+1]))
85
+ except Exception as e:
86
+ print(tag)
87
+ print(e)
88
+ pass
89
+
90
+ return {"tokens": text, "ner_tags": iob_tags, "mentions": mentions}
91
+
92
+ def get_token_for_char(self, i, offsets):
93
+ for off in range(len(offsets)):
94
+ if i < offsets[off]:
95
+ return off - 1
96
+ return len(offsets) - 1
97
+
98
+
99
+ class FlairChunkingModel:
100
+ def __init__(self, model_id_for_tokenizer, flair_model_name):
101
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id_for_tokenizer, use_fast=True)
102
+ self.tagger = SequenceTagger.load(flair_model_name)
103
+
104
+ def annotate_text(self, text):
105
+ iob_tags = ['O'] * len(text)
106
+ mentions = []
107
+ text_str = self.tokenizer.decode(text)
108
+ sentence = Sentence(text_str)
109
+
110
+ # Predict NER tags
111
+ self.tagger.predict(sentence)
112
+
113
+ ner_tags = sentence.get_spans('np')
114
+
115
+ offsets = []
116
+ offset = 0
117
+ for i, token_id in enumerate(text):
118
+ offsets.append(offset)
119
+ offset = len(self.tokenizer.decode(text[:i+1]))
120
+ offsets.append(offset)
121
+
122
+ for tag in ner_tags:
123
+ try:
124
+ start = self.get_token_for_char(tag.start_position, offsets)
125
+ end = self.get_token_for_char(tag.end_position-1, offsets)
126
+ mentions.append([start, end])
127
+ for i in range(start, end+1):
128
+ #iob_tags[i] = "I-"+ tag.get_labels('ner')[0].to_dict()['value']
129
+ iob_tags[i] = tag.get_labels('np')[0].to_dict()['value']
130
+ #iob_tags[start] = "B-"+ tag.get_labels('ner')[0].to_dict()['value']
131
+ iob_tags[start] = tag.get_labels('np')[0].to_dict()['value']
132
+ #print(tag, self.tokenizer.decode(text[start:end+1]))
133
+ except Exception as e:
134
+ print(tag)
135
+ print(e)
136
+ pass
137
+
138
+ return {"tokens": text, "ner_tags": iob_tags, "mentions": mentions}
139
+
140
+ def get_token_for_char(self, i, offsets):
141
+ for off in range(len(offsets)):
142
+ if i < offsets[off]:
143
+ return off - 1
144
+ return len(offsets) - 1
145
+
146
+
147
+ class FlairFrameModel:
148
+ def __init__(self, model_id_for_tokenizer, flair_model_name):
149
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id_for_tokenizer, use_fast=True)
150
+ self.tagger = SequenceTagger.load(flair_model_name)
151
+
152
+ def annotate_text(self, text):
153
+ iob_tags = ['O'] * len(text)
154
+ mentions = []
155
+ text_str = self.tokenizer.decode(text)
156
+ sentence = Sentence(text_str)
157
+
158
+ # Predict NER tags
159
+ self.tagger.predict(sentence)
160
+
161
+ ner_tags = sentence.get_labels('frame')
162
+
163
+
164
+ offsets = []
165
+ offset = 0
166
+ for i, token_id in enumerate(text):
167
+ offsets.append(offset)
168
+ offset = len(self.tokenizer.decode(text[:i+1]))
169
+ offsets.append(offset)
170
+
171
+ for tag in ner_tags:
172
+ try:
173
+ start = self.get_token_for_char(tag.data_point.start_position, offsets)
174
+ end = self.get_token_for_char(tag.data_point.end_position-1, offsets)
175
+ mentions.append([start, end])
176
+ for i in range(start, end+1):
177
+ #iob_tags[i] = "I-"+ tag.get_labels('ner')[0].to_dict()['value']
178
+ iob_tags[i] = tag.to_dict()['value']
179
+ #iob_tags[start] = "B-"+ tag.get_labels('ner')[0].to_dict()['value']
180
+ iob_tags[start] = tag.to_dict()['value']
181
+ #print(tag, self.tokenizer.decode(text[start:end+1]))
182
+ except Exception as e:
183
+ print(tag)
184
+ print(e)
185
+ pass
186
+
187
+ return {"tokens": text, "ner_tags": iob_tags, "mentions": mentions}
188
+
189
+ def get_token_for_char(self, i, offsets):
190
+ for off in range(len(offsets)):
191
+ if i < offsets[off]:
192
+ return off - 1
193
+ return len(offsets) - 1
194
+
195
+
196
+ class FlairPOSModel:
197
+ def __init__(self, model_id_for_tokenizer, flair_model_name):
198
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id_for_tokenizer, use_fast=True)
199
+ self.tagger = SequenceTagger.load(flair_model_name)
200
+
201
+ def annotate_text(self, text):
202
+ iob_tags = ['O'] * len(text)
203
+ mentions = []
204
+ text_str = self.tokenizer.decode(text)
205
+ sentence = Sentence(text_str)
206
+
207
+ # Predict NER tags
208
+ self.tagger.predict(sentence)
209
+
210
+ ner_tags = sentence.get_labels('pos')
211
+
212
+
213
+ offsets = []
214
+ offset = 0
215
+ for i, token_id in enumerate(text):
216
+ offsets.append(offset)
217
+ offset = len(self.tokenizer.decode(text[:i+1]))
218
+ offsets.append(offset)
219
+
220
+ for tag in ner_tags:
221
+ try:
222
+ start = self.get_token_for_char(tag.data_point.start_position, offsets)
223
+ end = self.get_token_for_char(tag.data_point.end_position-1, offsets)
224
+ mentions.append([start, end])
225
+ for i in range(start, end+1):
226
+ #iob_tags[i] = "I-"+ tag.get_labels('ner')[0].to_dict()['value']
227
+ iob_tags[i] = tag.to_dict()['value']
228
+ #iob_tags[start] = "B-"+ tag.get_labels('ner')[0].to_dict()['value']
229
+ iob_tags[start] = tag.to_dict()['value']
230
+ #print(tag, self.tokenizer.decode(text[start:end+1]))
231
+ except Exception as e:
232
+ print(tag)
233
+ print(e)
234
+ pass
235
+
236
+ return {"tokens": text, "ner_tags": iob_tags, "mentions": mentions}
237
+
238
+ def get_token_for_char(self, i, offsets):
239
+ for off in range(len(offsets)):
240
+ if i < offsets[off]:
241
+ return off - 1
242
+ return len(offsets) - 1
243
+
244
+ class DataGenerator(object):
245
+ def __init__(self, config, reference_model):
246
+ self.config = config
247
+ self.model_id = self.config.language_model
248
+ self.reference_model = reference_model
249
+ self.output_path = self.config.path_data
250
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.language_model, use_fast=True)
251
+ if self.tokenizer.pad_token is None:
252
+ self.tokenizer.pad_token = self.tokenizer.eos_token
253
+
254
+ if self.config.cuda:
255
+ device = "cuda"
256
+ else:
257
+ device = "cpu"
258
+
259
+ self.model = AutoModelForCausalLM.from_pretrained(self.config.language_model).to(device)
260
+
261
+ json.dump({
262
+ "generation_kwargs": self.config.generation_kwargs,
263
+ "model_id": self.config.language_model,
264
+ "flair_model_name": reference_model.name,
265
+ }, open(self.config.path_config, "w"), indent=1)
266
+
267
+
268
+ def generate_text(self, prompts, generation_kwargs):
269
+ generated_texts = []
270
+ for prompt in tqdm(prompts, desc="Generating text"):
271
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
272
+ generated_text_ids = self.model.generate(input_ids=input_ids.to(self.model.device), pad_token_id=self.tokenizer.pad_token_id, **generation_kwargs)
273
+ generated_text = generated_text_ids[0].tolist()
274
+ prompt_token_ids = input_ids[0].tolist()
275
+ generated_texts.append({"prompt": prompt_token_ids, "full": generated_text})
276
+ return generated_texts
277
+
278
+ def annotate_text(self, texts):
279
+ annotated_texts = []
280
+ for text in tqdm(texts, desc="Annotating text"):
281
+ annotated_text = self.reference_model.annotate_text(text["full"])
282
+ annotated_texts.append(annotated_text)
283
+ return annotated_texts
284
+
285
+ def save_data(self, data):
286
+ with open(self.output_path, 'w') as f:
287
+ json.dump(data, f, indent=1)
288
+
stoke/src/data/util.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from nltk.tokenize.treebank import TreebankWordDetokenizer as Detok
3
+ from torch.utils.data import Dataset
4
+ import json
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import os
8
+ import datasets
9
+ import random
10
+
11
+
12
+
13
+ class Detokenizer(object):
14
+ # https://stackoverflow.com/a/46311499
15
+ def __init__(self) -> None:
16
+ self.detokenizer = Detok()
17
+
18
+ def __call__(self, tokens, return_offsets=False):
19
+ text = self.detokenizer.detokenize(tokens)
20
+ text = re.sub('\s*,\s*', ', ', text)
21
+ text = re.sub('\s*\.\s*', '. ', text)
22
+ text = re.sub('\s*\?\s*', '? ', text)
23
+ text = text.replace(" --", "--")
24
+
25
+ if return_offsets:
26
+ offsets = [0]
27
+ for i in range(1, len(tokens)):
28
+ offsets.append(len(self(tokens[:i])))
29
+
30
+ """
31
+ # verify offsets
32
+ for i, offset in enumerate(offsets):
33
+ if i == 0:
34
+ continue
35
+ check = text[:offset]
36
+ target = self(tokens[:i])
37
+ try:
38
+ assert target == check
39
+ except AssertionError:
40
+ print(tokens)
41
+ print(f"'{check}' != '{target}'")
42
+ raise AssertionError
43
+ """
44
+
45
+ return text.strip(), offsets
46
+ return text.strip()
47
+
48
+ class JSONDataset(Dataset):
49
+
50
+ def __init__(self, path):
51
+ super().__init__()
52
+
53
+ self.samples = json.load(open(path, "r"))
54
+
55
+ def __len__(self):
56
+ return len(self.samples)
57
+
58
+ def __getitem__(self, idx):
59
+ return self.samples[idx]
60
+
61
+
62
+ def create_mask_for_len(seq_len, pad_to=None, skip_start=0, window=None):
63
+ mask = (-1 * (torch.triu(torch.ones(seq_len, seq_len), diagonal=1) - 1)).bool()
64
+
65
+ if skip_start != 0:
66
+ mask[:skip_start, :] = False
67
+ mask[:, :skip_start] = False
68
+
69
+ if window is not None:
70
+ for i in range(window, seq_len):
71
+ mask[i, :max(i-window, 0)] = False
72
+
73
+ if pad_to is None:
74
+ return mask
75
+
76
+ return F.pad(mask, (0, pad_to-seq_len, 0, pad_to-seq_len))
77
+
78
+
79
+ def collate_function_with_label_map(batch, label_map):
80
+ # prepare token ids
81
+ sequences = [torch.tensor(x['tokens']) for x in batch]
82
+ input_ids = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=0)
83
+
84
+ # prepare labels
85
+ labels_tokens = []
86
+ for batchitem in batch:
87
+ labels_tokens.append(torch.tensor([label_map.index(x) for x in batchitem['ner_tags']]))
88
+ labels_tokens = torch.nn.utils.rnn.pad_sequence(labels_tokens, batch_first=True, padding_value=0).to(torch.long)
89
+
90
+ labels_spans = torch.zeros((len(batch), input_ids.shape[-1], input_ids.shape[-1]))
91
+ for i, batchitem in enumerate(batch):
92
+ for mnt in batchitem['mentions']:
93
+ start, end = mnt
94
+ try:
95
+ labels_spans[i, end+1, start] = 1.0
96
+ except:
97
+ pass
98
+
99
+ # prepare masks
100
+ masks_tokens = [torch.tensor([False]+([True]*(len(x)-1))) for x in sequences]
101
+ masks_tokens = torch.nn.utils.rnn.pad_sequence(masks_tokens, batch_first=True, padding_value=False)
102
+ mask_spans = torch.stack([create_mask_for_len(len(x['tokens']), input_ids.shape[-1], skip_start=0, window=15) for x in batch])
103
+
104
+ # mask labels
105
+ labels_tokens = torch.masked_select(labels_tokens, masks_tokens).long()
106
+ labels_spans = torch.masked_select(labels_spans, mask_spans).long()
107
+
108
+ return {
109
+ 'input_ids': input_ids,
110
+ 'labels_tokens': labels_tokens,
111
+ 'labels_spans': labels_spans,
112
+ 'mask_tokens': masks_tokens.unsqueeze(-1),
113
+ 'mask_spans': mask_spans.unsqueeze(-1)
114
+ }
115
+
116
+ def print_metric(metric, class_labels, return_classwise=False, verbose=False):
117
+
118
+ f_ner = metric.compute()
119
+ p_ner = torch.nan_to_num(metric.num_tp / metric.num_prediction)
120
+ r_ner = torch.nan_to_num(metric.num_tp / metric.num_label)
121
+
122
+ if verbose:
123
+ print(f"{' '.ljust(10)} P R F S")
124
+
125
+ sum_support = 0
126
+ weighted_scores = [0, 0, 0]
127
+
128
+ classwise = {}
129
+ for ner_class, p, r, f, s in zip(class_labels, p_ner, r_ner, f_ner, metric.num_label):
130
+ if ner_class == "NONE" or ner_class == "O" or ner_class == "no_relation" or ner_class == "no_span":
131
+ continue
132
+ if verbose:
133
+ print(f"{ner_class.ljust(10)} - {p:.2f} - {r:.2f} - {f:.2f} - {int(s)}")
134
+ weighted_scores[0] += p*s
135
+ weighted_scores[1] += r*s
136
+ weighted_scores[2] += f*s
137
+ sum_support += s
138
+
139
+ classwise[ner_class] = {"p": p.item(), "r": r.item(), "f": f.item(), "s": s.item()}
140
+
141
+ p_micro = weighted_scores[0]/sum_support
142
+ r_micro = weighted_scores[1]/sum_support
143
+ f_micro = weighted_scores[2]/sum_support
144
+
145
+ classwise["macro"] = {"p": torch.mean(p_ner[1:]).item(), "r": torch.mean(r_ner[1:]).item(), "f": torch.mean(f_ner[1:]).item()}
146
+
147
+ if verbose:
148
+ print("")
149
+ print(f"MICRO - {p_micro:.2f} - {r_micro:.2f} - {f_micro:.2f}")
150
+ print(f"MACRO - {torch.mean(p_ner[1:]):.2f} - {torch.mean(r_ner[1:]):.2f} - {torch.mean(f_ner[1:]):.2f}")
151
+ print("")
152
+
153
+ if return_classwise:
154
+ return (p_micro.item(), r_micro.item(), f_micro.item()), classwise
155
+
156
+ return p_micro.item(), r_micro.item(), f_micro.item()
157
+
158
+ class GenerationConfig:
159
+
160
+ def __init__(self, language_model, output_path, dataset_name, cuda=False, generation_kwargs={}):
161
+ self.language_model = language_model
162
+ self.output_path = output_path
163
+ self.dataset_name = dataset_name
164
+ self.cuda = cuda
165
+ self.generation_kwargs = generation_kwargs
166
+
167
+ self.path_data = os.path.join(output_path, f"{language_model}/{dataset_name}/data.json")
168
+ self.path_config = os.path.join(output_path, f"{language_model}/{dataset_name}/config.json")
169
+
170
+ if not os.path.exists(os.path.join(output_path, f"{language_model}/{dataset_name}")):
171
+ os.makedirs(os.path.join(output_path, f"{language_model}/{dataset_name}"))
172
+
173
+
174
+ def conll_prompts():
175
+ ds = datasets.load_dataset("conll2003")["validation"]
176
+ dtk = Detokenizer()
177
+ prompts = [dtk(x["tokens"]) for x in ds]
178
+ ds = datasets.load_dataset("conll2003")["train"]
179
+ prompts += [dtk(x["tokens"]) for x in ds]
180
+ return prompts
181
+
182
+
183
+ def partition_dataset(data, split_sizes):
184
+ random.shuffle(data)
185
+
186
+ total_size = len(data)
187
+ split_points = [int(total_size * size) for size in split_sizes[:-1]]
188
+
189
+ datasets = []
190
+ start_idx = 0
191
+ for split_point in split_points:
192
+ datasets.append(data[start_idx: start_idx + split_point])
193
+ start_idx += split_point
194
+ datasets.append(data[start_idx:])
195
+
196
+ return datasets
197
+
198
+ def stats(ds, keys=None):
199
+ mentions_total = 0
200
+ mentions_per_type = {}
201
+ if keys is not None:
202
+ for key in keys:
203
+ mentions_per_type[key] = 0
204
+ for sample in ds:
205
+ mentions_total += len(sample['mentions'])
206
+ for mnt in sample['mentions']:
207
+ tag = sample['ner_tags'][mnt[0]]
208
+ if tag not in mentions_per_type.keys():
209
+ mentions_per_type[tag] = 0
210
+ mentions_per_type[tag] += 1
211
+ return len(mentions_per_type.keys()), mentions_total, mentions_per_type
212
+
213
+
214
+ def split_data(path_to_data, split_names=["train", "validation", "test"], split_sizes=[0.8, 0.1, 0.1]):
215
+ with open(path_to_data, 'r') as file:
216
+ data = json.load(file)
217
+
218
+ annotation_types = sorted(list(stats(data)[-1].keys()))
219
+ datasets = partition_dataset(data, split_sizes)
220
+
221
+ for i, dataset in enumerate(datasets):
222
+ ds = []
223
+ for x in dataset:
224
+ out = {}
225
+ out["tokens"] = x["tokens"]
226
+ out["ner_tags"] = [y.replace("I-", "").replace("B-", "") for y in x["ner_tags"]]
227
+ out["mentions"] = x["mentions"]
228
+ ds.append(out)
229
+
230
+ print(f"Size of dataset {split_names[i]}: {len(ds)}")
231
+ json.dump(ds, open(f"{path_to_data.split('.json')[0]}_{split_names[i]}.json", "w"))
232
+ json.dump(stats(dataset, annotation_types)[-1], open(f"{path_to_data.split('.json')[0]}_{split_names[i]}_stats.json", "w"), indent=1)
233
+
stoke/src/playground/__init__.py ADDED
File without changes
stoke/src/playground/app.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, STOKEStreamer
3
+ from threading import Thread
4
+ import json
5
+ import torch
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.colors import to_hex
8
+ import numpy as np
9
+ import os
10
+ import urllib.request
11
+ import zipfile
12
+
13
+
14
+ class MLP(torch.nn.Module):
15
+ def __init__(self, input_dim, output_dim, hidden_dim=1024, layer_id=0, cuda=False):
16
+ super(MLP, self).__init__()
17
+ self.fc1 = torch.nn.Linear(input_dim, hidden_dim) # Input layer to hidden layer
18
+ self.fc3 = torch.nn.Linear(hidden_dim, output_dim) # Hidden layer to output layer
19
+ self.layer_id = layer_id
20
+ if cuda:
21
+ self.device = "cuda"
22
+ else:
23
+ self.device = "cpu"
24
+ self.to(self.device)
25
+
26
+ def forward(self, x):
27
+ x = torch.flatten(x, start_dim=1)
28
+ x = torch.relu(self.fc1(x))
29
+ x = self.fc3(x)
30
+
31
+ return torch.argmax(x, dim=-1).cpu().detach(), torch.softmax(x, dim=-1).cpu().detach()
32
+
33
+ def map_value_to_color(value, colormap_name='tab20c'):
34
+ """
35
+ Map a value between 0 and 1 to a CSS color using a Python colormap.
36
+
37
+ Args:
38
+ value (float): A value between 0 and 1.
39
+ colormap_name (str): The name of the colormap to use (e.g., 'viridis').
40
+
41
+ Returns:
42
+ str: A CSS color string in the form 'rgb(r, g, b)'.
43
+ """
44
+ # Ensure the value is within the range [0, 1]
45
+ value = np.clip(value, 0.0, 1.0)
46
+
47
+ # Get the colormap
48
+ colormap = plt.get_cmap(colormap_name)
49
+
50
+ # Map the value to a color
51
+ rgba_color = colormap(value)
52
+
53
+ # Convert the RGBA color to CSS format
54
+ css_color = to_hex(rgba_color)
55
+
56
+ return css_color + "88"
57
+
58
+ @st.cache_resource
59
+ def get_model_and_tokenizer(name):
60
+ # Load pre-trained model and tokenizer
61
+ tok = AutoTokenizer.from_pretrained(name)
62
+ model = AutoModelForCausalLM.from_pretrained(name)
63
+ return model, tok
64
+
65
+ @st.cache_resource
66
+ def get_classifiers_for_model(att_size, emb_size, device, config_paths):
67
+ classifier_token = None
68
+ #print(config)
69
+ config = {
70
+ "classifier_token": json.load(open(os.path.join(config_paths["classifier_token"], "config.json"), "r")),
71
+ "classifier_span": json.load(open(os.path.join(config_paths["classifier_span"], "config.json"), "r"))
72
+ }
73
+
74
+ layer_id = config["classifier_token"]["layer"]
75
+
76
+ classifier_span = MLP(att_size, 2, hidden_dim=config["classifier_span"]["classifier_dim"]).to(device)
77
+ classifier_span.load_state_dict(torch.load(os.path.join(config_paths["classifier_span"], "checkpoint.pt"), map_location=device))
78
+
79
+ classifier_token = MLP(emb_size, len(config["classifier_token"]["label_map"]), layer_id=layer_id, hidden_dim=config["classifier_token"]["classifier_dim"]).to(device)
80
+ classifier_token.load_state_dict(torch.load(os.path.join(config_paths["classifier_token"], "checkpoint.pt"), map_location=device))
81
+
82
+ print(sum(p.numel() for p in classifier_span.parameters()), sum(p.numel() for p in classifier_token.parameters()))
83
+
84
+ return classifier_span, classifier_token, config["classifier_token"]["label_map"]
85
+
86
+ def get_available_models():
87
+ available_models = []
88
+ for model_name in ["gpt2", "gpt2-xl"]:
89
+ if os.path.isfile(f"checkpoints/{model_name}/config.json"):
90
+ available_models.append(model_name)
91
+ return available_models
92
+
93
+ def get_available_datasets(model_name):
94
+ available_datasets = []
95
+ config_path = f"checkpoints/{model_name}/config.json"
96
+ if os.path.isfile(config_path):
97
+ with open(config_path, "r") as f:
98
+ config = json.load(f)
99
+ # Assuming datasets are keys in config.json
100
+ available_datasets = list(config.keys())
101
+ return available_datasets
102
+
103
+ def download_and_extract_zip(url, extract_dir):
104
+ # Determine the parent directory
105
+ parent_dir = os.path.split(os.path.dirname(extract_dir))[-2]
106
+ print(parent_dir)
107
+
108
+ # Download the zip file to the parent directory
109
+ zip_file_path = os.path.join(parent_dir, "data.zip")
110
+ urllib.request.urlretrieve(url, zip_file_path)
111
+
112
+ # Extract the zip file
113
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
114
+ zip_ref.extractall(parent_dir)
115
+
116
+ # Remove the zip file
117
+ os.remove(zip_file_path)
118
+
119
+ def find_datasets_and_model_ids(root_dir):
120
+ datasets = {}
121
+
122
+ # Check if the root directory exists
123
+ if not os.path.exists(root_dir):
124
+ # If root directory doesn't exist, download a zip file and unpack it
125
+ print("Root directory doesn't exist. Downloading zip file...")
126
+ url = "https://drive.usercontent.google.com/download?id=1dHjH_J0zuPS-SDVrh49tMpIx5ramu_hc&export=download&authuser=0&confirm=t&uuid=4efcec77-571c-44c7-82f1-f39ddae50eb5&at=APZUnTW8g-Ab4PUT0-B9mh4jQSc-%3A1711040271924" # Replace with your actual download URL
127
+ download_and_extract_zip(url, root_dir)
128
+ print("Zip file downloaded and unpacked successfully.")
129
+
130
+
131
+ for root, dirs, files in os.walk(root_dir):
132
+ if 'config.json' in files and 'stoke_config.json' in files:
133
+ config_path = os.path.join(root, 'config.json')
134
+ stoke_config_path = os.path.join(root, 'stoke_config.json')
135
+
136
+ with open(config_path, 'r') as f:
137
+ config_data = json.load(f)
138
+ model_id = config_data.get('model_id')
139
+ if model_id:
140
+ dataset_name = os.path.basename(os.path.dirname(config_path))
141
+
142
+ with open(stoke_config_path, 'r') as f:
143
+ stoke_config_data = json.load(f)
144
+ if model_id:
145
+ dataset_name = os.path.basename(os.path.dirname(stoke_config_path))
146
+ datasets.setdefault(model_id, {})[dataset_name] = stoke_config_data
147
+
148
+ return datasets
149
+
150
+
151
+ # Main content
152
+ st.title("Playground")
153
+
154
+ # Sidebar for model and dataset selection
155
+ with st.sidebar:
156
+ st.subheader("Model and Dataset Selection")
157
+ datasets = find_datasets_and_model_ids("data/")
158
+ available_models = datasets.keys()
159
+ print(datasets)
160
+ if available_models:
161
+ model_selection = st.selectbox("Select Model", available_models)
162
+ else:
163
+ st.error("No models available. Please check the file paths.")
164
+
165
+ # Select dataset based on selected model
166
+ available_datasets = datasets[model_selection]
167
+ if available_datasets:
168
+ dataset_selection = st.selectbox("Select Dataset", available_datasets)
169
+ else:
170
+ st.error("No datasets available for the selected model.")
171
+
172
+ # Select dataset based on selected model
173
+ available_configs = datasets[model_selection][dataset_selection]
174
+ if available_configs:
175
+ config_selection = st.selectbox("Select Config", available_configs.keys())
176
+ else:
177
+ st.error("No configs available for the selected dataset.")
178
+
179
+ # Load model and streamer based on selections
180
+ model, tok = get_model_and_tokenizer(model_selection)
181
+ if torch.cuda.is_available():
182
+ model.cuda()
183
+ classifier_span, classifier_token, label_map = get_classifiers_for_model(model.config.n_head*model.config.n_layer, model.config.n_embd, model.device, datasets[model_selection][dataset_selection][config_selection])
184
+ streamer = STOKEStreamer(tok, classifier_token, classifier_span)
185
+
186
+ new_tags = label_map
187
+
188
+
189
+ def filter_spans(spans_and_values):
190
+ if spans_and_values == []:
191
+ return [], []
192
+ # Create a dictionary to store spans based on their second index values
193
+ span_dict = {}
194
+
195
+ spans, values = [x[0] for x in spans_and_values], [x[1] for x in spans_and_values]
196
+
197
+ # Iterate through the spans and update the dictionary with the highest value
198
+ for span, value in zip(spans, values):
199
+ start, end = span
200
+ if start > end or end - start > 15 or start == 0:
201
+ continue
202
+ current_value = span_dict.get(end, None)
203
+
204
+ if current_value is None or current_value[1] < value:
205
+ span_dict[end] = (span, value)
206
+
207
+ if span_dict == {}:
208
+ return [], []
209
+ # Extract the filtered spans and values
210
+ filtered_spans, filtered_values = zip(*span_dict.values())
211
+
212
+ return list(filtered_spans), list(filtered_values)
213
+
214
+ def remove_overlapping_spans(spans):
215
+ # Sort the spans based on their end points
216
+ sorted_spans = sorted(spans, key=lambda x: x[0][1])
217
+
218
+ non_overlapping_spans = []
219
+ last_end = float('-inf')
220
+
221
+ # Iterate through the sorted spans
222
+ for span in sorted_spans:
223
+ start, end = span[0]
224
+ value = span[1]
225
+
226
+ # If the current span does not overlap with the previous one
227
+ if start >= last_end:
228
+ non_overlapping_spans.append(span)
229
+ last_end = end
230
+ else:
231
+ # If it overlaps, choose the one with the highest value
232
+ existing_span_index = -1
233
+ for i, existing_span in enumerate(non_overlapping_spans):
234
+ if existing_span[0][1] <= start:
235
+ existing_span_index = i
236
+ break
237
+ if existing_span_index != -1 and non_overlapping_spans[existing_span_index][1] < value:
238
+ non_overlapping_spans[existing_span_index] = span
239
+
240
+ return non_overlapping_spans
241
+
242
+ def generate_html_no_overlap(tokenized_text, spans):
243
+ current_index = 0
244
+ html_content = ""
245
+
246
+ for (span_start, span_end), value in spans:
247
+ # Add text before the span
248
+ html_content += "".join(tokenized_text[current_index:span_start])
249
+
250
+ # Add the span with underlining
251
+ html_content += "<b><u>"
252
+ html_content += "".join(tokenized_text[span_start:span_end])
253
+ html_content += "</u></b> "
254
+
255
+ current_index = span_end
256
+
257
+ # Add any remaining text after the last span
258
+ html_content += "".join(tokenized_text[current_index:])
259
+
260
+ return html_content
261
+
262
+
263
+ css = """
264
+ <style>
265
+ .highlight {
266
+ display: inline;
267
+ }
268
+ .highlight::after {
269
+ background-color: var(data-color);
270
+ }
271
+ .spanhighlight {
272
+ padding: 2px 5px;
273
+ border-radius: 5px;
274
+ }
275
+ .tooltip {
276
+ position: relative;
277
+ display: inline-block;
278
+ }
279
+
280
+ .tooltip::after {
281
+ content: attr(data-tooltip-text); /* Set content from data-tooltip-text attribute */
282
+ display: none;
283
+ position: absolute;
284
+ background-color: #333;
285
+ color: #fff;
286
+ padding: 5px;
287
+ border-radius: 5px;
288
+ bottom: 100%; /* Position it above the element */
289
+ left: 50%;
290
+ transform: translateX(-50%);
291
+ width: auto;
292
+ min-width: 120px;
293
+ margin: 0 auto;
294
+ text-align: center;
295
+ }
296
+
297
+ .tooltip:hover::after {
298
+ display: block; /* Show the tooltip on hover */
299
+ }
300
+
301
+ .small-text {
302
+ padding: 2px 5px;
303
+ background-color: white;
304
+ border-radius: 5px;
305
+ font-size: xx-small;
306
+ margin-left: 0.5em;
307
+ vertical-align: 0.2em;
308
+ font-weight: bold;
309
+ color: grey;
310
+ }
311
+ </style>"""
312
+
313
+
314
+ def generate_html_spanwise(token_strings, tokenwise_preds, spans, tokenizer):
315
+
316
+ # spanwise annotated text
317
+ annotated = []
318
+ span_ends = -1
319
+ in_span = False
320
+
321
+ out_of_span_tokens = []
322
+ for i in reversed(range(len(tokenwise_preds))):
323
+
324
+ if in_span:
325
+ if i >= span_ends:
326
+ continue
327
+ else:
328
+ in_span = False
329
+
330
+ predicted_class = ""
331
+ style = ""
332
+
333
+ span = None
334
+ for s in spans:
335
+ if s[1] == i+1:
336
+ span = s
337
+
338
+ if tokenwise_preds[i] != 0 and span is not None:
339
+ predicted_class = f"highlight spanhighlight"
340
+ style = f"background-color: {map_value_to_color((tokenwise_preds[i]-1)/(len(new_tags)-1))}"
341
+ if tokenizer.convert_tokens_to_string([token_strings[i]]).startswith(" "):
342
+ annotated.append("Ġ")
343
+
344
+ span_opener = f"Ġ<span class='{predicted_class}' data-tooltip-text='{new_tags[tokenwise_preds[i]]}' style='{style}'>".replace(" ", "Ġ")
345
+ span_end = f"<span class='small-text'>{new_tags[tokenwise_preds[i]]}</span></span>"
346
+ annotated.extend(out_of_span_tokens)
347
+ out_of_span_tokens = []
348
+ span_ends = span[0]
349
+ in_span = True
350
+ annotated.append(span_end)
351
+ annotated.extend([token_strings[x] for x in reversed(range(span[0], span[1]))])
352
+ annotated.append(span_opener)
353
+ else:
354
+ out_of_span_tokens.append(token_strings[i])
355
+
356
+ annotated.extend(out_of_span_tokens)
357
+
358
+ return [x for x in reversed(annotated)]
359
+
360
+ # Define function to generate text based on input
361
+ def generate_text(generation_kwargs, output_field):
362
+
363
+ # Function to generate text in a separate thread
364
+ def generate_async():
365
+ model.generate(**generation_kwargs)
366
+
367
+ # Start text generation in a separate thread
368
+ thread = Thread(target=generate_async)
369
+ thread.start()
370
+
371
+ # Display generated text as it becomes available
372
+ text_tokenwise = ""
373
+ text_spans = ""
374
+ removed_spans = ""
375
+ tags = []
376
+ spans = []
377
+ for new_text in streamer:
378
+ if new_text[1] is not None and new_text[2] != ['']:
379
+ text_tokenwise = ""
380
+ tags.extend(new_text[1])
381
+ spans.extend(new_text[-1])
382
+
383
+ # Tokenwise Classification
384
+ for tk, pred in zip(new_text[2],tags):
385
+ if pred != 0:
386
+ style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}"
387
+ if tk.startswith(" "):
388
+ text_tokenwise += " "
389
+ text_tokenwise += f"<span class='tooltip highlight' data-tooltip-text='{new_tags[pred]}' style='{style}'>{tk}</span>"
390
+ else:
391
+ text_tokenwise += tk
392
+
393
+ # Span Classification
394
+ text_spans = ""
395
+ if len(spans) > 0:
396
+ filtered_spans = remove_overlapping_spans(spans)
397
+ text_spans = generate_html_no_overlap(new_text[2], filtered_spans)
398
+ if len(spans) - len(filtered_spans) > 0:
399
+ removed_spans = f"{len(spans) - len(filtered_spans)} span(s) hidden due to overlap."
400
+ else:
401
+ for tk in new_text[2]:
402
+ text_spans += f"{tk}"
403
+
404
+ # Spanwise Classification
405
+ annotated_tokens = generate_html_spanwise(new_text[2], tags, [x for x in filter_spans(spans)[0]], tok)
406
+ generated_text_spanwise = tok.convert_tokens_to_string(annotated_tokens).replace("<|endoftext|>", "")
407
+
408
+ output_field.empty()
409
+ output = f"{css}"
410
+ output += generated_text_spanwise.replace("\n", " ").replace("$", "$") + "\n<br>"
411
+ output += "<details><summary>Show tokenwise classification</summary>\n" + text_tokenwise.replace("\n", " ").replace("$", "\\$")
412
+ #output += "</details><details><summary>Show spans</summary>\n" + text_spans.replace("\n", " ").replace("$", "\\$")
413
+ if removed_spans != "":
414
+ output += f"<br><br><i>({removed_spans})</i>"
415
+ output += "</details>"
416
+ output_field.write(output, unsafe_allow_html=True)
417
+
418
+ # Input field
419
+ input_text = st.text_area("Enter prompt for completion", "")
420
+
421
+ # Sidebar for customizing generation parameters
422
+ with st.sidebar:
423
+ st.subheader("Generation Parameters")
424
+ max_new_tokens = st.slider("Max New Tokens", min_value=1, max_value=100, value=30)
425
+ repetition_penalty = st.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2)
426
+ do_sample = st.checkbox("Do Sample", value=True)
427
+ temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=1.0)
428
+ top_p = st.slider("Top-p", min_value=0.1, max_value=1.0, value=0.3)
429
+ top_k = st.slider("Top-k", min_value=10, max_value=100, value=50)
430
+ typical_p = st.slider("Typical P", min_value=0.1, max_value=1.0, value=1.0)
431
+
432
+ # Button to generate text
433
+ if st.button("Generate"):
434
+ if input_text:
435
+ output_field = st.empty()
436
+ inputs = tok([" " + input_text], return_tensors="pt").to(model.device)
437
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens,
438
+ repetition_penalty=repetition_penalty, temperature=temperature,
439
+ top_p=top_p, top_k=top_k, do_sample=do_sample, typical_p=typical_p)
440
+ generate_text(generation_kwargs, output_field)
441
+ else:
442
+ st.warning("Please enter some text first.")
stoke/src/selection/__init__.py ADDED
File without changes
stoke/src/selection/simple.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+
5
+ def find_best_checkpoint(path):
6
+ checkpoints_path = os.path.join(path, "checkpoints")
7
+ token_classifier_path = os.path.join(checkpoints_path, "token_classifier")
8
+ span_classifier_path = os.path.join(checkpoints_path, "span_classifier")
9
+
10
+ best_token_checkpoint = find_best_checkpoint_in_folder(token_classifier_path)
11
+ best_span_checkpoint = find_best_checkpoint_in_folder(span_classifier_path)
12
+
13
+ return best_token_checkpoint, best_span_checkpoint
14
+
15
+ def find_best_checkpoint_in_folder(folder_path):
16
+ best_checkpoint = None
17
+ best_f1_validation = -1
18
+
19
+ for subfolder in os.listdir(folder_path):
20
+ subfolder_path = os.path.join(folder_path, subfolder)
21
+ config_path = os.path.join(subfolder_path, "config.json")
22
+ checkpoint_path = os.path.join(subfolder_path, "checkpoint.pt")
23
+
24
+ if os.path.exists(config_path) and os.path.exists(checkpoint_path):
25
+ with open(config_path, 'r') as config_file:
26
+ config_data = json.load(config_file)
27
+ if "best_f1_validation" in config_data:
28
+ f1_validation = config_data["best_f1_validation"]
29
+ if f1_validation > best_f1_validation:
30
+ best_f1_validation = f1_validation
31
+ best_checkpoint = subfolder_path
32
+
33
+ return best_checkpoint
34
+
35
+ def create_config_for_path(path, name="default"):
36
+
37
+ best_token_checkpoint, best_span_checkpoint = find_best_checkpoint(path)
38
+
39
+ print("Best token classifier checkpoint:", best_token_checkpoint)
40
+ print("Best span classifier checkpoint:", best_span_checkpoint)
41
+
42
+ config = {
43
+ "classifier_token": best_token_checkpoint,
44
+ "classifier_span": best_span_checkpoint
45
+ }
46
+
47
+ configs_path = os.path.join(path, "stoke_config.json")
48
+
49
+ if os.path.exists(configs_path):
50
+ with open(configs_path, 'r') as configs_file:
51
+ existing_configs = json.load(configs_file)
52
+ else:
53
+ existing_configs = {}
54
+
55
+ existing_configs[name] = config
56
+
57
+ with open(configs_path, 'w') as configs_file:
58
+ json.dump(existing_configs, configs_file, indent=4)
59
+
60
+ print(f"Config '{name}' saved successfully.")
stoke/src/trainer/__init__.py ADDED
File without changes
stoke/src/trainer/trainer.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .util import TrainConfig
2
+ from ..data.util import JSONDataset, collate_function_with_label_map, print_metric
3
+ from ..classifier.probes import MLPProbe as MLP
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ from torcheval.metrics import MulticlassF1Score
8
+ from transformers.optimization import get_linear_schedule_with_warmup
9
+ from torch.optim import AdamW
10
+ import json
11
+ import os
12
+ import string
13
+ import random
14
+ from tqdm import tqdm
15
+
16
+
17
+ class Trainer:
18
+
19
+ def __init__(self, config:TrainConfig):
20
+ self.config = config
21
+ self._load_model()
22
+ self._load_data()
23
+ self._load_probes_and_optimizers()
24
+ print("Trainer is ready.")
25
+
26
+ def _load_model(self):
27
+ "Loads language model and tokenizer"
28
+ print(f"Loading model '{self.config.config_dataset['model_id']}'")
29
+
30
+ # check if custom huggingface cache was selected
31
+ kwds = {
32
+ }
33
+ if self.config.hfcache != "":
34
+ kwds["cache_dir"] = self.config.hfcache
35
+
36
+ # load model and tokenizer
37
+ self.model = AutoModelForCausalLM.from_pretrained(self.config.config_dataset['model_id'], output_attentions=True, output_hidden_states=True, return_dict=True, device_map="auto", **kwds).half()
38
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.config_dataset['model_id'], use_fast=True, **kwds)
39
+
40
+ if self.config.cuda:
41
+ self.model.cuda()
42
+
43
+ print("model and tokenizer loaded")
44
+
45
+ def _load_data(self):
46
+ "Loads datasets"
47
+ def collate_function(batch):
48
+ return collate_function_with_label_map(batch, self.config.label_map)
49
+
50
+ datasets = {}
51
+ self.dataloaders = {}
52
+ num_classes = None
53
+ self.config.label_map = []
54
+ for split in self.config.splits:
55
+ datasets[split] = JSONDataset(os.path.join(self.config.path, f"data_{split}.json"))
56
+ shuffle = False
57
+ if split == "train":
58
+ shuffle = True
59
+ self.dataloaders[split] = DataLoader(datasets[split], batch_size=self.config.batch_size, shuffle=shuffle, collate_fn=collate_function)
60
+
61
+ dataset_classes = json.load(open(os.path.join(self.config.path, f"data_{split}_stats.json"), "r"))
62
+ if num_classes is None:
63
+ num_classes = len(dataset_classes.keys())
64
+ self.config.label_map = ["O"] + list(dataset_classes.keys())
65
+ else:
66
+ assert len(dataset_classes.keys()) == num_classes
67
+ print(f"Loaded {split} dataset with {len(datasets[split])} samples and {num_classes} classes")
68
+
69
+ def _load_probes_and_optimizers(self):
70
+ "Loads probes and optimizers."
71
+ print("Preparing probes and optimizers")
72
+ if type(self.model.config):
73
+ n_layers = self.model.config.num_hidden_layers
74
+ n_heads = self.model.config.num_attention_heads
75
+ dim_hidden = self.model.config.hidden_size
76
+
77
+ print(f"Model has {n_layers} layers, hidden state size {dim_hidden}, and {n_heads} attention heads per layer")
78
+
79
+ if self.config.layers is None:
80
+ self.config.layers = [x for x in range(n_layers)]
81
+
82
+ print(f"Training tokenwise classifiers for {len(self.config.layers)} layer(s), {len(self.config.learning_rates)} learning rate(s), {len(self.config.classifier_dims)} hidden dims.")
83
+
84
+ self.classifier_device = "cpu"
85
+ if self.config.cuda:
86
+ self.classifier_device = "cuda"
87
+
88
+ if self.config.balance_loss:
89
+ class_frequency = [0 for _ in self.config.label_map]
90
+ for sample in self.dataloaders["train"]:
91
+ labels = [self.config.label_map.index(x) for x in sample['ner_tags']]
92
+ for x in labels:
93
+ class_frequency[x] += 1
94
+
95
+ class_weights = [sum(class_frequency)/x for x in class_frequency]
96
+ else:
97
+ class_weights = [1.0 for _ in self.config.label_map]
98
+
99
+
100
+ self.token_classifiers = []
101
+ for layer in self.config.layers:
102
+ for lr in self.config.learning_rates:
103
+ for dim_c in self.config.classifier_dims:
104
+ # set up classifier, optimizer, scheduler, and config
105
+ _classifier = MLP(dim_hidden, len(self.config.label_map), hidden_dim=dim_c, cuda=self.config.cuda)
106
+ _optimizer = AdamW(_classifier.parameters(), lr=lr, eps=1e-6)
107
+ _scheduler = get_linear_schedule_with_warmup(_optimizer, self.config.n_steps_per_epoch, self.config.n_epochs*self.config.n_steps_per_epoch)
108
+ _config = {
109
+ "layer": layer,
110
+ "model": self.config.config_dataset['model_id'],
111
+ "type": "token_classifier",
112
+ "label_map": self.config.label_map,
113
+ "learning_rate": lr,
114
+ "classifier_dim": dim_c,
115
+ "loss_weights": class_weights,
116
+ "identifier": ''.join(random.SystemRandom().choice(string.ascii_letters + string.digits) for _ in range(10)),
117
+ "best_f1_validation": -1,
118
+ "best_f1_validation_classwise": 0,
119
+ }
120
+ self.token_classifiers.append({
121
+ "config_train": self.config,
122
+ "config": _config,
123
+ "classifier": _classifier,
124
+ "optimizer": _optimizer,
125
+ "lr_scheduler": _scheduler,
126
+ "metric": MulticlassF1Score(num_classes=len(self.config.label_map), average=None, device=_classifier.device),
127
+ "criterion": torch.nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(self.classifier_device))
128
+
129
+
130
+ })
131
+
132
+ print(f"Total tokenwise classifiers: {len(self.token_classifiers)}")
133
+
134
+ print(f"Training span detectors for {len(self.config.learning_rates)} learning rate(s), {len(self.config.loss_weights_span)} loss weight(s), {len(self.config.classifier_dims)} hidden dim(s).")
135
+
136
+ self.span_classifiers = []
137
+ for lr in self.config.learning_rates:
138
+ for dim_c in self.config.classifier_dims:
139
+ for loss_weight in self.config.loss_weights_span:
140
+ # set up classifier, optimizer, scheduler, and config
141
+ _classifier = MLP(n_layers*n_heads, 2, hidden_dim=dim_c, cuda=self.config.cuda)
142
+ _optimizer = AdamW(_classifier.parameters(), lr=lr, eps=1e-6)
143
+ _scheduler = get_linear_schedule_with_warmup(_optimizer, self.config.n_steps_per_epoch, self.config.n_epochs*self.config.n_steps_per_epoch)
144
+ _config = {
145
+ "model": self.config.config_dataset['model_id'],
146
+ "type": "span_classifier",
147
+ "label_map": ["no_span", "span"],
148
+ "learning_rate": lr,
149
+ "classifier_dim": dim_c,
150
+ "loss_weights": loss_weight,
151
+ "identifier": ''.join(random.SystemRandom().choice(string.ascii_letters + string.digits) for _ in range(10)),
152
+ "best_f1_validation": -1,
153
+ "best_f1_validation_classwise": 0,
154
+ }
155
+ self.span_classifiers.append({
156
+ "config_train": self.config,
157
+ "config": _config,
158
+ "classifier": _classifier,
159
+ "optimizer": _optimizer,
160
+ "lr_scheduler": _scheduler,
161
+ "metric": MulticlassF1Score(num_classes=2, average=None, device=_classifier.device),
162
+ "criterion": torch.nn.CrossEntropyLoss(weight=torch.tensor(loss_weight).to(self.classifier_device))
163
+ })
164
+ print(f"Total span detectors: {len(self.span_classifiers)}")
165
+
166
+ def train(self):
167
+ data_iter_train = iter(self.dataloaders["train"])
168
+
169
+ self.best_f1 = {"token_classifier":-1, "span_classifier":-1}
170
+ self.best_config = {"token_classifier":None, "span_classifier":None}
171
+
172
+
173
+ for epoch in range(self.config.n_epochs):
174
+
175
+ # TRAIN
176
+ for item in self.token_classifiers + self.span_classifiers:
177
+ item['classifier'].train()
178
+ item['metric'].reset()
179
+
180
+ for step in tqdm(range(self.config.n_steps_per_epoch)):
181
+
182
+ # Get data
183
+ try:
184
+ sample = next(data_iter_train)
185
+ except StopIteration:
186
+ data_iter_train = iter(self.dataloaders["train"])
187
+ sample = next(data_iter_train)
188
+
189
+ with torch.no_grad():
190
+ input_ids = sample['input_ids']
191
+ labels_tokens = sample['labels_tokens']
192
+ labels_spans = sample['labels_spans']
193
+
194
+ outputs = self.model(input_ids.to(self.model.device), output_hidden_states=True, output_attentions=True)
195
+
196
+ hidden_states = {}
197
+ for layer in self.config.layers:
198
+ hidden_states[layer] = outputs.hidden_states[layer].to(self.classifier_device)
199
+
200
+ # get attentions and labels
201
+ attentions = torch.stack(outputs.attentions).swapaxes(0,1)
202
+ attentions = attentions.reshape(attentions.size(0), -1, attentions.size(-2), attentions.size(-1)).permute(0, 2, 3, 1)
203
+ attentions = torch.masked_select(attentions, sample['mask_spans'].to(self.classifier_device)).view(-1, attentions.size(-1))
204
+
205
+
206
+ # training step for each classifier
207
+ for item in self.span_classifiers + self.token_classifiers:
208
+ if item['config']['type'] == "span_classifier":
209
+ _preds = item['classifier'](attentions.to(item['classifier'].fc1.weight.dtype).to(self.classifier_device))
210
+ _labels = labels_spans.to(self.classifier_device)
211
+ elif item['config']['type'] == "token_classifier":
212
+ _preds = item['classifier'](hidden_states[item['config']['layer']].to(item['classifier'].fc1.weight.dtype))
213
+ _preds = torch.masked_select(_preds, sample['mask_tokens'].to(self.classifier_device))
214
+ _labels = labels_tokens.to(self.classifier_device)
215
+
216
+ loss = item['criterion'](_preds.view(-1, len(item['config']['label_map'])), _labels.view(-1))
217
+ item['metric'].update(_preds.view(-1, len(item['config']['label_map'])), _labels.view(-1))
218
+
219
+ item['optimizer'].zero_grad(set_to_none=True)
220
+ loss.backward()
221
+ item['optimizer'].step()
222
+ item['lr_scheduler'].step()
223
+
224
+ hidden_states = {}
225
+ attentions = None
226
+
227
+ # EVAL
228
+ for item in self.span_classifiers + self.token_classifiers:
229
+ item['classifier'].eval()
230
+ item['metric'].reset()
231
+
232
+ with torch.no_grad():
233
+
234
+ for sample in tqdm(self.dataloaders["validation"]):
235
+ input_ids = sample['input_ids']
236
+ labels_tokens = sample['labels_tokens']
237
+ labels_spans = sample['labels_spans']
238
+
239
+ # language model forward pass
240
+ outputs = self.model(input_ids.to(self.model.device), output_hidden_states=True, output_attentions=True)
241
+
242
+ # get internal representations into correct shapes
243
+ for layer in self.config.layers:
244
+ hidden_states[layer] = outputs.hidden_states[layer].to(self.classifier_device)
245
+ attentions = torch.stack(outputs.attentions).swapaxes(0,1)
246
+ attentions = attentions.reshape(attentions.size(0), -1, attentions.size(-2), attentions.size(-1)).permute(0, 2, 3, 1)
247
+ attentions = torch.masked_select(attentions, sample['mask_spans'].to(self.classifier_device)).view(-1, attentions.size(-1))
248
+
249
+ # classifier inference
250
+ for item in self.span_classifiers + self.token_classifiers:
251
+ if item['config']['type'] == "span_classifier":
252
+ _preds = item['classifier'](attentions.to(item['classifier'].fc1.weight.dtype).to(self.classifier_device))
253
+ _labels = labels_spans.to(self.classifier_device)
254
+ elif item['config']['type'] == "token_classifier":
255
+ _preds = item['classifier'](hidden_states[item['config']['layer']].to(item['classifier'].fc1.weight.dtype))
256
+ _preds = torch.masked_select(_preds, sample['mask_tokens'].to(self.classifier_device))
257
+ _labels = labels_tokens.to(self.classifier_device)
258
+ item['metric'].update(_preds.view(-1, len(item['config']['label_map'])), _labels.view(-1))
259
+
260
+ # logging and saving of checkpoints
261
+ for item in self.span_classifiers + self.token_classifiers:
262
+ (p_micro, r_micro, f_micro), classwise = print_metric(item['metric'], item['config']['label_map'], return_classwise=True, verbose=False)
263
+ if f_micro > item['config']['best_f1_validation']:
264
+ item['config']['best_f1_validation'] = f_micro
265
+ item['config']['best_f1_validation_classwise'] = classwise
266
+
267
+ ckp_path = os.path.join(self.config.checkpoint_path, f"{item['config']['type']}/{item['config']['identifier']}/")
268
+ os.makedirs(ckp_path, exist_ok=True)
269
+ torch.save(item['classifier'].state_dict(), os.path.join(ckp_path, f"checkpoint.pt"))
270
+ json.dump(item['config'], open(os.path.join(ckp_path, f"config.json"), "w"), indent=1)
271
+ json.dump(item['config_train'].to_dict(), open(os.path.join(ckp_path, f"config_train.json"), "w"), indent=1)
272
+
273
+ if f_micro > self.best_f1[item['config']['type']]:
274
+ self.best_f1[item['config']['type']] = f_micro
275
+ self.best_config[item['config']['type']] = item['config']
276
+
277
+ # print current best for each classifier type
278
+ for key in self.best_config.keys():
279
+ print(f"--- Best {key} config after epoch {epoch+1} ---")
280
+ if self.best_config[key] is not None:
281
+ for key, value in self.best_config[key].items():
282
+ print(key, value)
283
+ return self.best_config
284
+
stoke/src/trainer/util.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+
5
+
6
+ class TrainConfig:
7
+
8
+ def __init__(self, path, splits=['train', 'validation'],
9
+ layers=[9, 10, 11], hfcache='', classifier_dims=[4096], learning_rates=[1e-4],
10
+ cuda=False, n_steps_per_epoch=1000, n_epochs=2, batch_size=8, balance_loss=False,
11
+ loss_weights_span=[[1.0, 1.0], [1.0, 50.0], [1.0, 100.0]]):
12
+ self.path = path
13
+ self.checkpoint_path = os.path.join(self.path, "checkpoints/")
14
+ self.splits = splits
15
+ self.layers = layers
16
+ self.hfcache = hfcache
17
+ self.classifier_dims = classifier_dims
18
+ self.learning_rates = learning_rates
19
+ self.cuda = cuda
20
+ self.n_steps_per_epoch = n_steps_per_epoch
21
+ self.n_epochs = n_epochs
22
+ self.batch_size = batch_size
23
+ self.balance_loss = balance_loss
24
+ self.loss_weights_span = loss_weights_span
25
+ self.time = time.time()
26
+ self.config_dataset = json.load(open(os.path.join(path, f"config.json"), "r"))
27
+
28
+ def to_dict(self):
29
+ return {
30
+ "path": self.path,
31
+ "splits": self.splits,
32
+ "layers": self.layers,
33
+ "hfcache": self.hfcache,
34
+ "classifier_dims": self.classifier_dims,
35
+ "learning_rates": self.learning_rates,
36
+ "cuda": self.cuda,
37
+ "n_steps_per_epoch": self.n_steps_per_epoch,
38
+ "n_epochs": self.n_epochs,
39
+ "batch_size": self.batch_size,
40
+ "balance_loss": self.balance_loss,
41
+ "loss_weights_span": self.loss_weights_span,
42
+ "time": self.time,
43
+ "config_dataset": self.config_dataset
44
+ }
stoke/tests/__init__.py ADDED
File without changes