hylee719 commited on
Commit
813a1db
·
1 Parent(s): 726870b

Upload 10 files

Browse files
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ---
2
+ license: gpl
3
+ ---
__pycache__/handler.cpython-310.pyc ADDED
Binary file (8.74 kB). View file
 
__pycache__/handler.cpython-39.pyc ADDED
Binary file (8.82 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (6.77 kB). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (6.53 kB). View file
 
handler.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from scipy.special import softmax
3
+ import numpy as np
4
+ import weakref
5
+
6
+ from utils import clean_str, clean_str_nopunct
7
+ import torch
8
+ from utils import MultiHeadModel, BertInputBuilder, get_num_words
9
+
10
+ import transformers
11
+ from transformers import BertTokenizer, BertForSequenceClassification
12
+
13
+
14
+ transformers.logging.set_verbosity_debug()
15
+
16
+ UPTAKE_MODEL = 'ddemszky/uptake-model'
17
+ REASONING_MODEL = 'ddemszky/student-reasoning'
18
+ QUESTION_MODEL = 'ddemszky/question-detection'
19
+
20
+
21
+ class Utterance:
22
+ def __init__(self, speaker, text, uid=None,
23
+ transcript=None, starttime=None, endtime=None, **kwargs):
24
+ self.speaker = speaker
25
+ self.text = text
26
+ self.uid = uid
27
+ self.starttime = starttime
28
+ self.endtime = endtime
29
+ self.transcript = weakref.ref(transcript) if transcript else None
30
+ self.props = kwargs
31
+
32
+ self.uptake = None
33
+ self.reasoning = None
34
+ self.question = None
35
+
36
+ def get_clean_text(self, remove_punct=False):
37
+ if remove_punct:
38
+ return clean_str_nopunct(self.text)
39
+ return clean_str(self.text)
40
+
41
+ def get_num_words(self):
42
+ return get_num_words(self.text)
43
+
44
+ def to_dict(self):
45
+ return {
46
+ 'speaker': self.speaker,
47
+ 'text': self.text,
48
+ 'uid': self.uid,
49
+ 'starttime': self.starttime,
50
+ 'endtime': self.endtime,
51
+ 'uptake': self.uptake,
52
+ 'reasoning': self.reasoning,
53
+ 'question': self.question,
54
+ **self.props
55
+ }
56
+
57
+ def __repr__(self):
58
+ return f"Utterance(speaker='{self.speaker}'," \
59
+ f"text='{self.text}', uid={self.uid}," \
60
+ f"starttime={self.starttime}, endtime={self.endtime}, props={self.props})"
61
+
62
+
63
+ class Transcript:
64
+ def __init__(self, **kwargs):
65
+ self.utterances = []
66
+ self.params = kwargs
67
+
68
+ def add_utterance(self, utterance):
69
+ utterance.transcript = weakref.ref(self)
70
+ self.utterances.append(utterance)
71
+
72
+ def get_idx(self, idx):
73
+ if idx >= len(self.utterances):
74
+ return None
75
+ return self.utterances[idx]
76
+
77
+ def get_uid(self, uid):
78
+ for utt in self.utterances:
79
+ if utt.uid == uid:
80
+ return utt
81
+ return None
82
+
83
+ def length(self):
84
+ return len(self.utterances)
85
+
86
+ def to_dict(self):
87
+ return {
88
+ 'utterances': [utterance.to_dict() for utterance in self.utterances],
89
+ **self.params
90
+ }
91
+
92
+ def __repr__(self):
93
+ return f"Transcript(utterances={self.utterances}, custom_params={self.params})"
94
+
95
+
96
+ class QuestionModel:
97
+ def __init__(self, device, tokenizer, input_builder, max_length=300, path=QUESTION_MODEL):
98
+ print("Loading models...")
99
+ self.device = device
100
+ self.tokenizer = tokenizer
101
+ self.input_builder = input_builder
102
+ self.max_length = max_length
103
+ self.model = MultiHeadModel.from_pretrained(
104
+ path, head2size={"is_question": 2})
105
+ self.model.to(self.device)
106
+
107
+ def run_inference(self, transcript):
108
+ self.model.eval()
109
+ with torch.no_grad():
110
+ for i, utt in enumerate(transcript.utterances):
111
+ if "?" in utt.text:
112
+ utt.question = 1
113
+ else:
114
+ text = utt.get_clean_text(remove_punct=True)
115
+ instance = self.input_builder.build_inputs([], text,
116
+ max_length=self.max_length,
117
+ input_str=True)
118
+ output = self.get_prediction(instance)
119
+ print(output)
120
+ utt.question = np.argmax(
121
+ output["is_question_logits"][0].tolist())
122
+
123
+ def get_prediction(self, instance):
124
+ instance["attention_mask"] = [[1] * len(instance["input_ids"])]
125
+ for key in ["input_ids", "token_type_ids", "attention_mask"]:
126
+ instance[key] = torch.tensor(
127
+ instance[key]).unsqueeze(0) # Batch size = 1
128
+ instance[key].to(self.device)
129
+
130
+ output = self.model(input_ids=instance["input_ids"],
131
+ attention_mask=instance["attention_mask"],
132
+ token_type_ids=instance["token_type_ids"],
133
+ return_pooler_output=False)
134
+ return output
135
+
136
+
137
+ class ReasoningModel:
138
+ def __init__(self, device, tokenizer, input_builder, max_length=128, path=REASONING_MODEL):
139
+ print("Loading models...")
140
+ self.device = device
141
+ self.tokenizer = tokenizer
142
+ self.input_builder = input_builder
143
+ self.max_length = max_length
144
+ self.model = BertForSequenceClassification.from_pretrained(path)
145
+ self.model.to(self.device)
146
+
147
+ def run_inference(self, transcript, min_num_words=8):
148
+ self.model.eval()
149
+ with torch.no_grad():
150
+ for i, utt in enumerate(transcript.utterances):
151
+ if utt.get_num_words() >= min_num_words:
152
+ instance = self.input_builder.build_inputs([], utt.text,
153
+ max_length=self.max_length,
154
+ input_str=True)
155
+ output = self.get_prediction(instance)
156
+ utt.reasoning = np.argmax(output["logits"][0].tolist())
157
+
158
+ def get_prediction(self, instance):
159
+ instance["attention_mask"] = [[1] * len(instance["input_ids"])]
160
+ for key in ["input_ids", "token_type_ids", "attention_mask"]:
161
+ instance[key] = torch.tensor(
162
+ instance[key]).unsqueeze(0) # Batch size = 1
163
+ instance[key].to(self.device)
164
+
165
+ output = self.model(input_ids=instance["input_ids"],
166
+ attention_mask=instance["attention_mask"],
167
+ token_type_ids=instance["token_type_ids"])
168
+ return output
169
+
170
+
171
+ class UptakeModel:
172
+ def __init__(self, device, tokenizer, input_builder, max_length=120, path=UPTAKE_MODEL):
173
+ print("Loading models...")
174
+ self.device = device
175
+ self.tokenizer = tokenizer
176
+ self.input_builder = input_builder
177
+ self.max_length = max_length
178
+ self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2})
179
+ self.model.to(self.device)
180
+
181
+ def run_inference(self, transcript, min_prev_words, uptake_speaker=None):
182
+ self.model.eval()
183
+ prev_num_words = 0
184
+ prev_utt = None
185
+ with torch.no_grad():
186
+ for i, utt in enumerate(transcript.utterances):
187
+ if ((uptake_speaker is None) or (utt.speaker == uptake_speaker)) and (prev_num_words >= min_prev_words):
188
+ textA = prev_utt.get_clean_text(remove_punct=False)
189
+ textB = utt.get_clean_text(remove_punct=False)
190
+ instance = self.input_builder.build_inputs([textA], textB,
191
+ max_length=self.max_length,
192
+ input_str=True)
193
+ output = self.get_prediction(instance)
194
+
195
+ utt.uptake = int(
196
+ softmax(output["nsp_logits"][0].tolist())[1] > .8)
197
+ prev_num_words = utt.get_num_words()
198
+ prev_utt = utt
199
+
200
+ def get_prediction(self, instance):
201
+ instance["attention_mask"] = [[1] * len(instance["input_ids"])]
202
+ for key in ["input_ids", "token_type_ids", "attention_mask"]:
203
+ instance[key] = torch.tensor(
204
+ instance[key]).unsqueeze(0) # Batch size = 1
205
+ instance[key].to(self.device)
206
+
207
+ output = self.model(input_ids=instance["input_ids"],
208
+ attention_mask=instance["attention_mask"],
209
+ token_type_ids=instance["token_type_ids"],
210
+ return_pooler_output=False)
211
+ return output
212
+
213
+
214
+ class EndpointHandler():
215
+ def __init__(self, path="."):
216
+ print("Loading models...")
217
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
218
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
219
+ self.input_builder = BertInputBuilder(tokenizer=self.tokenizer)
220
+
221
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
222
+ """
223
+ data args:
224
+ inputs (:obj: `list`):
225
+ List of dicts, where each dict represents an utterance; each utterance object must have a `speaker`,
226
+ `text` and `uid`and can include list of custom properties
227
+ parameters (:obj: `dict`)
228
+ Return:
229
+ A :obj:`list` | `dict`: will be serialized and returned
230
+ """
231
+ # get inputs
232
+ utterances = data.pop("inputs", data)
233
+ params = data.pop("parameters", None)
234
+
235
+ print("EXAMPLES")
236
+ for utt in utterances[:3]:
237
+ print("speaker %s: %s" % (utt["speaker"], utt["text"]))
238
+
239
+ transcript = Transcript(filename=params.pop("filename", None))
240
+ for utt in utterances:
241
+ transcript.add_utterance(Utterance(**utt))
242
+
243
+ print("Running inference on %d examples..." % transcript.length())
244
+
245
+ # Uptake
246
+ uptake_model = UptakeModel(
247
+ self.device, self.tokenizer, self.input_builder)
248
+ uptake_model.run_inference(transcript, min_prev_words=params['uptake_min_num_words'],
249
+ uptake_speaker=params.pop("uptake_speaker", None))
250
+ del uptake_model
251
+ # Reasoning
252
+ reasoning_model = ReasoningModel(
253
+ self.device, self.tokenizer, self.input_builder)
254
+ reasoning_model.run_inference(transcript)
255
+ del reasoning_model
256
+ # Question
257
+ question_model = QuestionModel(
258
+ self.device, self.tokenizer, self.input_builder)
259
+ question_model.run_inference(transcript)
260
+ del question_model
261
+ return transcript.to_dict()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ clean-text==0.6.0
2
+ num2words==0.5.10
3
+ numpy==1.22.4
4
+ scipy==1.7.3
5
+ torch==1.10.2
6
+ transformers==4.29.1
test.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from handler import EndpointHandler
3
+
4
+ # init handler
5
+ my_handler = EndpointHandler()
6
+
7
+ # prepare sample payload
8
+ example = {
9
+ "inputs": [
10
+ {"uid": "1", "speaker": "Alice", "text": "How much is the fish?" },
11
+ {"uid": "2", "speaker": "Bob", "text": "I do not know about the fish. Because you put a long side and it’s a long side. What do you think." },
12
+ {"uid": "3", "speaker": "Alice", "text": "OK, thank you Bob." }
13
+ ],
14
+ "parameters": {
15
+ "uptake_min_num_words": 5,
16
+ "uptake_speaker": "Bob",
17
+ "filename": "sample.csv"
18
+ }
19
+ }
20
+
21
+ # test the handler
22
+ print(my_handler(example))
utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel
3
+ from torch import nn
4
+ from itertools import chain
5
+ from torch.nn import MSELoss, CrossEntropyLoss
6
+ from cleantext import clean
7
+ from num2words import num2words
8
+ import re
9
+ import string
10
+
11
+ punct_chars = list((set(string.punctuation) | {'’', '‘', '–', '—', '~', '|', '“', '”', '…', "'", "`", '_'}))
12
+ punct_chars.sort()
13
+ punctuation = ''.join(punct_chars)
14
+ replace = re.compile('[%s]' % re.escape(punctuation))
15
+
16
+ def get_num_words(text):
17
+ if not isinstance(text, str):
18
+ print("%s is not a string" % text)
19
+ text = replace.sub(' ', text)
20
+ text = re.sub(r'\s+', ' ', text)
21
+ text = text.strip()
22
+ text = re.sub(r'\[.+\]', " ", text)
23
+ return len(text.split())
24
+
25
+ def number_to_words(num):
26
+ try:
27
+ return num2words(re.sub(",", "", num))
28
+ except:
29
+ return num
30
+
31
+
32
+ clean_str = lambda s: clean(s,
33
+ fix_unicode=True, # fix various unicode errors
34
+ to_ascii=True, # transliterate to closest ASCII representation
35
+ lower=True, # lowercase text
36
+ no_line_breaks=True, # fully strip line breaks as opposed to only normalizing them
37
+ no_urls=True, # replace all URLs with a special token
38
+ no_emails=True, # replace all email addresses with a special token
39
+ no_phone_numbers=True, # replace all phone numbers with a special token
40
+ no_numbers=True, # replace all numbers with a special token
41
+ no_digits=False, # replace all digits with a special token
42
+ no_currency_symbols=False, # replace all currency symbols with a special token
43
+ no_punct=False, # fully remove punctuation
44
+ replace_with_url="<URL>",
45
+ replace_with_email="<EMAIL>",
46
+ replace_with_phone_number="<PHONE>",
47
+ replace_with_number=lambda m: number_to_words(m.group()),
48
+ replace_with_digit="0",
49
+ replace_with_currency_symbol="<CUR>",
50
+ lang="en"
51
+ )
52
+
53
+ clean_str_nopunct = lambda s: clean(s,
54
+ fix_unicode=True, # fix various unicode errors
55
+ to_ascii=True, # transliterate to closest ASCII representation
56
+ lower=True, # lowercase text
57
+ no_line_breaks=True, # fully strip line breaks as opposed to only normalizing them
58
+ no_urls=True, # replace all URLs with a special token
59
+ no_emails=True, # replace all email addresses with a special token
60
+ no_phone_numbers=True, # replace all phone numbers with a special token
61
+ no_numbers=True, # replace all numbers with a special token
62
+ no_digits=False, # replace all digits with a special token
63
+ no_currency_symbols=False, # replace all currency symbols with a special token
64
+ no_punct=True, # fully remove punctuation
65
+ replace_with_url="<URL>",
66
+ replace_with_email="<EMAIL>",
67
+ replace_with_phone_number="<PHONE>",
68
+ replace_with_number=lambda m: number_to_words(m.group()),
69
+ replace_with_digit="0",
70
+ replace_with_currency_symbol="<CUR>",
71
+ lang="en"
72
+ )
73
+
74
+
75
+
76
+ class MultiHeadModel(BertPreTrainedModel):
77
+ """Pre-trained BERT model that uses our loss functions"""
78
+
79
+ def __init__(self, config, head2size):
80
+ super(MultiHeadModel, self).__init__(config, head2size)
81
+ config.num_labels = 1
82
+ self.bert = BertModel(config)
83
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
84
+ module_dict = {}
85
+ for head_name, num_labels in head2size.items():
86
+ module_dict[head_name] = nn.Linear(config.hidden_size, num_labels)
87
+ self.heads = nn.ModuleDict(module_dict)
88
+
89
+ self.init_weights()
90
+
91
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None,
92
+ head2labels=None, return_pooler_output=False, head2mask=None,
93
+ nsp_loss_weights=None):
94
+
95
+ device = "cuda" if torch.cuda.is_available() else "cpu"
96
+
97
+ # Get logits
98
+ output = self.bert(
99
+ input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
100
+ output_attentions=False, output_hidden_states=False, return_dict=True)
101
+ pooled_output = self.dropout(output["pooler_output"]).to(device)
102
+
103
+ head2logits = {}
104
+ return_dict = {}
105
+ for head_name, head in self.heads.items():
106
+ head2logits[head_name] = self.heads[head_name](pooled_output)
107
+ head2logits[head_name] = head2logits[head_name].float()
108
+ return_dict[head_name + "_logits"] = head2logits[head_name]
109
+
110
+
111
+ if head2labels is not None:
112
+ for head_name, labels in head2labels.items():
113
+ num_classes = head2logits[head_name].shape[1]
114
+
115
+ # Regression (e.g. for politeness)
116
+ if num_classes == 1:
117
+
118
+ # Only consider positive examples
119
+ if head2mask is not None and head_name in head2mask:
120
+ num_positives = head2labels[head2mask[head_name]].sum() # use certain labels as mask
121
+ if num_positives == 0:
122
+ return_dict[head_name + "_loss"] = torch.tensor([0]).to(device)
123
+ else:
124
+ loss_fct = MSELoss(reduction='none')
125
+ loss = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1))
126
+ return_dict[head_name + "_loss"] = loss.dot(head2labels[head2mask[head_name]].float().view(-1)) / num_positives
127
+ else:
128
+ loss_fct = MSELoss()
129
+ return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1))
130
+ else:
131
+ loss_fct = CrossEntropyLoss(weight=nsp_loss_weights.float())
132
+ return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name], labels.view(-1))
133
+
134
+
135
+ if return_pooler_output:
136
+ return_dict["pooler_output"] = output["pooler_output"]
137
+
138
+ return return_dict
139
+
140
+ class InputBuilder(object):
141
+ """Base class for building inputs from segments."""
142
+
143
+ def __init__(self, tokenizer):
144
+ self.tokenizer = tokenizer
145
+ self.mask = [tokenizer.mask_token_id]
146
+
147
+ def build_inputs(self, history, reply, max_length):
148
+ raise NotImplementedError
149
+
150
+ def mask_seq(self, sequence, seq_id):
151
+ sequence[seq_id] = self.mask
152
+ return sequence
153
+
154
+ @classmethod
155
+ def _combine_sequence(self, history, reply, max_length, flipped=False):
156
+ # Trim all inputs to max_length
157
+ history = [s[:max_length] for s in history]
158
+ reply = reply[:max_length]
159
+ if flipped:
160
+ return [reply] + history
161
+ return history + [reply]
162
+
163
+
164
+ class BertInputBuilder(InputBuilder):
165
+ """Processor for BERT inputs"""
166
+
167
+ def __init__(self, tokenizer):
168
+ InputBuilder.__init__(self, tokenizer)
169
+ self.cls = [tokenizer.cls_token_id]
170
+ self.sep = [tokenizer.sep_token_id]
171
+ self.model_inputs = ["input_ids", "token_type_ids", "attention_mask"]
172
+ self.padded_inputs = ["input_ids", "token_type_ids"]
173
+ self.flipped = False
174
+
175
+
176
+ def build_inputs(self, history, reply, max_length, input_str=True):
177
+ """See base class."""
178
+ if input_str:
179
+ history = [self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(t)) for t in history]
180
+ reply = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(reply))
181
+ sequence = self._combine_sequence(history, reply, max_length, self.flipped)
182
+ sequence = [s + self.sep for s in sequence]
183
+ sequence[0] = self.cls + sequence[0]
184
+
185
+ instance = {}
186
+ instance["input_ids"] = list(chain(*sequence))
187
+ last_speaker = 0
188
+ other_speaker = 1
189
+ seq_length = len(sequence)
190
+ instance["token_type_ids"] = [last_speaker if ((seq_length - i) % 2 == 1) else other_speaker
191
+ for i, s in enumerate(sequence) for _ in s]
192
+ return instance