File size: 29,018 Bytes
1122de1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
import torch
from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel
from torch import nn
from itertools import chain
from torch.nn import MSELoss, CrossEntropyLoss
from cleantext import clean
from num2words import num2words
import re
import string
import pandas as pd
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import json
import tqdm
from transformers import GPT2Tokenizer
from openai import OpenAI
import os
from difflib import SequenceMatcher
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
from sentence_transformers import SentenceTransformer, util

# Load a pre-trained model
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')


device = "cuda" if torch.cuda.is_available() else "cpu"

punct_chars = list((set(string.punctuation) | {'’', '‘', '–', '—', '~', '|', '“', '”', '…', "'", "`", '_'}))
punct_chars.sort()
punctuation = ''.join(punct_chars)
replace = re.compile('[%s]' % re.escape(punctuation))

def get_num_words(text):
    if not isinstance(text, str):
        print("%s is not a string" % text)
    text = replace.sub(' ', text)
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    text = re.sub(r'\[.+\]', " ", text)
    return len(text.split())

def number_to_words(num):
    try:
        return num2words(re.sub(",", "", num))
    except:
        return num


clean_str = lambda s: clean(s,
                            fix_unicode=True,  # fix various unicode errors
                            to_ascii=True,  # transliterate to closest ASCII representation
                            lower=True,  # lowercase text
                            no_line_breaks=True,  # fully strip line breaks as opposed to only normalizing them
                            no_urls=True,  # replace all URLs with a special token
                            no_emails=True,  # replace all email addresses with a special token
                            no_phone_numbers=True,  # replace all phone numbers with a special token
                            no_numbers=True,  # replace all numbers with a special token
                            no_digits=False,  # replace all digits with a special token
                            no_currency_symbols=False,  # replace all currency symbols with a special token
                            no_punct=False,  # fully remove punctuation
                            replace_with_url="<URL>",
                            replace_with_email="<EMAIL>",
                            replace_with_phone_number="<PHONE>",
                            replace_with_number=lambda m: number_to_words(m.group()),
                            replace_with_digit="0",
                            replace_with_currency_symbol="<CUR>",
                            lang="en"
                            )

clean_str_nopunct = lambda s: clean(s,
                            fix_unicode=True,  # fix various unicode errors
                            to_ascii=True,  # transliterate to closest ASCII representation
                            lower=True,  # lowercase text
                            no_line_breaks=True,  # fully strip line breaks as opposed to only normalizing them
                            no_urls=True,  # replace all URLs with a special token
                            no_emails=True,  # replace all email addresses with a special token
                            no_phone_numbers=True,  # replace all phone numbers with a special token
                            no_numbers=True,  # replace all numbers with a special token
                            no_digits=False,  # replace all digits with a special token
                            no_currency_symbols=False,  # replace all currency symbols with a special token
                            no_punct=True,  # fully remove punctuation
                            replace_with_url="<URL>",
                            replace_with_email="<EMAIL>",
                            replace_with_phone_number="<PHONE>",
                            replace_with_number=lambda m: number_to_words(m.group()),
                            replace_with_digit="0",
                            replace_with_currency_symbol="<CUR>",
                            lang="en"
                            )



class MultiHeadModel(BertPreTrainedModel):
  """Pre-trained BERT model that uses our loss functions"""

  def __init__(self, config, head2size):
    super(MultiHeadModel, self).__init__(config, head2size)
    config.num_labels = 1
    self.bert = BertModel(config)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    module_dict = {}
    for head_name, num_labels in head2size.items():
      module_dict[head_name] = nn.Linear(config.hidden_size, num_labels)
    self.heads = nn.ModuleDict(module_dict)

    self.init_weights()

  def forward(self, input_ids, token_type_ids=None, attention_mask=None,
              head2labels=None, return_pooler_output=False, head2mask=None,
              nsp_loss_weights=None):

    # Get logits
    output = self.bert(
      input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
      output_attentions=False, output_hidden_states=False, return_dict=True)
    pooled_output = self.dropout(output["pooler_output"]).to(device)

    head2logits = {}
    return_dict = {}
    for head_name, head in self.heads.items():
      head2logits[head_name] = self.heads[head_name](pooled_output)
      head2logits[head_name] = head2logits[head_name].float()
      return_dict[head_name + "_logits"] = head2logits[head_name]


    if head2labels is not None:
      for head_name, labels in head2labels.items():
        num_classes = head2logits[head_name].shape[1]

        # Regression (e.g. for politeness)
        if num_classes == 1:

          # Only consider positive examples
          if head2mask is not None and head_name in head2mask:
            num_positives = head2labels[head2mask[head_name]].sum()  # use certain labels as mask
            if num_positives == 0:
              return_dict[head_name + "_loss"] = torch.tensor([0]).to(device)
            else:
              loss_fct = MSELoss(reduction='none')
              loss = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1))
              return_dict[head_name + "_loss"] = loss.dot(head2labels[head2mask[head_name]].float().view(-1)) / num_positives
          else:
            loss_fct = MSELoss()
            return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1))
        else:
          loss_fct = CrossEntropyLoss(weight=nsp_loss_weights.float())
          return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name], labels.view(-1))


    if return_pooler_output:
      return_dict["pooler_output"] = output["pooler_output"]

    return return_dict

