File size: 4,036 Bytes
741d69e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a6eecb
741d69e
 
 
 
 
3a6eecb
741d69e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from collections import defaultdict
from transformers import BertTokenizerFast, BertForTokenClassification
import streamlit as st
from transformers import CharSpan
import torch

COLOURS = {
    "ADJ": '#1DB5D7',
    "ART": '#47A455',
    "ELI": '#FF5252',
    "FIN": '#9873E6',
    "NEG": '#FF914D'
}

@st.cache(allow_output_mutation=True)
def load_model():
    print('Loading classification model')
    return BertForTokenClassification.from_pretrained("aligator/mBERT_grammatical_error_tagger")


@st.cache(allow_output_mutation=True)
def load_tokenizer():
    print("Loading tokenizer for classification model")
    return BertTokenizerFast.from_pretrained("aligator/mBERT_grammatical_error_tagger")


model = load_model()
tokenizer = load_tokenizer()

special_tokens = [tokenizer.unk_token, tokenizer.sep_token, tokenizer.pad_token, tokenizer.cls_token, tokenizer.mask_token]


def predict(logits, threshold=0.3):
  """Takes in output of model for a sentence as entry, returns one hot encoded prediction"""
  probabilities = torch.nn.Softmax(dim=2)(logits)
  predicted = (probabilities > threshold).long()
  # if O is among the predicted labels, do not return other possible error types
  for i,sent in enumerate(predicted):
    for j,tok in enumerate(sent):
      if tok[0] == 1:
        predicted[i][j] = torch.tensor([1] + [0]*(model.num_labels - 1))
  return predicted

def one_hot_label_decode(tok_labels):
  labels = []
  for i,v in enumerate(tok_labels):
    if v != 0:
      labels.append(model.config.id2label[i])
  return labels

def merge_char_spans(charspans: list):
  sorted_charspans = sorted(charspans)
  # make groups of continuous spans
  start = min((span.start for span in sorted_charspans))
  end = max((span.end for span in sorted_charspans))
  return CharSpan(start,end)

def merge_error_spans(errors):
    merged_spans = defaultdict(list)
    for error_type in errors:
        merged_spans[error_type] = errors[error_type][:1]
        for charspan in errors[error_type][1:]:
            if merged_spans[error_type][-1].end == charspan.start or merged_spans[error_type][-1].end +1 == charspan.start:
                merged_spans[error_type][-1] = merge_char_spans([merged_spans[error_type][-1], charspan])
    return merged_spans

def predict_error_spans(sentence):
    encoded_orig = tokenizer(sentence, return_tensors="pt")
    tokens = tokenizer.convert_ids_to_tokens(encoded_orig.input_ids[0])
    output = model(**encoded_orig)
    predicted_labels = predict(output.logits, threshold=0.2)
    errors = defaultdict(list)

    for i,(t,p) in enumerate(list(zip(tokens, predicted_labels[0]))):
        labels = one_hot_label_decode(p)
        if t not in special_tokens and labels != ['O'] : 
            for error_type in labels:
                position, e_type = error_type.split("-")
                if position == "B":
                    errors[e_type].append(encoded_orig.token_to_chars(i))
                else:
                    errors[e_type][-1] = merge_char_spans([errors[e_type][-1], encoded_orig.token_to_chars(i)])
    return sentence, merge_error_spans(errors)

def annotate(sentence: str):
    sent_and_error_spans = predict_error_spans(sentence)
    error_spans_ordered = []
    for error_type, charspans in sent_and_error_spans[1].items():
        for charspan in charspans:
            error_spans_ordered.append((charspan,error_type))
        error_spans_ordered = sorted(error_spans_ordered, key= lambda x:x[0])
    annotated_sentence = []
    for i,(error_span,label) in enumerate(error_spans_ordered):
        if i > 0:
            annotated_sentence.append(sentence[error_spans_ordered[i-1][0].end:error_span.start])
        else:
            annotated_sentence.append(sentence[:error_span.start])
        annotated_sentence.append((sentence[error_span.start:error_span.end], label, COLOURS[label]))
    if len(error_spans_ordered) > 0:
        annotated_sentence.append(sentence[error_span.end:])
    else:
        annotated_sentence.append(sentence)
    return annotated_sentence