Spaces:
Runtime error
Runtime error
import torch | |
from torch.nn.functional import softmax | |
from transformers import GPT2Tokenizer | |
import os | |
import requests | |
import tempfile | |
import gradio as gr | |
# Model loading and prediction function | |
def check_vulnerabilities(solidity_file_path, model_directory='models/', device='cuda'): | |
device = 'cuda' if torch.cuda.is_available() and device == 'cuda' else 'cpu' | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
tokenizer.pad_token = tokenizer.eos_token | |
with open(solidity_file_path, 'r', encoding='utf-8') as f: | |
test_code = f.read() | |
X = tokenizer([test_code], padding=True, truncation=True, return_tensors="pt") | |
dic_name = { | |
'reentrancy': 'Reentrancy Vulnerability', | |
'timestamp': 'Timestamp Dependency Vulnerability', | |
'delegatecall': 'Delegate Call Vulnerability', | |
'integeroverflow': 'Integer Overflow Vulnerability', | |
} | |
dic01 = {0: 'The vulnerability does not exist', 1: 'The vulnerability exists'} | |
results = {} | |
for model_name in os.listdir(model_directory): | |
vulnerability_name = dic_name[model_name.split('_')[0]] | |
cp_file = os.path.join(model_directory, model_name) | |
model = torch.load(cp_file, map_location=device) | |
X = X.to(device) | |
model.to(device) | |
model.eval() | |
pred = softmax(model(**X).logits, dim=1)[0] | |
results[vulnerability_name] = { | |
'result': dic01[int(pred.argmax(0))], | |
'confidence': pred.max().item() | |
} | |
return results | |
# Gradio interface function | |
def check_vulnerabilities_interface(solidity_file_url, file_id): | |
model_directory = 'models/' | |
device = 'cuda' | |
try: | |
response = requests.get(solidity_file_url) | |
response.raise_for_status() | |
with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{file_id}.sol") as temp_file: | |
temp_file.write(response.content) | |
temp_file.flush() | |
results = check_vulnerabilities(temp_file.name, model_directory, device) | |
os.remove(temp_file.name) | |
return results | |
except requests.exceptions.RequestException as e: | |
return {'error': f'Error fetching file: {e}'} | |
except Exception as e: | |
return {'error': str(e)} | |
# Set up the Gradio interface | |
interface = gr.Interface( | |
fn=check_vulnerabilities_interface, | |
inputs=[ | |
gr.components.Textbox(label="Solidity File URL", placeholder="Enter URL here...", lines=2), | |
gr.components.Textbox(label="File ID", placeholder="Enter file ID here...", lines=1) | |
], | |
outputs=gr.components.JSON(), | |
title="Solidity Vulnerability Checker", | |
description="Enter the URL of a Solidity file and a file ID to check for vulnerabilities." | |
) | |
# Run the interface | |
if __name__ == "__main__": | |
interface.launch(share=True) | |