alice-hml commited on
Commit
741d69e
1 Parent(s): 3ee4a68

Upload classifier.py

Browse files
Files changed (1) hide show
  1. classifier.py +103 -0
classifier.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from transformers import BertTokenizerFast, BertForTokenClassification
3
+ import streamlit as st
4
+ from transformers import CharSpan
5
+ import torch
6
+
7
+ COLOURS = {
8
+ "ADJ": '#1DB5D7',
9
+ "ART": '#47A455',
10
+ "ELI": '#FF5252',
11
+ "FIN": '#9873E6',
12
+ "NEG": '#FF914D'
13
+ }
14
+
15
+ @st.cache(allow_output_mutation=True)
16
+ def load_model():
17
+ print('Loading classification model')
18
+ return BertForTokenClassification.from_pretrained("classifier/")
19
+
20
+
21
+ @st.cache(allow_output_mutation=True)
22
+ def load_tokenizer():
23
+ print("Loading tokenizer for classification model")
24
+ return BertTokenizerFast.from_pretrained("classifier/")
25
+
26
+
27
+ model = load_model()
28
+ tokenizer = load_tokenizer()
29
+
30
+ special_tokens = [tokenizer.unk_token, tokenizer.sep_token, tokenizer.pad_token, tokenizer.cls_token, tokenizer.mask_token]
31
+
32
+
33
+ def predict(logits, threshold=0.3):
34
+ """Takes in output of model for a sentence as entry, returns one hot encoded prediction"""
35
+ probabilities = torch.nn.Softmax(dim=2)(logits)
36
+ predicted = (probabilities > threshold).long()
37
+ # if O is among the predicted labels, do not return other possible error types
38
+ for i,sent in enumerate(predicted):
39
+ for j,tok in enumerate(sent):
40
+ if tok[0] == 1:
41
+ predicted[i][j] = torch.tensor([1] + [0]*(model.num_labels - 1))
42
+ return predicted
43
+
44
+ def one_hot_label_decode(tok_labels):
45
+ labels = []
46
+ for i,v in enumerate(tok_labels):
47
+ if v != 0:
48
+ labels.append(model.config.id2label[i])
49
+ return labels
50
+
51
+ def merge_char_spans(charspans: list):
52
+ sorted_charspans = sorted(charspans)
53
+ # make groups of continuous spans
54
+ start = min((span.start for span in sorted_charspans))
55
+ end = max((span.end for span in sorted_charspans))
56
+ return CharSpan(start,end)
57
+
58
+ def merge_error_spans(errors):
59
+ merged_spans = defaultdict(list)
60
+ for error_type in errors:
61
+ merged_spans[error_type] = errors[error_type][:1]
62
+ for charspan in errors[error_type][1:]:
63
+ if merged_spans[error_type][-1].end == charspan.start or merged_spans[error_type][-1].end +1 == charspan.start:
64
+ merged_spans[error_type][-1] = merge_char_spans([merged_spans[error_type][-1], charspan])
65
+ return merged_spans
66
+
67
+ def predict_error_spans(sentence):
68
+ encoded_orig = tokenizer(sentence, return_tensors="pt")
69
+ tokens = tokenizer.convert_ids_to_tokens(encoded_orig.input_ids[0])
70
+ output = model(**encoded_orig)
71
+ predicted_labels = predict(output.logits, threshold=0.2)
72
+ errors = defaultdict(list)
73
+
74
+ for i,(t,p) in enumerate(list(zip(tokens, predicted_labels[0]))):
75
+ labels = one_hot_label_decode(p)
76
+ if t not in special_tokens and labels != ['O'] :
77
+ for error_type in labels:
78
+ position, e_type = error_type.split("-")
79
+ if position == "B":
80
+ errors[e_type].append(encoded_orig.token_to_chars(i))
81
+ else:
82
+ errors[e_type][-1] = merge_char_spans([errors[e_type][-1], encoded_orig.token_to_chars(i)])
83
+ return sentence, merge_error_spans(errors)
84
+
85
+ def annotate(sentence: str):
86
+ sent_and_error_spans = predict_error_spans(sentence)
87
+ error_spans_ordered = []
88
+ for error_type, charspans in sent_and_error_spans[1].items():
89
+ for charspan in charspans:
90
+ error_spans_ordered.append((charspan,error_type))
91
+ error_spans_ordered = sorted(error_spans_ordered, key= lambda x:x[0])
92
+ annotated_sentence = []
93
+ for i,(error_span,label) in enumerate(error_spans_ordered):
94
+ if i > 0:
95
+ annotated_sentence.append(sentence[error_spans_ordered[i-1][0].end:error_span.start])
96
+ else:
97
+ annotated_sentence.append(sentence[:error_span.start])
98
+ annotated_sentence.append((sentence[error_span.start:error_span.end], label, COLOURS[label]))
99
+ if len(error_spans_ordered) > 0:
100
+ annotated_sentence.append(sentence[error_span.end:])
101
+ else:
102
+ annotated_sentence.append(sentence)
103
+ return annotated_sentence