minhan6559 commited on
Commit
56517e7
·
verified ·
1 Parent(s): 22c850f

update progress bar for streamlit

Browse files
app.py CHANGED
@@ -10,6 +10,7 @@ import os
10
  import sys
11
  import tempfile
12
  import shutil
 
13
  import streamlit as st
14
  from pathlib import Path
15
  from typing import Dict, Any, Optional
@@ -87,6 +88,9 @@ def run_analysis(
87
  temp_dirs: Dict[str, str],
88
  api_key: str,
89
  provider: str,
 
 
 
90
  ) -> Dict[str, Any]:
91
  """Run the cybersecurity analysis pipeline."""
92
 
@@ -106,8 +110,11 @@ def run_analysis(
106
  tactic=None,
107
  model_name=model_name,
108
  temperature=0.1,
 
 
109
  log_agent_output_dir=temp_dirs["analysis"],
110
  response_agent_output_dir=temp_dirs["final_response"],
 
111
  )
112
  return {"success": True, "result": result}
113
  except Exception as e:
@@ -181,14 +188,6 @@ def main():
181
  help=f"Your {selected_provider} API key",
182
  )
183
 
184
- # Additional query
185
- st.subheader("Additional Context")
186
- user_query = st.text_area(
187
- "Optional Query",
188
- placeholder="e.g., 'Focus on credential access attacks'",
189
- help="Provide additional context or specific focus areas for the analysis",
190
- )
191
-
192
  # Main content area
193
  col1, col2 = st.columns([2, 1])
194
 
@@ -237,19 +236,31 @@ def main():
237
  status_text.text("Initializing analysis...")
238
  progress_bar.progress(10)
239
 
240
- # Run analysis
241
- status_text.text("Running cybersecurity analysis...")
242
- progress_bar.progress(50)
 
 
 
 
243
 
 
244
  analysis_result = run_analysis(
245
  log_file_path=log_file_path,
246
  model_name=selected_model,
247
- query=user_query,
248
  temp_dirs=temp_dirs,
249
  api_key=api_key,
250
  provider=selected_provider,
 
 
 
251
  )
252
 
 
 
 
 
253
  progress_bar.progress(90)
254
  status_text.text("Finalizing results...")
255
 
@@ -278,15 +289,7 @@ def main():
278
  st.metric("Abnormal Events", len(abnormal_events))
279
 
280
  with col3:
281
- execution_time = result.get("execution_time", "N/A")
282
- st.metric(
283
- "Execution Time",
284
- (
285
- f"{execution_time:.2f}s"
286
- if isinstance(execution_time, (int, float))
287
- else execution_time
288
- ),
289
- )
290
 
291
  # Show markdown report
292
  markdown_report = result.get("markdown_report", "")
 
10
  import sys
11
  import tempfile
12
  import shutil
13
+ import time
14
  import streamlit as st
15
  from pathlib import Path
16
  from typing import Dict, Any, Optional
 
88
  temp_dirs: Dict[str, str],
89
  api_key: str,
90
  provider: str,
91
+ max_log_analysis_iterations: int,
92
+ max_retrieval_iterations: int,
93
+ progress_callback=None,
94
  ) -> Dict[str, Any]:
95
  """Run the cybersecurity analysis pipeline."""
96
 
 
110
  tactic=None,
111
  model_name=model_name,
112
  temperature=0.1,
113
+ max_log_analysis_iterations=max_log_analysis_iterations,
114
+ max_retrieval_iterations=max_retrieval_iterations,
115
  log_agent_output_dir=temp_dirs["analysis"],
116
  response_agent_output_dir=temp_dirs["final_response"],
117
+ progress_callback=progress_callback,
118
  )
119
  return {"success": True, "result": result}
120
  except Exception as e:
 
188
  help=f"Your {selected_provider} API key",
189
  )
190
 
 
 
 
 
 
 
 
 
191
  # Main content area
192
  col1, col2 = st.columns([2, 1])
193
 
 
236
  status_text.text("Initializing analysis...")
237
  progress_bar.progress(10)
238
 
239
+ # Start timing
240
+ start_time = time.time()
241
+
242
+ # Create progress callback
243
+ def update_progress(progress: int, message: str):
244
+ progress_bar.progress(progress)
245
+ status_text.text(message)
246
 
