cpi-connect commited on
Commit
9df8979
1 Parent(s): 008fd4d

Upload model

Browse files
Files changed (1) hide show
  1. model.py +34 -34
model.py CHANGED
@@ -88,43 +88,43 @@ class CybersecurityKnowledgeGraphModel(PreTrainedModel):
88
  structured_output.extend(batch_output)
89
 
90
 
91
- args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(structured_output) if item["argument"]!= "O"]
92
 
93
- entities = []
94
- current_entity = None
95
- for position, label, token in args:
96
- if label.startswith('B-'):
97
- if current_entity is not None:
98
- entities.append(current_entity)
99
- current_entity = {'label': label[2:], 'text': token.replace(" ", ""), 'start': position, 'end': position}
100
- elif label.startswith('I-'):
101
- if current_entity is not None:
102
- current_entity['text'] += ' ' + token.replace(" ", "")
103
- current_entity['end'] = position
104
-
105
- for entity in entities:
106
- context = self.tokenizer.decode([item["id"] for item in structured_output[max(0, entity["start"] - 15) : min(len(structured_output), entity["end"] + 15)]])
107
- entity["context"] = context
108
 
109
- for entity in entities:
110
- if len(self.arg_2_role[entity["label"]]) > 1:
111
- sent_embed = self.embed_model.encode(entity["context"])
112
- arg_embed = self.embed_model.encode(entity["text"])
113
- embed = np.concatenate((sent_embed, arg_embed))
114
-
115
- arg_clf = self.role_classifiers[entity["label"]]
116
- role_id = arg_clf.predict(embed.reshape(1, -1))
117
- role = self.arg_2_role[entity["label"]][role_id[0]]
118
-
119
- entity["role"] = role
120
- else:
121
- entity["role"] = self.arg_2_role[entity["label"]][0]
122
 
123
- for item in structured_output:
124
- item["role"] = "O"
125
- for entity in entities:
126
- for i in range(entity["start"], entity["end"] + 1):
127
- structured_output[i]["role"] = entity["role"]
128
  return structured_output
129
 
130
  def forward_model(self, model, dataloader):
 
88
  structured_output.extend(batch_output)
89
 
90
 
91
+ # args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(structured_output) if item["argument"]!= "O"]
92
 
93
+ # entities = []
94
+ # current_entity = None
95
+ # for position, label, token in args:
96
+ # if label.startswith('B-'):
97
+ # if current_entity is not None:
98
+ # entities.append(current_entity)
99
+ # current_entity = {'label': label[2:], 'text': token.replace(" ", ""), 'start': position, 'end': position}
100
+ # elif label.startswith('I-'):
101
+ # if current_entity is not None:
102
+ # current_entity['text'] += ' ' + token.replace(" ", "")
103
+ # current_entity['end'] = position
104
+
105
+ # for entity in entities:
106
+ # context = self.tokenizer.decode([item["id"] for item in structured_output[max(0, entity["start"] - 15) : min(len(structured_output), entity["end"] + 15)]])
107
+ # entity["context"] = context
108
 
109
+ # for entity in entities:
110
+ # if len(self.arg_2_role[entity["label"]]) > 1:
111
+ # sent_embed = self.embed_model.encode(entity["context"])
112
+ # arg_embed = self.embed_model.encode(entity["text"])
113
+ # embed = np.concatenate((sent_embed, arg_embed))
114
+
115
+ # arg_clf = self.role_classifiers[entity["label"]]
116
+ # role_id = arg_clf.predict(embed.reshape(1, -1))
117
+ # role = self.arg_2_role[entity["label"]][role_id[0]]
118
+
119
+ # entity["role"] = role
120
+ # else:
121
+ # entity["role"] = self.arg_2_role[entity["label"]][0]
122
 
123
+ # for item in structured_output:
124
+ # item["role"] = "O"
125
+ # for entity in entities:
126
+ # for i in range(entity["start"], entity["end"] + 1):
127
+ # structured_output[i]["role"] = entity["role"]
128
  return structured_output
129
 
130
  def forward_model(self, model, dataloader):