YaHi commited on
Commit
c7828f5
1 Parent(s): fbb0b87

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +389 -0
handler.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from scipy.special import softmax
3
+ import numpy as np
4
+ import weakref
5
+ from utils import (
6
+ clean_str,
7
+ clean_str_nopunct,
8
+ MultiHeadModel,
9
+ BertInputBuilder,
10
+ get_num_words,
11
+ preprocess_transcript_for_eliciting,
12
+ preprocess_raw_files,
13
+ post_processing_output_json,
14
+ compute_student_engagement,
15
+ compute_talk_time,
16
+ gpt4_filtering_selection
17
+ )
18
+ import torch
19
+ from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoTokenizer
20
+
21
+ UPTAKE_MODEL='ddemszky/uptake-model'
22
+ QUESTION_MODEL ='ddemszky/question-detection'
23
+ ELICITING_MODEL = 'YaHi/teacher_electra_small'
24
+
25
+ class UptakeUtterance:
26
+ def __init__(self, speaker, text, uid=None,
27
+ transcript=None, starttime=None, endtime=None, **kwargs):
28
+ self.speaker = speaker
29
+ self.text = text
30
+ self.prev_utt = None
31
+ self.uid = uid
32
+ self.starttime = starttime
33
+ self.endtime = endtime
34
+ self.transcript = weakref.ref(transcript) if transcript else None
35
+ self.props = kwargs
36
+
37
+ self.uptake = None
38
+ self.question = None
39
+
40
+ def get_clean_text(self, remove_punct=False):
41
+ if remove_punct:
42
+ return clean_str_nopunct(self.text)
43
+ return clean_str(self.text)
44
+
45
+ def get_num_words(self):
46
+ if self.text is None:
47
+ return 0
48
+ return get_num_words(self.text)
49
+
50
+ def to_dict(self):
51
+ return {
52
+ 'speaker': self.speaker,
53
+ 'text': self.text,
54
+ 'prev_utt': self.prev_utt,
55
+ 'uid': self.uid,
56
+ 'starttime': self.starttime,
57
+ 'endtime': self.endtime,
58
+ 'uptake': self.uptake,
59
+ 'question': self.question,
60
+ **self.props
61
+ }
62
+
63
+ def __repr__(self):
64
+ return f"Utterance(speaker='{self.speaker}'," \
65
+ f"text='{self.text}', prev_utt='{self.prev_utt}', uid={self.uid}," \
66
+ f"starttime={self.starttime}, endtime={self.endtime}, props={self.props})"
67
+
68
+ class UptakeTranscript:
69
+ def __init__(self, **kwargs):
70
+ self.utterances = []
71
+ self.params = kwargs
72
+
73
+ def add_utterance(self, utterance):
74
+ utterance.transcript = weakref.ref(self)
75
+ self.utterances.append(utterance)
76
+
77
+ def get_idx(self, idx):
78
+ if idx >= len(self.utterances):
79
+ return None
80
+ return self.utterances[idx]
81
+
82
+ def get_uid(self, uid):
83
+ for utt in self.utterances:
84
+ if utt.uid == uid:
85
+ return utt
86
+ return None
87
+
88
+ def length(self):
89
+ return len(self.utterances)
90
+
91
+ def to_dict(self):
92
+ return {
93
+ 'utterances': [utterance.to_dict() for utterance in self.utterances],
94
+ **self.params
95
+ }
96
+
97
+ def __repr__(self):
98
+ return f"Transcript(utterances={self.utterances}, custom_params={self.params})"
99
+
100
+ class ElicitingUtterance:
101
+ def __init__(self, speaker, text, starttime, endtime, uid=None, transcript=None, prev_utt=None):
102
+ self.speaker = speaker
103
+ self.text = clean_str_nopunct(text)
104
+ self.uid = uid
105
+ self.transcript = transcript if transcript else None
106
+ self.prev_utt = prev_utt
107
+ self.eliciting = None
108
+ self.question = None
109
+ self.starttime = starttime
110
+ self.endtime = endtime
111
+
112
+ def __setitem__(self, key, value):
113
+ self.__dict__[key] = value
114
+
115
+ def get_clean_text(self, remove_punct=False):
116
+ if remove_punct:
117
+ return clean_str_nopunct(self.text)
118
+ return clean_str(self.text)
119
+
120
+ def to_dict(self):
121
+ return {
122
+ 'speaker': self.speaker,
123
+ 'text': self.text,
124
+ 'uid': self.uid,
125
+ 'prev_utt': self.prev_utt,
126
+ 'eliciting': self.eliciting,
127
+ 'question': self.question,
128
+ 'starttime': self.starttime,
129
+ 'endtime': self.endtime,
130
+ }
131
+
132
+
133
+ def __repr__(self):
134
+ return f"Utterance(speaker='{self.speaker}'," \
135
+ f"text='{self.text}', uid={self.uid}, prev_utt={self.prev_utt}, elicting={self.eliciting}, question={self.question}), starttime={self.starttime}, endtime={self.endtime})"
136
+
137
+ class ElicitingTranscript:
138
+ def __init__(self, utterances: List[ElicitingUtterance], tokenizer=None):
139
+ self.tokenizer = tokenizer
140
+ self.utterances = []
141
+ prev_utt = ""
142
+ prev_utt_teacher = ""
143
+ prev_speaker = None
144
+ for utterance in utterances:
145
+ try:
146
+ if 'student' in utterance["speaker"]:
147
+ utterance["speaker"] = 'student'
148
+ except:
149
+ continue
150
+ if (prev_speaker == 'tutor') and (utterance["speaker"] == 'student'):
151
+ utterance = ElicitingUtterance(**utterance, transcript=self, prev_utt=prev_utt.text)
152
+ elif (prev_speaker == 'student') and (utterance["speaker"] == 'tutor'):
153
+ utterance = ElicitingUtterance(**utterance, transcript=self, prev_utt=prev_utt.text)
154
+ prev_utt_teacher = utterance.text
155
+ elif (prev_speaker == 'student') and (utterance["speaker"] == 'student'):
156
+ try:
157
+ utterance = ElicitingUtterance(**utterance, transcript=self, prev_utt=prev_utt_teacher)
158
+ except:
159
+ print("Error on line 159 of handler.py")
160
+ print(utterance)
161
+ # breakpoint()
162
+ else:
163
+ utterance = ElicitingUtterance(**utterance, transcript=self, prev_utt="")
164
+ if utterance.speaker == 'tutor':
165
+ prev_utt_teacher = utterance.text
166
+ prev_utt = utterance
167
+ prev_speaker = utterance.speaker
168
+ self.utterances.append(utterance)
169
+
170
+ def __len__(self):
171
+ return len(self.utterances)
172
+
173
+ def __getitem__(self, index):
174
+ output = self.tokenizer([(self.utterances[index].prev_utt, self.utterances[index].text)], truncation=True)
175
+ output["speaker"] = self.utterances[index].speaker
176
+ output["uid"] = self.utterances[index].uid
177
+ output["prev_utt"] = self.utterances[index].prev_utt
178
+ output["text"] = self.utterances[index].text
179
+ return output
180
+
181
+ def to_dict(self):
182
+ return {
183
+ 'utterances': [utterance.to_dict() for utterance in self.utterances]
184
+ }
185
+
186
+ class QuestionModel:
187
+ def __init__(self, device, tokenizer, input_builder, max_length=300, path=QUESTION_MODEL):
188
+ print("Loading models...")
189
+ self.device = device
190
+ self.tokenizer = tokenizer
191
+ self.input_builder = input_builder
192
+ self.max_length = max_length
193
+ self.model = MultiHeadModel.from_pretrained(path, head2size={"is_question": 2})
194
+ self.model.to(self.device)
195
+
196
+
197
+ def run_inference(self, transcript):
198
+ self.model.eval()
199
+ with torch.no_grad():
200
+ for i, utt in enumerate(transcript.utterances):
201
+ if utt.text is None:
202
+ utt.question = None
203
+ continue
204
+ if "?" in utt.text:
205
+ utt.question = 1
206
+ else:
207
+ text = utt.get_clean_text(remove_punct=True)
208
+ instance = self.input_builder.build_inputs([], text,
209
+ max_length=self.max_length,
210
+ input_str=True)
211
+ output = self.get_prediction(instance)
212
+ utt.question = softmax(output["is_question_logits"][0].tolist())[1]
213
+
214
+ def get_prediction(self, instance):
215
+ instance["attention_mask"] = [[1] * len(instance["input_ids"])]
216
+ for key in ["input_ids", "token_type_ids", "attention_mask"]:
217
+ instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
218
+ instance[key].to(self.device)
219
+
220
+ output = self.model(input_ids=instance["input_ids"].to(self.device),
221
+ attention_mask=instance["attention_mask"].to(self.device),
222
+ token_type_ids=instance["token_type_ids"].to(self.device),
223
+ return_pooler_output=False)
224
+ return output
225
+
226
+ class UptakeModel:
227
+ def __init__(self, device, tokenizer, input_builder, max_length=120, path=UPTAKE_MODEL):
228
+ print("Loading models...")
229
+ self.device = device
230
+ self.tokenizer = tokenizer
231
+ self.input_builder = input_builder
232
+ self.max_length = max_length
233
+ self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2})
234
+ self.model.to(self.device)
235
+
236
+ def run_inference(self, transcript, min_prev_words, uptake_speaker=None):
237
+ self.model.eval()
238
+ prev_num_words = 0
239
+ prev_utt = None
240
+ with torch.no_grad():
241
+ for i, utt in enumerate(transcript.utterances):
242
+ if ((uptake_speaker is None) or (utt.speaker == uptake_speaker)) and (prev_num_words >= min_prev_words):
243
+ textA = prev_utt.get_clean_text(remove_punct=False)
244
+ textB = utt.get_clean_text(remove_punct=False)
245
+ instance = self.input_builder.build_inputs([textA], textB,
246
+ max_length=self.max_length,
247
+ input_str=True)
248
+ output = self.get_prediction(instance)
249
+
250
+ utt.uptake = softmax(output["nsp_logits"][0].tolist())[1]
251
+ utt.prev_utt = prev_utt.text
252
+ prev_num_words = utt.get_num_words()
253
+ prev_utt = utt
254
+
255
+ def get_prediction(self, instance):
256
+ instance["attention_mask"] = [[1] * len(instance["input_ids"])]
257
+ for key in ["input_ids", "token_type_ids", "attention_mask"]:
258
+ instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
259
+ instance[key].to(self.device)
260
+
261
+ output = self.model(input_ids=instance["input_ids"].to(self.device),
262
+ attention_mask=instance["attention_mask"].to(self.device),
263
+ token_type_ids=instance["token_type_ids"].to(self.device),
264
+ return_pooler_output=False)
265
+ return output
266
+
267
+ class ElicitingModel:
268
+ def __init__(self, device, tokenizer, path=ELICITING_MODEL):
269
+ print("Loading teacher models...")
270
+ self.device = device
271
+ self.tokenizer = tokenizer
272
+ self.model = AutoModelForSequenceClassification.from_pretrained(path).to(self.device)
273
+
274
+ def run_inference(self, dataset):
275
+ current_batch = 0
276
+ batch_size = 64
277
+
278
+ def generator():
279
+ while current_batch < len(dataset):
280
+ yield
281
+
282
+ for _ in generator():
283
+ # check if the remaining samples are less than the batch size
284
+ if len(dataset) - current_batch < batch_size:
285
+ batch_size = len(dataset) - current_batch
286
+
287
+ to_pad = [{"input_ids": example["input_ids"][0], "attention_mask": example["attention_mask"][0]} for example in dataset]
288
+ to_pad = to_pad[current_batch:current_batch + batch_size]
289
+ batch = self.tokenizer.pad(
290
+ to_pad,
291
+ padding=True,
292
+ max_length=None,
293
+ pad_to_multiple_of=None,
294
+ return_tensors="pt",
295
+ )
296
+ inputs = batch["input_ids"].to(self.device)
297
+ attention_mask = batch["attention_mask"].to(self.device)
298
+ with torch.no_grad():
299
+ outputs = self.model(inputs, attention_mask=attention_mask)
300
+ predictions = outputs.logits.argmax(dim=-1).cpu().numpy()
301
+
302
+ for i, prediction in enumerate(predictions):
303
+ if dataset.utterances[current_batch + i].speaker == 'tutor':
304
+ dataset.utterances[current_batch + i]["eliciting"] = prediction
305
+ current_batch += batch_size
306
+
307
+
308
+ class EndpointHandler():
309
+ def __init__(self, path="."):
310
+ print("Loading models...")
311
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
312
+
313
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
314
+ self.input_builder = BertInputBuilder(tokenizer=self.tokenizer)
315
+ self.uptake_model = UptakeModel(self.device, self.tokenizer, self.input_builder)
316
+ self.question_model = QuestionModel(self.device, self.tokenizer, self.input_builder)
317
+ self.eliciting_tokenizer = AutoTokenizer.from_pretrained(ELICITING_MODEL)
318
+ self.eliciting_model = ElicitingModel(self.device, self.tokenizer, path=ELICITING_MODEL)
319
+
320
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
321
+ """
322
+ data args:
323
+ inputs (:obj: `list`):
324
+ List of dicts, where each dict represents an utterance; each utterance object must have a `speaker`,
325
+ `text` and `uid`and can include list of custom properties
326
+ parameters (:obj: `dict`)
327
+ Return:
328
+ A :obj:`list` | `dict`: will be serialized and returned
329
+ """
330
+
331
+ # get inputs
332
+ utterances = data.pop("inputs", data)
333
+ params = data.pop("parameters", None) #TODO: make sure that it includes everything required
334
+
335
+ print(params["session_uuid"])
336
+
337
+ # pre-processing
338
+ utterances = preprocess_raw_files(utterances, params)
339
+
340
+ # compute student engagement and talk time metrics
341
+ num_students_engaged, num_students_engaged_talk_only = compute_student_engagement(utterances)
342
+ tutor_talk_time = compute_talk_time(utterances)
343
+
344
+ #TODO: make sure there is some routing going on here based on what session we are at
345
+ if params["session_type"] == "eliciting":
346
+ # pre-processing for eliciting
347
+ utterances_elicting = preprocess_transcript_for_eliciting(utterances)
348
+ eliciting_transcript = ElicitingTranscript(utterances_elicting, tokenizer=self.tokenizer)
349
+ self.eliciting_model.run_inference(eliciting_transcript)
350
+
351
+ # Question
352
+ self.question_model.run_inference(eliciting_transcript)
353
+
354
+ transcript_output = eliciting_transcript
355
+ else:
356
+ uptake_transcript = UptakeTranscript(filename=params.pop("filename", None))
357
+ for utt in utterances:
358
+ uptake_transcript.add_utterance(UptakeUtterance(**utt))
359
+
360
+ # Uptake
361
+ self.uptake_model.run_inference(uptake_transcript, min_prev_words=params['uptake_min_num_words'],
362
+ uptake_speaker=params.pop("uptake_speaker", None))
363
+
364
+ # Question
365
+ self.question_model.run_inference(uptake_transcript)
366
+ transcript_output = uptake_transcript
367
+
368
+ # post-processing
369
+ model_outputs = post_processing_output_json(transcript_output.to_dict(), params["session_uuid"], params["session_type"])
370
+
371
+ final_output = {}
372
+ final_output["metrics"] = {"num_students_engaged": num_students_engaged,
373
+ "num_students_engaged_talk_only": num_students_engaged_talk_only,
374
+ "tutor_talk_time": tutor_talk_time}
375
+
376
+ if len(model_outputs) > 0:
377
+ model_outputs = gpt4_filtering_selection(model_outputs, params["session_type"], params["focus_concept"])
378
+
379
+ final_output["model_outputs"] = model_outputs
380
+ final_output["event_id"] = params["event_id"]
381
+
382
+ import requests
383
+ webhooks_url = 'https://schoolhouse.world/api/webhooks/stanford-ai-feedback-highlights'
384
+ response = requests.post(webhooks_url, json=final_output)
385
+
386
+ print("Post request sent, here is the response: ", response)
387
+
388
+
389
+ return final_output