File size: 4,041 Bytes
83fd625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec68b63
83fd625
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
from transformers import AutoModelForTokenClassification
from annotated_text import annotated_text
import numpy as np
import os, joblib

from utils import get_idxs_from_text

model = AutoModelForTokenClassification.from_pretrained("CyberPeace-Institute/Cybersecurity-Knowledge-Graph", trust_remote_code=True)

role_classifiers = {}
folder_path = '/arg_role_models'
for filename in os.listdir(os.getcwd() + folder_path):
    if filename.endswith('.joblib'):
        file_path = os.getcwd() + os.path.join(folder_path, filename)
        clf = joblib.load(file_path)
        arg = filename.split(".")[0]
        role_classifiers[arg] = clf

def annotate(name):
    tokens = [item["token"] for item in output]
    tokens = [token.replace(" ", "") for token in tokens]
    text = model.tokenizer.decode([item["id"] for item in output])
    idxs = get_idxs_from_text(text, tokens)
    labels = [item[name] for item in output]

    annotated_text_list = []
    last_label = ""
    cumulative_tokens = "" 
    last_id = 0
    for idx, label in zip(idxs, labels):
        to_label = label
        label_short = to_label.split("-")[1] if "-" in to_label else to_label
        if last_label == label_short:
            cumulative_tokens += text[last_id : idx["end_idx"]]
            last_id = idx["end_idx"]
        else:
            if last_label != "":
                if last_label == "O":
                    annotated_text_list.append(cumulative_tokens)
                else:
                    annotated_text_list.append((cumulative_tokens, last_label))
            last_label = label_short
            cumulative_tokens = idx["word"]
            last_id = idx["end_idx"]
    if last_label == "O":
        annotated_text_list.append(cumulative_tokens)
    else:  
        annotated_text_list.append((cumulative_tokens, last_label))
    annotated_text(annotated_text_list)

def get_arg_roles(output):
    args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(output) if item["argument"]!= "O"]
        
    entities = []
    current_entity = None
    for position, label, token in args:
        if label.startswith('B-'):
            if current_entity is not None:
                entities.append(current_entity)
            current_entity = {'label': label[2:], 'text': token.replace(" ", ""), 'start': position, 'end': position}
        elif label.startswith('I-'):
            if current_entity is not None:
                current_entity['text'] += ' ' + token.replace(" ", "")
                current_entity['end'] = position
    for entity in entities:
        context = model.tokenizer.decode([item["id"] for item in output[max(0, entity["start"] - 15) : min(len(output), entity["end"] + 15)]])
        entity["context"] = context
    
    for entity in entities:
        if len(model.arg_2_role[entity["label"]]) > 1:
            sent_embed = model.embed_model.encode(entity["context"])
            arg_embed = model.embed_model.encode(entity["text"])
            embed = np.concatenate((sent_embed, arg_embed))
            arg_clf = role_classifiers[entity["label"]]
            role_id = arg_clf.predict(embed.reshape(1, -1))
            role = model.arg_2_role[entity["label"]][role_id[0]]
            entity["role"] = role
        else:
            entity["role"] = model.arg_2_role[entity["label"]][0]
    
    for item in output:
        item["role"] = "O"
    for entity in entities:
        for i in range(entity["start"], entity["end"] + 1):
            output[i]["role"] = entity["role"]
    return output

st.title("Create Knowledge Graphs from Cyber Incidents")

text_input = st.text_area("Enter your text here", height=100)

if text_input or st.button('Apply'):
    output = model(text_input)
    st.subheader("Event Nuggets")
    annotate("nugget")
    st.subheader("Event Arguments")
    annotate("argument")
    st.subheader("Realis of Event Nuggets")
    annotate("realis")
    output = get_arg_roles(output)
    st.subheader("Role of the Event Arguments")
    annotate("role")