from annotated_text import annotated_text, parameters, annotation | |
from nltk.tokenize import word_tokenize | |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline | |
import streamlit as st | |
import torch | |
# add the caching decorator and use custom text for spinner | |
def label_text(text): | |
if text != "": | |
tokenizer = AutoTokenizer.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh") | |
model = AutoModelForTokenClassification.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh") | |
nlp = pipeline("ner", model = model, tokenizer = tokenizer) | |
labels_dict = {0: 'O', | |
1: 'B-ADAGE', | |
2: 'I-ADAGE', | |
3: 'B-ART', | |
4: 'I-ART', | |
5: 'B-CARDINAL', | |
6: 'I-CARDINAL', | |
7: 'B-CONTACT', | |
8: 'I-CONTACT', | |
9: 'B-DATE', | |
10: 'I-DATE', | |
11: 'B-DISEASE', | |
12: 'I-DISEASE', | |
13: 'B-EVENT', | |
14: 'I-EVENT', | |
15: 'B-FACILITY', | |
16: 'I-FACILITY', | |
17: 'B-GPE', | |
18: 'I-GPE', | |
19: 'B-LANGUAGE', | |
20: 'I-LANGUAGE', | |
21: 'B-LAW', | |
22: 'I-LAW', | |
23: 'B-LOCATION', | |
24: 'I-LOCATION', | |
27: 'B-MONEY', | |
28: 'I-MONEY', | |
29: 'B-NON_HUMAN', | |
30: 'I-NON_HUMAN', | |
31: 'B-NORP', | |
32: 'I-NORP', | |
33: 'B-ORDINAL', | |
34: 'I-ORDINAL', | |
37: 'B-PERSON', | |
38: 'I-PERSON', | |
39: 'B-PERCENTAGE', | |
40: 'I-PERCENTAGE', | |
41: 'B-POSITION', | |
42: 'I-POSITION', | |
43: 'B-PRODUCT', | |
44: 'I-PRODUCT', | |
45: 'B-PROJECT', | |
46: 'I-PROJECT', | |
47: 'B-QUANTITY', | |
48: 'I-QUANTITY', | |
49: 'B-TIME', | |
50: 'I-TIME'} | |
single_sentence_tokens = word_tokenize(text) | |
tokenized_input = tokenizer(single_sentence_tokens, is_split_into_words = True, return_tensors = "pt") | |
tokens = tokenized_input.tokens() | |
output = model(**tokenized_input).logits | |
predictions = torch.argmax(output, dim = 2) | |
# convert label IDs to label names | |
word_ids = tokenized_input.word_ids(batch_index = 0) | |
previous_word_id = None | |
labels = [] | |
for token, word_id, prediction in zip(tokens, word_ids, predictions[0].numpy()): | |
# # Special tokens have a word id that is None. We set the label to -100 so they are | |
# # automatically ignored in the loss function. | |
if word_id is None or word_id == previous_word_id: | |
continue | |
elif word_id != previous_word_id: | |
labels.append(labels_dict[prediction]) | |
previous_word_id = word_id | |
assert len(single_sentence_tokens) == len(labels), "Mismatch between input token and label sizes!" | |
sentence_tokens = [] | |
sentence_labels = [] | |
token_list = [] | |
label_list = [] | |
previous_token = "" | |
previous_label = "" | |
for token, label in zip(single_sentence_tokens, labels): | |
current_token = token | |
current_label = label | |
# starting loop | |
if previous_label == "": | |
previous_token = current_token | |
previous_label = current_label | |
# collecting compound named entities | |
elif (previous_label.startswith("B-")) and (current_label.startswith("I-")): | |
token_list.append(previous_token) | |
label_list.append(previous_label) | |
elif (previous_label.startswith("I-")) and (current_label.startswith("I-")): | |
token_list.append(previous_token) | |
label_list.append(previous_label) | |
elif (previous_label.startswith("I-")) and (not current_label.startswith("I-")): | |
token_list.append(previous_token) | |
label_list.append(previous_label) | |
sentence_tokens.append(token_list) | |
sentence_labels.append(label_list) | |
token_list = [] | |
label_list = [] | |
# collecting single named entities: | |
elif (not previous_label.startswith("I-")) and (not current_label.startswith("I-")): | |
token_list.append(previous_token) | |
label_list.append(previous_label) | |
sentence_tokens.append(token_list) | |
sentence_labels.append(label_list) | |
token_list = [] | |
label_list = [] | |
previous_token = current_token | |
previous_label = current_label | |
token_list.append(previous_token) | |
label_list.append(previous_label) | |
sentence_tokens.append(token_list) | |
sentence_labels.append(label_list) | |
output = [] | |
for sentence_token, sentence_label in zip(sentence_tokens, sentence_labels): | |
if len(sentence_label[0]) > 1: | |
if len(sentence_label) > 1: | |
output.append((" ".join(sentence_token), sentence_label[0].split("-")[1])) | |
else: | |
output.append((sentence_token[0], sentence_label[0].split("-")[1])) | |
else: | |
# output.append((sentence_token[0], sentence_label[0])) | |
output.append(sentence_token[0]) | |
modified_output = [] | |
for element in output: | |
if not isinstance(element, tuple): | |
if element.isalnum(): | |
modified_output.append(' ' + element + ' ') | |
else: | |
modified_output.append(' ' + element + ' ') | |
else: | |
tuple_first = f" {element[0]} " | |
tuple_second = element[1] | |
new_tuple = (tuple_first, tuple_second) | |
modified_output.append(new_tuple) | |
else: | |
return st.markdown("<p id = 'warning'>PLEASE INSERT YOUR TEXT</p>", unsafe_allow_html = True) | |
return modified_output | |
######################### | |
#### CREATE SIDEBAR ##### | |
######################### | |
with open("style.css") as f: | |
css = | |
st.sidebar.markdown(f'<style>{css}</style>', unsafe_allow_html = True) | |
st.sidebar.markdown("<h1>Kazakh NER</h1>", unsafe_allow_html = True) | |
st.sidebar.markdown("<h2>Named entity classes</h2>", unsafe_allow_html = True) | |
with st.sidebar.expander("ADAGE"): st.write("Well-known Kazakh proverbs and sayings") | |
with st.sidebar.expander("ART"): st.write("Titles of books, songs, television programmes, etc.") | |
with st.sidebar.expander("CARDINAL"): st.write("Cardinal numbers, including whole numbers, fractions, and decimals") | |
with st.sidebar.expander("CONTACT"): st.write("Addresses, emails, phone numbers, URLs") | |
with st.sidebar.expander("DATE"): st.write("Dates or periods of 24 hours or more") | |
with st.sidebar.expander("DISEASE"): st.write("Diseases or medical conditions") | |
with st.sidebar.expander("EVENT"): st.write("Named events and phenomena") | |
with st.sidebar.expander("FACILITY"): st.write("Names of man-made structures") | |
with st.sidebar.expander("GPE"): st.write("Names of geopolitical entities") | |
with st.sidebar.expander("LANGUAGE"): st.write("Named languages") | |
with st.sidebar.expander("LAW"): st.write("Named legal documents") | |
with st.sidebar.expander("LOCATION"): st.write("Names of geographical locations other than GPEs") | |
with st.sidebar.expander("MISCELLANEOUS"): st.write("Entities of interest but hard to assign a proper tag to") | |
with st.sidebar.expander("MONEY"): st.write("Monetary values") | |
with st.sidebar.expander("NON_HUMAN"): st.write("Names of pets, animals or non-human creatures") | |
with st.sidebar.expander("NORP"): st.write("Adjectival forms of GPE and LOCATION; named religions, etc.") | |
with st.sidebar.expander("ORDINAL"): st.write("Ordinal numbers, including adverbials") | |
with st.sidebar.expander("ORGANISATION"): st.write("Names of companies, government agencies, etc.") | |
with st.sidebar.expander("PERCENTAGES"): st.write("Percentages") | |
with st.sidebar.expander("PERSON"): st.write("Names of persons") | |
with st.sidebar.expander("POSITION"): st.write("Names of posts and job titles") | |
with st.sidebar.expander("PRODUCT"): st.write("Names of products") | |
with st.sidebar.expander("PROJECT"): st.write("Names of projects, policies, plans, etc.") | |
with st.sidebar.expander("QUANTITY"): st.write("Length, distance, etc. measurements") | |
with st.sidebar.expander("TIME"): st.write("Times of day and time duration less than 24 hours") | |
###################### | |
#### CREATE FORM ##### | |
###################### | |
text_field = st.form(key = 'text_field') | |
form_text = text_field.text_input('Insert your text here') | |
submit = text_field.form_submit_button('Submit') | |
st.markdown('Press **Submit** to have your text labelled') | |
if submit: | |
annotated_text(label_text(form_text)) |