retr04error commited on
Commit
5e9aa8c
1 Parent(s): 817a92a
Files changed (1) hide show
  1. flaskAPI.py +103 -0
flaskAPI.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.functional import softmax
3
+ from transformers import GPT2Tokenizer
4
+ import os
5
+ import requests
6
+ import tempfile
7
+ from flask import Flask, request, jsonify
8
+ import gradio as gr
9
+
10
+ app = Flask(__name__)
11
+
12
+ def check_vulnerabilities(solidity_file_path, model_directory='models/', device='cuda'):
13
+ device = 'cuda' if torch.cuda.is_available() and device == 'cuda' else 'cpu'
14
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
15
+ tokenizer.pad_token = tokenizer.eos_token
16
+
17
+ with open(solidity_file_path, 'r', encoding='utf-8') as f:
18
+ test_code = f.read()
19
+
20
+ X = tokenizer([test_code], padding=True, truncation=True, return_tensors="pt")
21
+
22
+ dic_name = {
23
+ 'reentrancy': 'Reentrancy Vulnerability',
24
+ 'timestamp': 'Timestamp Dependency Vulnerability',
25
+ 'delegatecall': 'Delegate Call Vulnerability',
26
+ 'integeroverflow': 'Integer Overflow Vulnerability',
27
+ }
28
+
29
+ dic01 = {0: 'The vulnerability does not exist', 1: 'The vulnerability exists'}
30
+ results = {}
31
+
32
+ for model_name in os.listdir(model_directory):
33
+ vulnerability_name = dic_name[model_name.split('_')[0]]
34
+ cp_file = os.path.join(model_directory, model_name)
35
+ model = torch.load(cp_file)
36
+ X = X.to(device)
37
+ model.to(device)
38
+ model.eval()
39
+ pred = softmax(model(**X).logits, dim=1)[0]
40
+ results[vulnerability_name] = {
41
+ 'result': dic01[int(pred.argmax(0))],
42
+ 'confidence': pred.max().item()
43
+ }
44
+
45
+ return results
46
+
47
+ @app.route('/check_vulnerabilities', methods=['POST'])
48
+ def check_vulnerabilities_api():
49
+ if 'solidity_file_url' not in request.json or 'file_id' not in request.json:
50
+ return jsonify({'error': 'No solidity_file_url or file_id provided'}), 400
51
+
52
+ solidity_file_url = request.json['solidity_file_url']
53
+ file_id = request.json['file_id']
54
+
55
+ model_directory = 'models/'
56
+ device = 'cuda'
57
+
58
+ try:
59
+ response = requests.get(solidity_file_url)
60
+ response.raise_for_status()
61
+
62
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{file_id}.sol") as temp_file:
63
+ temp_file.write(response.content)
64
+ temp_file.flush()
65
+ results = check_vulnerabilities(temp_file.name, model_directory, device)
66
+ os.remove(temp_file.name)
67
+ return jsonify(results)
68
+
69
+ except requests.exceptions.RequestException as e:
70
+ return jsonify({'error': f'Error fetching file: {e}'}), 500
71
+ except Exception as e:
72
+ return jsonify({'error': str(e)}), 500
73
+
74
+ def check_vulnerabilities_interface(solidity_file_url, file_id):
75
+ model_directory = 'models/'
76
+ device = 'cuda'
77
+
78
+ try:
79
+ response = requests.get(solidity_file_url)
80
+ response.raise_for_status()
81
+
82
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{file_id}.sol") as temp_file:
83
+ temp_file.write(response.content)
84
+ temp_file.flush()
85
+ results = check_vulnerabilities(temp_file.name, model_directory, device)
86
+ os.remove(temp_file.name)
87
+ return results
88
+
89
+ except requests.exceptions.RequestException as e:
90
+ return {'error': f'Error fetching file: {e}'}
91
+ except Exception as e:
92
+ return {'error': str(e)}
93
+
94
+ interface = gr.Interface(
95
+ fn=check_vulnerabilities_interface,
96
+ inputs=[gr.inputs.Textbox(label="Solidity File URL"), gr.inputs.Textbox(label="File ID")],
97
+ outputs="json",
98
+ title="Solidity Vulnerability Checker",
99
+ description="Enter the URL of a Solidity file and a file ID to check for vulnerabilities."
100
+ )
101
+
102
+ if __name__ == '__main__':
103
+ app.run()