malconv / src /train.py
cycloevan's picture
Upload 17 files
b92918a verified
import os
import sys
import argparse
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.model import create_malconv_model
from src.utils import (
configure_gpu_memory,
plot_training_history,
evaluate_model,
get_file_paths_and_labels,
data_generator,
read_binary_file
)
def train_malconv(data_source,
epochs=10,
batch_size=256,
max_length=2_000_000,
validation_split=0.2,
save_path="models/malconv_model.h5"):
"""
MalConv ๋ชจ๋ธ ํ›ˆ๋ จ (๋ฐ์ดํ„ฐ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ ์‚ฌ์šฉ)
Args:
data_source: (malware_dir, benign_dir) ํŠœํ”Œ
epochs: ํ›ˆ๋ จ ์—ํฌํฌ ์ˆ˜
batch_size: ๋ฐฐ์น˜ ํฌ๊ธฐ
max_length: ์ตœ๋Œ€ ์ž…๋ ฅ ๊ธธ์ด (2MB)
validation_split: ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ ๋น„์œจ
save_path: ๋ชจ๋ธ ์ €์žฅ ๊ฒฝ๋กœ
"""
print("=" * 60)
print("MalConv ๋ชจ๋ธ ํ›ˆ๋ จ ์‹œ์ž‘ (๋ฐ์ดํ„ฐ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ ๋ชจ๋“œ)")
print("=" * 60)
# GPU ์„ค์ •
configure_gpu_memory()
# ๋ฐ์ดํ„ฐ ๊ฒฝ๋กœ ๋ฐ ๋ ˆ์ด๋ธ” ๋กœ๋”ฉ
if isinstance(data_source, tuple) and len(data_source) == 2:
malware_dir, benign_dir = data_source
filepaths, labels = get_file_paths_and_labels(malware_dir, benign_dir)
else:
raise ValueError("data_source๋Š” (malware_dir, benign_dir) ํŠœํ”Œ์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
# ํ›ˆ๋ จ/๊ฒ€์ฆ ๋ถ„ํ•  (ํŒŒ์ผ ๊ฒฝ๋กœ ๊ธฐ์ค€)
filepaths_train, filepaths_val, labels_train, labels_val = train_test_split(
filepaths, labels, test_size=validation_split, random_state=42, stratify=labels
)
print(f"์ด ๋ฐ์ดํ„ฐ: {len(filepaths)}")
print(f"ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ: {len(filepaths_train)}, ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ: {len(filepaths_val)}")
# ๋ฐ์ดํ„ฐ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ ์ƒ์„ฑ
train_gen = data_generator(filepaths_train, labels_train, batch_size, max_length)
val_gen = data_generator(filepaths_val, labels_val, batch_size, max_length, shuffle=False) # ๊ฒ€์ฆ ์‹œ์—๋Š” ์…”ํ”Œ ์•ˆํ•จ
# ๋ชจ๋ธ ์ƒ์„ฑ
print("MalConv ๋ชจ๋ธ ์ƒ์„ฑ ์ค‘...")
model = create_malconv_model(max_length)
# ๋”๋ฏธ ์ž…๋ ฅ์œผ๋กœ ๋ชจ๋ธ ๋นŒ๋“œ
dummy_input = np.zeros((1, max_length), dtype=np.uint8)
_ = model(dummy_input)
print("\n=== ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ===")
model.summary()
print(f"์ด ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: {model.count_params():,}")
# ์ฝœ๋ฐฑ ์„ค์ •
callbacks = [
tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=5, # ์ฐธ์„์„ฑ ์ฆ๊ฐ€
restore_best_weights=True,
verbose=1
),
tf.keras.callbacks.ModelCheckpoint(
save_path,
monitor='val_auc',
save_best_only=True,
verbose=1,
mode='max' # AUC๋Š” ๋†’์„์ˆ˜๋ก ์ข‹์Œ
)
]
# ํ›ˆ๋ จ
print(f"\n=== ํ›ˆ๋ จ ์‹œ์ž‘ ===")
print(f"๋ฐฐ์น˜ ํฌ๊ธฐ: {batch_size}")
print(f"์—ํฌํฌ: {epochs}")
history = model.fit(
train_gen,
steps_per_epoch=len(filepaths_train) // batch_size,
epochs=epochs,
validation_data=val_gen,
validation_steps=len(filepaths_val) // batch_size,
callbacks=callbacks,
verbose=1
)
# ํ‰๊ฐ€ (๋ฉ”๋ชจ๋ฆฌ ๋ฌธ์ œ๋กœ ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์˜ ์ผ๋ถ€๋งŒ ์‚ฌ์šฉ)
print("\n=== ์ตœ์ข… ํ‰๊ฐ€ ===")
num_eval_samples = min(len(filepaths_val), 1024) # ํ‰๊ฐ€ ์ƒ˜ํ”Œ ์ˆ˜ ์ œํ•œ
X_eval = np.array([read_binary_file(fp, max_length) for fp in filepaths_val[:num_eval_samples]])
y_eval = np.array(labels_val[:num_eval_samples])
if X_eval.size > 0:
results = evaluate_model(model, X_eval, y_eval, batch_size=batch_size//2)
else:
print("ํ‰๊ฐ€ํ•  ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
results = {}
# ์‹œ๊ฐํ™”
plot_training_history(history)
print(f"\n๋ชจ๋ธ์ด ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค: {save_path}")
return model, history, results
def main():
parser = argparse.ArgumentParser(description='MalConv ๋ชจ๋ธ ํ›ˆ๋ จ')
# ๋ฐ์ดํ„ฐ ์†Œ์Šค ์˜ต์…˜
parser.add_argument('--malware_dir', required=True, help='์•…์„ฑ์ฝ”๋“œ ๋””๋ ‰ํ† ๋ฆฌ')
parser.add_argument('--benign_dir', required=True, help='์ •์ƒํŒŒ์ผ ๋””๋ ‰ํ† ๋ฆฌ')
# ํ›ˆ๋ จ ์˜ต์…˜
parser.add_argument('--epochs', type=int, default=20, help='์—ํฌํฌ ์ˆ˜') # ์—ํฌํฌ ์ฆ๊ฐ€
parser.add_argument('--batch_size', type=int, default=64, help='๋ฐฐ์น˜ ํฌ๊ธฐ') # ๋ฐฐ์น˜ ํฌ๊ธฐ ์กฐ์ •
parser.add_argument('--max_length', type=int, default=2_000_000, help='์ตœ๋Œ€ ์ž…๋ ฅ ๊ธธ์ด')
parser.add_argument('--save_path', default='models/malconv_model.h5', help='๋ชจ๋ธ ์ €์žฅ ๊ฒฝ๋กœ')
args = parser.parse_args()
data_source = (args.malware_dir, args.benign_dir)
# ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
# ๋ชจ๋ธ ํ›ˆ๋ จ
train_malconv(
data_source=data_source,
epochs=args.epochs,
batch_size=args.batch_size,
max_length=args.max_length,
save_path=args.save_path
)
if __name__ == "__main__":
main()