|
import streamlit as st |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import json |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_model(): |
|
model_path = "Canstralian/pentest_ai" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto", |
|
load_in_4bit=False, |
|
load_in_8bit=True, |
|
trust_remote_code=True, |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
return model, tokenizer |
|
|
|
|
|
def generate_text(model, tokenizer, instruction): |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
tokens = tokenizer.encode(instruction, return_tensors='pt').to(device) |
|
generated_tokens = model.generate( |
|
tokens, |
|
max_length=1024, |
|
top_p=1.0, |
|
temperature=0.5, |
|
top_k=50 |
|
) |
|
return tokenizer.decode(generated_tokens[0], skip_special_tokens=True) |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_json_data(): |
|
json_data = [ |
|
{"name": "Raja Clarke", "email": "consectetuer@yahoo.edu", "country": "Chile", "company": "Urna Nunc Consulting"}, |
|
{"name": "Melissa Hobbs", "email": "massa.non@hotmail.couk", "country": "France", "company": "Gravida Mauris Limited"}, |
|
{"name": "John Doe", "email": "john.doe@example.com", "country": "USA", "company": "Example Corp"}, |
|
{"name": "Jane Smith", "email": "jane.smith@example.org", "country": "Canada", "company": "Innovative Solutions Inc"} |
|
] |
|
return json_data |
|
|
|
|
|
st.title("Penetration Testing AI Assistant") |
|
|
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
instruction = st.text_area("Enter your question for the AI assistant:") |
|
if st.button("Generate"): |
|
if instruction: |
|
response = generate_text(model, tokenizer, instruction) |
|
st.subheader("Generated Response:") |
|
st.write(response) |
|
else: |
|
st.warning("Please enter a question to generate a response.") |
|
|
|
|
|
st.subheader("User Data (from JSON)") |
|
user_data = load_json_data() |
|
|
|
|
|
for user in user_data: |
|
st.write(f"**Name:** {user['name']}") |
|
st.write(f"**Email:** {user['email']}") |
|
st.write(f"**Country:** {user['country']}") |
|
st.write(f"**Company:** {user['company']}") |
|
st.write("---") |
|
|