File size: 4,563 Bytes
4e38daf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import os, json
from cybersecurity_knowledge_graph.utils import get_content, get_event_args, get_event_nugget, get_idxs_from_text, get_args_entity_from_idx, find_dict_by_overlap
from tqdm import tqdm
import spacy
import jsonlines
from sklearn.model_selection import train_test_split
import math
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import numpy as np
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
pipe = pipeline("token-classification", model="CyberPeace-Institute/SecureBERT-NER")
nlp = spacy.load('en_core_web_sm')
"""
Class: EventArgumentRoleDataset
Description: This class represents a dataset for training and evaluating event argument role classifiers.
Attributes:
- path: The path to the folder containing JSON files with event data.
- tokenizer: A tokenizer for encoding text data.
- arg: The specific argument type (subtype) for which the dataset is created.
- data: A list to store data samples, each consisting of an embedding and a label.
- train_data, val_data, test_data: Lists to store the split training, validation, and test data samples.
- datapoint_id: An identifier for tracking data samples.
Methods:
- __len__(): Returns the total number of data samples in the dataset.
- __getitem__(index): Retrieves a data sample at a specified index.
- to_jsonlines(train_path, val_path, test_path): Writes the dataset to JSON files for train, validation, and test sets.
- train_val_test_split(): Splits the data into training and test sets.
- load_data(): Loads and preprocesses event data from JSON files, creating embeddings for argument-role classification.
"""
class EventArgumentRoleDataset():
def __init__(self, path, tokenizer, arg):
self.path = path
self.tokenizer = tokenizer
self.arg = arg
self.data = []
self.train_data, self.val_data, self.test_data = None, None, None
self.datapoint_id = 0
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
return sample
def to_jsonlines(self, train_path, val_path, test_path):
if self.train_data is None or self.test_data is None:
raise ValueError("Do the train-val-test split")
with jsonlines.open(train_path, "w") as f:
f.write_all(self.train_data)
# with jsonlines.open(val_path, "w") as f:
# f.write_all(self.val_data)
with jsonlines.open(test_path, "w") as f:
f.write_all(self.test_data)
def train_val_test_split(self):
self.train_data, self.test_data = train_test_split(self.data, test_size=0.1, random_state=42, shuffle=True)
# self.val_data, self.test_data = train_test_split(test_val, test_size=0.5, random_state=42, shuffle=True)
def load_data(self):
folder_path = self.path
json_files = [file for file in os.listdir(folder_path) if file.endswith('.json')]
# Load the nuggets
for idx, file_path in enumerate(tqdm(json_files)):
try:
with open(self.path + file_path, "r") as f:
file_json = json.load(f)
except:
print("Error in ", file_path)
content = get_content(file_json)
content = content.replace("\xa0", " ")
event_args = get_event_args(file_json)
doc = nlp(content)
sentence_indexes = []
for sent in doc.sents:
start_index = sent[0].idx
end_index = sent[-1].idx + len(sent[-1].text)
sentence_indexes.append((start_index, end_index))
for idx, (start, end) in enumerate(sentence_indexes):
sentence = content[start:end]
is_arg_sentence = [event_arg["startOffset"] >= start and event_arg["endOffset"] <= end for event_arg in event_args]
args = [event_args[idx] for idx, boolean in enumerate(is_arg_sentence) if boolean]
if args != []:
sentence_doc = nlp(sentence)
sentence_embed = embed_model.encode(sentence)
for arg in args:
if arg["type"] == self.arg:
arg_embed = embed_model.encode(arg["text"])
embedding = np.concatenate((sentence_embed, arg_embed))
self.data.append({"embedding" : embedding, "label" : arg["role"]["type"]}) |