cpi-connect's picture
Update app.py
ec68b63
import streamlit as st
from transformers import AutoModelForTokenClassification
from annotated_text import annotated_text
import numpy as np
import os, joblib
from utils import get_idxs_from_text
model = AutoModelForTokenClassification.from_pretrained("CyberPeace-Institute/Cybersecurity-Knowledge-Graph", trust_remote_code=True)
role_classifiers = {}
folder_path = '/arg_role_models'
for filename in os.listdir(os.getcwd() + folder_path):
if filename.endswith('.joblib'):
file_path = os.getcwd() + os.path.join(folder_path, filename)
clf = joblib.load(file_path)
arg = filename.split(".")[0]
role_classifiers[arg] = clf
def annotate(name):
tokens = [item["token"] for item in output]
tokens = [token.replace(" ", "") for token in tokens]
text = model.tokenizer.decode([item["id"] for item in output])
idxs = get_idxs_from_text(text, tokens)
labels = [item[name] for item in output]
annotated_text_list = []
last_label = ""
cumulative_tokens = ""
last_id = 0
for idx, label in zip(idxs, labels):
to_label = label
label_short = to_label.split("-")[1] if "-" in to_label else to_label
if last_label == label_short:
cumulative_tokens += text[last_id : idx["end_idx"]]
last_id = idx["end_idx"]
else:
if last_label != "":
if last_label == "O":
annotated_text_list.append(cumulative_tokens)
else:
annotated_text_list.append((cumulative_tokens, last_label))
last_label = label_short
cumulative_tokens = idx["word"]
last_id = idx["end_idx"]
if last_label == "O":
annotated_text_list.append(cumulative_tokens)
else:
annotated_text_list.append((cumulative_tokens, last_label))
annotated_text(annotated_text_list)
def get_arg_roles(output):
args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(output) if item["argument"]!= "O"]
entities = []
current_entity = None
for position, label, token in args:
if label.startswith('B-'):
if current_entity is not None:
entities.append(current_entity)
current_entity = {'label': label[2:], 'text': token.replace(" ", ""), 'start': position, 'end': position}
elif label.startswith('I-'):
if current_entity is not None:
current_entity['text'] += ' ' + token.replace(" ", "")
current_entity['end'] = position
for entity in entities:
context = model.tokenizer.decode([item["id"] for item in output[max(0, entity["start"] - 15) : min(len(output), entity["end"] + 15)]])
entity["context"] = context
for entity in entities:
if len(model.arg_2_role[entity["label"]]) > 1:
sent_embed = model.embed_model.encode(entity["context"])
arg_embed = model.embed_model.encode(entity["text"])
embed = np.concatenate((sent_embed, arg_embed))
arg_clf = role_classifiers[entity["label"]]
role_id = arg_clf.predict(embed.reshape(1, -1))
role = model.arg_2_role[entity["label"]][role_id[0]]
entity["role"] = role
else:
entity["role"] = model.arg_2_role[entity["label"]][0]
for item in output:
item["role"] = "O"
for entity in entities:
for i in range(entity["start"], entity["end"] + 1):
output[i]["role"] = entity["role"]
return output
st.title("Create Knowledge Graphs from Cyber Incidents")
text_input = st.text_area("Enter your text here", height=100)
if text_input or st.button('Apply'):
output = model(text_input)
st.subheader("Event Nuggets")
annotate("nugget")
st.subheader("Event Arguments")
annotate("argument")
st.subheader("Realis of Event Nuggets")
annotate("realis")
output = get_arg_roles(output)
st.subheader("Role of the Event Arguments")
annotate("role")