from flask import Flask, render_template, request, jsonify import torch import torch.nn as nn import numpy as np from transformers import AutoTokenizer, AutoModel import nvdlib # Flask app initialization app = Flask(__name__) # Define the model architecture class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.transformer_model = AutoModel.from_pretrained('jackaduma/SecRoBERTa') self.dropout = nn.Dropout(0.3) self.output = nn.Linear(768, 14) def forward(self, input_ids, attention_mask=None): _, o2 = self.transformer_model( input_ids=input_ids, attention_mask=attention_mask, return_dict=False ) x = self.dropout(o2) out = self.output(x) return out # Function to predict MITRE ATT&CK techniques def predict_techniques(model, tokenizer, cve_description, device): tokenized_input = tokenizer.encode_plus( cve_description, max_length=320, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) input_ids = tokenized_input['input_ids'].to(device) attention_mask = tokenized_input['attention_mask'].to(device) with torch.no_grad(): logits = model(input_ids, attention_mask) probs = torch.sigmoid(logits).cpu().numpy() predicted_techniques = np.round(probs) return predicted_techniques # Global variables for model and tokenizer global_model = None global_tokenizer = None # Lazy loading function to get the model and tokenizer def get_model_and_tokenizer(device='cpu'): global global_model, global_tokenizer if global_model is None or global_tokenizer is None: global_model = Model() global_model.load_state_dict(torch.load('tactic_predict.pt', map_location=device, weights_only=True)) global_model.to(device) global_model.eval() global_tokenizer = AutoTokenizer.from_pretrained('jackaduma/SecRoBERTa') return global_model, global_tokenizer # Route for the home page @app.route('/') def home(): return render_template('index.html') # Route to handle form submission and return results @app.route('/predict', methods=['POST']) def predict(): cve_id = request.form['cve_id'] r = nvdlib.searchCVE(cveId=cve_id)[0] desc_list = r.descriptions cve_data = next(desc.value for desc in desc_list if desc.lang == "en") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load model and tokenizer lazily model, tokenizer = get_model_and_tokenizer(device) predicted_techniques = predict_techniques(model, tokenizer, cve_data, device) tactic_names = [ "Reconnaissance", "Resource Development", "Initial Access", "Execution", "Persistence", "Privilege Escalation", "Defense Evasion", "Credential Access", "Discovery", "Lateral Movement", "Collection", "Command and Control", "Exfiltration", "Impact" ] predicted_tactic_names = [tactic_names[i] for i, val in enumerate(predicted_techniques[0]) if val == 1] return render_template('result.html', tactics=predicted_tactic_names, cve_id=cve_id, cve_desc=cve_data) # Run the app if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)