import torch from peft import PeftModel, PeftConfig from transformers import AutoModelForCausalLM, AutoTokenizer import gradio as gr import spaces # Load the model and tokenizer peft_model_id = "rootxhacker/CodeAstra-7B" config = PeftConfig.from_pretrained(peft_model_id) # Function to move tensors to CPU def to_cpu(obj): if isinstance(obj, torch.Tensor): return obj.cpu() elif isinstance(obj, list): return [to_cpu(item) for item in obj] elif isinstance(obj, tuple): return tuple(to_cpu(item) for item in obj) elif isinstance(obj, dict): return {key: to_cpu(value) for key, value in obj.items()} return obj # Load the model model = AutoModelForCausalLM.from_pretrained( config.base_model_name_or_path, return_dict=True, load_in_4bit=True, device_map='auto' ) tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) # Load the Lora model model = PeftModel.from_pretrained(model, peft_model_id) @spaces.GPU() def get_completion(query, model, tokenizer): try: # Move model to CUDA model = model.cuda() # Ensure input is on CUDA inputs = tokenizer(query, return_tensors="pt").to('cuda') with torch.no_grad(): outputs = model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7) # Move outputs to CPU before decoding outputs = to_cpu(outputs) return tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: return f"An error occurred: {str(e)}" finally: # Move model back to CPU to free up GPU memory model = model.cpu() torch.cuda.empty_cache() @spaces.GPU() def code_review(code_to_analyze): two_shot_prompt = f"""find all vulnerabilities which in the code {code_to_analyze} """ full_response = get_completion(two_shot_prompt, model, tokenizer) # Return the full response without any processing return full_response # Create Gradio interface iface = gr.Interface( fn=code_review, inputs=gr.Textbox(lines=10, label="Enter code to analyze"), outputs=gr.Textbox(label="Code Review Result"), title="Code Review Expert", description="This tool analyzes code for potential security flaws, logic vulnerabilities, and provides guidance on secure coding practices." ) # Launch the Gradio app iface.launch()