247
+ # Run analysis
248
  analysis_result = run_analysis(
249
  log_file_path=log_file_path,
250
  model_name=selected_model,
251
+ query="",
252
  temp_dirs=temp_dirs,
253
  api_key=api_key,
254
  provider=selected_provider,
255
+ max_log_analysis_iterations=3,
256
+ max_retrieval_iterations=2,
257
+ progress_callback=update_progress,
258
  )
259
 
260
+ # Calculate execution time
261
+ end_time = time.time()
262
+ execution_time = end_time - start_time
263
+
264
  progress_bar.progress(90)
265
  status_text.text("Finalizing results...")
266
 
 
289
  st.metric("Abnormal Events", len(abnormal_events))
290
 
291
  with col3:
292
+ st.metric("Execution Time", f"{execution_time:.2f}s")
 
 
 
 
 
 
 
 
293
 
294
  # Show markdown report
295
  markdown_report = result.get("markdown_report", "")
src/agents/log_analysis_agent/__pycache__/prompts.cpython-311.pyc CHANGED
Binary files a/src/agents/log_analysis_agent/__pycache__/prompts.cpython-311.pyc and b/src/agents/log_analysis_agent/__pycache__/prompts.cpython-311.pyc differ
 
src/agents/log_analysis_agent/__pycache__/utils.cpython-311.pyc CHANGED
Binary files a/src/agents/log_analysis_agent/__pycache__/utils.cpython-311.pyc and b/src/agents/log_analysis_agent/__pycache__/utils.cpython-311.pyc differ
 
src/full_pipeline/__pycache__/simple_pipeline.cpython-311.pyc CHANGED
Binary files a/src/full_pipeline/__pycache__/simple_pipeline.cpython-311.pyc and b/src/full_pipeline/__pycache__/simple_pipeline.cpython-311.pyc differ
 
src/full_pipeline/simple_pipeline.py CHANGED
@@ -42,20 +42,12 @@ class PipelineState(TypedDict):
42
  def create_simple_pipeline(
43
  model_name: str = "google_genai:gemini-2.0-flash",
44
  temperature: float = 0.1,
 
 
45
  log_agent_output_dir: str = "analysis",
46
  response_agent_output_dir: str = "final_response",
 
47
  ):
48
- """
49
- Create the simplified pipeline that directly connects the agents.
50
-
51
- Args:
52
- model_name: Name of the model to use (e.g., "gemini-2.0-flash", "gpt-oss-120b", "llama-3.1-8b-instant")
53
- temperature: Temperature for model generation
54
-
55
- Returns:
56
- Compiled pipeline workflow
57
- """
58
-
59
  # Initialize LLM client directly
60
  print("\n" + "=" * 60)
61
  print("INITIALIZING LLM CLIENT")
