import torch from torch.nn.functional import softmax from transformers import GPT2Tokenizer import os import requests import tempfile from flask import Flask, request, jsonify import gradio as gr app = Flask(__name__) def check_vulnerabilities(solidity_file_path, model_directory='models/', device='cuda'): device = 'cuda' if torch.cuda.is_available() and device == 'cuda' else 'cpu' tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token with open(solidity_file_path, 'r', encoding='utf-8') as f: test_code = f.read() X = tokenizer([test_code], padding=True, truncation=True, return_tensors="pt") dic_name = { 'reentrancy': 'Reentrancy Vulnerability', 'timestamp': 'Timestamp Dependency Vulnerability', 'delegatecall': 'Delegate Call Vulnerability', 'integeroverflow': 'Integer Overflow Vulnerability', } dic01 = {0: 'The vulnerability does not exist', 1: 'The vulnerability exists'} results = {} for model_name in os.listdir(model_directory): vulnerability_name = dic_name[model_name.split('_')[0]] cp_file = os.path.join(model_directory, model_name) model = torch.load(cp_file) X = X.to(device) model.to(device) model.eval() pred = softmax(model(**X).logits, dim=1)[0] results[vulnerability_name] = { 'result': dic01[int(pred.argmax(0))], 'confidence': pred.max().item() } return results @app.route('/check_vulnerabilities', methods=['POST']) def check_vulnerabilities_api(): if 'solidity_file_url' not in request.json or 'file_id' not in request.json: return jsonify({'error': 'No solidity_file_url or file_id provided'}), 400 solidity_file_url = request.json['solidity_file_url'] file_id = request.json['file_id'] model_directory = 'models/' device = 'cuda' try: response = requests.get(solidity_file_url) response.raise_for_status() with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{file_id}.sol") as temp_file: temp_file.write(response.content) temp_file.flush() results = check_vulnerabilities(temp_file.name, model_directory, device) os.remove(temp_file.name) return jsonify(results) except requests.exceptions.RequestException as e: return jsonify({'error': f'Error fetching file: {e}'}), 500 except Exception as e: return jsonify({'error': str(e)}), 500 if __name__ == '__main__': app.run()