cpi-connect commited on
Commit
621df19
1 Parent(s): 8241ba7

Upload model

Browse files
event_arg_predict.py CHANGED
@@ -3,11 +3,11 @@ from annotated_text import annotated_text
3
  import torch
4
  from torch.utils.data import DataLoader
5
 
6
- from cybersecurity_knowledge_graph.args_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
7
- from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS
8
- from cybersecurity_knowledge_graph.utils import get_content, get_event_nugget, get_idxs_from_text, get_entity_from_idx, list_of_pos_tags, event_args_list
9
 
10
- from cybersecurity_knowledge_graph.event_nugget_predict import get_event_nuggets
11
  import spacy
12
  from transformers import AutoTokenizer
13
  from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
@@ -37,7 +37,7 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
37
  model_checkpoint = "ehsanaghaei/SecureBERT"
38
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
39
 
40
- from cybersecurity_knowledge_graph.args_model_utils import CustomRobertaWithPOS as ArgumentModel
41
  model_nugget = ArgumentModel(num_classes=43)
42
  model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/argument_model_state_dict.pth", map_location=device))
43
  model_nugget.eval()
 
3
  import torch
4
  from torch.utils.data import DataLoader
5
 
6
+ from .args_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
7
+ from .nugget_model_utils import CustomRobertaWithPOS
8
+ from .utils import get_content, get_event_nugget, get_idxs_from_text, get_entity_from_idx, list_of_pos_tags, event_args_list
9
 
10
+ from .event_nugget_predict import get_event_nuggets
11
  import spacy
12
  from transformers import AutoTokenizer
13
  from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
 
37
  model_checkpoint = "ehsanaghaei/SecureBERT"
38
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
39
 
40
+ from .args_model_utils import CustomRobertaWithPOS as ArgumentModel
41
  model_nugget = ArgumentModel(num_classes=43)
42
  model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/argument_model_state_dict.pth", map_location=device))
43
  model_nugget.eval()
event_nugget_predict.py CHANGED
@@ -3,9 +3,9 @@ from annotated_text import annotated_text
3
  import torch
4
  from torch import nn
5
  from torch.utils.data import DataLoader
6
- from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS as NuggetModel
7
- from cybersecurity_knowledge_graph.nugget_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
8
- from cybersecurity_knowledge_graph.utils import get_idxs_from_text, event_nugget_list
9
  import spacy
10
  from transformers import AutoTokenizer
11
  from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
 
3
  import torch
4
  from torch import nn
5
  from torch.utils.data import DataLoader
6
+ from .nugget_model_utils import CustomRobertaWithPOS as NuggetModel
7
+ from .nugget_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
8
+ from .utils import get_idxs_from_text, event_nugget_list
9
  import spacy
10
  from transformers import AutoTokenizer
11
  from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
event_realis_predict.py CHANGED
@@ -3,12 +3,12 @@ import spacy
3
  import torch
4
  from torch.utils.data import DataLoader
5
  from transformers import AutoTokenizer
6
- from cybersecurity_knowledge_graph.utils import get_idxs_from_text
7
  import streamlit as st
8
  from annotated_text import annotated_text
9
- from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS
10
- from cybersecurity_knowledge_graph.event_nugget_predict import get_event_nuggets
11
- from cybersecurity_knowledge_graph.realis_model_utils import get_entity_for_realis_from_idx, tokenize_and_align_labels_with_pos_ner_realis
12
  from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
13
 
14
  event_nugget_list = ['B-Phishing',
@@ -49,7 +49,7 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
49
  model_checkpoint = "ehsanaghaei/SecureBERT"
50
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
51
 
52
- from cybersecurity_knowledge_graph.realis_model_utils import CustomRobertaWithPOS as RealisModel
53
  model_realis = RealisModel(num_classes_realis=4)
54
  model_realis.load_state_dict(torch.load("cybersecurity_knowledge_graph/realis_model_state_dict.pth", map_location=device))
55
  model_realis.eval()
 
3
  import torch
4
  from torch.utils.data import DataLoader
5
  from transformers import AutoTokenizer
6
+ from .utils import get_idxs_from_text
7
  import streamlit as st
8
  from annotated_text import annotated_text
9
+ from .nugget_model_utils import CustomRobertaWithPOS
10
+ from .event_nugget_predict import get_event_nuggets
11
+ from .realis_model_utils import get_entity_for_realis_from_idx, tokenize_and_align_labels_with_pos_ner_realis
12
  from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
13
 
14
  event_nugget_list = ['B-Phishing',
 
49
  model_checkpoint = "ehsanaghaei/SecureBERT"
50
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
51
 
52
+ from .realis_model_utils import CustomRobertaWithPOS as RealisModel
53
  model_realis = RealisModel(num_classes_realis=4)
54
  model_realis.load_state_dict(torch.load("cybersecurity_knowledge_graph/realis_model_state_dict.pth", map_location=device))
55
  model_realis.eval()