french_gec / classifier.py
alice-hml's picture
Upload classifier.py
3a6eecb
raw
history blame
No virus
4.04 kB
from collections import defaultdict
from transformers import BertTokenizerFast, BertForTokenClassification
import streamlit as st
from transformers import CharSpan
import torch
COLOURS = {
"ADJ": '#1DB5D7',
"ART": '#47A455',
"ELI": '#FF5252',
"FIN": '#9873E6',
"NEG": '#FF914D'
}
@st.cache(allow_output_mutation=True)
def load_model():
print('Loading classification model')
return BertForTokenClassification.from_pretrained("aligator/mBERT_grammatical_error_tagger")
@st.cache(allow_output_mutation=True)
def load_tokenizer():
print("Loading tokenizer for classification model")
return BertTokenizerFast.from_pretrained("aligator/mBERT_grammatical_error_tagger")
model = load_model()
tokenizer = load_tokenizer()
special_tokens = [tokenizer.unk_token, tokenizer.sep_token, tokenizer.pad_token, tokenizer.cls_token, tokenizer.mask_token]
def predict(logits, threshold=0.3):
"""Takes in output of model for a sentence as entry, returns one hot encoded prediction"""
probabilities = torch.nn.Softmax(dim=2)(logits)
predicted = (probabilities > threshold).long()
# if O is among the predicted labels, do not return other possible error types
for i,sent in enumerate(predicted):
for j,tok in enumerate(sent):
if tok[0] == 1:
predicted[i][j] = torch.tensor([1] + [0]*(model.num_labels - 1))
return predicted
def one_hot_label_decode(tok_labels):
labels = []
for i,v in enumerate(tok_labels):
if v != 0:
labels.append(model.config.id2label[i])
return labels
def merge_char_spans(charspans: list):
sorted_charspans = sorted(charspans)
# make groups of continuous spans
start = min((span.start for span in sorted_charspans))
end = max((span.end for span in sorted_charspans))
return CharSpan(start,end)
def merge_error_spans(errors):
merged_spans = defaultdict(list)
for error_type in errors:
merged_spans[error_type] = errors[error_type][:1]
for charspan in errors[error_type][1:]:
if merged_spans[error_type][-1].end == charspan.start or merged_spans[error_type][-1].end +1 == charspan.start:
merged_spans[error_type][-1] = merge_char_spans([merged_spans[error_type][-1], charspan])
return merged_spans
def predict_error_spans(sentence):
encoded_orig = tokenizer(sentence, return_tensors="pt")
tokens = tokenizer.convert_ids_to_tokens(encoded_orig.input_ids[0])
output = model(**encoded_orig)
predicted_labels = predict(output.logits, threshold=0.2)
errors = defaultdict(list)
for i,(t,p) in enumerate(list(zip(tokens, predicted_labels[0]))):
labels = one_hot_label_decode(p)
if t not in special_tokens and labels != ['O'] :
for error_type in labels:
position, e_type = error_type.split("-")
if position == "B":
errors[e_type].append(encoded_orig.token_to_chars(i))
else:
errors[e_type][-1] = merge_char_spans([errors[e_type][-1], encoded_orig.token_to_chars(i)])
return sentence, merge_error_spans(errors)
def annotate(sentence: str):
sent_and_error_spans = predict_error_spans(sentence)
error_spans_ordered = []
for error_type, charspans in sent_and_error_spans[1].items():
for charspan in charspans:
error_spans_ordered.append((charspan,error_type))
error_spans_ordered = sorted(error_spans_ordered, key= lambda x:x[0])
annotated_sentence = []
for i,(error_span,label) in enumerate(error_spans_ordered):
if i > 0:
annotated_sentence.append(sentence[error_spans_ordered[i-1][0].end:error_span.start])
else:
annotated_sentence.append(sentence[:error_span.start])
annotated_sentence.append((sentence[error_span.start:error_span.end], label, COLOURS[label]))
if len(error_spans_ordered) > 0:
annotated_sentence.append(sentence[error_span.end:])
else:
annotated_sentence.append(sentence)
return annotated_sentence