YaHi commited on
Commit
1122de1
1 Parent(s): 5969089

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +697 -0
utils.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel
3
+ from torch import nn
4
+ from itertools import chain
5
+ from torch.nn import MSELoss, CrossEntropyLoss
6
+ from cleantext import clean
7
+ from num2words import num2words
8
+ import re
9
+ import string
10
+ import pandas as pd
11
+ import nltk
12
+ nltk.download('punkt')
13
+ from nltk.tokenize import sent_tokenize
14
+ import json
15
+ import tqdm
16
+ from transformers import GPT2Tokenizer
17
+ from openai import OpenAI
18
+ import os
19
+ from difflib import SequenceMatcher
20
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
21
+ from sentence_transformers import SentenceTransformer, util
22
+
23
+ # Load a pre-trained model
24
+ sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
25
+
26
+
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+ punct_chars = list((set(string.punctuation) | {'’', '‘', '–', '—', '~', '|', '“', '”', '…', "'", "`", '_'}))
30
+ punct_chars.sort()
31
+ punctuation = ''.join(punct_chars)
32
+ replace = re.compile('[%s]' % re.escape(punctuation))
33
+
34
+ def get_num_words(text):
35
+ if not isinstance(text, str):
36
+ print("%s is not a string" % text)
37
+ text = replace.sub(' ', text)
38
+ text = re.sub(r'\s+', ' ', text)
39
+ text = text.strip()
40
+ text = re.sub(r'\[.+\]', " ", text)
41
+ return len(text.split())
42
+
43
+ def number_to_words(num):
44
+ try:
45
+ return num2words(re.sub(",", "", num))
46
+ except:
47
+ return num
48
+
49
+
50
+ clean_str = lambda s: clean(s,
51
+ fix_unicode=True, # fix various unicode errors
52
+ to_ascii=True, # transliterate to closest ASCII representation
53
+ lower=True, # lowercase text
54
+ no_line_breaks=True, # fully strip line breaks as opposed to only normalizing them
55
+ no_urls=True, # replace all URLs with a special token
56
+ no_emails=True, # replace all email addresses with a special token
57
+ no_phone_numbers=True, # replace all phone numbers with a special token
58
+ no_numbers=True, # replace all numbers with a special token
59
+ no_digits=False, # replace all digits with a special token
60
+ no_currency_symbols=False, # replace all currency symbols with a special token
61
+ no_punct=False, # fully remove punctuation
62
+ replace_with_url="<URL>",
63
+ replace_with_email="<EMAIL>",
64
+ replace_with_phone_number="<PHONE>",
65
+ replace_with_number=lambda m: number_to_words(m.group()),
66
+ replace_with_digit="0",
67
+ replace_with_currency_symbol="<CUR>",
68
+ lang="en"
69
+ )
70
+
71
+ clean_str_nopunct = lambda s: clean(s,
72
+ fix_unicode=True, # fix various unicode errors
73
+ to_ascii=True, # transliterate to closest ASCII representation
74
+ lower=True, # lowercase text
75
+ no_line_breaks=True, # fully strip line breaks as opposed to only normalizing them
76
+ no_urls=True, # replace all URLs with a special token
77
+ no_emails=True, # replace all email addresses with a special token
78
+ no_phone_numbers=True, # replace all phone numbers with a special token
79
+ no_numbers=True, # replace all numbers with a special token
80
+ no_digits=False, # replace all digits with a special token
81
+ no_currency_symbols=False, # replace all currency symbols with a special token
82
+ no_punct=True, # fully remove punctuation
83
+ replace_with_url="<URL>",
84
+ replace_with_email="<EMAIL>",
85
+ replace_with_phone_number="<PHONE>",
86
+ replace_with_number=lambda m: number_to_words(m.group()),
87
+ replace_with_digit="0",
88
+ replace_with_currency_symbol="<CUR>",
89
+ lang="en"
90
+ )
91
+
92
+
93
+
94
+ class MultiHeadModel(BertPreTrainedModel):
95
+ """Pre-trained BERT model that uses our loss functions"""
96
+
97
+ def __init__(self, config, head2size):
98
+ super(MultiHeadModel, self).__init__(config, head2size)
99
+ config.num_labels = 1
100
+ self.bert = BertModel(config)
101
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
102
+ module_dict = {}
103
+ for head_name, num_labels in head2size.items():
104
+ module_dict[head_name] = nn.Linear(config.hidden_size, num_labels)
105
+ self.heads = nn.ModuleDict(module_dict)
106
+
107
+ self.init_weights()
108
+
109
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None,
110
+ head2labels=None, return_pooler_output=False, head2mask=None,
111
+ nsp_loss_weights=None):
112
+
113
+ # Get logits
114
+ output = self.bert(
115
+ input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
116
+ output_attentions=False, output_hidden_states=False, return_dict=True)
117
+ pooled_output = self.dropout(output["pooler_output"]).to(device)
118
+
119
+ head2logits = {}
120
+ return_dict = {}
121
+ for head_name, head in self.heads.items():
122
+ head2logits[head_name] = self.heads[head_name](pooled_output)
123
+ head2logits[head_name] = head2logits[head_name].float()
124
+ return_dict[head_name + "_logits"] = head2logits[head_name]
125
+
126
+
127
+ if head2labels is not None:
128
+ for head_name, labels in head2labels.items():
129
+ num_classes = head2logits[head_name].shape[1]
130
+
131
+ # Regression (e.g. for politeness)
132
+ if num_classes == 1:
133
+
134
+ # Only consider positive examples
135
+ if head2mask is not None and head_name in head2mask:
136
+ num_positives = head2labels[head2mask[head_name]].sum() # use certain labels as mask
137
+ if num_positives == 0:
138
+ return_dict[head_name + "_loss"] = torch.tensor([0]).to(device)
139
+ else:
140
+ loss_fct = MSELoss(reduction='none')
141
+ loss = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1))
142
+ return_dict[head_name + "_loss"] = loss.dot(head2labels[head2mask[head_name]].float().view(-1)) / num_positives
143
+ else:
144
+ loss_fct = MSELoss()
145
+ return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1))
146
+ else:
147
+ loss_fct = CrossEntropyLoss(weight=nsp_loss_weights.float())
148
+ return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name], labels.view(-1))
149
+
150
+
151
+ if return_pooler_output:
152
+ return_dict["pooler_output"] = output["pooler_output"]
153
+
154
+ return return_dict
155
+
156
+ class InputBuilder(object):
157
+ """Base class for building inputs from segments."""
158
+
159
+ def __init__(self, tokenizer):
160
+ self.tokenizer = tokenizer
161
+ self.mask = [tokenizer.mask_token_id]
162
+
163
+ def build_inputs(self, history, reply, max_length):
164
+ raise NotImplementedError
165
+
166
+ def mask_seq(self, sequence, seq_id):
167
+ sequence[seq_id] = self.mask
168
+ return sequence
169
+
170
+ @classmethod
171
+ def _combine_sequence(self, history, reply, max_length, flipped=False):
172
+ # Trim all inputs to max_length
173
+ history = [s[:max_length] for s in history]
174
+ reply = reply[:max_length]
175
+ if flipped:
176
+ return [reply] + history
177
+ return history + [reply]
178
+
179
+
180
+ class BertInputBuilder(InputBuilder):
181
+ """Processor for BERT inputs"""
182
+
183
+ def __init__(self, tokenizer):
184
+ InputBuilder.__init__(self, tokenizer)
185
+ self.cls = [tokenizer.cls_token_id]
186
+ self.sep = [tokenizer.sep_token_id]
187
+ self.model_inputs = ["input_ids", "token_type_ids", "attention_mask"]
188
+ self.padded_inputs = ["input_ids", "token_type_ids"]
189
+ self.flipped = False
190
+
191
+
192
+ def build_inputs(self, history, reply, max_length, input_str=True):
193
+ """See base class."""
194
+ if input_str:
195
+ history = [self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(t)) for t in history]
196
+ reply = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(reply))
197
+ sequence = self._combine_sequence(history, reply, max_length, self.flipped)
198
+ sequence = [s + self.sep for s in sequence]
199
+ sequence[0] = self.cls + sequence[0]
200
+
201
+ instance = {}
202
+ instance["input_ids"] = list(chain(*sequence))
203
+ last_speaker = 0
204
+ other_speaker = 1
205
+ seq_length = len(sequence)
206
+ instance["token_type_ids"] = [last_speaker if ((seq_length - i) % 2 == 1) else other_speaker
207
+ for i, s in enumerate(sequence) for _ in s]
208
+ return instance
209
+
210
+ def preprocess_transcript_for_eliciting(transcript_json):
211
+ transcript_df = pd.DataFrame(transcript_json)
212
+ transcript_df.reset_index(drop=True, inplace=True)
213
+ def break_into_sentences(text):
214
+ return sent_tokenize(text)
215
+ transcript_df['text'] = transcript_df['text'].apply(str)
216
+ transcript_df['sentences'] = transcript_df['text'].apply(break_into_sentences)
217
+ transcript_df.rename(columns={"startTimestamp": "starttime", "endTimestamp": "endtime"}, inplace=True)
218
+ transcript_df.rename(columns={'is_chat?':'is_chat'}, inplace=True)
219
+
220
+ def create_sentence_df(row):
221
+ sentences = row['sentences']
222
+ speaker = row['speaker']
223
+ df = pd.DataFrame({'sentence':sentences})
224
+ df['speaker'] = speaker
225
+ df['userId'] = row['userId']
226
+ df['session_uuid'] = row['session_uuid']
227
+ df['starttime'] = row['starttime']
228
+ df['endtime'] = row['endtime']
229
+ df['is_chat'] = row['is_chat']
230
+ df['speaker_#'] = row['speaker_#']
231
+ return df
232
+
233
+ sentence_df = pd.concat(transcript_df.apply(create_sentence_df, axis=1).values)
234
+ sentence_df.reset_index(drop=True, inplace=True)
235
+
236
+ sentence_df.dropna(inplace=True)
237
+ sentence_df.rename(columns={'sentence':'text', 'userId':'uid'}, inplace=True)
238
+
239
+ # sentence_df['prev_utt'] = None
240
+
241
+ # prev_utt = None
242
+ # for index, row in sentence_df.iterrows():
243
+ # # Check if the current speaker is a student
244
+ # if row['speaker'] != 'tutor':
245
+ # # Store the current utterance as the previous one for the next iteration
246
+ # prev_utt = row['text']
247
+ # else:
248
+ # # If the current speaker is the tutor, update 'prev_utt' in the DataFrame
249
+ # if prev_utt is not None and index > 0:
250
+ # sentence_df.at[index, 'prev_utt'] = prev_utt
251
+ # prev_utt = None
252
+
253
+ # # drop rows where speaker_# is not tutor
254
+ # sentence_df = sentence_df[sentence_df['speaker_#'] == 'tutor']
255
+
256
+ # drop starttime, endtime, speaker_#, is_chat and session_uuid columns
257
+ sentence_df.drop(columns=['speaker_#', 'is_chat', 'session_uuid'], inplace=True)
258
+
259
+ session_json = sentence_df.to_json(orient='records')
260
+ session_json = json.loads(session_json)
261
+
262
+ return session_json
263
+
264
+
265
+
266
+ def preprocess_raw_files(input_json, params):
267
+ """
268
+ Preprocesses raw json file and returns another json file
269
+
270
+ Args:
271
+ input_json (str): input json file
272
+
273
+ Returns:
274
+ _type_: output json file
275
+
276
+ """
277
+ # convert raw json to dataframe
278
+ tutor_uuid = params['tutor_uuid']
279
+ session_uuid = params['session_uuid']
280
+
281
+ chat_transcript_df = convert_json_to_df(input_json, tutor_uuid, session_uuid)
282
+
283
+ # aggregate by speaker
284
+ aggregate_df = aggregate_by_speaker_id(chat_transcript_df)
285
+
286
+ # convert to json
287
+ aggregate_json = aggregate_df.to_json(orient='records')
288
+ aggregate_json = json.loads(aggregate_json)
289
+
290
+ return aggregate_json
291
+
292
+
293
+ def convert_json_to_df(input_json, tutor_uuid, session_uuid):
294
+ """
295
+ Extracts transcript and chat data from raw json file, assigns speaker and speaker_# columns, and returns a dataframe.
296
+ The dataframe contains the following columns:
297
+ - startTimestamp
298
+ - endTimestamp
299
+ - text
300
+ - userId
301
+ - is_chat?
302
+ - speaker
303
+ - speaker_#
304
+
305
+ Args:
306
+ input_json (str): input json file
307
+ tutor_uuid (str): tutor uuid
308
+
309
+ Returns:
310
+ _type_: dataframe
311
+ """
312
+ data = input_json
313
+
314
+ if data['transcript'] != []:
315
+ transcript_df = pd.DataFrame(data['transcript'])
316
+ transcript_df['is_chat?'] = 0
317
+ else:
318
+ raise ValueError("Transcript is empty")
319
+
320
+ # transcribe chat data as well
321
+ if data['chat'] != []:
322
+ chat_df = pd.DataFrame(data['chat'])
323
+ chat_df.rename(
324
+ columns={'timestamp': 'startTimestamp'}, inplace=True)
325
+ chat_df['endTimestamp'] = chat_df['startTimestamp']
326
+ chat_df['is_chat?'] = 1
327
+ else:
328
+ chat_df = pd.DataFrame(columns=list(transcript_df))
329
+
330
+ chat_transcript_df = pd.concat([chat_df, transcript_df], ignore_index=True).sort_values(
331
+ by='startTimestamp', ascending=True)
332
+
333
+ chat_transcript_df['session_uuid'] = session_uuid
334
+
335
+ # Add speaker column
336
+ count_non_chat = 0
337
+ for i, row in chat_transcript_df.iterrows():
338
+ if row['userId'] == tutor_uuid:
339
+ chat_transcript_df.loc[i, 'speaker'] = 'tutor'
340
+ elif row['userId'] is None:
341
+ if i == 0: # first chat
342
+ chat_transcript_df.loc[i, 'speaker'] = 'student' # this is a heuristic that may not be true
343
+ elif count_non_chat == 0: # first non-chat
344
+ chat_transcript_df.loc[i, 'speaker'] = 'tutor' # this is a heuristic that may not be true
345
+ else:
346
+ chat_transcript_df.loc[i, 'speaker'] = chat_transcript_df.loc[i-1, 'speaker'] # this is a heuristic that may not be true
347
+ else:
348
+ chat_transcript_df.loc[i, 'speaker'] = 'student'
349
+ if row['is_chat?'] == 0:
350
+ count_non_chat += 1
351
+
352
+ # Add speaker_# column, iterate through rows and assign speaker_# based on speaker
353
+ studentId2studentNum = {}
354
+ count_non_chat = 0
355
+ for i, row in chat_transcript_df.iterrows():
356
+ if row ['speaker'] == 'tutor':
357
+ chat_transcript_df.loc[i, 'speaker_#'] = 'tutor'
358
+ elif row['userId'] is None:
359
+ if i == 0: # first chat
360
+ chat_transcript_df.loc[i, 'speaker_#'] = 'student1'
361
+ elif count_non_chat == 0:
362
+ chat_transcript_df.loc[i, 'speaker_#'] = 'tutor'
363
+ else:
364
+ chat_transcript_df.loc[i, 'speaker_#'] = chat_transcript_df.loc[i-1, 'speaker_#']
365
+ else:
366
+ if row['userId'] in studentId2studentNum:
367
+ chat_transcript_df.loc[i, 'speaker_#'] = 'student' + str(studentId2studentNum[row['userId']])
368
+ else:
369
+ studentId2studentNum[row['userId']] = len(studentId2studentNum) + 1
370
+ chat_transcript_df.loc[i, 'speaker_#'] = 'student' + str(studentId2studentNum[row['userId']])
371
+ if row['is_chat?'] == 0:
372
+ count_non_chat += 1
373
+
374
+ return chat_transcript_df
375
+
376
+ def aggregate_by_speaker_id(data):
377
+ aggregate_df = []
378
+ speaker_id = None
379
+ speaker = None
380
+ aggregate_key_value = None
381
+ enumerated_speaker = None
382
+ is_chat = None
383
+ session = None
384
+ curr_text = ""
385
+ curr_starttime = None
386
+ curr_endtime = None
387
+
388
+ for _, row in tqdm.tqdm(data.iterrows()):
389
+ is_same_speaker_id = (row['speaker_#'] == aggregate_key_value)
390
+ is_same_type = (row['is_chat?'] == is_chat)
391
+
392
+ if (is_same_type) and (is_same_speaker_id):
393
+ # Concatenate text and update endtime
394
+ if type(row['text']) == str:
395
+ curr_text += " " + row['text']
396
+ curr_endtime = row['endTimestamp']
397
+ else:
398
+ # Append previous speaker's text to aggregate_df
399
+ aggregate_df.append({
400
+ "userId": speaker_id,
401
+ "is_chat": is_chat,
402
+ "session_uuid": session,
403
+ "starttime": curr_starttime,
404
+ "endtime": curr_endtime,
405
+ "text": curr_text,
406
+ "speaker": speaker,
407
+ "speaker_#": enumerated_speaker
408
+ })
409
+
410
+ # Update speaker, is_chat, session, curr_text, curr_starttime, curr_endtime
411
+ speaker_id = row['userId']
412
+ is_chat = row['is_chat?']
413
+ session = row['session_uuid']
414
+ curr_text = row['text'] if type(row['text']) == str else ""
415
+ curr_starttime = row['startTimestamp']
416
+ curr_endtime = row['endTimestamp']
417
+ speaker = row['speaker']
418
+ enumerated_speaker = row['speaker_#']
419
+ aggregate_key_value = row['speaker_#']
420
+
421
+ # Append last speaker's text to aggregate_df if it hasn't been appended yet
422
+ if aggregate_df[-1]['userId'] != speaker_id:
423
+ aggregate_df.append({
424
+ "userId": speaker_id,
425
+ "is_chat": is_chat,
426
+ "session_uuid": session,
427
+ "starttime": curr_starttime,
428
+ "endtime": curr_endtime,
429
+ "text": curr_text,
430
+ "speaker": speaker,
431
+ "speaker_#": enumerated_speaker
432
+ })
433
+
434
+ aggregate_df = pd.DataFrame(aggregate_df[1:])
435
+ return aggregate_df
436
+
437
+
438
+ def post_processing_output_json(transcript_json, session_id, session_type):
439
+ """
440
+ Post-processes the uptake and eliciting dataframes to ony include rows that satisfy certain conditions.
441
+
442
+ Args:
443
+ uptake_json (str): uptake json file
444
+ eliciting_json (str): eliciting json file
445
+
446
+ Returns:
447
+ _type_: output json file
448
+ """
449
+ if session_type == "eliciting":
450
+ eliciting_df = pd.DataFrame(transcript_json['utterances'])
451
+ eliciting_df.rename(columns={"text": "utt"}, inplace=True)
452
+ eliciting_df["session_uuid"] = session_id
453
+ eliciting_df.drop(columns=["uid"], inplace=True)
454
+
455
+ eliciting_df = eliciting_df[eliciting_df['speaker'] == 'tutor']
456
+
457
+ # only take rows of eliciting_df that have utt longer than 5 words
458
+ eliciting_df = eliciting_df[eliciting_df['utt'].str.split().str.len() > 5]
459
+
460
+ # only take rows of eliciting_df that have question > 0.5
461
+ eliciting_df = eliciting_df[eliciting_df['question'] > 0.5]
462
+
463
+ # only take rows of eliciting_df that have eliciting = 1.0
464
+ eliciting_df = eliciting_df[eliciting_df['eliciting'] == 1.0]
465
+ eliciting_df['eliciting'] = eliciting_df['eliciting'].apply(lambda x: 1 if x == 1.0 else x)
466
+ eliciting_df['eliciting'] = eliciting_df['eliciting'].astype('Int64')
467
+ final_df = eliciting_df[["utt", "eliciting", "starttime", "endtime", "session_uuid"]]
468
+
469
+ else:
470
+ # convert uptake to dataframe
471
+ uptake_df = pd.DataFrame(transcript_json['utterances'])
472
+ uptake_df.rename(columns={"text": "utt"}, inplace=True)
473
+ uptake_df.drop(columns=["uid", "userId", "is_chat", "speaker_#"], inplace=True)
474
+
475
+ # only take rows of total_upatke_df that have utt longer than 5 words
476
+ uptake_df = uptake_df[uptake_df['utt'].str.split().str.len() > 5]
477
+
478
+ # only take rows of uptake_df that have question > 0.5
479
+ uptake_df = uptake_df[uptake_df['question'] > 0.5]
480
+
481
+ # only take rows of uptake_df that have uptake > 0.8
482
+ uptake_df = uptake_df[uptake_df['uptake'] > 0.8]
483
+ uptake_df['uptake'] = uptake_df['uptake'].apply(lambda x: 1 if x > 0.8 else x)
484
+ uptake_df['uptake'] = uptake_df['uptake'].astype('Int64')
485
+ final_df = uptake_df[["utt", "prev_utt", "uptake", "starttime", "endtime", "session_uuid"]]
486
+
487
+ final_df = final_df.drop(columns=["session_uuid"]).copy()
488
+ # convert to json
489
+ final_output = final_df.to_json(orient='records')
490
+
491
+ final_output = json.loads(final_output)
492
+
493
+ return final_output
494
+
495
+ def compute_student_engagement(utterances):
496
+ """
497
+ Computes the number of students engaged in a session.
498
+
499
+ Args:
500
+ utterances json file
501
+
502
+ Returns:
503
+ _type_: int
504
+
505
+ """
506
+ # convert to dataframe
507
+ utterances_df = pd.DataFrame(utterances)
508
+
509
+ # only take rows of utterances_df that have speaker = student
510
+ utterances_df = utterances_df[utterances_df['speaker'] == 'student']
511
+ utterances_talk_df = utterances_df[utterances_df['is_chat'] == False]
512
+
513
+ # calculate number of students engaged
514
+ num_students_engaged = utterances_df['userId'].nunique()
515
+
516
+ # calculate number of students engaged in talk
517
+ num_students_engaged_talk = utterances_talk_df['userId'].nunique()
518
+
519
+ return num_students_engaged, num_students_engaged_talk
520
+
521
+ def compute_talk_time(utterances):
522
+ """
523
+ Computes the talk time of a tutor in a session.
524
+
525
+ Args:
526
+ utterances json file
527
+
528
+ Returns:
529
+ _type_: float
530
+ """
531
+ # convert to dataframe
532
+ utterances_df = pd.DataFrame(utterances)
533
+
534
+ # Filter out nan text
535
+ utterances_df = utterances_df[~utterances_df['text'].isna()]
536
+
537
+ # Calculate token ratio spoken
538
+ # Tokenize with GPT2 for talk
539
+ num_tokens = utterances_df['text'].apply(lambda x: len(tokenizer.encode(x)))
540
+ total_tokens = num_tokens.sum()
541
+
542
+ # Calculate total tokens for tutor
543
+ tutor_tokens = num_tokens[utterances_df['speaker'] == 'tutor'].sum()
544
+
545
+ # Add spoken_token_tutor_pct to output_df
546
+ if total_tokens == 0:
547
+ return 0
548
+ else:
549
+ return tutor_tokens / total_tokens
550
+
551
+ def gpt4_filtering_selection(json_final_output, session_type, focus_concept):
552
+
553
+ ELICITING_SYSTEM_PROMPT = """We want to extract the best moments of when a novice tutor asked questions that solicited learner ideas from looking at a copy of their session's transcript.
554
+ Please review the following list of utterances from the transcript, each separated by a double-slash.
555
+ Identify up to 3 utterances from the list that are the best examples of soliciting learner ideas, and if there are no examples then return “None”.
556
+ Ensure that the selected examples are a clear and complete question that would elicit learner engagement.
557
+ Prioritize questions that encourage students to reason out loud and elaborate on their problem-solving process, and avoid questions that may have single-word answer.
558
+ Return the selected examples in a json dictionary with the following format:
559
+ {"model_outputs": [{"utt": "A1"}, {"utt": "A2"}, {"utt": "A3"}]}"""
560
+
561
+
562
+ UPTAKE_SYSTEM_PROMPT = """We want to extract the best moments of when a novice tutor revoices and builds on learner ideas from looking at a copy of their session's transcript.
563
+ Effective building on students’ ideas looks like positive and encouraging uptake of their ideas, repeating back a previous statement, or affirming a student’s contribution.
564
+ Please review the following list of tuples in the form (A1 // B1) \n (A2 // B2) \n (A3 // B3)... where each tuple represents a pair of utterances from the transcript.
565
+ The first element A in each tuple is the previous utterance from the student, and the second element B is the current utterance in response from the tutor.
566
+ The A and B items in each tuple are separated by a double-slash.
567
+ Please return up to three of the provided tuples that are the best instances of a tutor revoicing a student’s ideas.
568
+ If there are no examples then return “None”. Please fix capitalization, punctuation, and blatant typos.
569
+ Return the selected examples in a json dictionary with the following format:
570
+ {"model_outputs": [{"prev_utt": "A1", "utt": "B1"}, {"prev_utt": "A2", "utt": "B2"}, {"prev_utt": "A3", "utt": "B3"}]}"""
571
+
572
+ ELICITING_REASONING = """We want to extract the best moments of when a novice tutor prompts their students for reasoning from looking at a copy of their session's transcript.
573
+ Effective prompting for reasoning looks like questions containing “why” and “how”, prompting students for their thoughts and explanations beyond a simple answer, and asking problem-specific questions.
574
+ Please review the following list of utterances from the transcript, each separated by a double-slash.
575
+ Identify up to 3 utterances from the list that are the best examples of soliciting learner ideas, and if there are no examples then return “None”.
576
+ Ensure that the selected examples are a clear and complete question that would elicit learner engagement.
577
+ Prioritize questions that encourage students to reason out loud and elaborate on their problem-solving process, and avoid questions that may have single-word answer.
578
+ Return the selected examples in a json dictionary with the following format:
579
+ {"model_outputs": [{"utt": "A1"}, {"utt": "A2"}, {"utt": "A3"}]}"""
580
+
581
+ # breakpoint()
582
+ if session_type == "eliciting":
583
+ if focus_concept == "reasoning":
584
+ system_prompt = ELICITING_REASONING
585
+ else:
586
+ system_prompt = ELICITING_SYSTEM_PROMPT
587
+ else:
588
+ system_prompt = UPTAKE_SYSTEM_PROMPT
589
+ df = pd.DataFrame(json_final_output)
590
+ client = OpenAI(
591
+ # This is the default and can be omitted
592
+ api_key="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",
593
+ )
594
+
595
+ if session_type == "eliciting":
596
+ # clean text
597
+ for i in range(len(df)):
598
+ response = client.chat.completions.create(
599
+ model="gpt-4-0125-preview",
600
+ # response_format={ "type": "json_object" },
601
+ messages=[
602
+ {"role": "system", "content": "Clean the following text: \n"},
603
+ {"role": "user", "content": f"{df['utt'].iloc[i]}"}
604
+ ]
605
+ )
606
+ df.iloc[i, df.columns.get_loc('utt')] = response.choices[0].message.content
607
+
608
+ # breakpoint()
609
+ list_of_utterances = df['utt'].tolist()
610
+ # expand the list of utterances into a string
611
+ expanded_utterances = ' ; '.join(list_of_utterances)
612
+ if session_type == "uptake":
613
+ expanded_utterances = ""
614
+ for i in range(len(df)):
615
+ df.iloc[i, df.columns.get_loc('utt')] = ' '.join(df['utt'].iloc[i].split()[:100])+ "[...]"
616
+ if len(df['prev_utt'].iloc[i].split()) > 100:
617
+ df.iloc[i, df.columns.get_loc('prev_utt')] = "[...]" + ' '.join(df['prev_utt'].iloc[i].split()[-100:])
618
+ expanded_utterances += f"({df['prev_utt'].iloc[i]} // {df['utt'].iloc[i]}) \n"
619
+
620
+
621
+ if len(list_of_utterances) > 0:
622
+ response = client.chat.completions.create(
623
+ model="gpt-4-0125-preview",
624
+ response_format={ "type": "json_object" },
625
+ messages=[
626
+ {"role": "system", "content": system_prompt},
627
+ {"role": "user", "content": f"{expanded_utterances}"}
628
+ ]
629
+ )
630
+ # place back into the dataframe
631
+ try:
632
+ json_output = json.loads(response.choices[0].message.content)['model_outputs']
633
+ chosen_utterances = [json_output[i]['utt'] for i in range(len(json_output))]
634
+ if session_type == "uptake":
635
+ chosen_prev_utterances = [json_output[i]['prev_utt'] for i in range(len(json_output))]
636
+ except:
637
+ print("Error on line 637 of utils.py")
638
+
639
+ def similar(a, b):
640
+ # Encode sentences to get their embeddings
641
+ embeddings_a = sentence_model.encode(a, convert_to_tensor=True)
642
+ embeddings_b = sentence_model.encode(b, convert_to_tensor=True)
643
+
644
+ # Compute cosine similarity
645
+ cosine_similarity = util.pytorch_cos_sim(embeddings_a, embeddings_b)
646
+
647
+ return cosine_similarity.item()
648
+
649
+ # find the index of the chosen utterances in the original list (regex to find the index, it does not have to be exact)
650
+ indices = []
651
+ for j, chosen_sentence in enumerate(chosen_utterances):
652
+ best_match_index = -1
653
+ highest_similarity = 0.0
654
+
655
+ for i, initial_sentence in enumerate(list_of_utterances):
656
+ similarity = similar(chosen_sentence, initial_sentence)
657
+ if similarity > highest_similarity:
658
+ highest_similarity = similarity
659
+ best_match_index = i
660
+
661
+ # replace the best match utterance with the chosen utterance in df
662
+ df.iloc[best_match_index, df.columns.get_loc('utt')] = chosen_sentence
663
+ if session_type == "uptake":
664
+ df.iloc[best_match_index, df.columns.get_loc('prev_utt')] = chosen_prev_utterances[j]
665
+ indices.append(best_match_index)
666
+
667
+ # check that the indices are unique
668
+ try:
669
+ assert len(indices) == len(set(indices))
670
+ except:
671
+ # only take unique indices
672
+ indices = list(set(indices))
673
+ print("error on line 673 of utils.py")
674
+ # if len(indices) != len(set(indices)):
675
+ # raise ValueError("Indices are not unique")
676
+
677
+ # filter the dataframe to only include the chosen utterances
678
+ df = df.iloc[indices]
679
+ df.reset_index(drop=True, inplace=True)
680
+
681
+ else:
682
+ df = df
683
+
684
+ # convert to json
685
+ final_output = df.to_json(orient='records')
686
+ final_output = json.loads(final_output)
687
+
688
+ return final_output
689
+
690
+
691
+
692
+
693
+
694
+
695
+
696
+
697
+