class InputBuilder(object):
  """Base class for building inputs from segments."""

  def __init__(self, tokenizer):
      self.tokenizer = tokenizer
      self.mask = [tokenizer.mask_token_id]

  def build_inputs(self, history, reply, max_length):
      raise NotImplementedError

  def mask_seq(self, sequence, seq_id):
      sequence[seq_id] = self.mask
      return sequence

  @classmethod
  def _combine_sequence(self, history, reply, max_length, flipped=False):
      # Trim all inputs to max_length
      history = [s[:max_length] for s in history]
      reply = reply[:max_length]
      if flipped:
          return [reply] + history
      return history + [reply]


class BertInputBuilder(InputBuilder):
  """Processor for BERT inputs"""

  def __init__(self, tokenizer):
      InputBuilder.__init__(self, tokenizer)
      self.cls = [tokenizer.cls_token_id]
      self.sep = [tokenizer.sep_token_id]
      self.model_inputs = ["input_ids", "token_type_ids", "attention_mask"]
      self.padded_inputs = ["input_ids", "token_type_ids"]
      self.flipped = False


  def build_inputs(self, history, reply, max_length, input_str=True):
    """See base class."""
    if input_str:
        history = [self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(t)) for t in history]
        reply = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(reply))
    sequence = self._combine_sequence(history, reply, max_length, self.flipped)
    sequence = [s + self.sep for s in sequence]
    sequence[0] = self.cls + sequence[0]

    instance = {}
    instance["input_ids"] = list(chain(*sequence))
    last_speaker = 0
    other_speaker = 1
    seq_length = len(sequence)
    instance["token_type_ids"] = [last_speaker if ((seq_length - i) % 2 == 1) else other_speaker
                                  for i, s in enumerate(sequence) for _ in s]
    return instance
  
def preprocess_transcript_for_eliciting(transcript_json):
    transcript_df = pd.DataFrame(transcript_json)
    transcript_df.reset_index(drop=True, inplace=True)
    def break_into_sentences(text):
        return sent_tokenize(text)
    transcript_df['text'] = transcript_df['text'].apply(str)
    transcript_df['sentences'] = transcript_df['text'].apply(break_into_sentences)
    transcript_df.rename(columns={"startTimestamp": "starttime", "endTimestamp": "endtime"}, inplace=True)
    transcript_df.rename(columns={'is_chat?':'is_chat'}, inplace=True)

    def create_sentence_df(row):
        sentences = row['sentences']
        speaker = row['speaker']
        df = pd.DataFrame({'sentence':sentences})
        df['speaker'] = speaker
        df['userId'] = row['userId']
        df['session_uuid'] = row['session_uuid']
        df['starttime'] = row['starttime']
        df['endtime'] = row['endtime']
        df['is_chat'] = row['is_chat']
        df['speaker_#'] = row['speaker_#']
        return df

    sentence_df = pd.concat(transcript_df.apply(create_sentence_df, axis=1).values)
    sentence_df.reset_index(drop=True, inplace=True)

    sentence_df.dropna(inplace=True)
    sentence_df.rename(columns={'sentence':'text', 'userId':'uid'}, inplace=True)

    # sentence_df['prev_utt'] = None

    # prev_utt = None
    # for index, row in sentence_df.iterrows():
    #     # Check if the current speaker is a student
    #     if row['speaker'] != 'tutor':
    #         # Store the current utterance as the previous one for the next iteration
    #         prev_utt = row['text']
    #     else:
    #         # If the current speaker is the tutor, update 'prev_utt' in the DataFrame
    #         if prev_utt is not None and index > 0:
    #             sentence_df.at[index, 'prev_utt'] = prev_utt
    #             prev_utt = None

    # # drop rows where speaker_# is not tutor
    # sentence_df = sentence_df[sentence_df['speaker_#'] == 'tutor']

    # drop starttime, endtime, speaker_#, is_chat and session_uuid columns
    sentence_df.drop(columns=['speaker_#', 'is_chat', 'session_uuid'], inplace=True)

    session_json = sentence_df.to_json(orient='records')
    session_json = json.loads(session_json)

    return session_json   



