malconv / src /predict.py
cycloevan's picture
Update script files
52b5518 verified
import os
import sys
import argparse
import numpy as np
import tensorflow as tf
import pandas as pd
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.utils import read_binary_file
from src.model import MalConv
def predict_file(model_path, file_path, max_length=2_000_000): # 2,000,000
"""
๋‹จ์ผ ํŒŒ์ผ์— ๋Œ€ํ•œ ์˜ˆ์ธก
Args:
model_path: ์ €์žฅ๋œ ๋ชจ๋ธ ๊ฒฝ๋กœ
file_path: ์˜ˆ์ธกํ•  ํŒŒ์ผ ๊ฒฝ๋กœ
max_length: ์ตœ๋Œ€ ์ž…๋ ฅ ๊ธธ์ด
Returns:
float: ์˜ˆ์ธก ํ™•๋ฅ  (0์— ๊ฐ€๊นŒ์šฐ๋ฉด ์•…์„ฑ์ฝ”๋“œ, 1์— ๊ฐ€๊นŒ์šฐ๋ฉด ์ •์ƒ)
"""
# ๋ชจ๋ธ ๋กœ๋“œ
model = MalConv(max_input_length=max_length)
# ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œํ•˜๊ธฐ ์ „์— ๋นŒ๋“œ
dummy_input = tf.zeros((1, max_length), dtype=tf.int32)
model(dummy_input) # ๋ชจ๋ธ ๋นŒ๋“œ
model.load_weights(model_path)
# ํŒŒ์ผ ์ฝ๊ธฐ
byte_array = read_binary_file(file_path, max_length)
# ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€
input_data = np.expand_dims(byte_array, axis=0)
# ์˜ˆ์ธก
prediction = model.predict(input_data, verbose=0)[0][0]
return prediction
def predict_batch(model_path, csv_path, output_path=None, max_length=2**20):
"""
๋ฐฐ์น˜ ์˜ˆ์ธก
Args:
model_path: ์ €์žฅ๋œ ๋ชจ๋ธ ๊ฒฝ๋กœ
csv_path: ์˜ˆ์ธกํ•  ํŒŒ์ผ๋“ค์˜ CSV ๊ฒฝ๋กœ
output_path: ๊ฒฐ๊ณผ ์ €์žฅ ๊ฒฝ๋กœ
max_length: ์ตœ๋Œ€ ์ž…๋ ฅ ๊ธธ์ด
"""
# ๋ชจ๋ธ ๋กœ๋“œ
print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
model = MalConv(max_input_length=max_length)
# ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œํ•˜๊ธฐ ์ „์— ๋นŒ๋“œ
dummy_input = tf.zeros((1, max_length), dtype=tf.int32)
model(dummy_input) # ๋ชจ๋ธ ๋นŒ๋“œ
model.load_weights(model_path)
# CSV ํŒŒ์ผ ์ฝ๊ธฐ
df = pd.read_csv(csv_path)
predictions = []
labels = []
print("์˜ˆ์ธก ์ค‘...")
for idx, row in df.iterrows():
file_path = row['filepath']
if os.path.exists(file_path):
try:
# ํŒŒ์ผ ์ฝ๊ธฐ
byte_array = read_binary_file(file_path, max_length)
input_data = np.expand_dims(byte_array, axis=0)
# ์˜ˆ์ธก
pred = model.predict(input_data, verbose=0)[0][0]
predictions.append(pred)
# ๋ผ๋ฒจ์ด ์žˆ๋Š” ๊ฒฝ์šฐ
if 'label' in row:
labels.append(row['label'])
# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
status = "์ •์ƒ" if pred > 0.5 else "์•…์„ฑ์ฝ”๋“œ"
confidence = pred if pred > 0.5 else 1 - pred
print(f"{file_path}: {status} (์‹ ๋ขฐ๋„: {confidence:.4f})")
except Exception as e:
print(f"Error processing {file_path}: {e}")
predictions.append(-1) # ์—๋Ÿฌ ํ‘œ์‹œ
else:
print(f"ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {file_path}")
predictions.append(-1)
# ๊ฒฐ๊ณผ ์ €์žฅ
result_df = df.copy()
result_df['prediction'] = predictions
result_df['predicted_label'] = (np.array(predictions) > 0.5).astype(int)
result_df['prediction_text'] = ['์ •์ƒ' if p > 0.5 else '์•…์„ฑ์ฝ”๋“œ' if p >= 0 else '์—๋Ÿฌ'
for p in predictions]
if output_path:
result_df.to_csv(output_path, index=False)
print(f"๊ฒฐ๊ณผ๊ฐ€ ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค: {output_path}")
# ์ •ํ™•๋„ ๊ณ„์‚ฐ (๋ผ๋ฒจ์ด ์žˆ๋Š” ๊ฒฝ์šฐ)
if labels and len(labels) == len(predictions):
valid_predictions = [p for p in predictions if p >= 0]
valid_labels = [labels[i] for i, p in enumerate(predictions) if p >= 0]
if valid_predictions:
pred_binary = (np.array(valid_predictions) > 0.5).astype(int)
accuracy = np.mean(pred_binary == np.array(valid_labels))
print(f"\n์ •ํ™•๋„: {accuracy:.4f}")
return result_df
def main():
parser = argparse.ArgumentParser(description='MalConv ๋ชจ๋ธ ์˜ˆ์ธก')
parser.add_argument('model_path', help='์ €์žฅ๋œ ๋ชจ๋ธ ๊ฒฝ๋กœ')
parser.add_argument('--file', help='๋‹จ์ผ ํŒŒ์ผ ์˜ˆ์ธก')
parser.add_argument('--csv', help='๋ฐฐ์น˜ ์˜ˆ์ธก์šฉ CSV ํŒŒ์ผ')
parser.add_argument('--output', help='๊ฒฐ๊ณผ ์ €์žฅ ๊ฒฝ๋กœ')
parser.add_argument('--max_length', type=int, default=2**20, help='์ตœ๋Œ€ ์ž…๋ ฅ ๊ธธ์ด')
args = parser.parse_args()
if args.file:
# ๋‹จ์ผ ํŒŒ์ผ ์˜ˆ์ธก
prediction = predict_file(args.model_path, args.file, args.max_length)
status = "์ •์ƒ" if prediction > 0.5 else "์•…์„ฑ์ฝ”๋“œ"
confidence = prediction if prediction > 0.5 else 1 - prediction
print(f"ํŒŒ์ผ: {args.file}")
print(f"์˜ˆ์ธก: {status} (์‹ ๋ขฐ๋„: {confidence:.4f})")
elif args.csv:
# ๋ฐฐ์น˜ ์˜ˆ์ธก
predict_batch(args.model_path, args.csv, args.output, args.max_length)
else:
print("--file ๋˜๋Š” --csv ์˜ต์…˜์„ ์ง€์ •ํ•ด์ฃผ์„ธ์š”.")
if __name__ == "__main__":
main()