Text Generation
Transformers
English
AI
NLP
Cybersecurity
Ethical Hacking
Pentesting
Inference Endpoints
Canstralian commited on
Commit
a6b8b8c
·
verified ·
1 Parent(s): 548a773

Update pentest_ai_streamlit.py

Browse files
Files changed (1) hide show
  1. pentest_ai_streamlit.py +46 -10
pentest_ai_streamlit.py CHANGED
@@ -1,13 +1,15 @@
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
4
 
 
5
  @st.cache(allow_output_mutation=True)
6
  def load_model():
7
  model_path = "Canstralian/pentest_ai"
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_path,
10
- torch_dtype=torch.float16,
11
  device_map="auto",
12
  load_in_4bit=False,
13
  load_in_8bit=True,
@@ -16,21 +18,55 @@ def load_model():
16
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
17
  return model, tokenizer
18
 
 
19
  def generate_text(model, tokenizer, instruction):
20
- tokens = tokenizer.encode(instruction, return_tensors='pt').to('cuda')
 
 
21
  generated_tokens = model.generate(
22
- tokens,
23
- max_length=1024,
24
- top_p=1.0,
25
- temperature=0.5,
26
  top_k=50
27
  )
28
  return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
29
 
30
- model, tokenizer = load_model()
 
 
 
 
 
 
 
 
 
31
 
 
32
  st.title("Penetration Testing AI Assistant")
33
- instruction = st.text_area("Enter your question:")
 
 
 
 
 
34
  if st.button("Generate"):
35
- response = generate_text(model, tokenizer, instruction)
36
- st.write(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
+ import json
5
 
6
+ # Load the model and tokenizer
7
  @st.cache(allow_output_mutation=True)
8
  def load_model():
9
  model_path = "Canstralian/pentest_ai"
10
  model = AutoModelForCausalLM.from_pretrained(
11
  model_path,
12
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Use float16 if CUDA is available
13
  device_map="auto",
14
  load_in_4bit=False,
15
  load_in_8bit=True,
 
18
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
19
  return model, tokenizer
20
 
21
+ # Function to generate text from the model
22
  def generate_text(model, tokenizer, instruction):
23
+ # Check if CUDA is available and send tensors to the appropriate device
24
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
+ tokens = tokenizer.encode(instruction, return_tensors='pt').to(device)
26
  generated_tokens = model.generate(
27
+ tokens,
28
+ max_length=1024,
29
+ top_p=1.0,
30
+ temperature=0.5,
31
  top_k=50
32
  )
33
  return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
34
 
35
+ # Load the JSON data (simulated here for simplicity)
36
+ @st.cache(allow_output_mutation=True)
37
+ def load_json_data():
38
+ json_data = [
39
+ {"name": "Raja Clarke", "email": "consectetuer@yahoo.edu", "country": "Chile", "company": "Urna Nunc Consulting"},
40
+ {"name": "Melissa Hobbs", "email": "massa.non@hotmail.couk", "country": "France", "company": "Gravida Mauris Limited"},
41
+ {"name": "John Doe", "email": "john.doe@example.com", "country": "USA", "company": "Example Corp"},
42
+ {"name": "Jane Smith", "email": "jane.smith@example.org", "country": "Canada", "company": "Innovative Solutions Inc"}
43
+ ]
44
+ return json_data
45
 
46
+ # Streamlit UI
47
  st.title("Penetration Testing AI Assistant")
48
+
49
+ # Load model and tokenizer
50
+ model, tokenizer = load_model()
51
+
52
+ # Generate some text based on user input
53
+ instruction = st.text_area("Enter your question for the AI assistant:")
54
  if st.button("Generate"):
55
+ if instruction:
56
+ response = generate_text(model, tokenizer, instruction)
57
+ st.subheader("Generated Response:")
58
+ st.write(response)
59
+ else:
60
+ st.warning("Please enter a question to generate a response.")
61
+
62
+ # Displaying user data from JSON
63
+ st.subheader("User Data (from JSON)")
64
+ user_data = load_json_data()
65
+
66
+ # Display user details in a readable format
67
+ for user in user_data:
68
+ st.write(f"**Name:** {user['name']}")
69
+ st.write(f"**Email:** {user['email']}")
70
+ st.write(f"**Country:** {user['country']}")
71
+ st.write(f"**Company:** {user['company']}")
72
+ st.write("---") # Separator