Spaces:
Runtime error
Runtime error
Upload classifier.py
Browse files- classifier.py +103 -0
classifier.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from transformers import BertTokenizerFast, BertForTokenClassification
|
3 |
+
import streamlit as st
|
4 |
+
from transformers import CharSpan
|
5 |
+
import torch
|
6 |
+
|
7 |
+
COLOURS = {
|
8 |
+
"ADJ": '#1DB5D7',
|
9 |
+
"ART": '#47A455',
|
10 |
+
"ELI": '#FF5252',
|
11 |
+
"FIN": '#9873E6',
|
12 |
+
"NEG": '#FF914D'
|
13 |
+
}
|
14 |
+
|
15 |
+
@st.cache(allow_output_mutation=True)
|
16 |
+
def load_model():
|
17 |
+
print('Loading classification model')
|
18 |
+
return BertForTokenClassification.from_pretrained("classifier/")
|
19 |
+
|
20 |
+
|
21 |
+
@st.cache(allow_output_mutation=True)
|
22 |
+
def load_tokenizer():
|
23 |
+
print("Loading tokenizer for classification model")
|
24 |
+
return BertTokenizerFast.from_pretrained("classifier/")
|
25 |
+
|
26 |
+
|
27 |
+
model = load_model()
|
28 |
+
tokenizer = load_tokenizer()
|
29 |
+
|
30 |
+
special_tokens = [tokenizer.unk_token, tokenizer.sep_token, tokenizer.pad_token, tokenizer.cls_token, tokenizer.mask_token]
|
31 |
+
|
32 |
+
|
33 |
+
def predict(logits, threshold=0.3):
|
34 |
+
"""Takes in output of model for a sentence as entry, returns one hot encoded prediction"""
|
35 |
+
probabilities = torch.nn.Softmax(dim=2)(logits)
|
36 |
+
predicted = (probabilities > threshold).long()
|
37 |
+
# if O is among the predicted labels, do not return other possible error types
|
38 |
+
for i,sent in enumerate(predicted):
|
39 |
+
for j,tok in enumerate(sent):
|
40 |
+
if tok[0] == 1:
|
41 |
+
predicted[i][j] = torch.tensor([1] + [0]*(model.num_labels - 1))
|
42 |
+
return predicted
|
43 |
+
|
44 |
+
def one_hot_label_decode(tok_labels):
|
45 |
+
labels = []
|
46 |
+
for i,v in enumerate(tok_labels):
|
47 |
+
if v != 0:
|
48 |
+
labels.append(model.config.id2label[i])
|
49 |
+
return labels
|
50 |
+
|
51 |
+
def merge_char_spans(charspans: list):
|
52 |
+
sorted_charspans = sorted(charspans)
|
53 |
+
# make groups of continuous spans
|
54 |
+
start = min((span.start for span in sorted_charspans))
|
55 |
+
end = max((span.end for span in sorted_charspans))
|
56 |
+
return CharSpan(start,end)
|
57 |
+
|
58 |
+
def merge_error_spans(errors):
|
59 |
+
merged_spans = defaultdict(list)
|
60 |
+
for error_type in errors:
|
61 |
+
merged_spans[error_type] = errors[error_type][:1]
|
62 |
+
for charspan in errors[error_type][1:]:
|
63 |
+
if merged_spans[error_type][-1].end == charspan.start or merged_spans[error_type][-1].end +1 == charspan.start:
|
64 |
+
merged_spans[error_type][-1] = merge_char_spans([merged_spans[error_type][-1], charspan])
|
65 |
+
return merged_spans
|
66 |
+
|
67 |
+
def predict_error_spans(sentence):
|
68 |
+
encoded_orig = tokenizer(sentence, return_tensors="pt")
|
69 |
+
tokens = tokenizer.convert_ids_to_tokens(encoded_orig.input_ids[0])
|
70 |
+
output = model(**encoded_orig)
|
71 |
+
predicted_labels = predict(output.logits, threshold=0.2)
|
72 |
+
errors = defaultdict(list)
|
73 |
+
|
74 |
+
for i,(t,p) in enumerate(list(zip(tokens, predicted_labels[0]))):
|
75 |
+
labels = one_hot_label_decode(p)
|
76 |
+
if t not in special_tokens and labels != ['O'] :
|
77 |
+
for error_type in labels:
|
78 |
+
position, e_type = error_type.split("-")
|
79 |
+
if position == "B":
|
80 |
+
errors[e_type].append(encoded_orig.token_to_chars(i))
|
81 |
+
else:
|
82 |
+
errors[e_type][-1] = merge_char_spans([errors[e_type][-1], encoded_orig.token_to_chars(i)])
|
83 |
+
return sentence, merge_error_spans(errors)
|
84 |
+
|
85 |
+
def annotate(sentence: str):
|
86 |
+
sent_and_error_spans = predict_error_spans(sentence)
|
87 |
+
error_spans_ordered = []
|
88 |
+
for error_type, charspans in sent_and_error_spans[1].items():
|
89 |
+
for charspan in charspans:
|
90 |
+
error_spans_ordered.append((charspan,error_type))
|
91 |
+
error_spans_ordered = sorted(error_spans_ordered, key= lambda x:x[0])
|
92 |
+
annotated_sentence = []
|
93 |
+
for i,(error_span,label) in enumerate(error_spans_ordered):
|
94 |
+
if i > 0:
|
95 |
+
annotated_sentence.append(sentence[error_spans_ordered[i-1][0].end:error_span.start])
|
96 |
+
else:
|
97 |
+
annotated_sentence.append(sentence[:error_span.start])
|
98 |
+
annotated_sentence.append((sentence[error_span.start:error_span.end], label, COLOURS[label]))
|
99 |
+
if len(error_spans_ordered) > 0:
|
100 |
+
annotated_sentence.append(sentence[error_span.end:])
|
101 |
+
else:
|
102 |
+
annotated_sentence.append(sentence)
|
103 |
+
return annotated_sentence
|