File size: 6,600 Bytes
77eacb7 cb8b12f 77eacb7 29d5112 8241ba7 77eacb7 acd5094 77eacb7 8241ba7 77eacb7 cb8b12f 77eacb7 816e104 4fa0a53 cb8b12f 77eacb7 303b1b2 77eacb7 cb8b12f 77eacb7 cb8b12f 77eacb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from transformers import PreTrainedModel
import torch
import joblib, os
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from .nugget_model_utils import CustomRobertaWithPOS as NuggetModel
from .args_model_utils import CustomRobertaWithPOS as ArgumentModel
from .realis_model_utils import CustomRobertaWithPOS as RealisModel
from .configuration import CybersecurityKnowledgeGraphConfig
from .event_nugget_predict import create_dataloader as event_nugget_dataloader
from .event_realis_predict import create_dataloader as event_realis_dataloader
from .event_arg_predict import create_dataloader as event_argument_dataloader
class CybersecurityKnowledgeGraphModel(PreTrainedModel):
config_class = CybersecurityKnowledgeGraphConfig
def __init__(self, config):
super().__init__(config)
self.tokenizer = AutoTokenizer.from_pretrained("ehsanaghaei/SecureBERT")
self.event_nugget_model_path = config.event_nugget_model_path
self.event_argument_model_path = config.event_argument_model_path
self.event_realis_model_path = config.event_realis_model_path
self.event_nugget_dataloader = event_nugget_dataloader
self.event_argument_dataloader = event_argument_dataloader
self.event_realis_dataloader = event_realis_dataloader
self.event_nugget_model = NuggetModel(num_classes = 11)
self.event_argument_model = ArgumentModel(num_classes = 43)
self.event_realis_model = RealisModel(num_classes_realis = 4)
self.role_classifiers = {}
self.embed_model = SentenceTransformer('all-MiniLM-L6-v2')
self.event_nugget_list = config.event_nugget_list
self.event_args_list = config.event_args_list
self.realis_list = config.realis_list
self.arg_2_role = config.arg_2_role
def forward(self, text):
nugget_dataloader, _ = self.event_nugget_dataloader(text)
argument_dataloader, _ = self.event_argument_dataloader(self.event_nugget_model, text)
realis_dataloader, _ = self.event_realis_dataloader(self.event_nugget_model, text)
nugget_pred = self.forward_model(self.event_nugget_model, nugget_dataloader)
no_nuggets = torch.all(nugget_pred == 0, dim=1)
argument_preds = torch.empty(nugget_pred.size())
realis_preds = torch.empty(nugget_pred.size())
for idx, (batch, no_nugget) in enumerate(zip(nugget_pred, no_nuggets)):
if no_nugget:
argument_pred, realis_pred = torch.zeros(batch.size()), torch.zeros(batch.size())
else:
argument_pred = self.forward_model(self.event_argument_model, argument_dataloader)
realis_pred = self.forward_model(self.event_realis_model, realis_dataloader)
argument_preds[idx] = argument_pred
realis_preds[idx] = realis_pred
attention_mask = [batch["attention_mask"] for batch in nugget_dataloader]
attention_mask = torch.cat(attention_mask, dim=-1)
input_ids = [batch["input_ids"] for batch in nugget_dataloader]
input_ids = torch.cat(input_ids, dim=-1)
output = {"nugget" : nugget_pred, "argument" : argument_preds, "realis" : realis_preds, "input_ids" : input_ids, "attention_mask" : attention_mask}
no_of_batch = output['input_ids'].shape[0]
structured_output = []
for b in range(no_of_batch):
token_mask = [True if self.tokenizer.decode(token) not in self.tokenizer.all_special_tokens else False for token in output['input_ids'][b]]
filtered_ids = output['input_ids'][b][token_mask]
filtered_tokens = [self.tokenizer.decode(token) for token in filtered_ids]
filtered_nuggets = output['nugget'][b][token_mask]
filtered_args = output['argument'][b][token_mask]
filtered_realis = output['realis'][b][token_mask]
batch_output = [{"id" : id.item(), "token" : token, "nugget" : self.event_nugget_list[int(nugget.item())], "argument" : self.event_args_list[int(arg.item())], "realis" : self.realis_list[int(realis.item())]}
for id, token, nugget, arg, realis in zip(filtered_ids, filtered_tokens, filtered_nuggets, filtered_args, filtered_realis)]
structured_output.extend(batch_output)
args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(structured_output) if item["argument"]!= "O"]
entities = []
current_entity = None
for position, label, token in args:
if label.startswith('B-'):
if current_entity is not None:
entities.append(current_entity)
current_entity = {'label': label[2:], 'text': token.replace(" ", ""), 'start': position, 'end': position}
elif label.startswith('I-'):
if current_entity is not None:
current_entity['text'] += ' ' + token.replace(" ", "")
current_entity['end'] = position
for entity in entities:
context = self.tokenizer.decode([item["id"] for item in structured_output[max(0, entity["start"] - 15) : min(len(structured_output), entity["end"] + 15)]])
entity["context"] = context
for entity in entities:
if len(self.arg_2_role[entity["label"]]) > 1:
sent_embed = self.embed_model.encode(entity["context"])
arg_embed = self.embed_model.encode(entity["text"])
embed = np.concatenate((sent_embed, arg_embed))
arg_clf = self.role_classifiers[entity["label"]]
role_id = arg_clf.predict(embed.reshape(1, -1))
role = self.arg_2_role[entity["label"]][role_id[0]]
entity["role"] = role
else:
entity["role"] = self.arg_2_role[entity["label"]][0]
for item in structured_output:
item["role"] = "O"
for entity in entities:
for i in range(entity["start"], entity["end"] + 1):
structured_output[i]["role"] = entity["role"]
return structured_output
def forward_model(self, model, dataloader):
predicted_label = []
for batch in dataloader:
with torch.no_grad():
logits = model(**batch)
batch_predicted_label = logits.argmax(-1)
predicted_label.append(batch_predicted_label)
return torch.cat(predicted_label, dim=-1) |