def preprocess_raw_files(input_json, params):
    """
    Preprocesses raw json file and returns another json file

    Args:
        input_json (str): input json file

    Returns:
        _type_: output json file
    
    """
    # convert raw json to dataframe
    tutor_uuid = params['tutor_uuid']
    session_uuid = params['session_uuid']

    chat_transcript_df = convert_json_to_df(input_json, tutor_uuid, session_uuid)

    # aggregate by speaker
    aggregate_df = aggregate_by_speaker_id(chat_transcript_df)

    # convert to json
    aggregate_json = aggregate_df.to_json(orient='records')
    aggregate_json = json.loads(aggregate_json)

    return aggregate_json


def convert_json_to_df(input_json, tutor_uuid, session_uuid):
    """
    Extracts transcript and chat data from raw json file, assigns speaker and speaker_# columns, and returns a dataframe.
    The dataframe contains the following columns:
    - startTimestamp
    - endTimestamp
    - text
    - userId
    - is_chat?
    - speaker
    - speaker_#

    Args:
        input_json (str): input json file
        tutor_uuid (str): tutor uuid

    Returns:
        _type_: dataframe
    """
    data = input_json

    if data['transcript'] != []:
        transcript_df = pd.DataFrame(data['transcript'])
        transcript_df['is_chat?'] = 0
    else:
        raise ValueError("Transcript is empty")

    # transcribe chat data as well
    if data['chat'] != []:
        chat_df = pd.DataFrame(data['chat'])
        chat_df.rename(
            columns={'timestamp': 'startTimestamp'}, inplace=True)
        chat_df['endTimestamp'] = chat_df['startTimestamp']
        chat_df['is_chat?'] = 1
    else:
        chat_df = pd.DataFrame(columns=list(transcript_df))

    chat_transcript_df = pd.concat([chat_df, transcript_df], ignore_index=True).sort_values(
        by='startTimestamp', ascending=True)
    
    chat_transcript_df['session_uuid'] = session_uuid

    # Add speaker column
    count_non_chat = 0
    for i, row in chat_transcript_df.iterrows():
        if row['userId'] == tutor_uuid:
            chat_transcript_df.loc[i, 'speaker'] = 'tutor'
        elif row['userId'] is None:
            if i == 0: # first chat
                chat_transcript_df.loc[i, 'speaker'] = 'student' # this is a heuristic that may not be true
            elif count_non_chat == 0: # first non-chat
                chat_transcript_df.loc[i, 'speaker'] = 'tutor' # this is a heuristic that may not be true
            else:
                chat_transcript_df.loc[i, 'speaker'] = chat_transcript_df.loc[i-1, 'speaker'] # this is a heuristic that may not be true
        else:
            chat_transcript_df.loc[i, 'speaker'] = 'student'
        if row['is_chat?'] == 0:
            count_non_chat += 1

    # Add speaker_# column, iterate through rows and assign speaker_# based on speaker
    studentId2studentNum = {}
    count_non_chat = 0
    for i, row in chat_transcript_df.iterrows():
        if row ['speaker'] == 'tutor':
            chat_transcript_df.loc[i, 'speaker_#'] = 'tutor'
        elif row['userId'] is None:
            if i == 0: # first chat
                chat_transcript_df.loc[i, 'speaker_#'] = 'student1'
            elif count_non_chat == 0:
                chat_transcript_df.loc[i, 'speaker_#'] = 'tutor'
            else:
                chat_transcript_df.loc[i, 'speaker_#'] = chat_transcript_df.loc[i-1, 'speaker_#']
        else:
            if row['userId'] in studentId2studentNum:
                chat_transcript_df.loc[i, 'speaker_#'] = 'student' + str(studentId2studentNum[row['userId']])
            else:
                studentId2studentNum[row['userId']] = len(studentId2studentNum) + 1
                chat_transcript_df.loc[i, 'speaker_#'] = 'student' + str(studentId2studentNum[row['userId']])
        if row['is_chat?'] == 0:
            count_non_chat += 1
    
    return chat_transcript_df

