import streamlit as st from transformers import AutoModelForTokenClassification from annotated_text import annotated_text import numpy as np import os, joblib from utils import get_idxs_from_text model = AutoModelForTokenClassification.from_pretrained("CyberPeace-Institute/Cybersecurity-Knowledge-Graph", trust_remote_code=True) role_classifiers = {} folder_path = '/arg_role_models' for filename in os.listdir(os.getcwd() + folder_path): if filename.endswith('.joblib'): file_path = os.getcwd() + os.path.join(folder_path, filename) clf = joblib.load(file_path) arg = filename.split(".")[0] role_classifiers[arg] = clf def annotate(name): tokens = [item["token"] for item in output] tokens = [token.replace(" ", "") for token in tokens] text = model.tokenizer.decode([item["id"] for item in output]) idxs = get_idxs_from_text(text, tokens) labels = [item[name] for item in output] annotated_text_list = [] last_label = "" cumulative_tokens = "" last_id = 0 for idx, label in zip(idxs, labels): to_label = label label_short = to_label.split("-")[1] if "-" in to_label else to_label if last_label == label_short: cumulative_tokens += text[last_id : idx["end_idx"]] last_id = idx["end_idx"] else: if last_label != "": if last_label == "O": annotated_text_list.append(cumulative_tokens) else: annotated_text_list.append((cumulative_tokens, last_label)) last_label = label_short cumulative_tokens = idx["word"] last_id = idx["end_idx"] if last_label == "O": annotated_text_list.append(cumulative_tokens) else: annotated_text_list.append((cumulative_tokens, last_label)) annotated_text(annotated_text_list) def get_arg_roles(output): args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(output) if item["argument"]!= "O"] entities = [] current_entity = None for position, label, token in args: if label.startswith('B-'): if current_entity is not None: entities.append(current_entity) current_entity = {'label': label[2:], 'text': token.replace(" ", ""), 'start': position, 'end': position} elif label.startswith('I-'): if current_entity is not None: current_entity['text'] += ' ' + token.replace(" ", "") current_entity['end'] = position for entity in entities: context = model.tokenizer.decode([item["id"] for item in output[max(0, entity["start"] - 15) : min(len(output), entity["end"] + 15)]]) entity["context"] = context for entity in entities: if len(model.arg_2_role[entity["label"]]) > 1: sent_embed = model.embed_model.encode(entity["context"]) arg_embed = model.embed_model.encode(entity["text"]) embed = np.concatenate((sent_embed, arg_embed)) arg_clf = role_classifiers[entity["label"]] role_id = arg_clf.predict(embed.reshape(1, -1)) role = model.arg_2_role[entity["label"]][role_id[0]] entity["role"] = role else: entity["role"] = model.arg_2_role[entity["label"]][0] for item in output: item["role"] = "O" for entity in entities: for i in range(entity["start"], entity["end"] + 1): output[i]["role"] = entity["role"] return output st.title("Create Knowledge Graphs from Cyber Incidents") text_input = st.text_area("Enter your text here", height=100) if text_input or st.button('Apply'): output = model(text_input) st.subheader("Event Nuggets") annotate("nugget") st.subheader("Event Arguments") annotate("argument") st.subheader("Realis of Event Nuggets") annotate("realis") output = get_arg_roles(output) st.subheader("Role of the Event Arguments") annotate("role")