import torch from torch.nn.functional import softmax from transformers import GPT2Tokenizer import os import requests import tempfile import gradio as gr # Model loading and prediction function def check_vulnerabilities(solidity_file_path, model_directory='models/', device='cuda'): device = 'cuda' if torch.cuda.is_available() and device == 'cuda' else 'cpu' tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token with open(solidity_file_path, 'r', encoding='utf-8') as f: test_code = f.read() X = tokenizer([test_code], padding=True, truncation=True, return_tensors="pt") dic_name = { 'reentrancy': 'Reentrancy Vulnerability', 'timestamp': 'Timestamp Dependency Vulnerability', 'delegatecall': 'Delegate Call Vulnerability', 'integeroverflow': 'Integer Overflow Vulnerability', } dic01 = {0: 'The vulnerability does not exist', 1: 'The vulnerability exists'} results = {} for model_name in os.listdir(model_directory): vulnerability_name = dic_name[model_name.split('_')[0]] cp_file = os.path.join(model_directory, model_name) model = torch.load(cp_file, map_location=device) X = X.to(device) model.to(device) model.eval() pred = softmax(model(**X).logits, dim=1)[0] results[vulnerability_name] = { 'result': dic01[int(pred.argmax(0))], 'confidence': pred.max().item() } return results # Gradio interface function def check_vulnerabilities_interface(solidity_file_url, file_id): model_directory = 'models/' device = 'cuda' try: response = requests.get(solidity_file_url) response.raise_for_status() with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{file_id}.sol") as temp_file: temp_file.write(response.content) temp_file.flush() results = check_vulnerabilities(temp_file.name, model_directory, device) os.remove(temp_file.name) return results except requests.exceptions.RequestException as e: return {'error': f'Error fetching file: {e}'} except Exception as e: return {'error': str(e)} # Set up the Gradio interface interface = gr.Interface( fn=check_vulnerabilities_interface, inputs=[ gr.components.Textbox(label="Solidity File URL", placeholder="Enter URL here...", lines=2), gr.components.Textbox(label="File ID", placeholder="Enter file ID here...", lines=1) ], outputs=gr.components.JSON(), title="Solidity Vulnerability Checker", description="Enter the URL of a Solidity file and a file ID to check for vulnerabilities." ) # Run the interface if __name__ == "__main__": interface.launch(share=True)