def aggregate_by_speaker_id(data):
    aggregate_df = []
    speaker_id = None
    speaker = None
    aggregate_key_value = None
    enumerated_speaker = None
    is_chat = None
    session = None
    curr_text = ""
    curr_starttime = None
    curr_endtime = None

    for _, row in tqdm.tqdm(data.iterrows()):
        is_same_speaker_id = (row['speaker_#'] == aggregate_key_value) 
        is_same_type = (row['is_chat?'] == is_chat)

        if (is_same_type) and (is_same_speaker_id):
            # Concatenate text and update endtime
            if type(row['text']) == str:
                curr_text += " " + row['text']
            curr_endtime = row['endTimestamp']
        else: 
            # Append previous speaker's text to aggregate_df
            aggregate_df.append({
                "userId": speaker_id,
                "is_chat": is_chat,
                "session_uuid": session,
                "starttime": curr_starttime,
                "endtime": curr_endtime,
                "text": curr_text,
                "speaker": speaker,
                "speaker_#": enumerated_speaker
            })

            # Update speaker, is_chat, session, curr_text, curr_starttime, curr_endtime
            speaker_id = row['userId']
            is_chat = row['is_chat?']
            session = row['session_uuid']
            curr_text = row['text'] if type(row['text']) == str else ""
            curr_starttime = row['startTimestamp']
            curr_endtime = row['endTimestamp']
            speaker = row['speaker']
            enumerated_speaker = row['speaker_#']
            aggregate_key_value = row['speaker_#']

    # Append last speaker's text to aggregate_df if it hasn't been appended yet
    if aggregate_df[-1]['userId'] != speaker_id:
        aggregate_df.append({
            "userId": speaker_id,
            "is_chat": is_chat,
            "session_uuid": session,
            "starttime": curr_starttime,
            "endtime": curr_endtime,
            "text": curr_text,
            "speaker": speaker,
            "speaker_#": enumerated_speaker
        })

    aggregate_df = pd.DataFrame(aggregate_df[1:])
    return aggregate_df

    
def post_processing_output_json(transcript_json, session_id, session_type):
    """
    Post-processes the uptake and eliciting dataframes to ony include rows that satisfy certain conditions.

    Args:
        uptake_json (str): uptake json file
        eliciting_json (str): eliciting json file

    Returns:
        _type_: output json file
    """
    if session_type == "eliciting":
        eliciting_df = pd.DataFrame(transcript_json['utterances'])
        eliciting_df.rename(columns={"text": "utt"}, inplace=True)
        eliciting_df["session_uuid"] = session_id
        eliciting_df.drop(columns=["uid"], inplace=True)

        eliciting_df = eliciting_df[eliciting_df['speaker'] == 'tutor']

        # only take rows of eliciting_df that have utt longer than 5 words
        eliciting_df = eliciting_df[eliciting_df['utt'].str.split().str.len() > 5]

        # only take rows of eliciting_df that have question > 0.5
        eliciting_df = eliciting_df[eliciting_df['question'] > 0.5]

        # only take rows of eliciting_df that have eliciting = 1.0
        eliciting_df = eliciting_df[eliciting_df['eliciting'] == 1.0]
        eliciting_df['eliciting'] = eliciting_df['eliciting'].apply(lambda x: 1 if x == 1.0 else x)
        eliciting_df['eliciting'] = eliciting_df['eliciting'].astype('Int64')
        final_df = eliciting_df[["utt", "eliciting", "starttime", "endtime", "session_uuid"]]

    else:
        # convert uptake to dataframe
        uptake_df = pd.DataFrame(transcript_json['utterances'])
        uptake_df.rename(columns={"text": "utt"}, inplace=True)
        uptake_df.drop(columns=["uid", "userId", "is_chat", "speaker_#"], inplace=True)

        # only take rows of total_upatke_df that have utt longer than 5 words
        uptake_df = uptake_df[uptake_df['utt'].str.split().str.len() > 5]

        # only take rows of uptake_df that have question > 0.5
        uptake_df = uptake_df[uptake_df['question'] > 0.5]

        # only take rows of uptake_df that have uptake > 0.8
        uptake_df = uptake_df[uptake_df['uptake'] > 0.8]
        uptake_df['uptake'] = uptake_df['uptake'].apply(lambda x: 1 if x > 0.8 else x)
        uptake_df['uptake'] = uptake_df['uptake'].astype('Int64')
        final_df = uptake_df[["utt", "prev_utt", "uptake", "starttime", "endtime", "session_uuid"]]
        
    final_df = final_df.drop(columns=["session_uuid"]).copy()
    # convert to json
    final_output = final_df.to_json(orient='records')

    final_output = json.loads(final_output)

    return final_output