@@ -82,11 +74,15 @@ def create_simple_pipeline(
82
 
83
  # Initialize agents with shared LLM client
84
  log_agent = LogAnalysisAgent(
85
- output_dir=log_agent_output_dir, max_iterations=2, llm_client=llm_client
 
 
86
  )
87
 
88
  retrieval_supervisor = RetrievalSupervisor(
89
- kb_path="./cyber_knowledge_base", max_iterations=2, llm_client=llm_client
 
 
90
  )
91
 
92
  response_agent = ResponseAgent(
@@ -104,12 +100,18 @@ def create_simple_pipeline(
104
  log_file = state["log_file"]
105
  print(f"Analyzing log file: {log_file}")
106
 
 
 
 
107
  # Run log analysis (agent should not print its own phase headers)
108
  analysis_result = log_agent.analyze(log_file)
109
 
110
  # Store results in state
111
  state["log_analysis_result"] = analysis_result
112
 
 
 
 
113
  print(
114
  f"\nLog Analysis Assessment: {analysis_result.get('overall_assessment', 'UNKNOWN')}"
115
  )
@@ -133,6 +135,9 @@ def create_simple_pipeline(
133
  print(f"Generated retrieval query based on {assessment} assessment")
134
  print("\nStarting retrieval supervisor with log analysis context...\n")
135
 
 
 
 
136
  # Run retrieval supervisor with trace=True to show terminal output
137
  retrieval_result = retrieval_supervisor.invoke(
138
  query=query,
@@ -141,6 +146,9 @@ def create_simple_pipeline(
141
  trace=False, # This shows the agent conversations in terminal
142
  )
143
 
 
 
 
144
  # Store retrieval results in state
145
  state["retrieval_result"] = retrieval_result
146
 
@@ -153,6 +161,9 @@ def create_simple_pipeline(
153
  print("=" * 60)
154
  print("Creating Event ID → MITRE technique mappings...")
155
 
 
 
 
156
  # Run response agent analysis (agent should not print its own phase headers)
157
  response_analysis, markdown_report = response_agent.analyze_and_map(
158
  log_analysis_result=state["log_analysis_result"],
@@ -161,6 +172,9 @@ def create_simple_pipeline(
161
  tactic=state.get("tactic"),
162
  )
163
 
 
 
 
164
  # Store response analysis in state
165
  state["response_analysis"] = response_analysis
166
 
@@ -246,8 +260,11 @@ def analyze_log_file(
246
  tactic: str = None,
247
  model_name: str = "google_genai:gemini-2.0-flash",
248
  temperature: float = 0.1,
 
 
249
  log_agent_output_dir: str = "analysis",
250
  response_agent_output_dir: str = "final_response",
 
251
  ):
252
  """
253
  Analyze a single log file through the integrated pipeline.
@@ -258,6 +275,8 @@ def analyze_log_file(
258
  tactic: Optional tactic name for organizing output
259
  model_name: Name of the model to use (e.g., "gemini-2.0-flash", "gpt-oss-120b", "llama-3.1-8b-instant")
260
  temperature: Temperature for model generation
 
 
261
  log_agent_output_dir: Directory to save log agent output
262
  response_agent_output_dir: Directory to save response agent output
263
  """
@@ -276,8 +295,11 @@ def analyze_log_file(
276
  pipeline = create_simple_pipeline(
277
  model_name=model_name,
278
  temperature=temperature,
 
 
279
  log_agent_output_dir=log_agent_output_dir,
280
  response_agent_output_dir=response_agent_output_dir,
 
281
  )
282
 
283
  # Initialize state
@@ -293,9 +315,16 @@ def analyze_log_file(
293
 
294
  # Run pipeline
295
  start_time = time.time()
 
 
 
 
296
  final_state = pipeline.invoke(initial_state)
297
  end_time = time.time()
298
 
 
 
 
299
  print(f"\nTotal execution time: {end_time - start_time:.2f} seconds")
300
  print("Analysis complete!")
301
  return final_state
@@ -326,6 +355,8 @@ def main():
326
  query = None
327
  model_name = "gemini-2.0-flash" # Default model
328
  temperature = 0.1
 
 
329
  log_agent_output_dir = "analysis"
330
  response_agent_output_dir = "final_response"
331
 
@@ -350,6 +381,8 @@ def main():
350
  tactic=None,
351
  model_name=model_name,
352
  temperature=temperature,
 
 
353
  log_agent_output_dir=log_agent_output_dir,
354
  response_agent_output_dir=response_agent_output_dir,
355
  )
 
42
  def create_simple_pipeline(
43
  model_name: str = "google_genai:gemini-2.0-flash",
44
  temperature: float = 0.1,
45
+ max_log_analysis_iterations: int = 2,
46
+ max_retrieval_iterations: int = 2,
47
  log_agent_output_dir: str = "analysis",
48
  response_agent_output_dir: str = "final_response",
49
+ progress_callback=None,
50
  ):
 
 
 
 
 
 
 
 
 
 
 
51
  # Initialize LLM client directly
52
  print("\n" + "=" * 60)
53
  print("INITIALIZING LLM CLIENT")
 
74
 
75
  # Initialize agents with shared LLM client
76
  log_agent = LogAnalysisAgent(
77
+ output_dir=log_agent_output_dir,
78
+ max_iterations=max_log_analysis_iterations,
79
+ llm_client=llm_client,
80
  )
81
 
82
  retrieval_supervisor = RetrievalSupervisor(
83
+ kb_path="./cyber_knowledge_base",
84
+ max_iterations=max_retrieval_iterations,
85
+ llm_client=llm_client,
86
  )
87
 
88
  response_agent = ResponseAgent(
 
100
  log_file = state["log_file"]
101
  print(f"Analyzing log file: {log_file}")
102
 
103
+ if progress_callback:
104
+ progress_callback(20, "Running log analysis...")
105
+
106
  # Run log analysis (agent should not print its own phase headers)
107
  analysis_result = log_agent.analyze(log_file)
108
 
109
  # Store results in state
110
  state["log_analysis_result"] = analysis_result
111
 
112
+ if progress_callback:
113
+ progress_callback(40, "Log analysis completed")
114
+
115
  print(
116
  f"\nLog Analysis Assessment: {analysis_result.get('overall_assessment', 'UNKNOWN')}"
117
  )
 
135
  print(f"Generated retrieval query based on {assessment} assessment")
136
  print("\nStarting retrieval supervisor with log analysis context...\n")
137
 
138
+ if progress_callback:
139
+ progress_callback(50, "Running threat intelligence retrieval...")
140
+
141
  # Run retrieval supervisor with trace=True to show terminal output
142
  retrieval_result = retrieval_supervisor.invoke(
143
  query=query,
 
146
  trace=False, # This shows the agent conversations in terminal
147
  )
148
 
149
+ if progress_callback:
150
+ progress_callback(70, "Threat intelligence retrieval completed")
151
+
152
  # Store retrieval results in state
153
  state["retrieval_result"] = retrieval_result
154
 
 
161
  print("=" * 60)
162
  print("Creating Event ID → MITRE technique mappings...")
163
 
164
+ if progress_callback:
165
+ progress_callback(80, "Running response correlation analysis...")
166
+
167
  # Run response agent analysis (agent should not print its own phase headers)
168
  response_analysis, markdown_report = response_agent.analyze_and_map(
169
  log_analysis_result=state["log_analysis_result"],
 
172
  tactic=state.get("tactic"),
173
  )
174
 
175
+ if progress_callback:
176
+ progress_callback(90, "Response analysis completed")
177
+
178
  # Store response analysis in state
179
  state["response_analysis"] = response_analysis
180
 
 
260
  tactic: str = None,
261
  model_name: str = "google_genai:gemini-2.0-flash",
262
  temperature: float = 0.1,
263
+ max_log_analysis_iterations: int = 2,
264
+ max_retrieval_iterations: int = 2,
265
  log_agent_output_dir: str = "analysis",
266
  response_agent_output_dir: str = "final_response",
267
+ progress_callback=None,
268
  ):
269
  """
270
  Analyze a single log file through the integrated pipeline.
 
275
  tactic: Optional tactic name for organizing output
276
  model_name: Name of the model to use (e.g., "gemini-2.0-flash", "gpt-oss-120b", "llama-3.1-8b-instant")
277
  temperature: Temperature for model generation
278
+ max_log_analysis_iterations: Maximum number of iterations for the log analysis agent
279
+ max_retrieval_iterations: Maximum number of iterations for the retrieval supervisor
280
  log_agent_output_dir: Directory to save log agent output
281
  response_agent_output_dir: Directory to save response agent output
282
  """
 
295
  pipeline = create_simple_pipeline(
296
  model_name=model_name,
297
  temperature=temperature,
298
+ max_log_analysis_iterations=max_log_analysis_iterations,
299
+ max_retrieval_iterations=max_retrieval_iterations,
300
  log_agent_output_dir=log_agent_output_dir,
301
  response_agent_output_dir=response_agent_output_dir,
302
+ progress_callback=progress_callback,
303
  )
304
 
305
  # Initialize state
 
315
 
316
  # Run pipeline
317
  start_time = time.time()
318
+
319
+ if progress_callback:
320
+ progress_callback(10, "Initializing pipeline...")
321
+
322
  final_state = pipeline.invoke(initial_state)
323
  end_time = time.time()
324
 
325
+ if progress_callback:
326
+ progress_callback(100, "Analysis complete!")
327
+
328
  print(f"\nTotal execution time: {end_time - start_time:.2f} seconds")
329
  print("Analysis complete!")
330
  return final_state
 
355
  query = None
356
  model_name = "gemini-2.0-flash" # Default model
357
  temperature = 0.1
358
+ max_log_analysis_iterations = 2
359
+ max_retrieval_iterations = 2
360
  log_agent_output_dir = "analysis"
361
  response_agent_output_dir = "final_response"
362
 
 
381
  tactic=None,
382
  model_name=model_name,
383
  temperature=temperature,
384
+ max_log_analysis_iterations=max_log_analysis_iterations,
385
+ max_retrieval_iterations=max_retrieval_iterations,
386
  log_agent_output_dir=log_agent_output_dir,
387
  response_agent_output_dir=response_agent_output_dir,
388
  )