|
import streamlit as st |
|
from transformers import AutoModelForTokenClassification |
|
from annotated_text import annotated_text |
|
import numpy as np |
|
import os, joblib |
|
|
|
from utils import get_idxs_from_text |
|
|
|
model = AutoModelForTokenClassification.from_pretrained("CyberPeace-Institute/Cybersecurity-Knowledge-Graph", trust_remote_code=True) |
|
|
|
role_classifiers = {} |
|
folder_path = '/arg_role_models' |
|
for filename in os.listdir(os.getcwd() + folder_path): |
|
if filename.endswith('.joblib'): |
|
file_path = os.getcwd() + os.path.join(folder_path, filename) |
|
clf = joblib.load(file_path) |
|
arg = filename.split(".")[0] |
|
role_classifiers[arg] = clf |
|
|
|
def annotate(name): |
|
tokens = [item["token"] for item in output] |
|
tokens = [token.replace(" ", "") for token in tokens] |
|
text = model.tokenizer.decode([item["id"] for item in output]) |
|
idxs = get_idxs_from_text(text, tokens) |
|
labels = [item[name] for item in output] |
|
|
|
annotated_text_list = [] |
|
last_label = "" |
|
cumulative_tokens = "" |
|
last_id = 0 |
|
for idx, label in zip(idxs, labels): |
|
to_label = label |
|
label_short = to_label.split("-")[1] if "-" in to_label else to_label |
|
if last_label == label_short: |
|
cumulative_tokens += text[last_id : idx["end_idx"]] |
|
last_id = idx["end_idx"] |
|
else: |
|
if last_label != "": |
|
if last_label == "O": |
|
annotated_text_list.append(cumulative_tokens) |
|
else: |
|
annotated_text_list.append((cumulative_tokens, last_label)) |
|
last_label = label_short |
|
cumulative_tokens = idx["word"] |
|
last_id = idx["end_idx"] |
|
if last_label == "O": |
|
annotated_text_list.append(cumulative_tokens) |
|
else: |
|
annotated_text_list.append((cumulative_tokens, last_label)) |
|
annotated_text(annotated_text_list) |
|
|
|
def get_arg_roles(output): |
|
args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(output) if item["argument"]!= "O"] |
|
|
|
entities = [] |
|
current_entity = None |
|
for position, label, token in args: |
|
if label.startswith('B-'): |
|
if current_entity is not None: |
|
entities.append(current_entity) |
|
current_entity = {'label': label[2:], 'text': token.replace(" ", ""), 'start': position, 'end': position} |
|
elif label.startswith('I-'): |
|
if current_entity is not None: |
|
current_entity['text'] += ' ' + token.replace(" ", "") |
|
current_entity['end'] = position |
|
for entity in entities: |
|
context = model.tokenizer.decode([item["id"] for item in output[max(0, entity["start"] - 15) : min(len(output), entity["end"] + 15)]]) |
|
entity["context"] = context |
|
|
|
for entity in entities: |
|
if len(model.arg_2_role[entity["label"]]) > 1: |
|
sent_embed = model.embed_model.encode(entity["context"]) |
|
arg_embed = model.embed_model.encode(entity["text"]) |
|
embed = np.concatenate((sent_embed, arg_embed)) |
|
arg_clf = role_classifiers[entity["label"]] |
|
role_id = arg_clf.predict(embed.reshape(1, -1)) |
|
role = model.arg_2_role[entity["label"]][role_id[0]] |
|
entity["role"] = role |
|
else: |
|
entity["role"] = model.arg_2_role[entity["label"]][0] |
|
|
|
for item in output: |
|
item["role"] = "O" |
|
for entity in entities: |
|
for i in range(entity["start"], entity["end"] + 1): |
|
output[i]["role"] = entity["role"] |
|
return output |
|
|
|
st.title("Create Knowledge Graphs from Cyber Incidents") |
|
|
|
text_input = st.text_area("Enter your text here", height=100) |
|
|
|
if text_input or st.button('Apply'): |
|
output = model(text_input) |
|
st.subheader("Event Nuggets") |
|
annotate("nugget") |
|
st.subheader("Event Arguments") |
|
annotate("argument") |
|
st.subheader("Realis of Event Nuggets") |
|
annotate("realis") |
|
output = get_arg_roles(output) |
|
st.subheader("Role of the Event Arguments") |
|
annotate("role") |
|
|