|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from src.model import create_malconv_model |
|
|
from src.utils import ( |
|
|
configure_gpu_memory, |
|
|
plot_training_history, |
|
|
evaluate_model, |
|
|
get_file_paths_and_labels, |
|
|
data_generator, |
|
|
read_binary_file |
|
|
) |
|
|
|
|
|
def train_malconv(data_source, |
|
|
epochs=10, |
|
|
batch_size=256, |
|
|
max_length=2_000_000, |
|
|
validation_split=0.2, |
|
|
save_path="models/malconv_model.h5"): |
|
|
""" |
|
|
MalConv ๋ชจ๋ธ ํ๋ จ (๋ฐ์ดํฐ ์ ๋๋ ์ดํฐ ์ฌ์ฉ) |
|
|
|
|
|
Args: |
|
|
data_source: (malware_dir, benign_dir) ํํ |
|
|
epochs: ํ๋ จ ์ํฌํฌ ์ |
|
|
batch_size: ๋ฐฐ์น ํฌ๊ธฐ |
|
|
max_length: ์ต๋ ์
๋ ฅ ๊ธธ์ด (2MB) |
|
|
validation_split: ๊ฒ์ฆ ๋ฐ์ดํฐ ๋น์จ |
|
|
save_path: ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก |
|
|
""" |
|
|
|
|
|
print("=" * 60) |
|
|
print("MalConv ๋ชจ๋ธ ํ๋ จ ์์ (๋ฐ์ดํฐ ์ ๋๋ ์ดํฐ ๋ชจ๋)") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
configure_gpu_memory() |
|
|
|
|
|
|
|
|
if isinstance(data_source, tuple) and len(data_source) == 2: |
|
|
malware_dir, benign_dir = data_source |
|
|
filepaths, labels = get_file_paths_and_labels(malware_dir, benign_dir) |
|
|
else: |
|
|
raise ValueError("data_source๋ (malware_dir, benign_dir) ํํ์ด์ด์ผ ํฉ๋๋ค.") |
|
|
|
|
|
|
|
|
filepaths_train, filepaths_val, labels_train, labels_val = train_test_split( |
|
|
filepaths, labels, test_size=validation_split, random_state=42, stratify=labels |
|
|
) |
|
|
|
|
|
print(f"์ด ๋ฐ์ดํฐ: {len(filepaths)}") |
|
|
print(f"ํ๋ จ ๋ฐ์ดํฐ: {len(filepaths_train)}, ๊ฒ์ฆ ๋ฐ์ดํฐ: {len(filepaths_val)}") |
|
|
|
|
|
|
|
|
train_gen = data_generator(filepaths_train, labels_train, batch_size, max_length) |
|
|
val_gen = data_generator(filepaths_val, labels_val, batch_size, max_length, shuffle=False) |
|
|
|
|
|
|
|
|
print("MalConv ๋ชจ๋ธ ์์ฑ ์ค...") |
|
|
model = create_malconv_model(max_length) |
|
|
|
|
|
|
|
|
dummy_input = np.zeros((1, max_length), dtype=np.uint8) |
|
|
_ = model(dummy_input) |
|
|
|
|
|
print("\n=== ๋ชจ๋ธ ์ํคํ
์ฒ ===") |
|
|
model.summary() |
|
|
print(f"์ด ํ๋ผ๋ฏธํฐ ์: {model.count_params():,}") |
|
|
|
|
|
|
|
|
callbacks = [ |
|
|
tf.keras.callbacks.EarlyStopping( |
|
|
monitor='val_loss', |
|
|
patience=5, |
|
|
restore_best_weights=True, |
|
|
verbose=1 |
|
|
), |
|
|
tf.keras.callbacks.ModelCheckpoint( |
|
|
save_path, |
|
|
monitor='val_auc', |
|
|
save_best_only=True, |
|
|
verbose=1, |
|
|
mode='max' |
|
|
) |
|
|
] |
|
|
|
|
|
|
|
|
print(f"\n=== ํ๋ จ ์์ ===") |
|
|
print(f"๋ฐฐ์น ํฌ๊ธฐ: {batch_size}") |
|
|
print(f"์ํฌํฌ: {epochs}") |
|
|
|
|
|
history = model.fit( |
|
|
train_gen, |
|
|
steps_per_epoch=len(filepaths_train) // batch_size, |
|
|
epochs=epochs, |
|
|
validation_data=val_gen, |
|
|
validation_steps=len(filepaths_val) // batch_size, |
|
|
callbacks=callbacks, |
|
|
verbose=1 |
|
|
) |
|
|
|
|
|
|
|
|
print("\n=== ์ต์ข
ํ๊ฐ ===") |
|
|
num_eval_samples = min(len(filepaths_val), 1024) |
|
|
X_eval = np.array([read_binary_file(fp, max_length) for fp in filepaths_val[:num_eval_samples]]) |
|
|
y_eval = np.array(labels_val[:num_eval_samples]) |
|
|
|
|
|
if X_eval.size > 0: |
|
|
results = evaluate_model(model, X_eval, y_eval, batch_size=batch_size//2) |
|
|
else: |
|
|
print("ํ๊ฐํ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค.") |
|
|
results = {} |
|
|
|
|
|
|
|
|
plot_training_history(history) |
|
|
|
|
|
print(f"\n๋ชจ๋ธ์ด ์ ์ฅ๋์์ต๋๋ค: {save_path}") |
|
|
|
|
|
return model, history, results |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='MalConv ๋ชจ๋ธ ํ๋ จ') |
|
|
|
|
|
|
|
|
parser.add_argument('--malware_dir', required=True, help='์
์ฑ์ฝ๋ ๋๋ ํ ๋ฆฌ') |
|
|
parser.add_argument('--benign_dir', required=True, help='์ ์ํ์ผ ๋๋ ํ ๋ฆฌ') |
|
|
|
|
|
|
|
|
parser.add_argument('--epochs', type=int, default=20, help='์ํฌํฌ ์') |
|
|
parser.add_argument('--batch_size', type=int, default=64, help='๋ฐฐ์น ํฌ๊ธฐ') |
|
|
parser.add_argument('--max_length', type=int, default=2_000_000, help='์ต๋ ์
๋ ฅ ๊ธธ์ด') |
|
|
parser.add_argument('--save_path', default='models/malconv_model.h5', help='๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
data_source = (args.malware_dir, args.benign_dir) |
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(args.save_path), exist_ok=True) |
|
|
|
|
|
|
|
|
train_malconv( |
|
|
data_source=data_source, |
|
|
epochs=args.epochs, |
|
|
batch_size=args.batch_size, |
|
|
max_length=args.max_length, |
|
|
save_path=args.save_path |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |