File size: 6,600 Bytes
77eacb7
 
cb8b12f
 
 
 
 
77eacb7
29d5112
8241ba7
 
77eacb7
acd5094
77eacb7
8241ba7
 
 
77eacb7
 
 
 
 
 
cb8b12f
 
77eacb7
 
 
 
 
 
 
 
 
 
 
 
816e104
4fa0a53
cb8b12f
 
 
 
 
 
 
77eacb7
 
 
303b1b2
 
77eacb7
 
 
 
 
 
 
 
 
 
 
 
 
 
cb8b12f
 
 
 
 
 
 
 
 
77eacb7
cb8b12f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77eacb7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from transformers import PreTrainedModel
import torch
import joblib, os
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer


from .nugget_model_utils import CustomRobertaWithPOS as NuggetModel
from .args_model_utils import CustomRobertaWithPOS as ArgumentModel
from .realis_model_utils import CustomRobertaWithPOS as RealisModel

from .configuration import CybersecurityKnowledgeGraphConfig

from .event_nugget_predict import create_dataloader as event_nugget_dataloader
from .event_realis_predict import create_dataloader as event_realis_dataloader
from .event_arg_predict import create_dataloader as event_argument_dataloader

class CybersecurityKnowledgeGraphModel(PreTrainedModel):
    config_class = CybersecurityKnowledgeGraphConfig

    def __init__(self, config):
        super().__init__(config)
        self.tokenizer = AutoTokenizer.from_pretrained("ehsanaghaei/SecureBERT")
        
        self.event_nugget_model_path = config.event_nugget_model_path
        self.event_argument_model_path = config.event_argument_model_path
        self.event_realis_model_path = config.event_realis_model_path

        self.event_nugget_dataloader = event_nugget_dataloader
        self.event_argument_dataloader = event_argument_dataloader
        self.event_realis_dataloader = event_realis_dataloader

        self.event_nugget_model = NuggetModel(num_classes = 11)
        self.event_argument_model = ArgumentModel(num_classes = 43)
        self.event_realis_model = RealisModel(num_classes_realis = 4)

        self.role_classifiers = {}
        self.embed_model = SentenceTransformer('all-MiniLM-L6-v2')


        self.event_nugget_list = config.event_nugget_list
        self.event_args_list = config.event_args_list
        self.realis_list = config.realis_list
        self.arg_2_role = config.arg_2_role


    def forward(self, text):
        nugget_dataloader, _ = self.event_nugget_dataloader(text)
        argument_dataloader, _ = self.event_argument_dataloader(self.event_nugget_model, text)
        realis_dataloader, _ = self.event_realis_dataloader(self.event_nugget_model, text)

        nugget_pred = self.forward_model(self.event_nugget_model, nugget_dataloader)
        no_nuggets = torch.all(nugget_pred == 0, dim=1)

        argument_preds = torch.empty(nugget_pred.size())
        realis_preds = torch.empty(nugget_pred.size())
        for idx, (batch, no_nugget) in enumerate(zip(nugget_pred, no_nuggets)):
            if no_nugget:
                argument_pred, realis_pred = torch.zeros(batch.size()), torch.zeros(batch.size())
            else:
                argument_pred = self.forward_model(self.event_argument_model, argument_dataloader)
                realis_pred = self.forward_model(self.event_realis_model, realis_dataloader)
            argument_preds[idx] = argument_pred
            realis_preds[idx] = realis_pred
        
        attention_mask = [batch["attention_mask"] for batch in nugget_dataloader]
        attention_mask = torch.cat(attention_mask, dim=-1)

        input_ids = [batch["input_ids"] for batch in nugget_dataloader]
        input_ids = torch.cat(input_ids, dim=-1)
        
        output = {"nugget" : nugget_pred, "argument" : argument_preds, "realis" : realis_preds, "input_ids" : input_ids, "attention_mask" : attention_mask}
        no_of_batch = output['input_ids'].shape[0]

        structured_output = []
        for b in range(no_of_batch):
            token_mask = [True if self.tokenizer.decode(token) not in self.tokenizer.all_special_tokens else False for token in output['input_ids'][b]]
            filtered_ids = output['input_ids'][b][token_mask]
            filtered_tokens = [self.tokenizer.decode(token) for token in filtered_ids]

            filtered_nuggets = output['nugget'][b][token_mask]
            filtered_args = output['argument'][b][token_mask]
            filtered_realis = output['realis'][b][token_mask]

            batch_output = [{"id" : id.item(), "token" : token, "nugget" : self.event_nugget_list[int(nugget.item())], "argument" : self.event_args_list[int(arg.item())], "realis" : self.realis_list[int(realis.item())]} 
                            for id, token, nugget, arg, realis in zip(filtered_ids, filtered_tokens, filtered_nuggets, filtered_args, filtered_realis)]
            structured_output.extend(batch_output)
        
        
        args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(structured_output) if item["argument"]!= "O"]
        
        entities = []
        current_entity = None
        for position, label, token in args:
            if label.startswith('B-'):
                if current_entity is not None:
                    entities.append(current_entity)
                current_entity = {'label': label[2:], 'text': token.replace(" ", ""), 'start': position, 'end': position}
            elif label.startswith('I-'):
                if current_entity is not None:
                    current_entity['text'] += ' ' + token.replace(" ", "")
                    current_entity['end'] = position

        for entity in entities:
            context = self.tokenizer.decode([item["id"] for item in structured_output[max(0, entity["start"] - 15) : min(len(structured_output), entity["end"] + 15)]])
            entity["context"] = context
        
        for entity in entities:
            if len(self.arg_2_role[entity["label"]]) > 1:
                sent_embed = self.embed_model.encode(entity["context"])
                arg_embed = self.embed_model.encode(entity["text"])
                embed = np.concatenate((sent_embed, arg_embed))

                arg_clf = self.role_classifiers[entity["label"]]
                role_id = arg_clf.predict(embed.reshape(1, -1))
                role = self.arg_2_role[entity["label"]][role_id[0]]

                entity["role"] = role
            else:
                entity["role"] = self.arg_2_role[entity["label"]][0]
        
        for item in structured_output:
            item["role"] = "O"
        for entity in entities:
            for i in range(entity["start"], entity["end"] + 1):
                structured_output[i]["role"] = entity["role"]
        return structured_output

    def forward_model(self, model, dataloader):
        predicted_label = []
        for batch in dataloader:
            with torch.no_grad():
                logits = model(**batch)
            batch_predicted_label = logits.argmax(-1)
            predicted_label.append(batch_predicted_label)
        return torch.cat(predicted_label, dim=-1)