Ashlee Kupor
commited on
Commit
•
5cefadd
1
Parent(s):
04000c5
Add model
Browse files- config.json +28 -0
- eval_results.txt +12 -0
- handler.py +138 -0
- handler.py~ +27 -0
- merges.txt +0 -0
- model_args.json +1 -0
- pytorch_model.bin +3 -0
- requirements.txt +5 -0
- special_tokens_map.json +15 -0
- test_run_handler.py +13 -0
- tokenizer.json +0 -0
- tokenizer_config.json +16 -0
- training_args.bin +3 -0
- training_progress_scores.csv +8 -0
- vocab.json +0 -0
config.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "roberta-base",
|
3 |
+
"architectures": [
|
4 |
+
"RobertaForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"classifier_dropout": null,
|
9 |
+
"eos_token_id": 2,
|
10 |
+
"hidden_act": "gelu",
|
11 |
+
"hidden_dropout_prob": 0.1,
|
12 |
+
"hidden_size": 768,
|
13 |
+
"initializer_range": 0.02,
|
14 |
+
"intermediate_size": 3072,
|
15 |
+
"layer_norm_eps": 1e-05,
|
16 |
+
"max_position_embeddings": 514,
|
17 |
+
"model_type": "roberta",
|
18 |
+
"num_attention_heads": 12,
|
19 |
+
"num_hidden_layers": 12,
|
20 |
+
"pad_token_id": 1,
|
21 |
+
"position_embedding_type": "absolute",
|
22 |
+
"problem_type": "single_label_classification",
|
23 |
+
"torch_dtype": "float32",
|
24 |
+
"transformers_version": "4.28.0",
|
25 |
+
"type_vocab_size": 1,
|
26 |
+
"use_cache": true,
|
27 |
+
"vocab_size": 50265
|
28 |
+
}
|
eval_results.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accuracy = 0.9996004794246903
|
2 |
+
auprc = 0.9997711321330485
|
3 |
+
auroc = 0.9999888917688007
|
4 |
+
eval_loss = 0.002696692644352147
|
5 |
+
f1 = 0.9955555555555555
|
6 |
+
fn = 1
|
7 |
+
fp = 0
|
8 |
+
mcc = 0.9953571764069896
|
9 |
+
precision = 0.9911504424778761
|
10 |
+
recall = 1.0
|
11 |
+
tn = 2390
|
12 |
+
tp = 112
|
handler.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from simpletransformers.classification import ClassificationModel, ClassificationArgs
|
2 |
+
from typing import Dict, List, Any
|
3 |
+
import pandas as pd
|
4 |
+
import webvtt
|
5 |
+
from datetime import datetime
|
6 |
+
import torch
|
7 |
+
import spacy
|
8 |
+
|
9 |
+
nlp = spacy.load("en_core_web_sm")
|
10 |
+
tokenizer = nlp.tokenizer
|
11 |
+
token_limit = 200
|
12 |
+
|
13 |
+
class Utterance(object):
|
14 |
+
|
15 |
+
def __init__(self, starttime, endtime, speaker, text,
|
16 |
+
idx, prev_utterance, prev_prev_utterance):
|
17 |
+
self.starttime = starttime
|
18 |
+
self.endtime = endtime
|
19 |
+
self.speaker = speaker
|
20 |
+
self.text = text
|
21 |
+
self.idx = idx
|
22 |
+
self.prev = prev_utterance
|
23 |
+
self.prev_prev = prev_prev_utterance
|
24 |
+
|
25 |
+
class EndpointHandler():
|
26 |
+
def __init__(self, path="."):
|
27 |
+
print("Loading models...")
|
28 |
+
cuda_available = torch.cuda.is_available()
|
29 |
+
self.model = ClassificationModel(
|
30 |
+
"roberta", path, use_cuda=cuda_available
|
31 |
+
)
|
32 |
+
|
33 |
+
def utterance_to_str(self, utterance: Utterance) -> str:
|
34 |
+
# connecting only uses text
|
35 |
+
doc = nlp(utterance.text)
|
36 |
+
if len(doc) > token_limit:
|
37 |
+
return self.handle_long_utterances(doc)
|
38 |
+
return utterance.text
|
39 |
+
|
40 |
+
def handle_long_utterances(self, doc: str) -> List[str]:
|
41 |
+
split_count = 1
|
42 |
+
total_sent = len([x for x in doc.sents])
|
43 |
+
sent_count = 0
|
44 |
+
token_count = 0
|
45 |
+
split_utterance = ''
|
46 |
+
utterances = []
|
47 |
+
for sent in doc.sents:
|
48 |
+
# add a sentence to split
|
49 |
+
split_utterance = split_utterance + ' ' + sent.text
|
50 |
+
token_count += len(sent)
|
51 |
+
sent_count +=1
|
52 |
+
if token_count >= token_limit or sent_count == total_sent:
|
53 |
+
# save utterance segment
|
54 |
+
utterances.append(split_utterance)
|
55 |
+
|
56 |
+
# restart count
|
57 |
+
split_utterance = ''
|
58 |
+
token_count = 0
|
59 |
+
split_count += 1
|
60 |
+
|
61 |
+
return utterances
|
62 |
+
|
63 |
+
|
64 |
+
def convert_time(self, time_str):
|
65 |
+
time = datetime.strptime(time_str, "%H:%M:%S.%f")
|
66 |
+
return 1000 * (3600 * time.hour + 60 * time.minute + time.second) + time.microsecond / 1000
|
67 |
+
|
68 |
+
def process_vtt_transcript(self, vttfile) -> List[Utterance]:
|
69 |
+
"""Process raw vtt file."""
|
70 |
+
|
71 |
+
utterances_list = []
|
72 |
+
text = ""
|
73 |
+
prev_speaker = None
|
74 |
+
prev_start = "00:00:00.000"
|
75 |
+
prev_end = "00:00:00.000"
|
76 |
+
idx = 0
|
77 |
+
prev_utterance = None
|
78 |
+
prev_prev_utterance = None
|
79 |
+
for caption in webvtt.read(vttfile):
|
80 |
+
|
81 |
+
# Get speaker
|
82 |
+
check_for_speaker = caption.text.split(":")
|
83 |
+
if len(check_for_speaker) > 1: # the speaker was changed or restated
|
84 |
+
speaker = check_for_speaker[0]
|
85 |
+
else:
|
86 |
+
speaker = prev_speaker
|
87 |
+
|
88 |
+
# Get utterance
|
89 |
+
new_text = check_for_speaker[1] if len(check_for_speaker) > 1 else check_for_speaker[0]
|
90 |
+
|
91 |
+
# If speaker was changed, start new batch
|
92 |
+
if (prev_speaker is not None) and (speaker != prev_speaker):
|
93 |
+
utterance = Utterance(starttime=self.convert_time(prev_start),
|
94 |
+
endtime=self.convert_time(prev_end),
|
95 |
+
speaker=prev_speaker,
|
96 |
+
text=text.strip(),
|
97 |
+
idx=idx,
|
98 |
+
prev_utterance=prev_utterance,
|
99 |
+
prev_prev_utterance=prev_prev_utterance)
|
100 |
+
|
101 |
+
utterances_list.append(utterance)
|
102 |
+
|
103 |
+
# Start new batch
|
104 |
+
prev_start = caption.start
|
105 |
+
text = ""
|
106 |
+
prev_prev_utterance = prev_utterance
|
107 |
+
prev_utterance = utterance
|
108 |
+
idx+=1
|
109 |
+
text += new_text + " "
|
110 |
+
prev_end = caption.end
|
111 |
+
prev_speaker = speaker
|
112 |
+
|
113 |
+
# Append last one
|
114 |
+
if prev_speaker is not None:
|
115 |
+
utterance = Utterance(starttime=self.convert_time(prev_start),
|
116 |
+
endtime=self.convert_time(prev_end),
|
117 |
+
speaker=prev_speaker,
|
118 |
+
text=text.strip(),
|
119 |
+
idx=idx,
|
120 |
+
prev_utterance=prev_utterance,
|
121 |
+
prev_prev_utterance=prev_prev_utterance)
|
122 |
+
utterances_list.append(utterance)
|
123 |
+
|
124 |
+
print(utterances_list)
|
125 |
+
return utterances_list
|
126 |
+
|
127 |
+
|
128 |
+
def __call__(self, data_file: str) -> List[Dict[str, Any]]:
|
129 |
+
''' data_file is a str pointing to filename of type .vtt '''
|
130 |
+
|
131 |
+
utterances_list = []
|
132 |
+
for utterance in self.process_vtt_transcript(data_file):
|
133 |
+
#TODO: filter out to only have SL utterances
|
134 |
+
utterances_list.append(self.utterance_to_str(utterance))
|
135 |
+
|
136 |
+
predictions, raw_outputs = self.model.predict(utterances_list)
|
137 |
+
|
138 |
+
return predictions
|
handler.py~
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from simpletransformers.classification import ClassificationModel, ClassificationArgs
|
2 |
+
from typing import Dict, List, Any
|
3 |
+
import pandas as pd
|
4 |
+
import webvtt
|
5 |
+
from datetime import datetime
|
6 |
+
import torch
|
7 |
+
import spacy
|
8 |
+
|
9 |
+
nlp = spacy.load("en_core_web_sm")
|
10 |
+
tokenizer = nlp.tokenizer
|
11 |
+
token_limit = 200
|
12 |
+
|
13 |
+
class EndpointHandler():
|
14 |
+
def __init__(self, path="."):
|
15 |
+
print("Loading models...")
|
16 |
+
cuda_available = torch.cuda.is_available()
|
17 |
+
self.model = ClassificationModel(
|
18 |
+
"roberta", path, use_cuda=cuda_available
|
19 |
+
)
|
20 |
+
|
21 |
+
def __call__(self, data_file: str) -> List[Dict[str, Any]]:
|
22 |
+
''' data_file is a str pointing to filename of type .vtt '''
|
23 |
+
|
24 |
+
utterances_list = []
|
25 |
+
predictions, raw_outputs = self.model.predict(utterances_list)
|
26 |
+
|
27 |
+
return predictions
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model_args.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"adafactor_beta1": null, "adafactor_clip_threshold": 1.0, "adafactor_decay_rate": -0.8, "adafactor_eps": [1e-30, 0.001], "adafactor_relative_step": true, "adafactor_scale_parameter": true, "adafactor_warmup_init": true, "adam_betas": [0.9, 0.999], "adam_epsilon": 1e-08, "best_model_dir": "outputs/roberta/connecting_FINAL_MODEL/best_model_all_transcripts", "cache_dir": "outputs/roberta/connecting_FINAL_MODEL/cache", "config": {}, "cosine_schedule_num_cycles": 0.5, "custom_layer_parameters": [], "custom_parameter_groups": [], "dataloader_num_workers": 0, "do_lower_case": false, "dynamic_quantize": false, "early_stopping_consider_epochs": false, "early_stopping_delta": 0, "early_stopping_metric": "eval_loss", "early_stopping_metric_minimize": true, "early_stopping_patience": 3, "encoding": null, "eval_batch_size": 8, "evaluate_during_training": true, "evaluate_during_training_silent": true, "evaluate_during_training_steps": 348, "evaluate_during_training_verbose": false, "evaluate_each_epoch": true, "fp16": false, "gradient_accumulation_steps": 2, "learning_rate": 4e-05, "local_rank": -1, "logging_steps": 50, "loss_type": null, "loss_args": {}, "manual_seed": null, "max_grad_norm": 1.0, "max_seq_length": 512, "model_name": "roberta-base", "model_type": "roberta", "multiprocessing_chunksize": -1, "n_gpu": 1, "no_cache": false, "no_save": false, "not_saved_args": [], "num_train_epochs": 5, "optimizer": "AdamW", "output_dir": "outputs/roberta/connecting_FINAL_MODEL", "overwrite_output_dir": true, "polynomial_decay_schedule_lr_end": 1e-07, "polynomial_decay_schedule_power": 1.0, "process_count": 1, "quantized_model": false, "reprocess_input_data": true, "save_best_model": true, "save_eval_checkpoints": false, "save_model_every_epoch": false, "save_optimizer_and_scheduler": true, "save_steps": 2000, "scheduler": "linear_schedule_with_warmup", "silent": false, "skip_special_tokens": true, "tensorboard_dir": "outputs/roberta/connecting_FINAL_MODEL/tensorboard", "thread_count": null, "tokenizer_name": "roberta-base", "tokenizer_type": null, "train_batch_size": 8, "train_custom_parameters_only": false, "use_cached_eval_features": false, "use_early_stopping": false, "use_hf_datasets": false, "use_multiprocessing": false, "use_multiprocessing_for_evaluation": false, "wandb_kwargs": {"reinit": true}, "wandb_project": "connecting_all_transcripts", "warmup_ratio": 0.06, "warmup_steps": 53, "weight_decay": 0.0, "model_class": "ClassificationModel", "labels_list": [0, 1], "labels_map": {}, "lazy_delimiter": "\t", "lazy_labels_column": 1, "lazy_loading": false, "lazy_loading_start_line": 1, "lazy_text_a_column": null, "lazy_text_b_column": null, "lazy_text_column": 0, "onnx": false, "regression": false, "sliding_window": false, "special_tokens_list": [], "stride": 0.8, "tie_value": 1}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cbb13cc66ca8a6202d9de0c2f7f7e060d79764bfe005a269ef31074120544e15
|
3 |
+
size 498662069
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas==1.1.1
|
2 |
+
scikit_learn==1.1.3
|
3 |
+
scipy==1.7.1
|
4 |
+
simpletransformers==0.63.6
|
5 |
+
torch==1.6.0
|
special_tokens_map.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<s>",
|
3 |
+
"cls_token": "<s>",
|
4 |
+
"eos_token": "</s>",
|
5 |
+
"mask_token": {
|
6 |
+
"content": "<mask>",
|
7 |
+
"lstrip": true,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false
|
11 |
+
},
|
12 |
+
"pad_token": "<pad>",
|
13 |
+
"sep_token": "</s>",
|
14 |
+
"unk_token": "<unk>"
|
15 |
+
}
|
test_run_handler.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from handler import EndpointHandler
|
2 |
+
|
3 |
+
# init handler
|
4 |
+
my_handler = EndpointHandler(path=".")
|
5 |
+
|
6 |
+
# prepare sample payload
|
7 |
+
test_payload = 'test.transcript.vtt'
|
8 |
+
|
9 |
+
# test the handler
|
10 |
+
test_pred=my_handler(test_payload)
|
11 |
+
|
12 |
+
# show results
|
13 |
+
print("test_pred", test_pred)
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"bos_token": "<s>",
|
4 |
+
"clean_up_tokenization_spaces": true,
|
5 |
+
"cls_token": "<s>",
|
6 |
+
"do_lower_case": false,
|
7 |
+
"eos_token": "</s>",
|
8 |
+
"errors": "replace",
|
9 |
+
"mask_token": "<mask>",
|
10 |
+
"model_max_length": 512,
|
11 |
+
"pad_token": "<pad>",
|
12 |
+
"sep_token": "</s>",
|
13 |
+
"tokenizer_class": "RobertaTokenizer",
|
14 |
+
"trim_offsets": true,
|
15 |
+
"unk_token": "<unk>"
|
16 |
+
}
|
training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7a7eb9bd523295c1f076863d6e51ca1a99210220944f2a8573aa5e6a556be9c
|
3 |
+
size 3451
|
training_progress_scores.csv
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
global_step,train_loss,mcc,tp,tn,fp,fn,auroc,auprc,accuracy,precision,recall,f1,eval_loss
|
2 |
+
174,0.336158812046051,0.0,0,2390,0,113,0.9098048654052654,0.499127409757406,0.954854174990012,0.0,0.0,0.0,0.16167433786030394
|
3 |
+
348,0.010986842215061188,0.785426877747522,104,2343,47,9,0.9911689561965416,0.9236932727451965,0.9776268477826608,0.9203539823008849,0.6887417218543046,0.7878787878787877,0.08293315396103383
|
4 |
+
348,0.07023407518863678,0.785426877747522,104,2343,47,9,0.9911689561965416,0.9236932727451965,0.9776268477826608,0.9203539823008849,0.6887417218543046,0.7878787878787877,0.08293315396103383
|
5 |
+
522,0.0005581587320193648,0.9180271585164762,110,2374,16,3,0.9986596067686155,0.9899529045270361,0.9924091090691171,0.9734513274336283,0.873015873015873,0.9205020920502092,0.027681814245459364
|
6 |
+
696,0.0003248823923058808,0.9775600823633134,113,2385,5,0,0.9999814862813344,0.9996191363565474,0.9980023971234518,1.0,0.9576271186440678,0.9783549783549783,0.005484335490429146
|
7 |
+
696,0.00022003523190505803,0.9775600823633134,113,2385,5,0,0.9999814862813344,0.9996191363565474,0.9980023971234518,1.0,0.9576271186440678,0.9783549783549783,0.005484335490429146
|
8 |
+
870,0.00024747333372943103,0.9953571764069896,112,2390,0,1,0.9999888917688007,0.9997711321330485,0.9996004794246903,0.9911504424778761,1.0,0.9955555555555555,0.002696692644352147
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|