NerRoB-czech / website_script.py
AlzbetaStrompova
fix param name
d09e4cf
raw
history blame contribute delete
No virus
6.79 kB
import json
import copy
import pickle
import torch
from simplemma import lemmatize
from transformers import AutoTokenizer
from extended_embeddings.extended_embedding_token_classification import ExtendedEmbeddigsRobertaForTokenClassification
from data_manipulation.dataset_funcions import gazetteer_matching, align_gazetteers_with_tokens
# code originaly from data_manipulation.creation_gazetteers
def lemmatizing(x):
if x == "":
return ""
return lemmatize(x, lang="cs")
# code originaly from data_manipulation.creation_gazetteers
def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
reverse_dictionary = {}
for key, values in dictionary.items():
for value in values:
reverse_dictionary[value] = key
if apply_lemmatizing:
temp = lemmatizing(value)
if temp != value:
reverse_dictionary[temp] = key
return reverse_dictionary
def load_json(path):
"""
Load gazetteers from a file
:param path: path to the gazetteer file
:return: a dict of gazetteers
"""
with open(path, 'r') as file:
data = json.load(file)
return data
def load_pickle(path):
"""
Load pickle gazetteers from a file
:param path: path to the gazetteer file
:return: a dict of gazetteers
"""
with open(path, 'rb') as file:
data = pickle.load(file)
return data
def load():
"""
Load the tokenizer, model, and gazetteers for named entity recognition.
Returns:
tokenizer (AutoTokenizer): The tokenizer for tokenizing input text.
model (ExtendedEmbeddigsRobertaForTokenClassification): The pre-trained model for named entity recognition.
gazetteers_for_matching (list): A list of gazetteers for matching named entities.
"""
model_name = "ufal/robeczech-base"
model_path = "bettystr/NerRoB-czech"
gazetteers_path = "gazetteers.pkl"
model = ExtendedEmbeddigsRobertaForTokenClassification.from_pretrained(model_path).to("cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
gazetteers_for_matching = load_pickle(gazetteers_path)
temp = []
for i in gazetteers_for_matching.keys():
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
gazetteers_for_matching = temp
return tokenizer, model, gazetteers_for_matching
def add_additional_gazetteers(gazetteers_for_matching, file_names):
"""
Adds additional gazetteers to the existing dict.
Args:
gazetteers_for_matching (dict): The list of gazetteers to be updated.
file_names (list): The list of file names containing additional gazetteers.
Returns:
dict: The updated list of gazetteers.
"""
if file_names is None or file_names == []:
return gazetteers_for_matching
temp = []
for l1 in gazetteers_for_matching:
d2 = copy.deepcopy(l1)
temp.append(d2)
for file_name in file_names:
with open(file_name, 'r') as file:
data = json.load(file)
for key, value_lst in data.items():
key = key.upper()
for dictionary in temp:
if key in dictionary.values():
for value in value_lst:
dictionary[value] = key
return temp
def run(tokenizer, model, gazetteers, text, file_names=None):
"""
Runs the named entity recognition (NER) model on the given text.
Args:
tokenizer (Tokenizer): The tokenizer used to tokenize the input text.
model (Model): The NER model used for prediction.
gazetteers (list): A list of gazetteers used for matching entities in the text.
text (str): The input text to perform NER on.
file_names (list, optional): A list of file names to be used as additional gazetteers.
Returns:
list: A list of dictionaries representing the predicted entities in the text. Each dictionary contains the following keys:
- "start" (int): The starting position of the entity in the text.
- "end" (int): The ending position of the entity in the text.
- "entity" (str): The type of the entity.
- "score" (float): The confidence score of the entity prediction.
- "word" (str): The actual word representing the entity.
- "count" (int): The number of tokens in the entity.
"""
gazetteers_for_matching = add_additional_gazetteers(gazetteers, file_names)
tokenized_inputs = tokenizer(
text, truncation=True, is_split_into_words=False, return_offsets_mapping=True
)
matches = gazetteer_matching(text, gazetteers_for_matching)
new_g = []
word_ids = tokenized_inputs.word_ids()
new_g.append(align_gazetteers_with_tokens(matches, word_ids))
p, o, l = [], [], []
for i in new_g:
p.append([x[0] for x in i])
o.append([x[1] for x in i])
l.append([x[2] for x in i])
input_ids = torch.tensor(tokenized_inputs["input_ids"], device="cpu").unsqueeze(0)
attention_mask = torch.tensor(tokenized_inputs["attention_mask"], device="cpu").unsqueeze(0)
per = torch.tensor(p, device="cpu")
org = torch.tensor(o, device="cpu")
loc = torch.tensor(l, device="cpu")
output = model(input_ids=input_ids, attention_mask=attention_mask, per=per, org=org, loc=loc).logits
predictions = torch.argmax(output, dim=2).tolist()
predicted_tags = [[model.config.id2label[idx] for idx in sentence] for sentence in predictions]
softmax = torch.nn.Softmax(dim=2)
scores = softmax(output).squeeze(0).tolist()
result = []
temp = {
"start": 0,
"end": 0,
"entity": "O",
"score": 0,
"word": "",
"count": 0
}
for pos, entity, score in zip(tokenized_inputs.offset_mapping, predicted_tags[0], scores):
if pos[0] == pos[1] or entity == "O":
continue
if "I-" + temp["entity"] == entity: # same entity
temp["word"] += text[temp["end"]:pos[0]] + text[pos[0]:pos[1]]
temp["end"] = pos[1]
temp["count"] += 1
temp["score"] += max(score)
else: # new entity
if temp["count"] > 0:
temp["score"] += max(score)
temp["score"] /= temp.pop("count")
result.append(temp)
temp = {
"start": pos[0],
"end": pos[1],
"entity": entity[2:],
"score": 0,
"word": text[pos[0]:pos[1]],
"count": 1
}
if temp["count"] > 0:
temp["score"] += max(score)
temp["score"] /= temp.pop("count")
result.append(temp)
return result