def compute_student_engagement(utterances):
    """
    Computes the number of students engaged in a session.

    Args:
        utterances json file

    Returns:
        _type_: int

    """
    # convert to dataframe
    utterances_df = pd.DataFrame(utterances)

    # only take rows of utterances_df that have speaker = student
    utterances_df = utterances_df[utterances_df['speaker'] == 'student']
    utterances_talk_df = utterances_df[utterances_df['is_chat'] == False]

    # calculate number of students engaged
    num_students_engaged = utterances_df['userId'].nunique()

    # calculate number of students engaged in talk
    num_students_engaged_talk = utterances_talk_df['userId'].nunique()

    return num_students_engaged, num_students_engaged_talk

def compute_talk_time(utterances):
    """
    Computes the talk time of a tutor in a session.

    Args:
        utterances json file

    Returns:
        _type_: float
    """
    # convert to dataframe
    utterances_df = pd.DataFrame(utterances)

    # Filter out nan text
    utterances_df = utterances_df[~utterances_df['text'].isna()]

    # Calculate token ratio spoken
    # Tokenize with GPT2 for talk
    num_tokens = utterances_df['text'].apply(lambda x: len(tokenizer.encode(x)))
    total_tokens = num_tokens.sum()

    # Calculate total tokens for tutor
    tutor_tokens = num_tokens[utterances_df['speaker'] == 'tutor'].sum()

    # Add spoken_token_tutor_pct to output_df
    if total_tokens == 0:
        return 0
    else:
        return tutor_tokens / total_tokens
    
