test_scratch / main.py
khoicrtp's picture
Upload 2 files
953210f
raw
history blame
6.08 kB
import json
import transformers
import textwrap
from transformers import LlamaTokenizer, LlamaForCausalLM
import os
import sys
from typing import List
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
)
import fire
import torch
from datasets import load_dataset
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from pylab import rcParams
sns.set(rc={'figure.figsize': (10, 7)})
sns.set(rc={'figure.dpi': 100})
sns.set(style='white', palette='muted', font_scale=1.2)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)
def find_files(directory):
file_list = []
for root, dirs, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
file_list.append(file_path)
return file_list
def load_all_mitre_dataset(filepath):
res = []
for file in find_files(filepath):
# print(file)
if file.endswith(".json"):
# filename = os.path.join(filepath, file)
data = json.load(open(file))
for object_data in data["objects"]:
if "name" in object_data:
# print(object_data["name"])
res.append(object_data)
return res
loaded_data = load_all_mitre_dataset("./cti-ATT-CK-v13.1")
print("[+] ALL FILES: ", len(loaded_data))
# print(loaded_data[0])
"""
{
"instruction": "What is",
"input": "field definition",
"output": "field )
}
"""
def formal_dataset(loaded_data):
res = []
print(loaded_data[0])
for data in loaded_data:
try:
# print(object_data["name"])
res.append({
"instruction": "What is",
"input": data["name"],
"output": data["description"]
})
except:
pass
# print(len(res))
return res
dataset_data = formal_dataset(loaded_data)
print("[+] DATASET LEN: ", len(dataset_data))
print(dataset_data[0])
with open("mitre-dataset.json", "w") as f:
json.dump(dataset_data, f)
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
BASE_MODEL = "decapoda-research/llama-7b-hf"
device_map = {
"transformer.word_embeddings": 0,
"transformer.word_embeddings_layernorm": 0,
"lm_head": "cpu",
"transformer.h": 0,
"transformer.ln_f": 0,
}
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
quantization_config=quantization_config,
)
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left"
data = load_dataset("json", data_files="mitre-dataset.json")
print(data["train"])
def generate_prompt(data_point):
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:
{data_point["output"]}"""
CUTOFF_LEN = 256
def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=CUTOFF_LEN,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < CUTOFF_LEN
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = generate_prompt(data_point)
tokenized_full_prompt = tokenize(full_prompt)
return tokenized_full_prompt
train_val = data["train"].train_test_split(
test_size=200, shuffle=True, seed=42
)
train_data = (
train_val["train"].map(generate_and_tokenize_prompt)
)
val_data = (
train_val["test"].map(generate_and_tokenize_prompt)
)
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = [
"q_proj",
"v_proj",
]
BATCH_SIZE = 128
MICRO_BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
LEARNING_RATE = 3e-4
TRAIN_STEPS = 300
OUTPUT_DIR = "experiments"
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=LORA_TARGET_MODULES,
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
training_arguments = transformers.TrainingArguments(
per_device_train_batch_size=MICRO_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
warmup_steps=100,
max_steps=TRAIN_STEPS,
learning_rate=LEARNING_RATE,
logging_steps=10,
optim="adamw_torch",
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=50,
save_steps=50,
output_dir=OUTPUT_DIR,
save_total_limit=3,
load_best_model_at_end=True,
report_to="tensorboard"
)
data_collator = transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
)
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=training_arguments,
data_collator=data_collator
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
).__get__(model, type(model))
print("Compiling model...")
model = torch.compile(model)
print("Done compiling model...")
print("Training model...")
trainer.train()
print("Done training model...")
print("Saving model...")
model.save_pretrained(OUTPUT_DIR)
print("Done saving model...")