french_gec / classifier.py
alice-hml's picture
update huggingface name
c114565
raw
history blame
4.04 kB
from collections import defaultdict
from transformers import BertTokenizerFast, BertForTokenClassification
import streamlit as st
from transformers import CharSpan
import torch
COLOURS = {
"ADJ": '#1DB5D7',
"ART": '#47A455',
"ELI": '#FF5252',
"FIN": '#9873E6',
"NEG": '#FF914D'
}
@st.cache(allow_output_mutation=True)
def load_model():
print('Loading classification model')
return BertForTokenClassification.from_pretrained("alice-hml/mBERT_grammatical_error_tagger")
@st.cache(allow_output_mutation=True)
def load_tokenizer():
print("Loading tokenizer for classification model")
return BertTokenizerFast.from_pretrained("alice-hml/mBERT_grammatical_error_tagger")
model = load_model()
tokenizer = load_tokenizer()
special_tokens = [tokenizer.unk_token, tokenizer.sep_token, tokenizer.pad_token, tokenizer.cls_token, tokenizer.mask_token]
def predict(logits, threshold=0.3):
"""Takes in output of model for a sentence as entry, returns one hot encoded prediction"""
probabilities = torch.nn.Softmax(dim=2)(logits)
predicted = (probabilities > threshold).long()
# if O is among the predicted labels, do not return other possible error types
for i,sent in enumerate(predicted):
for j,tok in enumerate(sent):
if tok[0] == 1:
predicted[i][j] = torch.tensor([1] + [0]*(model.num_labels - 1))
return predicted
def one_hot_label_decode(tok_labels):
labels = []
for i,v in enumerate(tok_labels):
if v != 0:
labels.append(model.config.id2label[i])
return labels
def merge_char_spans(charspans: list):
sorted_charspans = sorted(charspans)
# make groups of continuous spans
start = min((span.start for span in sorted_charspans))
end = max((span.end for span in sorted_charspans))
return CharSpan(start,end)
def merge_error_spans(errors):
merged_spans = defaultdict(list)
for error_type in errors:
merged_spans[error_type] = errors[error_type][:1]
for charspan in errors[error_type][1:]:
if merged_spans[error_type][-1].end == charspan.start or merged_spans[error_type][-1].end +1 == charspan.start:
merged_spans[error_type][-1] = merge_char_spans([merged_spans[error_type][-1], charspan])
return merged_spans
def predict_error_spans(sentence):
encoded_orig = tokenizer(sentence, return_tensors="pt")
tokens = tokenizer.convert_ids_to_tokens(encoded_orig.input_ids[0])
output = model(**encoded_orig)
predicted_labels = predict(output.logits, threshold=0.2)
errors = defaultdict(list)
for i,(t,p) in enumerate(list(zip(tokens, predicted_labels[0]))):
labels = one_hot_label_decode(p)
if t not in special_tokens and labels != ['O'] :
for error_type in labels:
position, e_type = error_type.split("-")
if position == "B":
errors[e_type].append(encoded_orig.token_to_chars(i))
else:
errors[e_type][-1] = merge_char_spans([errors[e_type][-1], encoded_orig.token_to_chars(i)])
return sentence, merge_error_spans(errors)
def annotate(sentence: str):
sent_and_error_spans = predict_error_spans(sentence)
error_spans_ordered = []
for error_type, charspans in sent_and_error_spans[1].items():
for charspan in charspans:
error_spans_ordered.append((charspan,error_type))
error_spans_ordered = sorted(error_spans_ordered, key= lambda x:x[0])
annotated_sentence = []
for i,(error_span,label) in enumerate(error_spans_ordered):
if i > 0:
annotated_sentence.append(sentence[error_spans_ordered[i-1][0].end:error_span.start])
else:
annotated_sentence.append(sentence[:error_span.start])
annotated_sentence.append((sentence[error_span.start:error_span.end], label, COLOURS[label]))
if len(error_spans_ordered) > 0:
annotated_sentence.append(sentence[error_span.end:])
else:
annotated_sentence.append(sentence)
return annotated_sentence