Cybersecurity-Knowledge-Graph / event_realis_predict.py
cpi-connect's picture
Upload model
303b1b2
import os
import spacy
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from .utils import get_idxs_from_text
import streamlit as st
from annotated_text import annotated_text
from .nugget_model_utils import CustomRobertaWithPOS
from .event_nugget_predict import get_event_nuggets
from .realis_model_utils import get_entity_for_realis_from_idx, tokenize_and_align_labels_with_pos_ner_realis
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
event_nugget_list = ['B-Phishing',
'I-Phishing',
'O',
'B-DiscoverVulnerability',
'B-Ransom',
'I-Ransom',
'B-Databreach',
'I-DiscoverVulnerability',
'B-PatchVulnerability',
'I-PatchVulnerability',
'I-Databreach']
realis_list = ["O", "Generic", "Other", "Actual"]
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def find_dep_depth(token):
depth = 0
current_token = token
while current_token.head != current_token:
depth += 1
current_token = current_token.head
return min(depth, 16)
nlp = spacy.load('en_core_web_sm')
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_checkpoint = "ehsanaghaei/SecureBERT"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
# from .realis_model_utils import CustomRobertaWithPOS as RealisModel
# model_realis = RealisModel(num_classes_realis=4)
# model_realis.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/realis_model_state_dict.pth", map_location=device))
# model_realis.eval()
"""
Function: create_dataloader(text_input)
Description: This function prepares a DataLoader for processing text input, including tokenization and alignment of labels.
Inputs:
- text_input: The input text to be processed.
Output:
- dataloader: A DataLoader for the tokenized and batched text data.
- tokenized_dataset_ner: The tokenized dataset used for training.
"""
def create_dataloader(model_nugget, text_input):
event_nuggets = get_event_nuggets(model_nugget, text_input)
doc = nlp(text_input)
content_as_words_emdash = [tok.text for tok in doc]
content_as_words_emdash = [word.replace("``", '"').replace("''", '"').replace("$", "") for word in content_as_words_emdash]
content_idx_dict = get_idxs_from_text(text_input, content_as_words_emdash)
data = []
words = []
nugget_ner_tags = []
pos_spacy = [tok.pos_ for tok in doc]
ner_spacy = [ent.ent_iob_ + "-" + ent.ent_type_ if ent.ent_iob_ != "O" else ent.ent_iob_ for ent in doc]
dep_spacy = [tok.dep_ for tok in doc]
depth_spacy = [find_dep_depth(tok) for tok in doc]
for content_dict in content_idx_dict:
start_idx, end_idx = content_dict["start_idx"], content_dict["end_idx"]
entity = get_entity_for_realis_from_idx(start_idx, end_idx, event_nuggets)
words.append(content_dict["word"])
nugget_ner_tags.append(entity)
content_token_len = len(tokenizer(words, truncation=False, is_split_into_words=True)["input_ids"])
if content_token_len > tokenizer.model_max_length:
no_split = (content_token_len // tokenizer.model_max_length) + 2
split_len = (len(words) // no_split) + 1
last_id = 0
threshold = split_len
for id, token in enumerate(words):
if token == "." and id > threshold:
data.append(
{
"tokens" : words[last_id : id + 1],
"ner_tags" : nugget_ner_tags[last_id : id + 1],
"pos_spacy" : pos_spacy[last_id : id + 1],
"ner_spacy" : ner_spacy[last_id : id + 1],
"dep_spacy" : dep_spacy[last_id : id + 1],
"depth_spacy" : depth_spacy[last_id : id + 1],
}
)
last_id = id + 1
threshold += split_len
data.append({"tokens" : words[last_id : ],
"ner_tags" : nugget_ner_tags[last_id : ],
"pos_spacy" : pos_spacy[last_id : ],
"ner_spacy" : ner_spacy[last_id : ],
"dep_spacy" : dep_spacy[last_id : ],
"depth_spacy" : depth_spacy[last_id : ]})
else:
data.append(
{
"tokens" : words,
"ner_tags" : nugget_ner_tags,
"pos_spacy" : pos_spacy,
"ner_spacy" : ner_spacy,
"dep_spacy" : dep_spacy,
"depth_spacy" : depth_spacy
}
)
ner_features = Features({'tokens' : Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
'ner_tags' : Sequence(feature=ClassLabel(num_classes=len(event_nugget_list), names=event_nugget_list, names_file=None, id=None), length=-1, id=None),
'pos_spacy' : Sequence(feature=ClassLabel(num_classes=len(pos_spacy_tag_list), names=pos_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
'ner_spacy' : Sequence(feature=ClassLabel(num_classes=len(ner_spacy_tag_list), names=ner_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
'dep_spacy' : Sequence(feature=ClassLabel(num_classes=len(dep_spacy_tag_list), names=dep_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
'depth_spacy' : Sequence(feature=ClassLabel(num_classes=17, names= list(range(17)), names_file=None, id=None), length=-1, id=None)
})
dataset = Dataset.from_list(data, features=ner_features)
tokenized_dataset_ner = dataset.map(tokenize_and_align_labels_with_pos_ner_realis, fn_kwargs={'tokenizer' : tokenizer, 'ner_names' : event_nugget_list}, batched=True, load_from_cache_file=False)
tokenized_dataset_ner = tokenized_dataset_ner.with_format("torch")
tokenized_dataset_ner = tokenized_dataset_ner.remove_columns("tokens")
batch_size = 4 # Number of input texts
dataloader = DataLoader(tokenized_dataset_ner, batch_size=batch_size)
return dataloader, tokenized_dataset_ner
"""
Function: predict(dataloader)
Description: This function performs inference on a given DataLoader using a trained model and returns the predicted labels.
Inputs:
- dataloader: A DataLoader containing input data for prediction.
Output:
- predicted_label: A tensor containing the predicted labels for the input data.
"""
def predict(dataloader):
predicted_label = []
for batch in dataloader:
with torch.no_grad():
logits = model_realis(**batch)
batch_predicted_label = logits.argmax(-1)
predicted_label.append(batch_predicted_label)
return torch.cat(predicted_label, dim=-1)
"""
Function: show_annotations(text_input)
Description: This function displays annotated event nuggets in the provided input text using the Streamlit library.
Inputs:
- text_input: The input text containing event nuggets to be annotated and displayed.
Output:
- An interactive display of annotated event nuggets within the input text.
"""
def show_annotations(text_input):
st.title("Event Realis")
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
predicted_label = predict(dataloader)
for idx, labels in enumerate(predicted_label):
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
idxs = get_idxs_from_text(text, tokens)
labels = labels[token_mask]
annotated_text_list = []
last_label = ""
cumulative_tokens = ""
last_id = 0
for idx, label in zip(idxs, labels):
to_label = realis_list[label]
label_short = to_label.split("-")[1] if "-" in to_label else to_label
if last_label == label_short:
cumulative_tokens += text[last_id : idx["end_idx"]]
last_id = idx["end_idx"]
else:
if last_label != "":
if last_label == "O":
annotated_text_list.append(cumulative_tokens)
else:
annotated_text_list.append((cumulative_tokens, last_label))
last_label = label_short
cumulative_tokens = idx["word"]
last_id = idx["end_idx"]
if last_label == "O":
annotated_text_list.append(cumulative_tokens)
else:
annotated_text_list.append((cumulative_tokens, last_label))
annotated_text(annotated_text_list)
"""
Function: get_event_realis(text_input)
Description: This function extracts predicted event realis (event modality) from the provided input text.
Inputs:
- text_input: The input text containing event realis to be extracted.
Output:
- predicted_event_realis: A list of dictionaries, each representing an extracted event realis with start and end offsets,
realis type, and text content.
"""
def get_event_realis(text_input):
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
predicted_label = predict(dataloader)
predicted_event_realis = []
text_length = 0
for idx, labels in enumerate(predicted_label):
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
idxs = get_idxs_from_text(text_input[text_length : ], tokens)
labels = labels[token_mask]
start_idx = 0
end_idx = 0
last_label = ""
for idx, label in zip(idxs, labels):
to_label = realis_list[label]
label_split = to_label
if label_split == last_label:
end_idx = idx["end_idx"]
else:
if text_input[start_idx : end_idx] != "" and last_label != "O":
predicted_event_realis.append(
{
"startOffset" : text_length + start_idx,
"endOffset" : text_length + end_idx,
"realis" : last_label,
"text" : text_input[text_length + start_idx : text_length + end_idx]
}
)
start_idx = idx["start_idx"]
end_idx = idx["start_idx"] + len(idx["word"])
last_label = label_split
text_length += idx["end_idx"]
return predicted_event_realis