khoicrtp commited on
Commit
469f565
1 Parent(s): c62d375
Files changed (1) hide show
  1. main.py +261 -249
main.py CHANGED
@@ -1,250 +1,262 @@
1
- import json
2
- import transformers
3
- import textwrap
4
- from transformers import LlamaTokenizer, LlamaForCausalLM
5
- import os
6
- import sys
7
- from typing import List
8
-
9
- from peft import (
10
- LoraConfig,
11
- get_peft_model,
12
- get_peft_model_state_dict,
13
- prepare_model_for_int8_training,
14
- )
15
-
16
- import fire
17
- import torch
18
- from datasets import load_dataset
19
- import pandas as pd
20
-
21
- import matplotlib.pyplot as plt
22
- import matplotlib as mpl
23
- import seaborn as sns
24
- from pylab import rcParams
25
-
26
- sns.set(rc={'figure.figsize': (10, 7)})
27
- sns.set(rc={'figure.dpi': 100})
28
- sns.set(style='white', palette='muted', font_scale=1.2)
29
-
30
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
- print(DEVICE)
32
-
33
-
34
- def find_files(directory):
35
- file_list = []
36
- for root, dirs, files in os.walk(directory):
37
- for file in files:
38
- file_path = os.path.join(root, file)
39
- file_list.append(file_path)
40
- return file_list
41
-
42
-
43
- def load_all_mitre_dataset(filepath):
44
- res = []
45
- for file in find_files(filepath):
46
- # print(file)
47
- if file.endswith(".json"):
48
- # filename = os.path.join(filepath, file)
49
- data = json.load(open(file))
50
- for object_data in data["objects"]:
51
- if "name" in object_data:
52
- # print(object_data["name"])
53
- res.append(object_data)
54
- return res
55
-
56
-
57
- loaded_data = load_all_mitre_dataset("./cti-ATT-CK-v13.1")
58
- print("[+] ALL FILES: ", len(loaded_data))
59
- # print(loaded_data[0])
60
-
61
-
62
- """
63
- {
64
- "instruction": "What is",
65
- "input": "field definition",
66
- "output": "field )
67
- }
68
- """
69
-
70
-
71
- def formal_dataset(loaded_data):
72
- res = []
73
- print(loaded_data[0])
74
- for data in loaded_data:
75
- try:
76
- # print(object_data["name"])
77
- res.append({
78
- "instruction": "What is",
79
- "input": data["name"],
80
- "output": data["description"]
81
- })
82
- except:
83
- pass
84
- # print(len(res))
85
- return res
86
-
87
-
88
- dataset_data = formal_dataset(loaded_data)
89
- print("[+] DATASET LEN: ", len(dataset_data))
90
- print(dataset_data[0])
91
-
92
- with open("mitre-dataset.json", "w") as f:
93
- json.dump(dataset_data, f)
94
-
95
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
96
-
97
- quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
98
-
99
- BASE_MODEL = "decapoda-research/llama-7b-hf"
100
-
101
- device_map = {
102
- "transformer.word_embeddings": 0,
103
- "transformer.word_embeddings_layernorm": 0,
104
- "lm_head": "cpu",
105
- "transformer.h": 0,
106
- "transformer.ln_f": 0,
107
- }
108
-
109
- model = AutoModelForCausalLM.from_pretrained(
110
- BASE_MODEL,
111
- device_map="auto",
112
- quantization_config=quantization_config,
113
- )
114
-
115
- tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
116
-
117
- tokenizer.pad_token_id = (
118
- 0 # unk. we want this to be different from the eos token
119
- )
120
- tokenizer.padding_side = "left"
121
-
122
- data = load_dataset("json", data_files="mitre-dataset.json")
123
- print(data["train"])
124
-
125
-
126
- def generate_prompt(data_point):
127
- return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
128
- ### Instruction:
129
- {data_point["instruction"]}
130
- ### Input:
131
- {data_point["input"]}
132
- ### Response:
133
- {data_point["output"]}"""
134
-
135
-
136
- CUTOFF_LEN = 256
137
-
138
-
139
- def tokenize(prompt, add_eos_token=True):
140
- result = tokenizer(
141
- prompt,
142
- truncation=True,
143
- max_length=CUTOFF_LEN,
144
- padding=False,
145
- return_tensors=None,
146
- )
147
- if (
148
- result["input_ids"][-1] != tokenizer.eos_token_id
149
- and len(result["input_ids"]) < CUTOFF_LEN
150
- and add_eos_token
151
- ):
152
- result["input_ids"].append(tokenizer.eos_token_id)
153
- result["attention_mask"].append(1)
154
-
155
- result["labels"] = result["input_ids"].copy()
156
-
157
- return result
158
-
159
-
160
- def generate_and_tokenize_prompt(data_point):
161
- full_prompt = generate_prompt(data_point)
162
- tokenized_full_prompt = tokenize(full_prompt)
163
- return tokenized_full_prompt
164
-
165
-
166
- train_val = data["train"].train_test_split(
167
- test_size=200, shuffle=True, seed=42
168
- )
169
- train_data = (
170
- train_val["train"].map(generate_and_tokenize_prompt)
171
- )
172
- val_data = (
173
- train_val["test"].map(generate_and_tokenize_prompt)
174
- )
175
-
176
- LORA_R = 8
177
- LORA_ALPHA = 16
178
- LORA_DROPOUT = 0.05
179
- LORA_TARGET_MODULES = [
180
- "q_proj",
181
- "v_proj",
182
- ]
183
-
184
- BATCH_SIZE = 128
185
- MICRO_BATCH_SIZE = 4
186
- GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
187
- LEARNING_RATE = 3e-4
188
- TRAIN_STEPS = 300
189
- OUTPUT_DIR = "experiments"
190
-
191
- model = prepare_model_for_int8_training(model)
192
- config = LoraConfig(
193
- r=LORA_R,
194
- lora_alpha=LORA_ALPHA,
195
- target_modules=LORA_TARGET_MODULES,
196
- lora_dropout=LORA_DROPOUT,
197
- bias="none",
198
- task_type="CAUSAL_LM",
199
- )
200
- model = get_peft_model(model, config)
201
- model.print_trainable_parameters()
202
-
203
- training_arguments = transformers.TrainingArguments(
204
- per_device_train_batch_size=MICRO_BATCH_SIZE,
205
- gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
206
- warmup_steps=100,
207
- max_steps=TRAIN_STEPS,
208
- learning_rate=LEARNING_RATE,
209
- logging_steps=10,
210
- optim="adamw_torch",
211
- evaluation_strategy="steps",
212
- save_strategy="steps",
213
- eval_steps=50,
214
- save_steps=50,
215
- output_dir=OUTPUT_DIR,
216
- save_total_limit=3,
217
- load_best_model_at_end=True,
218
- report_to="tensorboard"
219
- )
220
-
221
- data_collator = transformers.DataCollatorForSeq2Seq(
222
- tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
223
- )
224
-
225
- trainer = transformers.Trainer(
226
- model=model,
227
- train_dataset=train_data,
228
- eval_dataset=val_data,
229
- args=training_arguments,
230
- data_collator=data_collator
231
- )
232
- model.config.use_cache = False
233
- old_state_dict = model.state_dict
234
- model.state_dict = (
235
- lambda self, *_, **__: get_peft_model_state_dict(
236
- self, old_state_dict()
237
- )
238
- ).__get__(model, type(model))
239
-
240
- print("Compiling model...")
241
- model = torch.compile(model)
242
- print("Done compiling model...")
243
-
244
- print("Training model...")
245
- trainer.train()
246
- print("Done training model...")
247
-
248
- print("Saving model...")
249
- model.save_pretrained(OUTPUT_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
250
  print("Done saving model...")
 
1
+ import json
2
+ import transformers
3
+ import textwrap
4
+ from transformers import LlamaTokenizer, LlamaForCausalLM
5
+ import os
6
+ import sys
7
+ from typing import List
8
+
9
+ from peft import (
10
+ LoraConfig,
11
+ get_peft_model,
12
+ get_peft_model_state_dict,
13
+ prepare_model_for_int8_training,
14
+ )
15
+
16
+ import fire
17
+ import torch
18
+ from datasets import load_dataset
19
+ import pandas as pd
20
+
21
+ import matplotlib.pyplot as plt
22
+ import matplotlib as mpl
23
+ import seaborn as sns
24
+ from pylab import rcParams
25
+
26
+ sns.set(rc={'figure.figsize': (10, 7)})
27
+ sns.set(rc={'figure.dpi': 100})
28
+ sns.set(style='white', palette='muted', font_scale=1.2)
29
+
30
+ #DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
+ DEVICE = "cpu"
32
+ print(DEVICE)
33
+
34
+
35
+ def find_files(directory):
36
+ file_list = []
37
+ for root, dirs, files in os.walk(directory):
38
+ for file in files:
39
+ file_path = os.path.join(root, file)
40
+ file_list.append(file_path)
41
+ return file_list
42
+
43
+
44
+ def load_all_mitre_dataset(filepath):
45
+ res = []
46
+ for file in find_files(filepath):
47
+ # print(file)
48
+ if file.endswith(".json"):
49
+ # filename = os.path.join(filepath, file)
50
+ data_local = json.load(open(file))
51
+ for object_data in data_local["objects"]:
52
+ if "name" in object_data:
53
+ # print(object_data["name"])
54
+ res.append(object_data)
55
+ return res
56
+
57
+
58
+ loaded_data = load_all_mitre_dataset("./cti-ATT-CK-v13.1")
59
+ print("[+] ALL FILES: ", len(loaded_data))
60
+ # print(loaded_data[0])
61
+
62
+
63
+ """
64
+ {
65
+ "instruction": "What is",
66
+ "input": "field definition",
67
+ "output": "field )
68
+ }
69
+ """
70
+
71
+
72
+ def formal_dataset(loaded_data):
73
+ res = []
74
+ print(loaded_data[0])
75
+ for data in loaded_data:
76
+ try:
77
+ # print(object_data["name"])
78
+ res.append({
79
+ "instruction": "What is",
80
+ "input": data["name"],
81
+ "output": data["description"]
82
+ })
83
+ except:
84
+ pass
85
+ print("[+] FORMAL DATASET:", len(res))
86
+ return res
87
+
88
+
89
+ dataset_data = formal_dataset(loaded_data)
90
+ print("[+] DATASET LEN: ", len(dataset_data))
91
+ print(dataset_data[0])
92
+
93
+ with open("mitre-dataset.json", "w") as f:
94
+ json.dump(dataset_data, f)
95
+
96
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
97
+
98
+ quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
99
+
100
+ BASE_MODEL = "decapoda-research/llama-7b-hf"
101
+
102
+ device_map = {
103
+ "transformer.word_embeddings": 0,
104
+ "transformer.word_embeddings_layernorm": 0,
105
+ "lm_head": "cpu",
106
+ "transformer.h": 0,
107
+ "transformer.ln_f": 0,
108
+ }
109
+
110
+ model = AutoModelForCausalLM.from_pretrained(
111
+ BASE_MODEL,
112
+ quantization_config=quantization_config,
113
+ return_dict=True,
114
+ load_in_8bit=True
115
+ #torch_dtype=torch.float16,
116
+ # device_map={'': 0},
117
+ )
118
+
119
+ tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
120
+
121
+ tokenizer.pad_token_id = (
122
+ 0 # unk. we want this to be different from the eos token
123
+ )
124
+ tokenizer.padding_side = "left"
125
+
126
+ data = load_dataset("json", data_files="mitre-dataset.json")
127
+ print("[+] DATA TRAIN:", data["train"])
128
+
129
+
130
+ def generate_prompt(data_point):
131
+ return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
132
+ ### Instruction:
133
+ {data_point["instruction"]}
134
+ ### Input:
135
+ {data_point["input"]}
136
+ ### Response:
137
+ {data_point["output"]}"""
138
+
139
+
140
+ CUTOFF_LEN = 256
141
+
142
+
143
+ def tokenize(prompt, add_eos_token=True):
144
+ result = tokenizer(
145
+ prompt,
146
+ truncation=True,
147
+ max_length=CUTOFF_LEN,
148
+ padding=False,
149
+ return_tensors=None,
150
+ )
151
+ if (
152
+ result["input_ids"][-1] != tokenizer.eos_token_id
153
+ and len(result["input_ids"]) < CUTOFF_LEN
154
+ and add_eos_token
155
+ ):
156
+ result["input_ids"].append(tokenizer.eos_token_id)
157
+ result["attention_mask"].append(1)
158
+
159
+ result["labels"] = result["input_ids"].copy()
160
+
161
+ return result
162
+
163
+
164
+ def generate_and_tokenize_prompt(data_point):
165
+ full_prompt = generate_prompt(data_point)
166
+ tokenized_full_prompt = tokenize(full_prompt)
167
+ return tokenized_full_prompt
168
+
169
+ print("-------------------------------")
170
+ print("DATA[TRAIN]", data["train"])
171
+ train_val = data["train"].train_test_split(
172
+ test_size=200, shuffle=True, seed=42
173
+ )
174
+ train_data = (
175
+ train_val["train"].map(generate_and_tokenize_prompt)
176
+ )
177
+ val_data = (
178
+ train_val["test"].map(generate_and_tokenize_prompt)
179
+ )
180
+ print("--------------------------")
181
+ print(train_val)
182
+ print("--------------------------")
183
+ print(train_data)
184
+ print("--------------------------")
185
+ print(val_data)
186
+ LORA_R = 8
187
+ LORA_ALPHA = 16
188
+ LORA_DROPOUT = 0.05
189
+ LORA_TARGET_MODULES = [
190
+ "q_proj",
191
+ "v_proj",
192
+ ]
193
+
194
+ BATCH_SIZE = 128
195
+ MICRO_BATCH_SIZE = 4
196
+ GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
197
+ LEARNING_RATE = 3e-4
198
+ TRAIN_STEPS = 300
199
+ OUTPUT_DIR = "experiments"
200
+
201
+ model = prepare_model_for_int8_training(model)
202
+ config = LoraConfig(
203
+ r=LORA_R,
204
+ lora_alpha=LORA_ALPHA,
205
+ target_modules=LORA_TARGET_MODULES,
206
+ lora_dropout=LORA_DROPOUT,
207
+ bias="none",
208
+ task_type="CAUSAL_LM",
209
+ )
210
+ model = get_peft_model(model, config)
211
+ model.print_trainable_parameters()
212
+
213
+ training_arguments = transformers.TrainingArguments(
214
+ per_device_train_batch_size=MICRO_BATCH_SIZE,
215
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
216
+ warmup_steps=100,
217
+ max_steps=TRAIN_STEPS,
218
+ learning_rate=LEARNING_RATE,
219
+ logging_steps=10,
220
+ optim="adamw_torch",
221
+ evaluation_strategy="steps",
222
+ save_strategy="steps",
223
+ eval_steps=50,
224
+ save_steps=50,
225
+ output_dir=OUTPUT_DIR,
226
+ save_total_limit=3,
227
+ no_cuda=True,
228
+ load_best_model_at_end=True,
229
+ report_to="tensorboard"
230
+ )
231
+
232
+ data_collator = transformers.DataCollatorForSeq2Seq(
233
+ tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
234
+ )
235
+
236
+
237
+ model.config.use_cache = False
238
+ old_state_dict = model.state_dict
239
+ model.state_dict = (
240
+ lambda self, *_, **__: get_peft_model_state_dict(
241
+ self, old_state_dict()
242
+ )
243
+ ).__get__(model, type(model))
244
+
245
+ print("Compiling model...")
246
+ model = torch.compile(model)
247
+ print("Done compiling model...")
248
+ print(model)
249
+ trainer = transformers.Trainer(
250
+ model=model,
251
+ train_dataset=train_data,
252
+ eval_dataset=val_data,
253
+ args=training_arguments,
254
+ data_collator=data_collator
255
+ )
256
+ print("Training model...")
257
+ trainer.train()
258
+ print("Done training model...")
259
+
260
+ print("Saving model...")
261
+ model.save_pretrained(OUTPUT_DIR)
262
  print("Done saving model...")