def gpt4_filtering_selection(json_final_output, session_type, focus_concept):

    ELICITING_SYSTEM_PROMPT = """We want to extract the best moments of when a novice tutor asked questions that solicited learner ideas from looking at a copy of their session's transcript. 
    Please review the following list of utterances from the transcript, each separated by a double-slash. 
    Identify up to 3 utterances from the list that are the best examples of soliciting learner ideas, and if there are no examples then return “None”. 
    Ensure that the selected examples are a clear and complete question that would elicit learner engagement. 
    Prioritize questions that encourage students to reason out loud and elaborate on their problem-solving process, and avoid questions that may have single-word answer. 
    Return the selected examples in a json dictionary with the following format:
    {"model_outputs": [{"utt": "A1"}, {"utt": "A2"}, {"utt": "A3"}]}"""


    UPTAKE_SYSTEM_PROMPT = """We want to extract the best moments of when a novice tutor revoices and builds on learner ideas from looking at a copy of their session's transcript. 
    Effective building on students’ ideas looks like positive and encouraging uptake of their ideas, repeating back a previous statement, or affirming a student’s contribution. 
    Please review the following list of tuples in the form (A1 // B1) \n (A2 // B2) \n (A3 // B3)... where each tuple represents a pair of utterances from the transcript.
    The first element A in each tuple is the previous utterance from the student, and the second element B is the current utterance in response from the tutor. 
    The A and B items in each tuple are separated by a double-slash.
    Please return up to three of the provided tuples that are the best instances of a tutor revoicing a student’s ideas. 
    If there are no examples then return “None”. Please fix capitalization, punctuation, and blatant typos. 
    Return the selected examples in a json dictionary with the following format: 
    {"model_outputs": [{"prev_utt": "A1", "utt": "B1"}, {"prev_utt": "A2", "utt": "B2"}, {"prev_utt": "A3", "utt": "B3"}]}"""
    
    ELICITING_REASONING = """We want to extract the best moments of when a novice tutor prompts their students for reasoning from looking at a copy of their session's transcript. 
    Effective prompting for reasoning looks like questions containing “why” and “how”, prompting students for their thoughts and explanations beyond a simple answer, and asking problem-specific questions. 
    Please review the following list of utterances from the transcript, each separated by a double-slash. 
    Identify up to 3 utterances from the list that are the best examples of soliciting learner ideas, and if there are no examples then return “None”. 
    Ensure that the selected examples are a clear and complete question that would elicit learner engagement. 
    Prioritize questions that encourage students to reason out loud and elaborate on their problem-solving process, and avoid questions that may have single-word answer. 
    Return the selected examples in a json dictionary with the following format:
    {"model_outputs": [{"utt": "A1"}, {"utt": "A2"}, {"utt": "A3"}]}"""

    # breakpoint()
    if session_type == "eliciting":
        if focus_concept == "reasoning":
            system_prompt = ELICITING_REASONING
        else:
            system_prompt = ELICITING_SYSTEM_PROMPT
    else:
        system_prompt = UPTAKE_SYSTEM_PROMPT
    df = pd.DataFrame(json_final_output)
    client = OpenAI(
    # This is the default and can be omitted
        api_key="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",
    )

    if session_type == "eliciting":
            # clean text
            for i in range(len(df)):
                response = client.chat.completions.create(
                    model="gpt-4-0125-preview",
                    # response_format={ "type": "json_object" }, 
                    messages=[
                        {"role": "system", "content": "Clean the following text: \n"},
                        {"role": "user", "content": f"{df['utt'].iloc[i]}"}
                    ]
                )
                df.iloc[i, df.columns.get_loc('utt')] = response.choices[0].message.content

    # breakpoint()
    list_of_utterances = df['utt'].tolist()
    # expand the list of utterances into a string
    expanded_utterances = ' ; '.join(list_of_utterances)
    if session_type == "uptake":
        expanded_utterances = ""
        for i in range(len(df)):
            df.iloc[i, df.columns.get_loc('utt')] = ' '.join(df['utt'].iloc[i].split()[:100])+ "[...]"
            if len(df['prev_utt'].iloc[i].split()) > 100:
                df.iloc[i, df.columns.get_loc('prev_utt')] = "[...]" + ' '.join(df['prev_utt'].iloc[i].split()[-100:])
            expanded_utterances += f"({df['prev_utt'].iloc[i]} // {df['utt'].iloc[i]}) \n"
      

    if len(list_of_utterances) > 0:
        response = client.chat.completions.create(
            model="gpt-4-0125-preview",
            response_format={ "type": "json_object" }, 
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": f"{expanded_utterances}"}
            ]
        )
        # place back into the dataframe
        try:
            json_output = json.loads(response.choices[0].message.content)['model_outputs']
            chosen_utterances = [json_output[i]['utt'] for i in range(len(json_output))]
            if session_type == "uptake":
                chosen_prev_utterances = [json_output[i]['prev_utt'] for i in range(len(json_output))]
        except:
            print("Error on line 637 of utils.py")

        def similar(a, b):
            # Encode sentences to get their embeddings
            embeddings_a = sentence_model.encode(a, convert_to_tensor=True)
            embeddings_b = sentence_model.encode(b, convert_to_tensor=True)
            
            # Compute cosine similarity
            cosine_similarity = util.pytorch_cos_sim(embeddings_a, embeddings_b)
            
            return cosine_similarity.item()

        # find the index of the chosen utterances in the original list (regex to find the index, it does not have to be exact)
        indices = []
        for j, chosen_sentence in enumerate(chosen_utterances):
            best_match_index = -1
            highest_similarity = 0.0
            
            for i, initial_sentence in enumerate(list_of_utterances):
                similarity = similar(chosen_sentence, initial_sentence)
                if similarity > highest_similarity:
                    highest_similarity = similarity
                    best_match_index = i

            # replace the best match utterance with the chosen utterance in df
            df.iloc[best_match_index, df.columns.get_loc('utt')] = chosen_sentence
            if session_type == "uptake":
                df.iloc[best_match_index, df.columns.get_loc('prev_utt')] = chosen_prev_utterances[j]
            indices.append(best_match_index)

        # check that the indices are unique
        try:
            assert len(indices) == len(set(indices))
        except:
            # only take unique indices
            indices = list(set(indices))
            print("error on line 673 of utils.py")
        # if len(indices) != len(set(indices)):
        #     raise ValueError("Indices are not unique")

        # filter the dataframe to only include the chosen utterances
        df = df.iloc[indices]
        df.reset_index(drop=True, inplace=True)

    else:
        df = df

    # convert to json
    final_output = df.to_json(orient='records')
    final_output = json.loads(final_output)

    return final_output