|
import torch |
|
from torch import nn |
|
import en_core_web_sm |
|
from transformers import AutoModel, TrainingArguments, Trainer, RobertaTokenizer, RobertaModel |
|
from transformers import AutoTokenizer |
|
|
|
model_checkpoint = "ehsanaghaei/SecureBERT" |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True) |
|
roberta_model = RobertaModel.from_pretrained(model_checkpoint).to(device) |
|
|
|
event_nugget_list = ['B-Phishing', |
|
'I-Phishing', |
|
'O', |
|
'B-DiscoverVulnerability', |
|
'B-Ransom', |
|
'I-Ransom', |
|
'B-Databreach', |
|
'I-DiscoverVulnerability', |
|
'B-PatchVulnerability', |
|
'I-PatchVulnerability', |
|
'I-Databreach'] |
|
|
|
nlp = en_core_web_sm.load() |
|
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"] |
|
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"] |
|
dep_spacy_tag_list = list(nlp.get_pipe("parser").labels) |
|
|
|
class CustomRobertaWithPOS(nn.Module): |
|
def __init__(self, num_classes_realis): |
|
super(CustomRobertaWithPOS, self).__init__() |
|
self.num_classes_realis = num_classes_realis |
|
self.pos_embed = nn.Embedding(len(pos_spacy_tag_list), 16) |
|
self.ner_embed = nn.Embedding(len(ner_spacy_tag_list), 8) |
|
self.dep_embed = nn.Embedding(len(dep_spacy_tag_list), 8) |
|
self.depth_embed = nn.Embedding(17, 8) |
|
self.nugget_embed = nn.Embedding(len(event_nugget_list), 8) |
|
self.roberta = roberta_model |
|
self.dropout1 = nn.Dropout(0.2) |
|
self.fc1 = nn.Linear(self.roberta.config.hidden_size + 48, self.num_classes_realis) |
|
|
|
def forward(self, input_ids, attention_mask, pos_spacy, ner_spacy, dep_spacy, depth_spacy, ner_tags): |
|
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) |
|
last_hidden_output = outputs.last_hidden_state |
|
|
|
pos_mask = pos_spacy != -100 |
|
pos_embed_masked = self.pos_embed(pos_spacy[pos_mask]) |
|
pos_embed = torch.zeros((pos_spacy.shape[0], pos_spacy.shape[1], 16), dtype=torch.float).to(device) |
|
pos_embed[pos_mask] = pos_embed_masked |
|
|
|
ner_mask = ner_spacy != -100 |
|
ner_embed_masked = self.ner_embed(ner_spacy[ner_mask]) |
|
ner_embed = torch.zeros((ner_spacy.shape[0], ner_spacy.shape[1], 8), dtype=torch.float).to(device) |
|
ner_embed[ner_mask] = ner_embed_masked |
|
|
|
dep_mask = dep_spacy != -100 |
|
dep_embed_masked = self.dep_embed(dep_spacy[dep_mask]) |
|
dep_embed = torch.zeros((dep_spacy.shape[0], dep_spacy.shape[1], 8), dtype=torch.float).to(device) |
|
dep_embed[dep_mask] = dep_embed_masked |
|
|
|
depth_mask = depth_spacy != -100 |
|
depth_embed_masked = self.depth_embed(depth_spacy[depth_mask]) |
|
depth_embed = torch.zeros((depth_spacy.shape[0], depth_spacy.shape[1], 8), dtype=torch.float).to(device) |
|
depth_embed[dep_mask] = depth_embed_masked |
|
|
|
nugget_mask = ner_tags != -100 |
|
nugget_embed_masked = self.nugget_embed(ner_tags[nugget_mask]) |
|
nugget_embed = torch.zeros((ner_tags.shape[0], ner_tags.shape[1], 8), dtype=torch.float).to(device) |
|
nugget_embed[dep_mask] = nugget_embed_masked |
|
|
|
features_concat = torch.cat((last_hidden_output, pos_embed, ner_embed, dep_embed, depth_embed, nugget_embed), 2).to(device) |
|
features_concat = self.dropout1(features_concat) |
|
features_concat = self.dropout1(features_concat) |
|
|
|
logits = self.fc1(features_concat) |
|
|
|
return logits |
|
|
|
|
|
def get_entity_for_realis_from_idx(start_idx, end_idx, event_nuggets): |
|
event_nuggets_idxs = [(nugget["startOffset"], nugget["endOffset"]) for nugget in event_nuggets] |
|
for idx, (nugget_start, nugget_end) in enumerate(event_nuggets_idxs): |
|
if (start_idx == nugget_start and end_idx == nugget_end) or (start_idx == nugget_start and end_idx <= nugget_end) or (start_idx == nugget_start and end_idx > nugget_end) or (end_idx == nugget_end and start_idx < nugget_start) or (start_idx <= nugget_start and end_idx <= nugget_end and end_idx > nugget_start): |
|
return "B-" + event_nuggets[idx]["subtype"] |
|
elif (start_idx > nugget_start and end_idx <= nugget_end) or (start_idx > nugget_start and start_idx < nugget_end): |
|
return "I-" + event_nuggets[idx]["subtype"] |
|
return "O" |
|
|
|
def tokenize_and_align_labels_with_pos_ner_realis(examples, tokenizer, ner_names, label_all_tokens = True): |
|
tokenized_inputs = tokenizer(examples["tokens"], padding='max_length', truncation=True, is_split_into_words=True) |
|
|
|
labels = [] |
|
nuggets = [] |
|
ner_spacy = [] |
|
pos_spacy = [] |
|
dep_spacy = [] |
|
depth_spacy = [] |
|
|
|
for i, (nugget, pos, ner, dep, depth) in enumerate(zip(examples["ner_tags"], examples["pos_spacy"], examples["ner_spacy"], examples["dep_spacy"], examples["depth_spacy"])): |
|
word_ids = tokenized_inputs.word_ids(batch_index=i) |
|
previous_word_idx = None |
|
nugget_ids = [] |
|
ner_spacy_ids = [] |
|
pos_spacy_ids = [] |
|
dep_spacy_ids = [] |
|
depth_spacy_ids = [] |
|
|
|
for word_idx in word_ids: |
|
|
|
|
|
if word_idx is None: |
|
nugget_ids.append(-100) |
|
ner_spacy_ids.append(-100) |
|
pos_spacy_ids.append(-100) |
|
dep_spacy_ids.append(-100) |
|
depth_spacy_ids.append(-100) |
|
|
|
elif word_idx != previous_word_idx: |
|
nugget_ids.append(nugget[word_idx]) |
|
ner_spacy_ids.append(ner[word_idx]) |
|
pos_spacy_ids.append(pos[word_idx]) |
|
dep_spacy_ids.append(dep[word_idx]) |
|
depth_spacy_ids.append(depth[word_idx]) |
|
|
|
|
|
else: |
|
nugget_ids.append(nugget[word_idx] if label_all_tokens else -100) |
|
ner_spacy_ids.append(ner[word_idx] if label_all_tokens else -100) |
|
pos_spacy_ids.append(pos[word_idx] if label_all_tokens else -100) |
|
dep_spacy_ids.append(dep[word_idx] if label_all_tokens else -100) |
|
depth_spacy_ids.append(depth[word_idx] if label_all_tokens else -100) |
|
previous_word_idx = word_idx |
|
|
|
nuggets.append(nugget_ids) |
|
ner_spacy.append(ner_spacy_ids) |
|
pos_spacy.append(pos_spacy_ids) |
|
dep_spacy.append(dep_spacy_ids) |
|
depth_spacy.append(depth_spacy_ids) |
|
|
|
tokenized_inputs["ner_tags"] = nuggets |
|
tokenized_inputs["pos_spacy"] = pos_spacy |
|
tokenized_inputs["ner_spacy"] = ner_spacy |
|
tokenized_inputs["dep_spacy"] = dep_spacy |
|
tokenized_inputs["depth_spacy"] = depth_spacy |
|
return tokenized_inputs |
|
|