krish-rebase / flaskAPI.py
retr04error's picture
llm
1a49556
raw
history blame
2.61 kB
import torch
from torch.nn.functional import softmax
from transformers import GPT2Tokenizer
import os
import requests
import tempfile
from flask import Flask, request, jsonify
import gradio as gr
app = Flask(__name__)
def check_vulnerabilities(solidity_file_path, model_directory='models/', device='cuda'):
device = 'cuda' if torch.cuda.is_available() and device == 'cuda' else 'cpu'
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
with open(solidity_file_path, 'r', encoding='utf-8') as f:
test_code = f.read()
X = tokenizer([test_code], padding=True, truncation=True, return_tensors="pt")
dic_name = {
'reentrancy': 'Reentrancy Vulnerability',
'timestamp': 'Timestamp Dependency Vulnerability',
'delegatecall': 'Delegate Call Vulnerability',
'integeroverflow': 'Integer Overflow Vulnerability',
}
dic01 = {0: 'The vulnerability does not exist', 1: 'The vulnerability exists'}
results = {}
for model_name in os.listdir(model_directory):
vulnerability_name = dic_name[model_name.split('_')[0]]
cp_file = os.path.join(model_directory, model_name)
model = torch.load(cp_file)
X = X.to(device)
model.to(device)
model.eval()
pred = softmax(model(**X).logits, dim=1)[0]
results[vulnerability_name] = {
'result': dic01[int(pred.argmax(0))],
'confidence': pred.max().item()
}
return results
@app.route('/check_vulnerabilities', methods=['POST'])
def check_vulnerabilities_api():
if 'solidity_file_url' not in request.json or 'file_id' not in request.json:
return jsonify({'error': 'No solidity_file_url or file_id provided'}), 400
solidity_file_url = request.json['solidity_file_url']
file_id = request.json['file_id']
model_directory = 'models/'
device = 'cuda'
try:
response = requests.get(solidity_file_url)
response.raise_for_status()
with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{file_id}.sol") as temp_file:
temp_file.write(response.content)
temp_file.flush()
results = check_vulnerabilities(temp_file.name, model_directory, device)
os.remove(temp_file.name)
return jsonify(results)
except requests.exceptions.RequestException as e:
return jsonify({'error': f'Error fetching file: {e}'}), 500
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run()