Add Initial implementation of the deepforest-agent

#1
Files changed (34) hide show
  1. LICENSE +21 -0
  2. README.md +100 -12
  3. app.py +501 -0
  4. pyproject.toml +66 -0
  5. requirements.txt +43 -0
  6. src/__init__.py +0 -0
  7. src/deepforest_agent/__init__.py +0 -0
  8. src/deepforest_agent/agents/__init__.py +0 -0
  9. src/deepforest_agent/agents/deepforest_detector_agent.py +403 -0
  10. src/deepforest_agent/agents/ecology_analysis_agent.py +92 -0
  11. src/deepforest_agent/agents/memory_agent.py +238 -0
  12. src/deepforest_agent/agents/orchestrator.py +795 -0
  13. src/deepforest_agent/agents/visual_analysis_agent.py +307 -0
  14. src/deepforest_agent/conf/__init__.py +0 -0
  15. src/deepforest_agent/conf/config.py +60 -0
  16. src/deepforest_agent/models/__init__.py +0 -0
  17. src/deepforest_agent/models/llama32_3b_instruct.py +242 -0
  18. src/deepforest_agent/models/qwen_vl_3b_instruct.py +152 -0
  19. src/deepforest_agent/models/smollm3_3b.py +244 -0
  20. src/deepforest_agent/prompts/__init__.py +0 -0
  21. src/deepforest_agent/prompts/prompt_templates.py +257 -0
  22. src/deepforest_agent/tools/__init__.py +0 -0
  23. src/deepforest_agent/tools/deepforest_tool.py +323 -0
  24. src/deepforest_agent/tools/tool_handler.py +188 -0
  25. src/deepforest_agent/utils/__init__.py +0 -0
  26. src/deepforest_agent/utils/cache_utils.py +306 -0
  27. src/deepforest_agent/utils/detection_narrative_generator.py +445 -0
  28. src/deepforest_agent/utils/image_utils.py +465 -0
  29. src/deepforest_agent/utils/logging_utils.py +449 -0
  30. src/deepforest_agent/utils/parsing_utils.py +238 -0
  31. src/deepforest_agent/utils/rtree_spatial_utils.py +394 -0
  32. src/deepforest_agent/utils/state_manager.py +574 -0
  33. src/deepforest_agent/utils/tile_manager.py +211 -0
  34. tests/test_deepforest_tool.py +465 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 DeepForest Agent
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,100 @@
1
- ---
2
- title: Deepforest Agent
3
- emoji: 🔥
4
- colorFrom: pink
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.44.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepForest Multi-Agent System
2
+
3
+ The DeepForest Multi-Agent System provides ecological image analysis by orchestrating multiple AI agents that work together to understand ecological images. Simply upload an image of a forest, wildlife habitat, or ecological scene, and ask questions in natural language.
4
+
5
+ ## Installation
6
+
7
+ ### 1. Clone the repository
8
+
9
+ ```bash
10
+ git clone https://github.com/weecology/deepforest-agent.git
11
+ cd deepforest-agent
12
+ ```
13
+
14
+ ### 2. Create and activate a Conda environment
15
+
16
+ ```bash
17
+ conda create -n deepforest_agent python=3.12.11
18
+ conda activate deepforest_agent
19
+ ```
20
+
21
+ ### 3. Install dependencies
22
+
23
+ ```bash
24
+ pip install -r requirements.txt
25
+ pip install -e .
26
+ ```
27
+
28
+ ### 4. Configure the HuggingFace Token
29
+ Create a `.env` file in the root directory of the deepforest-agent project and add your HuggingFace token like below:
30
+
31
+ ```bash
32
+ HF_TOKEN="your_huggingface_token_here"
33
+ ```
34
+
35
+ You can obtain your token from [HuggingFace Access Token](https://huggingface.co/settings/tokens). Make sure the Token type is "Write".
36
+
37
+ ## Usage
38
+
39
+ The DeepForest Agent runs through a Gradio web interface. To start the interface, execute:
40
+
41
+ ```bash
42
+ python app.py
43
+ ```
44
+
45
+ A link like http://127.0.0.1:7860 will appear in the terminal. Open it in your browser to interact with the agent. A public Gradio link may also be provided if available.
46
+
47
+ **Sample Recording of Running the System:** [Drive Link](https://drive.google.com/file/d/1gNMn-xJd48Ld3TZU4oiYvTbiWaiLsc8G/view?usp=sharing)
48
+
49
+
50
+ ### How to Use
51
+
52
+ 1. Upload an ecological image (aerial/drone photography works best)
53
+ 2. Ask questions about wildlife, forest health, or ecological patterns. For example:
54
+ - How many trees are detected, and how many of them are alive vs dead?
55
+ - How many birds are around each dead tree?
56
+ - What objects are in the northwest region of the image?
57
+ - Do any birds overlap with livestock in this image?
58
+ - What percentage of the image is covered by trees vs birds vs livestock?
59
+ 3. Get comprehensive analysis combining computer vision and ecological insights. The gallery shows the annotated image with objects and the detection monitor presents the summary of DeepForest detection.
60
+
61
+
62
+ ## Features
63
+
64
+ - **Multi-Species Detection**: Automatically detects trees, birds, and livestock using specialized DeepForest models
65
+ - **Tree Health Assessment**: Identifies alive and dead trees using DeepForest Tree Detector whenever user asks.
66
+ - **Visual Analysis**: Dual analysis of original and annotated images using Qwen2.5-VL-3B-Instruct model
67
+ - **Memory Context**: Maintains conversation history for contextual understanding across multiple queries
68
+ - **Tiling Image for Visual Agent:** Larger images are tiled and processed individually for the visual agent.
69
+ - **R-Tree Spatial Indexing:** Stores DeepForest Results in an R-Tree spatial index structure and use spatial queries to retrieve relevant information and present it to the user.
70
+ - **Ecological Insights**: Synthesizes detection data with visual analysis and memory context for comprehensive ecological understanding
71
+ - **Streaming Responses**: Real-time updates as each agent processes your query
72
+
73
+
74
+ ## Requirements
75
+
76
+ ### Hardware Requirements
77
+ - **GPU**: GPU with at least 24GB VRAM (recommended for optimal performance). The system is optimized for GPU execution. Running on CPU will take significantly longer processing times
78
+ - **Storage**: At least 35GB free space for model downloads
79
+
80
+ ### API Requirements
81
+ - **HuggingFace Token**: Required for model access.
82
+
83
+
84
+ ## Image Processing Times
85
+
86
+ - **Standard Images**: Most ecological images process within 30 seconds on GPU
87
+ - **Large GeoTIFF Files**: Larger geospatial images may require significant time for complete analysis
88
+
89
+
90
+ ## Models Used
91
+
92
+ - **SmolLM3-3B**: For Memory Agent to get context, and for Detector Agent to call the tool with appropriate parameters
93
+ - **Qwen2.5-VL-3B-Instruct**: Used in Visual agent for multimodal image-text understanding
94
+ - **Llama-3.2-3B-Instruct**: For Ecology agents for text understanding and generation
95
+ - **DeepForest Models**: For tree, bird, and livestock detection. Also used for alive/dead tree classification.
96
+
97
+
98
+ ## Multi-Agent Workflow
99
+
100
+ [![](https://mermaid.ink/img/pako:eNqlV9ty4jgQ_RWVt2pexmQI9_hhtwiQK4SLgVyceVBMO7giLEq2kzAk_75tSTYiM9k8LA8UpruPuk93H9tby-cLsBzrUdD1kky79xHBT9ubxSDIbM04XcTkfEUf4Scplf4mx15HAE2AuBDHIY_IN3IehUlIWfgL_0zQ9FNhHEv_jkJyIUKccQpio80dae56Q-EvIU4ETbhwyCBlSVhqP0KUkGsungLGXzJUkegw9d2VwT1vACsuNkT6O8RdcdYfVEvVY-3ck24nW-12RmPSDQX4CWlH8QuIf95N0JPM--0W4jdy6vVeMSV0nHLOSIdijuS8q2FPJeyZN4FEhPAMyr4gXUgQOyPligosKHzOuTiTEedez-eMPxYJ9xldUVI9qGDKeWbuJkqQkDDWoecy9M47CSPKyATiNY9iIC9hsiS6rg6PEnjdZ0gVc8XfyIU3D-MUY9sIsEHg_PTxC0SleX9H14U86nJ7kjKmer6LcVPfx44HKdsn7XJHWt9zw-iRwcfQDl-tGRRzcakzIyUyHA5ITwgu3sjAO6GMPVD_ySHTkEHpmMZIaQ6iYwcyw6t8BtVBmXusCBnRxF8SF0dRB1zJgKEncXBAe9gpGYATuabYI2D5QA6l68jDdB_CCPNHEqQnco5Tmacwkm59k4O-_GuM8xBzlsoB6CzBfyIBFzgUsD7hAkccOQwT-hCyMMnPHMvIyVYVMsYuoY2cco6VX3WJAahiGeyzPym67HruU7g2TyumUZ_lyrOmXp8_Euk7ARrjLGnzVJpnelhKw4htdi1EXpc_fz9Ytn3u_XYolv3ZSs7lMdfeKUSQ0Z8vGJItO6iSwjnS_tfS_2a7c1PLts_DTZ4ODpVa1rMweSO3v63ofi9vdqOoogZhjBW127j-4KeY3X_wqZWyrWTx2Jukkek-QF1lsUMSAfDjIRSLHwz1IE64_5QLpFbIzo6MnYK46WpFcbe_4fpwscDlTyBPu6O1s-u5SHUxoCSMDLnSvl0llbdmzrdq0kfepJRlR9w1zQQchXwBr0g97kaSrsn3Pwka0blyke-DWojxOF8c5w9VfC_OmACjmSlehuu8nrFeg8mOiEwzBCwhirMzPxdW9T2T8a7rjUS21R_D9_VxMtHeei30Xky9fTXFnLVu7v74Ko-pXqLZTug_aO6e4rvIPl3tZn2m6pjPvfwm8EvJkM4g63DCyf6dIN8rvVjXnkLd3Smm_Aki8rBRP_K10nt1o0domoqoKDSTravsh2bEvG3rxblRU3XrzdYL82lAPgDIoY2eQcSy1biLOPYFwq0av7s7rxvGa0Y3xfx-R7oiniEslLTHxuAMUBV2U6e-7-4UlPlfnGxQs9skCBlz_oLDoB6AaelqS1CFelA3Lb3cEqCtbFoucrRWUIeWaZl_GoN7oU0-1MA3TdhnjVcOKsGhabrLgw6DFhyZdfmMxnEXArKSTVGPSObhNj5EYYezy6NWOb8svYSLZOlU1q8fYJ7lcJswqroCpubToP4lzEIL_v_OJ1Z9LhbGJK-AgqNDaFS_ggK1fHu1SaYLnHL5qNFqfYXDjUfTvakpcI78SvPhs9IMNNKzT83GmaYL-9Lu2wP7yh7aI7MtptPcvrbbbfv42O50bNT0PdpNx9HIHo9t1LgPfJo-s5k9R8DrPaJMh67tuvZ0at_c2LitJg2WjW8K4cJyAspisK0ViBXNrq1tBnBvoWyt4N5y8GcEKQaxe-s-ese4NY3uOF9ZTiJSjBQ8fVzmF6lUkW5I8TVkVYALfGkA0eFplFhOrXwoMSxna71azmGzfFArN8qNZr1Rb7RqRw3b2lhOqVKttA6alWat1qgf1hvVZu3dtn7JcxsH5VqzUa81Kq1mvdpoVGwLFpmmDNQrkHwTev8XjGAuiw?type=png)](https://mermaid.live/edit#pako:eNqlV9ty4jgQ_RWVt2pexmQI9_hhtwiQK4SLgVyceVBMO7giLEq2kzAk_75tSTYiM9k8LA8UpruPuk93H9tby-cLsBzrUdD1kky79xHBT9ubxSDIbM04XcTkfEUf4Scplf4mx15HAE2AuBDHIY_IN3IehUlIWfgL_0zQ9FNhHEv_jkJyIUKccQpio80dae56Q-EvIU4ETbhwyCBlSVhqP0KUkGsungLGXzJUkegw9d2VwT1vACsuNkT6O8RdcdYfVEvVY-3ck24nW-12RmPSDQX4CWlH8QuIf95N0JPM--0W4jdy6vVeMSV0nHLOSIdijuS8q2FPJeyZN4FEhPAMyr4gXUgQOyPligosKHzOuTiTEedez-eMPxYJ9xldUVI9qGDKeWbuJkqQkDDWoecy9M47CSPKyATiNY9iIC9hsiS6rg6PEnjdZ0gVc8XfyIU3D-MUY9sIsEHg_PTxC0SleX9H14U86nJ7kjKmer6LcVPfx44HKdsn7XJHWt9zw-iRwcfQDl-tGRRzcakzIyUyHA5ITwgu3sjAO6GMPVD_ySHTkEHpmMZIaQ6iYwcyw6t8BtVBmXusCBnRxF8SF0dRB1zJgKEncXBAe9gpGYATuabYI2D5QA6l68jDdB_CCPNHEqQnco5Tmacwkm59k4O-_GuM8xBzlsoB6CzBfyIBFzgUsD7hAkccOQwT-hCyMMnPHMvIyVYVMsYuoY2cco6VX3WJAahiGeyzPym67HruU7g2TyumUZ_lyrOmXp8_Euk7ARrjLGnzVJpnelhKw4htdi1EXpc_fz9Ytn3u_XYolv3ZSs7lMdfeKUSQ0Z8vGJItO6iSwjnS_tfS_2a7c1PLts_DTZ4ODpVa1rMweSO3v63ofi9vdqOoogZhjBW127j-4KeY3X_wqZWyrWTx2Jukkek-QF1lsUMSAfDjIRSLHwz1IE64_5QLpFbIzo6MnYK46WpFcbe_4fpwscDlTyBPu6O1s-u5SHUxoCSMDLnSvl0llbdmzrdq0kfepJRlR9w1zQQchXwBr0g97kaSrsn3Pwka0blyke-DWojxOF8c5w9VfC_OmACjmSlehuu8nrFeg8mOiEwzBCwhirMzPxdW9T2T8a7rjUS21R_D9_VxMtHeei30Xky9fTXFnLVu7v74Ko-pXqLZTug_aO6e4rvIPl3tZn2m6pjPvfwm8EvJkM4g63DCyf6dIN8rvVjXnkLd3Smm_Aki8rBRP_K10nt1o0domoqoKDSTravsh2bEvG3rxblRU3XrzdYL82lAPgDIoY2eQcSy1biLOPYFwq0av7s7rxvGa0Y3xfx-R7oiniEslLTHxuAMUBV2U6e-7-4UlPlfnGxQs9skCBlz_oLDoB6AaelqS1CFelA3Lb3cEqCtbFoucrRWUIeWaZl_GoN7oU0-1MA3TdhnjVcOKsGhabrLgw6DFhyZdfmMxnEXArKSTVGPSObhNj5EYYezy6NWOb8svYSLZOlU1q8fYJ7lcJswqroCpubToP4lzEIL_v_OJ1Z9LhbGJK-AgqNDaFS_ggK1fHu1SaYLnHL5qNFqfYXDjUfTvakpcI78SvPhs9IMNNKzT83GmaYL-9Lu2wP7yh7aI7MtptPcvrbbbfv42O50bNT0PdpNx9HIHo9t1LgPfJo-s5k9R8DrPaJMh67tuvZ0at_c2LitJg2WjW8K4cJyAspisK0ViBXNrq1tBnBvoWyt4N5y8GcEKQaxe-s-ese4NY3uOF9ZTiJSjBQ8fVzmF6lUkW5I8TVkVYALfGkA0eFplFhOrXwoMSxna71azmGzfFArN8qNZr1Rb7RqRw3b2lhOqVKttA6alWat1qgf1hvVZu3dtn7JcxsH5VqzUa81Kq1mvdpoVGwLFpmmDNQrkHwTev8XjGAuiw)
app.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from pathlib import Path
4
+ import time
5
+ import json
6
+ import gradio as gr
7
+
8
+ # This allows imports to work when app.py is in root but modules are in src/
9
+ current_dir = Path(__file__).parent.absolute()
10
+ src_dir = current_dir / "src"
11
+
12
+ if not src_dir.exists():
13
+ raise RuntimeError(f"Source directory not found: {src_dir}")
14
+
15
+ # Add to Python path if not already there
16
+ if str(src_dir) not in sys.path:
17
+ sys.path.insert(0, str(src_dir))
18
+
19
+ print(f"App running from: {current_dir}")
20
+ print(f"Source directory: {src_dir}")
21
+ print(f"Python path includes src: {str(src_dir) in sys.path}")
22
+
23
+ from deepforest_agent.agents.orchestrator import AgentOrchestrator
24
+ from deepforest_agent.utils.state_manager import session_state_manager
25
+ from deepforest_agent.utils.image_utils import (
26
+ encode_pil_image_to_base64_url,
27
+ load_pil_image_from_path,
28
+ get_image_info,
29
+ validate_image_path
30
+ )
31
+ from deepforest_agent.utils.logging_utils import multi_agent_logger
32
+
33
+
34
+ def upload_image(image_path):
35
+ """
36
+ Handle image upload and initialize a new session for the multi-agent workflow.
37
+
38
+ This function is triggered when a user uploads an image. It creates a new
39
+ session with isolated state and updates the UI to show the chat interface
40
+ and monitoring components.
41
+
42
+ Args:
43
+ image_path (str or None): The file path to uploaded image from Gradio
44
+
45
+ Returns:
46
+ tuple: A tuple containing 9 Gradio component updates:
47
+ - gr.Chatbot: Chat interface (visible/hidden)
48
+ - image: Uploaded image state
49
+ - str: Upload status message
50
+ - gr.Textbox: Message input field (visible/hidden)
51
+ - gr.Button: Send button (visible/hidden)
52
+ - gr.Button: Clear button (visible/hidden)
53
+ - gr.Gallery: Generated images gallery (visible/hidden)
54
+ - str: Monitor text with session information
55
+ - str: Session ID for this user
56
+ """
57
+ if image_path is None:
58
+ return (
59
+ gr.Chatbot(visible=False),
60
+ None, # uploaded_image_state
61
+ "No image uploaded",
62
+ gr.Textbox(visible=False),
63
+ gr.Button(visible=False), # send_btn
64
+ gr.Button(visible=False), # clear_btn
65
+ gr.Gallery(visible=False),
66
+ "No image uploaded",
67
+ None # session_id
68
+ )
69
+
70
+ if not validate_image_path(image_path):
71
+ return (
72
+ gr.Chatbot(visible=False),
73
+ None,
74
+ "Invalid image file or path not accessible",
75
+ gr.Textbox(visible=False),
76
+ gr.Button(visible=False),
77
+ gr.Button(visible=False),
78
+ gr.Gallery(visible=False),
79
+ "Invalid image file for analysis.",
80
+ None
81
+ )
82
+
83
+ try:
84
+ pil_image = load_pil_image_from_path(image_path)
85
+ if pil_image is None:
86
+ raise Exception("Failed to load image")
87
+ image_info = get_image_info(image_path)
88
+ except Exception as e:
89
+ return (
90
+ gr.Chatbot(visible=False),
91
+ None,
92
+ f"Error loading image: {str(e)}",
93
+ gr.Textbox(visible=False),
94
+ gr.Button(visible=False),
95
+ gr.Button(visible=False),
96
+ gr.Gallery(visible=False),
97
+ "Error loading image for analysis.",
98
+ None
99
+ )
100
+
101
+ # Create new session for this user
102
+ session_id = session_state_manager.create_session(pil_image)
103
+ session_state_manager.set(session_id, "image_file_path", image_path)
104
+
105
+ detection_monitor = ""
106
+
107
+ multi_agent_logger.log_session_event(
108
+ session_id=session_id,
109
+ event_type="session_created",
110
+ details={
111
+ "image_size": image_info.get("size") if image_info else pil_image.size,
112
+ "image_mode": image_info.get("mode") if image_info else pil_image.mode,
113
+ "image_path": image_path,
114
+ "file_size_bytes": image_info.get("file_size_bytes") if image_info else "unknown"
115
+ }
116
+ )
117
+
118
+ return (
119
+ gr.Chatbot(visible=True, value=[]),
120
+ pil_image,
121
+ f"Image uploaded successfully! Size: {pil_image.size}",
122
+ gr.Textbox(visible=True),
123
+ gr.Button(visible=True), # send_btn
124
+ gr.Button(visible=True), # clear_btn
125
+ gr.Gallery(visible=True, value=[]),
126
+ detection_monitor,
127
+ session_id # Return session ID
128
+ )
129
+
130
+
131
+ def process_message_streaming(user_message, chatbot_history, generated_images, detection_monitor, session_id):
132
+ """
133
+ Process user message through the multi-agent workflow with streaming updates.
134
+
135
+ Args:
136
+ user_message (str): The user's input message
137
+ chatbot_history (list): Current chat history for display
138
+ generated_images (list): List of annotated images in PIL Image objects
139
+ detection_monitor (str): Current detection data monitoring text
140
+ session_id (str): Unique session identifier for this user
141
+
142
+ Yields:
143
+ tuple: A tuple containing 6 updated components:
144
+ - chatbot_history: Updated conversation history
145
+ - msg_input_clear: Empty string to clear message input field
146
+ - generated_images: Updated list of annotated images
147
+ - detection_monitor: Updated detection data monitor
148
+ - send_btn: Button component with interactive state
149
+ - msg_input: Input field component with interactive state
150
+ """
151
+ if not user_message.strip():
152
+ yield chatbot_history, "", generated_images, detection_monitor, gr.Button(interactive=True), gr.Textbox(interactive=True)
153
+ return
154
+
155
+ # Check if session exists
156
+ if session_id is None or not session_state_manager.session_exists(session_id):
157
+ error_msg = "Session expired or invalid. Please upload an image to start a new session."
158
+ chatbot_history.append({"role": "user", "content": user_message})
159
+ chatbot_history.append({"role": "assistant", "content": error_msg})
160
+ yield chatbot_history, "", generated_images, detection_monitor, gr.Button(interactive=True), gr.Textbox(interactive=True)
161
+ return
162
+
163
+ # Check if image is available in session
164
+ current_image = session_state_manager.get(session_id, "current_image")
165
+ if current_image is None:
166
+ error_msg = "No image found in your session. Please upload an image first."
167
+ chatbot_history.append({"role": "user", "content": user_message})
168
+ chatbot_history.append({"role": "assistant", "content": error_msg})
169
+ yield chatbot_history, "", generated_images, detection_monitor, gr.Button(interactive=True), gr.Textbox(interactive=True)
170
+ return
171
+
172
+ total_execution_start = time.perf_counter()
173
+
174
+ multi_agent_logger.log_user_query(
175
+ session_id=session_id,
176
+ user_message=user_message
177
+ )
178
+
179
+ try:
180
+ if session_state_manager.get(session_id, "first_message", True):
181
+ image_base64_url = encode_pil_image_to_base64_url(current_image)
182
+ user_msg = {
183
+ "role": "user",
184
+ "content": [
185
+ {"type": "image", "image": image_base64_url},
186
+ {"type": "text", "text": user_message}
187
+ ]
188
+ }
189
+ session_state_manager.set(session_id, "first_message", False)
190
+ else:
191
+ user_msg = {
192
+ "role": "user",
193
+ "content": [
194
+ {"type": "text", "text": user_message}
195
+ ]
196
+ }
197
+
198
+ session_state_manager.add_to_conversation(session_id, user_msg)
199
+ chatbot_history.append({"role": "user", "content": user_message})
200
+
201
+ chatbot_history.append({"role": "assistant", "content": "Starting analysis..."})
202
+
203
+ yield chatbot_history, "", generated_images, detection_monitor, gr.Button(interactive=False), gr.Textbox(interactive=False)
204
+
205
+ conversation_history = session_state_manager.get(session_id, "conversation_history", [])
206
+
207
+ print(f"Session {session_id} - User message: {user_message}")
208
+
209
+ orchestrator = AgentOrchestrator()
210
+
211
+ start_time = time.perf_counter()
212
+
213
+ try:
214
+ # Process with streaming updates
215
+ final_result = None
216
+
217
+ for result in orchestrator.process_user_message_streaming(
218
+ user_message=user_message,
219
+ conversation_history=conversation_history,
220
+ session_id=session_id
221
+ ):
222
+ if result["type"] == "progress":
223
+ chatbot_history[-1] = {"role": "assistant", "content": result["message"]}
224
+
225
+ yield chatbot_history, "", generated_images, detection_monitor, gr.Button(interactive=False), gr.Textbox(interactive=False)
226
+
227
+ elif result["type"] == "memory_direct":
228
+ final_response = result["message"]
229
+ chatbot_history[-1] = {"role": "assistant", "content": final_response}
230
+
231
+ updated_detection_monitor = result.get("detection_data", "")
232
+
233
+ final_result = result
234
+
235
+ yield chatbot_history, "", generated_images, updated_detection_monitor, gr.Button(interactive=True), gr.Textbox(interactive=True)
236
+ break
237
+
238
+ elif result["type"] == "streaming":
239
+ # Update the last message with streaming response
240
+ chatbot_history[-1] = {"role": "assistant", "content": result["message"]}
241
+
242
+ yield chatbot_history, "", generated_images, detection_monitor, gr.Button(interactive=False), gr.Textbox(interactive=False)
243
+
244
+ if result.get("is_complete", False):
245
+ final_response = result["message"]
246
+
247
+ elif result["type"] == "final":
248
+ final_response = result["message"]
249
+ chatbot_history[-1] = {"role": "assistant", "content": final_response}
250
+
251
+ final_result = result
252
+ break
253
+
254
+ if final_result:
255
+ total_execution_time = time.perf_counter() - total_execution_start
256
+
257
+ execution_summary = final_result.get("execution_summary", {})
258
+ agent_results = final_result.get("agent_results", {})
259
+ execution_time = final_result.get("execution_time", 0)
260
+
261
+ assistant_msg = {
262
+ "role": "assistant",
263
+ "content": [{"type": "text", "text": final_response}]
264
+ }
265
+ session_state_manager.add_to_conversation(session_id, assistant_msg)
266
+
267
+ multi_agent_logger.log_agent_execution(
268
+ session_id=session_id,
269
+ agent_name="ecology",
270
+ agent_input="Final synthesis of all agent outputs",
271
+ agent_output=final_response,
272
+ execution_time=total_execution_time
273
+ )
274
+
275
+ annotated_image = session_state_manager.get(session_id, "annotated_image")
276
+ if annotated_image:
277
+ generated_images.append(annotated_image)
278
+
279
+ updated_detection_monitor = final_result.get("detection_data", "")
280
+
281
+ yield chatbot_history, "", generated_images, updated_detection_monitor, gr.Button(interactive=True), gr.Textbox(interactive=True)
282
+
283
+ finally:
284
+ orchestrator.cleanup_all_agents()
285
+
286
+ except Exception as e:
287
+ total_execution_time = time.perf_counter() - total_execution_start
288
+ error_msg = f"Workflow error: {str(e)}"
289
+ print(f"MAIN APP ERROR (Session {session_id}): {error_msg}")
290
+
291
+ multi_agent_logger.log_error(
292
+ session_id=session_id,
293
+ error_type="app_workflow_error",
294
+ error_message=f"Workflow failed after {total_execution_time:.2f}s: {str(e)}"
295
+ )
296
+
297
+ if chatbot_history and chatbot_history[-1]["role"] == "assistant":
298
+ chatbot_history[-1] = {"role": "assistant", "content": error_msg}
299
+ else:
300
+ chatbot_history.append({"role": "assistant", "content": error_msg})
301
+
302
+ error_detection_monitor = "ERROR: Workflow failed - no detection data available"
303
+
304
+ yield chatbot_history, "", generated_images, error_detection_monitor, gr.Button(interactive=True), gr.Textbox(interactive=True)
305
+
306
+ def clear_chat(session_id):
307
+ """
308
+ Clear chat history and cancel any ongoing processing for the session.
309
+
310
+ Args:
311
+ session_id (str): The session identifier to clear. Must correspond to
312
+ an existing active session.
313
+
314
+ Returns:
315
+ tuple: A tuple containing 5 updated components:
316
+ - chatbot_history: Empty list clearing chat display
317
+ - generated_images: Empty list clearing image gallery
318
+ - monitor_message: Status message indicating successful clear
319
+ operation and session ID
320
+ - send_btn: Re-enabled send button component
321
+ - msg_input: Re-enabled message input component
322
+
323
+ """
324
+ if session_id and session_state_manager.session_exists(session_id):
325
+ session_state_manager.cancel_session(session_id)
326
+ session_state_manager.clear_conversation(session_id)
327
+
328
+ multi_agent_logger.log_session_event(
329
+ session_id=session_id,
330
+ event_type="conversation_cleared"
331
+ )
332
+
333
+ return (
334
+ [], # chatbot
335
+ [], # generated_images
336
+ "",
337
+ gr.Button(interactive=True), # Re-enable send button
338
+ gr.Textbox(interactive=True) # Re-enable message input
339
+ )
340
+ else:
341
+ return (
342
+ [], # chatbot
343
+ [], # generated_images
344
+ "",
345
+ gr.Button(interactive=True), # Re-enable send button
346
+ gr.Textbox(interactive=True) # Re-enable message input
347
+ )
348
+
349
+
350
+ def create_interface():
351
+ """
352
+ Create and configure the complete Gradio web interface with streaming support.
353
+
354
+ Returns:
355
+ gr.Blocks: Complete Gradio application interface
356
+ """
357
+
358
+ with gr.Blocks(
359
+ title="DeepForest Multi-Agent System",
360
+ theme=gr.themes.Default(
361
+ spacing_size=gr.themes.sizes.spacing_sm,
362
+ radius_size=gr.themes.sizes.radius_none,
363
+ primary_hue=gr.themes.colors.emerald,
364
+ secondary_hue=gr.themes.colors.lime
365
+ )
366
+ ) as app:
367
+
368
+ # Gradio State variables
369
+ uploaded_image_state = gr.State(None)
370
+ generated_images_state = gr.State([])
371
+ session_id_state = gr.State(None)
372
+
373
+ gr.Markdown("# DeepForest Multi-Agent System")
374
+ gr.Markdown("*DeepForest with SmolLM3-3B + Qwen-VL-3B-Instruct + Llama 3.2-3B-Instruct*")
375
+
376
+ with gr.Row():
377
+ # Left column
378
+ with gr.Column(scale=1):
379
+ image_upload = gr.Image(
380
+ type="filepath",
381
+ label="Upload Ecological Image",
382
+ height=300
383
+ )
384
+ upload_status = gr.Textbox(
385
+ label="Upload Status",
386
+ value="Upload an image to begin analysis",
387
+ interactive=False
388
+ )
389
+
390
+ # Right column
391
+ with gr.Column(scale=2):
392
+ chatbot = gr.Chatbot(
393
+ label="Multi-Agent Ecological Analysis",
394
+ height=400,
395
+ visible=False,
396
+ show_copy_button=True,
397
+ type='messages'
398
+ )
399
+
400
+ with gr.Row():
401
+ msg_input = gr.Textbox(
402
+ placeholder="Ask about wildlife, forest health, ecological patterns...",
403
+ scale=4,
404
+ visible=False
405
+ )
406
+ send_btn = gr.Button("Analyze", scale=1, visible=False, variant="primary")
407
+ clear_btn = gr.Button("Clear", scale=1, visible=False)
408
+
409
+ with gr.Row():
410
+ generated_images_display = gr.Gallery(
411
+ label="Annotated Images after DeepForest Detection",
412
+ columns=2,
413
+ height=400,
414
+ visible=False,
415
+ show_label=True
416
+ )
417
+
418
+ with gr.Row():
419
+ with gr.Column():
420
+ gr.Markdown("### Detection Data Monitor")
421
+
422
+ detection_data_monitor = gr.Textbox(
423
+ label="Detection Data Monitor",
424
+ value="Upload an image and ask a question to see detection data",
425
+ interactive=False,
426
+ show_copy_button=True
427
+ )
428
+
429
+ with gr.Row(visible=False) as example_row:
430
+ gr.Markdown("""
431
+ **Multi-agent test questions:**
432
+ - How many trees are detected, and how many of them are alive vs dead?
433
+ - How many birds are around each dead tree?
434
+ - What objects are in the northwest region of the image?
435
+ - Do any birds overlap with livestock in this image?
436
+ - What percentage of the image is covered by trees vs birds vs livestock?
437
+ """)
438
+
439
+ # Image upload
440
+ image_upload.change(
441
+ fn=upload_image,
442
+ inputs=[image_upload],
443
+ outputs=[
444
+ chatbot,
445
+ uploaded_image_state,
446
+ upload_status,
447
+ msg_input,
448
+ send_btn,
449
+ clear_btn,
450
+ generated_images_display,
451
+ detection_data_monitor,
452
+ session_id_state
453
+ ]
454
+ ).then(
455
+ fn=lambda: gr.Row(visible=True),
456
+ outputs=[example_row]
457
+ )
458
+
459
+ # Send button with streaming
460
+ send_btn.click(
461
+ fn=process_message_streaming,
462
+ inputs=[msg_input, chatbot, generated_images_state, detection_data_monitor, session_id_state],
463
+ outputs=[chatbot, msg_input, generated_images_state, detection_data_monitor, send_btn, msg_input]
464
+ ).then(
465
+ fn=lambda images: images,
466
+ inputs=[generated_images_state],
467
+ outputs=[generated_images_display]
468
+ )
469
+
470
+ # Enter key with streaming
471
+ msg_input.submit(
472
+ fn=process_message_streaming,
473
+ inputs=[msg_input, chatbot, generated_images_state, detection_data_monitor, session_id_state],
474
+ outputs=[chatbot, msg_input, generated_images_state, detection_data_monitor, send_btn, msg_input]
475
+ ).then(
476
+ fn=lambda images: images,
477
+ inputs=[generated_images_state],
478
+ outputs=[generated_images_display]
479
+ )
480
+
481
+ clear_btn.click(
482
+ fn=clear_chat,
483
+ inputs=[session_id_state],
484
+ outputs=[chatbot, generated_images_state, detection_data_monitor, send_btn, msg_input]
485
+ ).then(
486
+ fn=lambda: [],
487
+ outputs=[generated_images_display]
488
+ )
489
+
490
+ return app
491
+
492
+
493
+ app = create_interface()
494
+
495
+ if __name__ == "__main__":
496
+ app.launch(
497
+ share=True,
498
+ debug=True,
499
+ show_error=True,
500
+ max_threads=3
501
+ )
pyproject.toml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "deepforest_agent"
3
+ version = "0.1.0"
4
+ description = "AI Agent for DeepForest object detection"
5
+ authors = [
6
+ {name = "Your Name", email = "you@example.com"}
7
+ ]
8
+ requires-python = ">=3.12"
9
+ readme = "README.md"
10
+ dependencies = [
11
+ "accelerate",
12
+ "albumentations<2.0",
13
+ "deepforest",
14
+ "fastapi",
15
+ "geopandas",
16
+ "google-genai",
17
+ "google-generativeai",
18
+ "gradio",
19
+ "gradio-image-annotation",
20
+ "langchain",
21
+ "langchain-community",
22
+ "langchain-google-genai",
23
+ "langchain-huggingface",
24
+ "langgraph",
25
+ "matplotlib",
26
+ "numpy",
27
+ "rtree",
28
+ "num2words",
29
+ "openai",
30
+ "opencv-python",
31
+ "outlines",
32
+ "pandas",
33
+ "pillow",
34
+ "scikit-learn",
35
+ "plotly",
36
+ "pydantic",
37
+ "pydantic-settings",
38
+ "pytest",
39
+ "pytest-cov",
40
+ "python-dotenv",
41
+ "pyyaml",
42
+ "qwen-vl-utils",
43
+ "rasterio",
44
+ "requests",
45
+ "scikit-image",
46
+ "seaborn",
47
+ "shapely",
48
+ "streamlit",
49
+ "torch",
50
+ "torchvision",
51
+ "tqdm",
52
+ "transformers",
53
+ "bitsandbytes",
54
+ ]
55
+
56
+ [project.optional-dependencies]
57
+ dev = [
58
+ "pre-commit",
59
+ "pytest",
60
+ "pytest-profiling",
61
+ "yapf"
62
+ ]
63
+
64
+ [build-system]
65
+ requires = ["setuptools>=61.0"]
66
+ build-backend = "setuptools.build_meta"
requirements.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ albumentations<2.0
3
+ deepforest
4
+ fastapi
5
+ geopandas
6
+ google-genai
7
+ google-generativeai
8
+ gradio
9
+ gradio-image-annotation
10
+ langchain
11
+ langchain-community
12
+ langchain-google-genai
13
+ langchain-huggingface
14
+ langgraph
15
+ matplotlib
16
+ numpy
17
+ rtree
18
+ num2words
19
+ openai
20
+ opencv-python
21
+ outlines
22
+ pandas
23
+ scikit-learn
24
+ pillow
25
+ plotly
26
+ pydantic
27
+ pydantic-settings
28
+ pytest
29
+ pytest-cov
30
+ python-dotenv
31
+ pyyaml
32
+ qwen-vl-utils
33
+ rasterio
34
+ requests
35
+ scikit-image
36
+ seaborn
37
+ shapely
38
+ streamlit
39
+ torch
40
+ torchvision
41
+ tqdm
42
+ transformers
43
+ bitsandbytes
src/__init__.py ADDED
File without changes
src/deepforest_agent/__init__.py ADDED
File without changes
src/deepforest_agent/agents/__init__.py ADDED
File without changes
src/deepforest_agent/agents/deepforest_detector_agent.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional
2
+ import json
3
+ import re
4
+ import time
5
+
6
+ from deepforest_agent.utils.cache_utils import tool_call_cache
7
+ from deepforest_agent.models.smollm3_3b import SmolLM3ModelManager
8
+ from deepforest_agent.tools.tool_handler import handle_tool_call, extract_all_tool_calls
9
+ from deepforest_agent.conf.config import Config
10
+ from deepforest_agent.prompts.prompt_templates import create_detector_system_prompt_with_reasoning, get_deepforest_tool_schema
11
+ from deepforest_agent.utils.state_manager import session_state_manager
12
+ from deepforest_agent.utils.logging_utils import multi_agent_logger
13
+ from deepforest_agent.utils.parsing_utils import parse_deepforest_agent_response_with_reasoning
14
+ from deepforest_agent.utils.rtree_spatial_utils import DetectionSpatialAnalyzer
15
+ from deepforest_agent.utils.detection_narrative_generator import DetectionNarrativeGenerator
16
+
17
+
18
+
19
+ class DeepForestDetectorAgent:
20
+ """
21
+ DeepForest detector agent responsible for executing object detection.
22
+ Uses SmolLM3-3B model for tool calling.
23
+ """
24
+
25
+ def __init__(self):
26
+ """Initialize the DeepForest Detector Agent."""
27
+ self.agent_config = Config.AGENT_CONFIGS["deepforest_detector"]
28
+ self.model_manager = SmolLM3ModelManager(Config.AGENT_MODELS["deepforest_detector"])
29
+
30
+ def _filter_models_based_on_visual(self, visual_objects: List[str], original_models: List[str]):
31
+ """
32
+ Filter original model names based on visual agent's detected objects.
33
+ Remove models that weren't visually detected.
34
+
35
+ Args:
36
+ visual_objects (List[str]): Objects detected by visual agent
37
+ original_models (List[str]): Original model list from tool call
38
+ """
39
+ pass
40
+
41
+ def execute_detection_with_context(
42
+ self,
43
+ user_message: str,
44
+ session_id: str,
45
+ visual_objects_detected: List[str],
46
+ memory_context: str
47
+ ) -> Dict[str, Any]:
48
+ """
49
+ Execute DeepForest detection with R-tree spatial analysis and narrative generation.
50
+
51
+ Args:
52
+ user_message (str): User's query
53
+ session_id (str): Unique session identifier for this user
54
+ visual_objects_detected (List[str]): Objects detected by visual agent
55
+ memory_context (str): Context from memory agent
56
+
57
+ Returns:
58
+ Dictionary with detection results, R-tree analysis, and narrative
59
+ """
60
+ # Validate session exists
61
+ if not session_state_manager.session_exists(session_id):
62
+ return {
63
+ "detection_summary": f"Session {session_id} not found.",
64
+ "detections_list": [],
65
+ "total_detections": 0,
66
+ "status": "error",
67
+ "error": f"Session {session_id} not found",
68
+ "detection_narrative": "No detection narrative available due to session error."
69
+ }
70
+
71
+ try:
72
+ tool_generation_start = time.perf_counter()
73
+
74
+ system_prompt = create_detector_system_prompt_with_reasoning(
75
+ user_message, memory_context, visual_objects_detected
76
+ )
77
+
78
+ messages = [
79
+ {"role": "system", "content": system_prompt},
80
+ {"role": "user", "content": user_message}
81
+ ]
82
+
83
+ deepforest_tool_schema = get_deepforest_tool_schema()
84
+
85
+ response = self.model_manager.generate_response(
86
+ messages=messages,
87
+ max_new_tokens=self.agent_config["max_new_tokens"],
88
+ temperature=self.agent_config["temperature"],
89
+ top_p=self.agent_config["top_p"],
90
+ tools=[deepforest_tool_schema]
91
+ )
92
+
93
+ tool_generation_time = time.perf_counter() - tool_generation_start
94
+
95
+ print(f"Session {session_id} - Detector Raw Response: {response}")
96
+
97
+ multi_agent_logger.log_agent_execution(
98
+ session_id=session_id,
99
+ agent_name="detector",
100
+ agent_input=f"User: {user_message}",
101
+ agent_output=response,
102
+ execution_time=tool_generation_time
103
+ )
104
+
105
+ parsed_response = self._parse_response_with_reasoning(response)
106
+
107
+ if "error" in parsed_response:
108
+ multi_agent_logger.log_error(
109
+ session_id=session_id,
110
+ error_type="tool_call_parsing_error",
111
+ error_message=parsed_response["error"]
112
+ )
113
+
114
+ return {
115
+ "detection_summary": f"Tool call parsing failed: {parsed_response['error']}",
116
+ "detections_list": [],
117
+ "total_detections": 0,
118
+ "status": "error",
119
+ "error": parsed_response["error"],
120
+ "detection_narrative": "No detection narrative available due to parsing error."
121
+ }
122
+
123
+ reasoning = parsed_response["reasoning"]
124
+ tool_calls = parsed_response["tool_calls"]
125
+
126
+ print(f"Session {session_id} - Reasoning: {reasoning}")
127
+ print(f"Session {session_id} - Found {len(tool_calls)} tool calls")
128
+
129
+ all_results = []
130
+ combined_detection_summary = []
131
+ combined_detections_list = []
132
+ total_detections = 0
133
+
134
+ for i, tool_call in enumerate(tool_calls):
135
+ print(f"Session {session_id} - Executing tool call {i+1}/{len(tool_calls)}")
136
+
137
+ tool_name = tool_call["name"]
138
+ tool_arguments = tool_call["arguments"]
139
+
140
+ cached_result = tool_call_cache.get_cached_result(tool_name, tool_arguments)
141
+
142
+ if cached_result:
143
+ print(f"Session {session_id} - Tool call {i+1}: Using cached results")
144
+
145
+ if cached_result.get("annotated_image"):
146
+ session_state_manager.set(session_id, "annotated_image", cached_result["annotated_image"])
147
+
148
+ cache_key = cached_result["cache_info"]["cache_key"]
149
+ session_state_manager.add_tool_call_to_history(
150
+ session_id, tool_name, tool_arguments, cache_key
151
+ )
152
+
153
+ multi_agent_logger.log_tool_call(
154
+ session_id=session_id,
155
+ tool_name=tool_name,
156
+ tool_arguments=tool_arguments,
157
+ tool_result=cached_result,
158
+ execution_time=0.0,
159
+ cache_hit=True,
160
+ reasoning=f"Tool call {i+1}: {reasoning}"
161
+ )
162
+
163
+ tool_result = {
164
+ "tool_call_number": i + 1,
165
+ "tool_name": tool_name,
166
+ "tool_arguments": tool_arguments,
167
+ "cache_key": cache_key,
168
+ "detection_summary": cached_result["detection_summary"],
169
+ "detections_list": cached_result.get("detections_list", []),
170
+ "total_detections": len(cached_result.get("detections_list", [])),
171
+ "status": "success",
172
+ "cache_hit": True
173
+ }
174
+
175
+ all_results.append(tool_result)
176
+ combined_detection_summary.append(cached_result["detection_summary"])
177
+ combined_detections_list.extend(cached_result.get("detections_list", []))
178
+ total_detections += len(cached_result.get("detections_list", []))
179
+
180
+ else:
181
+ print(f"Session {session_id} - Tool call {i+1}: Cache MISS, executing tool")
182
+
183
+ tool_execution_start = time.perf_counter()
184
+ execution_result = handle_tool_call(tool_name, tool_arguments, session_id)
185
+
186
+ tool_execution_time = time.perf_counter() - tool_execution_start
187
+
188
+ if isinstance(execution_result, dict) and "detection_summary" in execution_result:
189
+ cache_result = {
190
+ "detection_summary": execution_result["detection_summary"],
191
+ "detections_list": execution_result.get("detections_list", []),
192
+ "total_detections": execution_result.get("total_detections", 0),
193
+ "status": "success"
194
+ }
195
+
196
+ annotated_image = session_state_manager.get(session_id, "annotated_image")
197
+ if annotated_image:
198
+ cache_result["annotated_image"] = annotated_image
199
+
200
+ cache_key = tool_call_cache.store_result(tool_name, tool_arguments, cache_result)
201
+
202
+ session_state_manager.add_tool_call_to_history(
203
+ session_id, tool_name, tool_arguments, cache_key
204
+ )
205
+
206
+ multi_agent_logger.log_tool_call(
207
+ session_id=session_id,
208
+ tool_name=tool_name,
209
+ tool_arguments=tool_arguments,
210
+ tool_result=execution_result,
211
+ execution_time=tool_execution_time,
212
+ cache_hit=False,
213
+ reasoning=f"Tool call {i+1}: {reasoning}"
214
+ )
215
+
216
+ tool_result = {
217
+ "tool_call_number": i + 1,
218
+ "tool_name": tool_name,
219
+ "tool_arguments": tool_arguments,
220
+ "cache_key": cache_key,
221
+ "detection_summary": execution_result["detection_summary"],
222
+ "detections_list": execution_result.get("detections_list", []),
223
+ "total_detections": execution_result.get("total_detections", 0),
224
+ "status": "success",
225
+ "cache_hit": False
226
+ }
227
+
228
+ all_results.append(tool_result)
229
+ combined_detection_summary.append(execution_result["detection_summary"])
230
+ combined_detections_list.extend(execution_result.get("detections_list", []))
231
+ total_detections += execution_result.get("total_detections", 0)
232
+
233
+ else:
234
+ error_msg = str(execution_result) if isinstance(execution_result, str) else "Unknown tool execution error"
235
+ print(f"Session {session_id} - Tool call {i+1} execution failed: {error_msg}")
236
+
237
+ multi_agent_logger.log_error(
238
+ session_id=session_id,
239
+ error_type="tool_execution_error",
240
+ error_message=f"Tool call {i+1} execution failed after {tool_execution_time:.2f}s: {error_msg}"
241
+ )
242
+
243
+ tool_result = {
244
+ "tool_call_number": i + 1,
245
+ "tool_name": tool_name,
246
+ "tool_arguments": tool_arguments,
247
+ "detection_summary": f"Tool call {i+1} failed: {error_msg}",
248
+ "detections_list": [],
249
+ "total_detections": 0,
250
+ "status": "error",
251
+ "error": error_msg,
252
+ "cache_hit": False
253
+ }
254
+ all_results.append(tool_result)
255
+
256
+ final_detection_summary = " | ".join(combined_detection_summary) if combined_detection_summary else "No successful detections"
257
+
258
+ # Generate comprehensive R-tree based narrative
259
+ detection_narrative = self._generate_spatial_narrative(
260
+ combined_detections_list, session_id
261
+ )
262
+
263
+ # Log the detection narrative
264
+ multi_agent_logger.log_agent_execution(
265
+ session_id=session_id,
266
+ agent_name="detection_narrative",
267
+ agent_input=f"Detection narrative for {len(combined_detections_list)} detections",
268
+ agent_output=detection_narrative,
269
+ execution_time=0.0
270
+ )
271
+
272
+ result = {
273
+ "detection_summary": final_detection_summary,
274
+ "detections_list": combined_detections_list,
275
+ "total_detections": total_detections,
276
+ "status": "success",
277
+ "reasoning": reasoning,
278
+ "visual_objects_input": visual_objects_detected,
279
+ "tool_calls_executed": len(tool_calls),
280
+ "tool_results": all_results,
281
+ "detection_narrative": detection_narrative,
282
+ "raw_tool_response": response
283
+ }
284
+
285
+ print(f"Session {session_id} - Executed {len(tool_calls)} tool calls successfully")
286
+ print(f"Session {session_id} - Generated detection narrative ({len(detection_narrative)} characters)")
287
+ return result
288
+
289
+ except Exception as e:
290
+ error_msg = f"Error in detector agent for session {session_id}: {str(e)}"
291
+ print(f"Detector Agent Error: {error_msg}")
292
+
293
+ multi_agent_logger.log_error(
294
+ session_id=session_id,
295
+ error_type="detector_agent_exception",
296
+ error_message=error_msg
297
+ )
298
+
299
+ return {
300
+ "detection_summary": f"Detection agent error: {error_msg}",
301
+ "detections_list": [],
302
+ "total_detections": 0,
303
+ "status": "error",
304
+ "error": error_msg,
305
+ "visual_objects_input": visual_objects_detected,
306
+ "detection_narrative": f"Detection narrative generation failed due to error: {error_msg}"
307
+ }
308
+
309
+ def _generate_spatial_narrative(self, detections_list: List[Dict[str, Any]], session_id: str) -> str:
310
+ """
311
+ Generate comprehensive spatial narrative using R-tree analysis.
312
+
313
+ Args:
314
+ detections_list: Combined list of all detections
315
+ session_id: Session identifier for getting image dimensions
316
+
317
+ Returns:
318
+ Comprehensive detection narrative
319
+ """
320
+ if not detections_list:
321
+ return "No detections available for spatial narrative generation."
322
+
323
+ try:
324
+ # Get image dimensions
325
+ current_image = session_state_manager.get(session_id, "current_image")
326
+ if current_image:
327
+ image_width, image_height = current_image.size
328
+ else:
329
+ # Default dimensions if image not available
330
+ image_width, image_height = 1920, 1080
331
+
332
+ # Generate narrative using DetectionNarrativeGenerator
333
+ narrative_generator = DetectionNarrativeGenerator(image_width, image_height)
334
+ comprehensive_narrative = narrative_generator.generate_comprehensive_narrative(detections_list)
335
+
336
+ print(f"Session {session_id} - Generated comprehensive spatial narrative")
337
+
338
+ return comprehensive_narrative
339
+
340
+ except Exception as e:
341
+ error_msg = f"Error generating spatial narrative: {str(e)}"
342
+ print(f"Session {session_id} - {error_msg}")
343
+
344
+ # Just return the detection summary itself
345
+ total_count = len(detections_list)
346
+ label_counts = {}
347
+ classification_counts = {}
348
+
349
+ for detection in detections_list:
350
+ base_label = detection.get('label', 'unknown')
351
+ label_counts[base_label] = label_counts.get(base_label, 0) + 1
352
+
353
+ # Handle tree classifications
354
+ if base_label == 'tree':
355
+ classification_label = detection.get('classification_label')
356
+ classification_score = detection.get('classification_score')
357
+
358
+ # Only count valid classifications (not NaN)
359
+ if (classification_label and
360
+ classification_score is not None and
361
+ str(classification_label).lower() != 'nan' and
362
+ str(classification_score).lower() != 'nan'):
363
+
364
+ classification_counts[classification_label] = classification_counts.get(classification_label, 0) + 1
365
+
366
+ # Build simple summary
367
+ object_parts = []
368
+ for label, count in label_counts.items():
369
+ if label == 'tree' and classification_counts:
370
+ # Special handling for trees with classifications
371
+ total_trees = count
372
+ tree_part = f"{total_trees} trees are detected"
373
+
374
+ if classification_counts:
375
+ classification_parts = []
376
+ for class_label, class_count in classification_counts.items():
377
+ class_name = class_label.replace('_', ' ')
378
+ classification_parts.append(f"{class_count} {class_name}s")
379
+
380
+ tree_part += f". These {total_trees} trees are classified as {' and '.join(classification_parts)}"
381
+
382
+ object_parts.append(tree_part)
383
+ else:
384
+ label_name = label.replace('_', ' ')
385
+ object_parts.append(f"{count} {label_name}{'s' if count != 1 else ''}")
386
+
387
+ fallback_summary = f"DeepForest detected {total_count} objects: {', '.join(object_parts)}."
388
+
389
+ return fallback_summary
390
+
391
+ def _parse_response_with_reasoning(self, response: str) -> Dict[str, Any]:
392
+ """
393
+ Parse model response to extract reasoning and multiple tool calls.
394
+
395
+ Args:
396
+ response (str): Raw response from the model
397
+
398
+ Returns:
399
+ Dictionary containing either:
400
+ - {"reasoning": str, "tool_call": dict} on success
401
+ - {"error": str} on parsing failure
402
+ """
403
+ return parse_deepforest_agent_response_with_reasoning(response)
src/deepforest_agent/agents/ecology_analysis_agent.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Dict, List, Any, Optional, Generator
3
+
4
+ from deepforest_agent.models.llama32_3b_instruct import Llama32ModelManager
5
+ from deepforest_agent.conf.config import Config
6
+ from deepforest_agent.prompts.prompt_templates import create_ecology_synthesis_prompt
7
+ from deepforest_agent.utils.state_manager import session_state_manager
8
+
9
+
10
+ class EcologyAnalysisAgent:
11
+ """
12
+ Ecology analysis agent responsible for combining all data into comprehensive ecological insights.
13
+ Uses Llama-3.2-3B-Instruct model for detailed structured response generation with analysis.
14
+ """
15
+
16
+ def __init__(self):
17
+ """Initialize the Ecology Analysis Agent."""
18
+ self.agent_config = Config.AGENT_CONFIGS["ecology_analysis"]
19
+ self.model_manager = Llama32ModelManager(Config.AGENT_MODELS["ecology_analysis"])
20
+
21
+ def synthesize_analysis_streaming(
22
+ self,
23
+ user_message: str,
24
+ memory_context: str,
25
+ cached_json: Optional[Dict[str, Any]] = None,
26
+ current_json: Optional[Dict[str, Any]] = None,
27
+ session_id: Optional[str] = None
28
+ ) -> Generator[Dict[str, Any], None, None]:
29
+ """
30
+ Synthesize all agent outputs with streaming text generation.
31
+
32
+ Args:
33
+ user_message (str): The user's original query for the analysis.
34
+ memory_context (str): The context and conversation history provided
35
+ by a memory agent.
36
+ cached_json (Optional[Dict[str, Any]]): A dictionary of previously
37
+ cached JSON data, if available. Defaults to None.
38
+ current_json (Optional[Dict[str, Any]]): A dictionary of new JSON data
39
+ from the current analysis step. Defaults to None.
40
+ session_id (Optional[str]): A unique session identifier for tracking
41
+ and logging. Defaults to None.
42
+
43
+ Yields:
44
+ Dict[str, Any]: Dictionary containing:
45
+ - token: Generated text token
46
+ - is_complete: Whether generation is finished
47
+ """
48
+ if session_id and not session_state_manager.session_exists(session_id):
49
+ yield {
50
+ "token": f"Session {session_id} not found. Unable to synthesize analysis.",
51
+ "is_complete": True
52
+ }
53
+ return
54
+
55
+ try:
56
+ synthesis_prompt = create_ecology_synthesis_prompt(
57
+ user_message=user_message,
58
+ comprehensive_context=memory_context,
59
+ cached_json=cached_json,
60
+ current_json=current_json
61
+ )
62
+ print(f"Ecology Synthesis Prompt:\n{synthesis_prompt}\n")
63
+
64
+ messages = [
65
+ {"role": "system", "content": synthesis_prompt},
66
+ {"role": "user", "content": user_message}
67
+ ]
68
+
69
+ print(f"Session {session_id} - Ecology Agent: Starting streaming synthesis")
70
+
71
+ # Stream the response token by token
72
+ for token_data in self.model_manager.generate_response_streaming(
73
+ messages=messages,
74
+ max_new_tokens=self.agent_config["max_new_tokens"],
75
+ temperature=self.agent_config["temperature"],
76
+ top_p=self.agent_config["top_p"]
77
+ ):
78
+ yield token_data
79
+
80
+ if token_data["is_complete"]:
81
+ print(f"Session {session_id} - Ecology Agent: Streaming synthesis completed")
82
+ break
83
+
84
+ except Exception as e:
85
+ error_msg = f"Error in ecology synthesis for session {session_id}: {str(e)}"
86
+ print(f"Ecology Analysis Error: {error_msg}")
87
+
88
+ # Yield error_msg response as single token
89
+ yield {
90
+ "token": error_msg,
91
+ "is_complete": True
92
+ }
src/deepforest_agent/agents/memory_agent.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional
2
+ import re
3
+ import time
4
+ import json
5
+
6
+ from deepforest_agent.models.smollm3_3b import SmolLM3ModelManager
7
+ from deepforest_agent.conf.config import Config
8
+ from deepforest_agent.prompts.prompt_templates import format_memory_prompt
9
+ from deepforest_agent.utils.state_manager import session_state_manager
10
+ from deepforest_agent.utils.logging_utils import multi_agent_logger
11
+ from deepforest_agent.utils.parsing_utils import parse_memory_agent_response
12
+ from deepforest_agent.utils.cache_utils import tool_call_cache
13
+ from deepforest_agent.conf.config import Config
14
+
15
+ class MemoryAgent:
16
+ """
17
+ Memory agent responsible for analyzing conversation history in new format.
18
+ Uses SmolLM3-3B model for getting relevant context
19
+ """
20
+
21
+ def __init__(self):
22
+ """Initialize the Memory Agent with model manager and configuration."""
23
+ self.agent_config = Config.AGENT_CONFIGS["memory"]
24
+ self.model_manager = SmolLM3ModelManager(Config.AGENT_MODELS["memory"])
25
+
26
+ def _filter_conversation_history(self, conversation_history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
27
+ """
28
+ Filter conversation history to include user and assistant messages.
29
+
30
+ Args:
31
+ conversation_history: Full conversation history
32
+
33
+ Returns:
34
+ Filtered history with only user/assistant messages
35
+ """
36
+ filtered_history = []
37
+
38
+ for message in conversation_history:
39
+ if message.get("role") in ["user", "assistant"]:
40
+ content = message.get("content", "")
41
+ if isinstance(content, list):
42
+ text_parts = [item.get("text", "") for item in content if item.get("type") == "text"]
43
+ content = " ".join(text_parts)
44
+ elif isinstance(content, str):
45
+ content = content
46
+ else:
47
+ content = str(content)
48
+
49
+ filtered_history.append({
50
+ "role": message["role"],
51
+ "content": content
52
+ })
53
+
54
+ return filtered_history
55
+
56
+ def _get_conversation_history_context(self, session_id: str) -> str:
57
+ """
58
+ Get formatted conversation history with turn-based structure.
59
+
60
+ Args:
61
+ session_id: Session identifier
62
+
63
+ Returns:
64
+ Formatted conversation history with turn structure
65
+ """
66
+ conversation_history = session_state_manager.get(session_id, "conversation_history", [])
67
+
68
+ print(f"Session {session_id} - Conversation length: {len(conversation_history)}")
69
+
70
+ if not conversation_history:
71
+ return "No previous conversation history available."
72
+
73
+ # Build turn-based history
74
+ formatted_history = []
75
+ turn_number = 1
76
+
77
+ # Process conversation in pairs (user -> assistant)
78
+ i = 0
79
+ while i < len(conversation_history):
80
+ if i + 1 < len(conversation_history):
81
+ user_msg = conversation_history[i]
82
+ assistant_msg = conversation_history[i + 1]
83
+
84
+ if user_msg.get("role") == "user" and assistant_msg.get("role") == "assistant":
85
+ # Extract user query
86
+ user_content = user_msg.get("content", "")
87
+ if isinstance(user_content, list):
88
+ text_parts = [item.get("text", "") for item in user_content if item.get("type") == "text"]
89
+ user_query = " ".join(text_parts)
90
+ else:
91
+ user_query = str(user_content)
92
+
93
+ # Get stored context data for this turn
94
+ visual_context = session_state_manager.get(session_id, f"turn_{turn_number}_visual_context", "No visual analysis available")
95
+ detection_narrative = session_state_manager.get(session_id, f"turn_{turn_number}_detection_narrative", "No detection narrative available")
96
+ tool_cache_id = session_state_manager.get(session_id, f"turn_{turn_number}_tool_cache_id", "No tool cache ID")
97
+ tool_call_info = "No tool call information available"
98
+ if tool_cache_id:
99
+ try:
100
+ if tool_cache_id in tool_call_cache.cache_data:
101
+ cached_entry = tool_call_cache.cache_data[tool_cache_id]
102
+ tool_name = cached_entry.get("tool_name", "unknown")
103
+ stored_arguments = cached_entry.get("arguments", {})
104
+
105
+ all_arguments = Config.DEEPFOREST_DEFAULTS.copy()
106
+ all_arguments.update(stored_arguments)
107
+
108
+ # Format tool call info with all arguments
109
+ args_str = ", ".join([f"{k}={v}" for k, v in all_arguments.items()])
110
+ tool_call_info = f"Tool: {tool_name} called with arguments: {args_str}"
111
+ except Exception as e:
112
+ tool_call_info = f"Error retrieving tool call info: {str(e)}"
113
+
114
+ turn_text = f"--- Turn {turn_number}: ---\n"
115
+ turn_text += f"Turn {turn_number} User query: {user_query}\n"
116
+ turn_text += f"Turn {turn_number} Visual analysis full image or per tile: {visual_context}\n"
117
+ turn_text += f"Turn {turn_number} Tool cache ID: {tool_cache_id}\n"
118
+ turn_text += f"Turn {turn_number} Tool call details: {tool_call_info}\n"
119
+ turn_text += f"Turn {turn_number} Detection Data Analysis: {detection_narrative}\n"
120
+ turn_text += f"--- Turn {turn_number} Completed ---\n"
121
+
122
+ formatted_history.append(turn_text)
123
+ turn_number += 1
124
+ i += 2
125
+ else:
126
+ i += 1
127
+ else:
128
+ i += 1
129
+
130
+ if not formatted_history:
131
+ return "No complete conversation turns available."
132
+
133
+ print(f"Formatted {len(formatted_history)} conversation turns")
134
+ return "\n\n".join(formatted_history)
135
+
136
+ def process_conversation_history_structured(
137
+ self,
138
+ conversation_history: List[Dict[str, Any]],
139
+ latest_message: str,
140
+ session_id: str
141
+ ) -> Dict[str, Any]:
142
+ """
143
+ Process conversation history and extract relevant context with structured output.
144
+
145
+ Args:
146
+ conversation_history: Full conversation history
147
+ latest_message: Current user message requiring context analysis
148
+ session_id: Unique session identifier for this user
149
+
150
+ Returns:
151
+ Dict with structured output including tool_cache_id and relevant context
152
+ """
153
+ if not session_state_manager.session_exists(session_id):
154
+ return {
155
+ "answer_present": False,
156
+ "direct_answer": "NO",
157
+ "tool_cache_id": None,
158
+ "relevant_context": f"Session {session_id} not found. Current query: {latest_message}",
159
+ "raw_response": f"Session {session_id} not found"
160
+ }
161
+
162
+ filtered_history = self._filter_conversation_history(conversation_history)
163
+ conversation_context = self._get_conversation_history_context(session_id)
164
+
165
+ memory_prompt = format_memory_prompt(filtered_history, latest_message, conversation_context)
166
+ print(f"Memory Agent Prompt:\n{memory_prompt}\n")
167
+
168
+ messages = [
169
+ {"role": "system", "content": memory_prompt},
170
+ {"role": "user", "content": latest_message}
171
+ ]
172
+
173
+ memory_execution_start = time.perf_counter()
174
+
175
+ try:
176
+ response = self.model_manager.generate_response(
177
+ messages=messages,
178
+ max_new_tokens=self.agent_config["max_new_tokens"],
179
+ temperature=self.agent_config["temperature"],
180
+ top_p=self.agent_config["top_p"]
181
+ )
182
+
183
+ memory_execution_time = time.perf_counter() - memory_execution_start
184
+
185
+ print(f"Session {session_id} - Memory Agent: Raw response received")
186
+ print(f"Raw Response: {response}")
187
+
188
+ parsed_result = parse_memory_agent_response(response)
189
+
190
+ multi_agent_logger.log_agent_execution(
191
+ session_id=session_id,
192
+ agent_name="memory",
193
+ agent_input=f"Latest message: {latest_message}",
194
+ agent_output=response,
195
+ execution_time=memory_execution_time
196
+ )
197
+
198
+ print(f"Session {session_id} - Memory Agent: Analysis completed")
199
+ print(f"Has Answer: {parsed_result['answer_present']}")
200
+
201
+ return parsed_result
202
+
203
+ except Exception as e:
204
+ memory_execution_time = time.perf_counter() - memory_execution_start
205
+ error_msg = f"Error processing conversation history in session {session_id}: {str(e)}"
206
+ print(f"Session {session_id} - Memory Agent Error: {e}")
207
+
208
+ multi_agent_logger.log_error(
209
+ session_id=session_id,
210
+ error_type="memory_agent_error",
211
+ error_message=f"Memory agent failed after {memory_execution_time:.2f}s: {str(e)}"
212
+ )
213
+
214
+ return {
215
+ "answer_present": False,
216
+ "direct_answer": "NO",
217
+ "tool_cache_id": None,
218
+ "relevant_context": f"{error_msg}. Current query: {latest_message}",
219
+ "raw_response": str(e)
220
+ }
221
+
222
+ def store_turn_context(self, session_id: str, turn_number: int, visual_context: str,
223
+ detection_narrative: str, tool_cache_id: Optional[str]) -> None:
224
+ """
225
+ Store context data for a specific conversation turn.
226
+
227
+ Args:
228
+ session_id: Session identifier
229
+ turn_number: Turn number in conversation
230
+ visual_context: Visual analysis context
231
+ detection_narrative: Detection narrative
232
+ tool_cache_id: Tool cache identifier
233
+ """
234
+ session_state_manager.set(session_id, f"turn_{turn_number}_visual_context", visual_context)
235
+ session_state_manager.set(session_id, f"turn_{turn_number}_detection_narrative", detection_narrative)
236
+ session_state_manager.set(session_id, f"turn_{turn_number}_tool_cache_id", tool_cache_id or "No tool cache ID")
237
+
238
+ print(f"Session {session_id} - Stored context for turn {turn_number}")
src/deepforest_agent/agents/orchestrator.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+ import torch
4
+ import gc
5
+ from typing import Dict, List, Any, Optional, Generator
6
+
7
+ from deepforest_agent.agents.memory_agent import MemoryAgent
8
+ from deepforest_agent.agents.deepforest_detector_agent import DeepForestDetectorAgent
9
+ from deepforest_agent.agents.visual_analysis_agent import VisualAnalysisAgent
10
+ from deepforest_agent.agents.ecology_analysis_agent import EcologyAnalysisAgent
11
+ from deepforest_agent.utils.state_manager import session_state_manager
12
+ from deepforest_agent.utils.cache_utils import tool_call_cache
13
+ from deepforest_agent.utils.image_utils import check_image_resolution_for_deepforest
14
+ from deepforest_agent.utils.logging_utils import multi_agent_logger
15
+ from deepforest_agent.utils.detection_narrative_generator import DetectionNarrativeGenerator
16
+
17
+
18
+ class AgentOrchestrator:
19
+ """
20
+ Orchestrates the multi-agent workflow with memory context + visual contexts + DeepForest detection context + ecological synthesis.
21
+ """
22
+
23
+ def __init__(self):
24
+ """Initialize the Agent Orchestrator."""
25
+ self.memory_agent = MemoryAgent()
26
+ self.detector_agent = DeepForestDetectorAgent()
27
+ self.visual_agent = VisualAnalysisAgent()
28
+ self.ecology_agent = EcologyAnalysisAgent()
29
+
30
+ self.execution_stats = {
31
+ "total_runs": 0,
32
+ "successful_runs": 0,
33
+ "average_execution_time": 0.0,
34
+ "memory_direct_answers": 0,
35
+ "deepforest_skipped": 0
36
+ }
37
+
38
+ def _log_gpu_memory(self, session_id: str, stage: str, agent_name: str):
39
+ """
40
+ Log current GPU memory usage.
41
+
42
+ Args:
43
+ session_id (str): Unique identifier for the user session being processed
44
+ stage (str): Workflow stage identifier (e.g., "before", "after", "cleanup")
45
+ agent_name (str): Name of the agent being monitored (e.g., "Visual Analysis",
46
+ "DeepForest Detection", "Memory Agent")
47
+ """
48
+ if torch.cuda.is_available():
49
+ allocated_gb = torch.cuda.memory_allocated() / 1024**3
50
+ cached_gb = torch.cuda.memory_reserved() / 1024**3
51
+
52
+ multi_agent_logger.log_agent_execution(
53
+ session_id=session_id,
54
+ agent_name=f"gpu_memory_{stage}",
55
+ agent_input=f"{agent_name} - {stage}",
56
+ agent_output=f"GPU Memory - Allocated: {allocated_gb:.2f} GB, Cached: {cached_gb:.2f} GB",
57
+ execution_time=0.0
58
+ )
59
+ print(f"Session {session_id} - {agent_name} {stage}: GPU Memory - Allocated: {allocated_gb:.2f} GB, Cached: {cached_gb:.2f} GB")
60
+
61
+ def cleanup_all_agents(self):
62
+ """Cleanup models to manage memory."""
63
+ print("Orchestrator cleanup:")
64
+ gc.collect()
65
+ if torch.cuda.is_available():
66
+ torch.cuda.empty_cache()
67
+ torch.cuda.synchronize()
68
+ torch.cuda.ipc_collect()
69
+ print(f"Final GPU memory after orchestrator cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
70
+
71
+ def _aggressive_gpu_cleanup(self, session_id: str, stage: str):
72
+ """
73
+ Perform aggressive GPU memory cleanup.
74
+
75
+ Args:
76
+ session_id (str): Unique identifier for the user session
77
+ stage (str): Workflow stage identifier for logging context
78
+ """
79
+ if torch.cuda.is_available():
80
+ for i in range(3):
81
+ gc.collect()
82
+ torch.cuda.empty_cache()
83
+
84
+ torch.cuda.ipc_collect()
85
+ torch.cuda.synchronize()
86
+
87
+ try:
88
+ torch.cuda.reset_peak_memory_stats()
89
+ torch.cuda.reset_accumulated_memory_stats()
90
+ except:
91
+ pass
92
+
93
+ allocated = torch.cuda.memory_allocated() / 1024**3
94
+ cached = torch.cuda.memory_reserved() / 1024**3
95
+
96
+ print(f"Session {session_id} - {stage} aggressive cleanup: {allocated:.2f} GB allocated, {cached:.2f} GB cached")
97
+
98
+ def _format_detection_data_for_monitor(self, detection_narrative: str, detections_list: Optional[List[Dict[str, Any]]] = None) -> str:
99
+ """
100
+ Format detection data for monitor display.
101
+
102
+ Args:
103
+ detection_narrative: Generated detection context from DeepForest Data
104
+ detections_list: Full DeepForest detection data
105
+
106
+ Returns:
107
+ Formatted detection data for monitor
108
+ """
109
+ monitor_parts = []
110
+
111
+ if detections_list:
112
+ monitor_parts.append("=== DEEPFOREST DETECTIONS ===")
113
+ monitor_parts.append(json.dumps(detections_list, indent=2))
114
+ monitor_parts.append("")
115
+
116
+ if detection_narrative:
117
+ monitor_parts.append("=== DETECTION NARRATIVE ===")
118
+ monitor_parts.append(detection_narrative)
119
+
120
+ return "\n".join(monitor_parts) if monitor_parts else "No detection data available"
121
+
122
+ def _get_cached_detection_narrative(self, tool_cache_id: str) -> Optional[str]:
123
+ """
124
+ Retrieve detection narrative using tool cache ID from the tool_call_cache.
125
+
126
+ Args:
127
+ tool_cache_id: Tool cache identifier
128
+
129
+ Returns:
130
+ Detection context from DeepForest Data if found, None otherwise
131
+ """
132
+ try:
133
+ print(f"Looking up cached detection narrative for tool_cache_id: {tool_cache_id}")
134
+
135
+ # Handle multiple cache IDs
136
+ cache_ids = [id.strip() for id in tool_cache_id.split(",")] if tool_cache_id else []
137
+ all_narratives = []
138
+
139
+ for cache_id in cache_ids:
140
+ if cache_id in tool_call_cache.cache_data:
141
+ cached_entry = tool_call_cache.cache_data[cache_id]
142
+ cached_result = cached_entry.get("result", {})
143
+ tool_name = cached_entry.get("tool_name", "unknown")
144
+ tool_arguments = cached_entry.get("arguments", {})
145
+
146
+ # Get all possible arguments including defaults from Config
147
+ from deepforest_agent.conf.config import Config
148
+ all_arguments = Config.DEEPFOREST_DEFAULTS.copy()
149
+ all_arguments.update(tool_arguments)
150
+
151
+ # Format tool call info with all arguments
152
+ args_str = ", ".join([f"{k}={v}" for k, v in all_arguments.items()])
153
+
154
+ # Check if we have detections_list to generate narrative from
155
+ detections_list = cached_result.get("detections_list", [])
156
+
157
+ if detections_list:
158
+ print(f"Found {len(detections_list)} cached detections for cache ID {cache_id}")
159
+
160
+ # Get image dimensions for narrative generation
161
+ try:
162
+ session_keys = list(session_state_manager._sessions.keys())
163
+ if session_keys:
164
+ current_image = session_state_manager.get(session_keys[0], "current_image")
165
+ if current_image:
166
+ image_width, image_height = current_image.size
167
+ else:
168
+ image_width, image_height = 0, 0
169
+ else:
170
+ image_width, image_height = 0, 0
171
+ except:
172
+ image_width, image_height = 0, 0
173
+
174
+ # Generate fresh narrative from cached detection data
175
+ narrative_generator = DetectionNarrativeGenerator(image_width, image_height)
176
+ cached_detection_narrative = narrative_generator.generate_comprehensive_narrative(detections_list)
177
+
178
+ # Format with proper tool cache ID structure
179
+ formatted_narrative = f"**TOOL CACHE ID:** {cache_id}\nDeepForest tool run with arguments ({args_str}) and got the below narratives:\nDETECTION NARRATIVE:\n{cached_detection_narrative}"
180
+ all_narratives.append(formatted_narrative)
181
+ else:
182
+ detection_summary = cached_result.get("detection_summary", "")
183
+ if detection_summary:
184
+ formatted_summary = f"**TOOL CACHE ID:** {cache_id}\nDeepForest tool run with arguments ({args_str}) and got the below narratives:\nDETECTION NARRATIVE:\n{detection_summary}"
185
+ all_narratives.append(formatted_summary)
186
+
187
+ if all_narratives:
188
+ print(f"Generated {len(all_narratives)} cached detection narratives")
189
+ return "\n\n".join(all_narratives)
190
+
191
+ print(f"No cached data found for tool_cache_id(s): {tool_cache_id}")
192
+ return None
193
+
194
+ except Exception as e:
195
+ print(f"Error retrieving cached detection narrative for {tool_cache_id}: {e}")
196
+ return None
197
+
198
+ def process_user_message_streaming(
199
+ self,
200
+ user_message: str,
201
+ conversation_history: List[Dict[str, Any]],
202
+ session_id: str
203
+ ) -> Generator[Dict[str, Any], None, None]:
204
+ """
205
+ Orchestrate the multi-agent workflow with memory context and detection narrative flow.
206
+
207
+ Args:
208
+ user_message: Current user message/query to be processed
209
+ conversation_history: Full conversation history
210
+ session_id: Unique session identifier for this user's workflow
211
+
212
+ Yields:
213
+ Dict[str, Any]: Progress updates during processing
214
+ """
215
+ start_time = time.perf_counter()
216
+ self.execution_stats["total_runs"] += 1
217
+
218
+ print(f"Session {session_id} - Query: {user_message}")
219
+ print(f"Session {session_id} - Conversation history length: {len(conversation_history)}")
220
+
221
+ agent_results = {}
222
+ execution_summary = {
223
+ "agents_executed": [],
224
+ "execution_order": [],
225
+ "timings": {},
226
+ "status": "in_progress",
227
+ "session_id": session_id,
228
+ "workflow_type": "memory_narrative_flow",
229
+ "memory_provided_direct_answer": False,
230
+ "deepforest_executed": False
231
+ }
232
+
233
+ memory_context = ""
234
+ visual_context = ""
235
+ detection_narrative = ""
236
+ memory_tool_cache_id = None
237
+ current_tool_cache_id = None
238
+
239
+ try:
240
+ if not session_state_manager.session_exists(session_id):
241
+ raise ValueError(f"Session {session_id} not found")
242
+
243
+ session_state_manager.set_processing_state(session_id, True)
244
+ session_state_manager.reset_cancellation(session_id)
245
+
246
+ yield {
247
+ "stage": "memory",
248
+ "message": "Analyzing conversation memory and context...",
249
+ "type": "progress"
250
+ }
251
+
252
+ if session_state_manager.is_cancelled(session_id):
253
+ raise Exception("Processing cancelled by user")
254
+
255
+ print(f"\nSTEP 1: Memory Agent Processing (Session {session_id})")
256
+ self._log_gpu_memory(session_id, "before", "Memory Agent")
257
+ memory_start = time.perf_counter()
258
+
259
+ memory_result = self.memory_agent.process_conversation_history_structured(
260
+ conversation_history=conversation_history,
261
+ latest_message=user_message,
262
+ session_id=session_id
263
+ )
264
+
265
+ memory_time = time.perf_counter() - memory_start
266
+ self._log_gpu_memory(session_id, "after", "Memory Agent")
267
+ self._aggressive_gpu_cleanup(session_id, "after_memory_agent")
268
+ execution_summary["timings"]["memory_agent"] = memory_time
269
+ execution_summary["agents_executed"].append("memory")
270
+ execution_summary["execution_order"].append("memory")
271
+ agent_results["memory"] = memory_result
272
+
273
+ # Extract memory context and tool cache ID
274
+ memory_context = memory_result.get("relevant_context", "No memory context available")
275
+ tool_cache_id = memory_result.get("tool_cache_id")
276
+
277
+ print(f"Session {session_id} - Memory Agent: Completed in {memory_time:.2f}s")
278
+ print(f"Session {session_id} - Memory Has Answer: {memory_result['answer_present']}")
279
+ print(f"Session {session_id} - Tool Cache ID: {tool_cache_id}")
280
+
281
+ if memory_result["answer_present"]:
282
+ print(f"Session {session_id} - Memory has direct answer - using cached data for synthesis")
283
+
284
+ self.execution_stats["memory_direct_answers"] += 1
285
+ execution_summary["memory_provided_direct_answer"] = True
286
+
287
+ # Get cached detection narrative if available
288
+ cached_detection_narrative = ""
289
+ if tool_cache_id:
290
+ cached_detection_narrative = self._get_cached_detection_narrative(tool_cache_id) or ""
291
+
292
+ yield {
293
+ "stage": "ecology",
294
+ "message": "Using memory context and cached detection narrative for synthesis...",
295
+ "type": "progress"
296
+ }
297
+
298
+ if session_state_manager.is_cancelled(session_id):
299
+ raise Exception("Processing cancelled by user")
300
+
301
+ print(f"\nSTEP 2 (MEMORY PATH): Ecology Agent with Memory Context (Session {session_id})")
302
+ self._log_gpu_memory(session_id, "before", "Ecology Agent (Memory Path)")
303
+ ecology_start = time.perf_counter()
304
+
305
+ # Prepare comprehensive context
306
+ comprehensive_context = self._prepare_comprehensive_context(
307
+ memory_context=memory_context,
308
+ visual_context="",
309
+ detection_narrative=cached_detection_narrative,
310
+ tool_cache_id=tool_cache_id
311
+ )
312
+
313
+ final_response = ""
314
+ for token_result in self.ecology_agent.synthesize_analysis_streaming(
315
+ user_message=user_message,
316
+ memory_context=comprehensive_context,
317
+ cached_json=None,
318
+ current_json=None,
319
+ session_id=session_id
320
+ ):
321
+
322
+ if session_state_manager.is_cancelled(session_id):
323
+ raise Exception("Processing cancelled by user")
324
+
325
+ final_response += token_result["token"]
326
+
327
+ yield {
328
+ "stage": "ecology_streaming",
329
+ "message": final_response,
330
+ "type": "streaming",
331
+ "is_complete": token_result["is_complete"]
332
+ }
333
+
334
+ if token_result["is_complete"]:
335
+ ecology_time = time.perf_counter() - ecology_start
336
+ self._log_gpu_memory(session_id, "after", "Ecology Agent (Memory Path)")
337
+ execution_summary["timings"]["ecology_agent"] = ecology_time
338
+ execution_summary["agents_executed"].append("ecology")
339
+ execution_summary["execution_order"].append("ecology")
340
+ agent_results["ecology"] = {"final_response": final_response}
341
+ print(f"Session {session_id} - Ecology (Memory Path): Completed in {ecology_time:.2f}s")
342
+ break
343
+
344
+ total_time = time.perf_counter() - start_time
345
+ execution_summary["timings"]["total"] = total_time
346
+ execution_summary["status"] = "completed_via_memory"
347
+
348
+ detection_data_monitor = self._format_detection_data_for_monitor(
349
+ detection_narrative=cached_detection_narrative
350
+ )
351
+
352
+ yield {
353
+ "stage": "complete",
354
+ "message": final_response,
355
+ "type": "final",
356
+ "detection_data": detection_data_monitor,
357
+ "agent_results": agent_results,
358
+ "execution_summary": execution_summary,
359
+ "execution_time": total_time,
360
+ "status": "success"
361
+ }
362
+ return
363
+ else:
364
+ for result in self._execute_full_pipeline_with_narrative_flow(
365
+ user_message=user_message,
366
+ conversation_history=conversation_history,
367
+ session_id=session_id,
368
+ memory_context=memory_context,
369
+ memory_tool_cache_id=memory_result.get("tool_cache_id"),
370
+ start_time=start_time
371
+ ):
372
+ yield result
373
+ if result["type"] == "final":
374
+ return
375
+
376
+ except Exception as e:
377
+ error_msg = f"Orchestrator error (Session {session_id}): {str(e)}"
378
+ print(f"ORCHESTRATOR ERROR: {error_msg}")
379
+
380
+ try:
381
+ self._aggressive_gpu_cleanup(session_id, "emergency")
382
+ except Exception as cleanup_error:
383
+ print(f"Emergency cleanup error: {cleanup_error}")
384
+
385
+ partial_time = time.perf_counter() - start_time
386
+ execution_summary["timings"]["total"] = partial_time
387
+ execution_summary["status"] = "error"
388
+ execution_summary["error"] = error_msg
389
+
390
+ fallback_response = self._create_fallback_response(
391
+ user_message=user_message,
392
+ agent_results=agent_results,
393
+ error=error_msg,
394
+ session_id=session_id
395
+ )
396
+
397
+ yield {
398
+ "stage": "error",
399
+ "message": fallback_response,
400
+ "type": "final",
401
+ "detection_data": "Error occurred - no detection data available",
402
+ "agent_results": agent_results,
403
+ "execution_summary": execution_summary,
404
+ "execution_time": partial_time,
405
+ "status": "error",
406
+ "error": error_msg
407
+ }
408
+
409
+ finally:
410
+ session_state_manager.set_processing_state(session_id, False)
411
+
412
+ def _execute_full_pipeline_with_narrative_flow(
413
+ self,
414
+ user_message: str,
415
+ conversation_history: List[Dict[str, Any]],
416
+ session_id: str,
417
+ memory_context: str,
418
+ memory_tool_cache_id: Optional[str],
419
+ start_time: float
420
+ ) -> Generator[Dict[str, Any], None, None]:
421
+ """
422
+ Execute the complete pipeline using memory context, visual contexts, and detection narratives.
423
+
424
+ Args:
425
+ user_message: Current user query
426
+ conversation_history: Complete conversation context
427
+ session_id: Unique session identifier
428
+ memory_context: Context from memory agent
429
+ memory_tool_cache_id (Optional[str]): Cache identifier from memory agent
430
+ start_time: Start time for total execution calculation
431
+
432
+ Yields:
433
+ Dict[str, Any]: Progress updates during processing containing:
434
+ - stage (str): Current workflow stage ("visual_analysis", "detector", etc.)
435
+ - message (str): Human-readable progress message
436
+ - type (str): Update type ("progress", "streaming", "final")
437
+ - Additional stage-specific data (detection_data, agent_results, etc.)
438
+ """
439
+ agent_results = {}
440
+ execution_summary = {
441
+ "agents_executed": [],
442
+ "execution_order": [],
443
+ "timings": {},
444
+ "status": "in_progress",
445
+ "session_id": session_id,
446
+ "workflow_type": "Full Pipeline with Narrative Flow",
447
+ "memory_provided_direct_answer": False,
448
+ "deepforest_executed": False
449
+ }
450
+
451
+ visual_context = ""
452
+ detection_narrative = ""
453
+
454
+ yield {"stage": "visual_analysis", "message": "Analyzing image with unified full/tiled approach...", "type": "progress"}
455
+
456
+ if session_state_manager.is_cancelled(session_id):
457
+ raise Exception("Processing cancelled by user")
458
+
459
+ print(f"\nSTEP 1: Visual Analysis (Session {session_id})")
460
+ self._log_gpu_memory(session_id, "before", "Visual Analysis")
461
+ visual_start = time.perf_counter()
462
+
463
+ # Unified visual analysis
464
+ visual_analysis_result = self.visual_agent.analyze_full_image(
465
+ user_message=user_message,
466
+ session_id=session_id
467
+ )
468
+
469
+ visual_time = time.perf_counter() - visual_start
470
+ self._log_gpu_memory(session_id, "after", "Visual Analysis")
471
+ self._aggressive_gpu_cleanup(session_id, "after_visual_analysis")
472
+ execution_summary["timings"]["visual_analysis"] = visual_time
473
+ execution_summary["agents_executed"].append("visual_analysis")
474
+ execution_summary["execution_order"].append("visual_analysis")
475
+ agent_results["visual_analysis"] = visual_analysis_result
476
+
477
+ # Extract visual context
478
+ visual_context = visual_analysis_result.get("visual_analysis", "No visual analysis available")
479
+
480
+ print(f"Session {session_id} - Visual Analysis: {visual_analysis_result.get('status')}")
481
+ print(f"Session {session_id} - Analysis Type: {visual_analysis_result.get('analysis_type')}")
482
+
483
+ yield {"stage": "resolution_check", "message": "Checking image resolution for DeepForest suitability...", "type": "progress"}
484
+
485
+ if session_state_manager.is_cancelled(session_id):
486
+ raise Exception("Processing cancelled by user")
487
+
488
+ print(f"\nSTEP 2: Resolution Check (Session {session_id})")
489
+ resolution_start = time.perf_counter()
490
+
491
+ image_file_path = session_state_manager.get(session_id, "image_file_path")
492
+ resolution_result = None
493
+
494
+ if image_file_path:
495
+ resolution_result = check_image_resolution_for_deepforest(image_file_path)
496
+ resolution_time = time.perf_counter() - resolution_start
497
+
498
+ multi_agent_logger.log_resolution_check(
499
+ session_id=session_id,
500
+ image_file_path=image_file_path,
501
+ resolution_result=resolution_result,
502
+ execution_time=resolution_time
503
+ )
504
+ else:
505
+ resolution_result = {
506
+ "is_suitable": True,
507
+ "resolution_info": "No file path available for resolution check",
508
+ "error": None
509
+ }
510
+ resolution_time = time.perf_counter() - resolution_start
511
+
512
+ execution_summary["timings"]["resolution_check"] = resolution_time
513
+ execution_summary["agents_executed"].append("resolution_check")
514
+ execution_summary["execution_order"].append("resolution_check")
515
+ agent_results["resolution_check"] = resolution_result
516
+
517
+ # Determine if DeepForest should run
518
+ detection_result = None
519
+ image_quality_good = visual_analysis_result.get("image_quality_for_deepforest", "No").lower() == "yes"
520
+ resolution_suitable = resolution_result.get("is_suitable", True)
521
+
522
+ if resolution_suitable and image_quality_good:
523
+ yield {"stage": "detector", "message": "Quality and resolution good - executing DeepForest detection with narrative generation...", "type": "progress"}
524
+
525
+ if session_state_manager.is_cancelled(session_id):
526
+ raise Exception("Processing cancelled by user")
527
+
528
+ print(f"\nSTEP 3: DeepForest Detection with R-tree and Narrative (Session {session_id})")
529
+ self._log_gpu_memory(session_id, "before", "DeepForest Detection")
530
+ detector_start = time.perf_counter()
531
+
532
+ visual_objects = visual_analysis_result.get("deepforest_objects_present", [])
533
+
534
+ try:
535
+ detection_result = self.detector_agent.execute_detection_with_context(
536
+ user_message=user_message,
537
+ session_id=session_id,
538
+ visual_objects_detected=visual_objects,
539
+ memory_context=memory_context
540
+ )
541
+
542
+ detector_time = time.perf_counter() - detector_start
543
+ self._log_gpu_memory(session_id, "after", "DeepForest Detection")
544
+ self._aggressive_gpu_cleanup(session_id, "after_deepforest_detection")
545
+ execution_summary["timings"]["detector_agent"] = detector_time
546
+ execution_summary["agents_executed"].append("detector")
547
+ execution_summary["execution_order"].append("detector")
548
+ execution_summary["deepforest_executed"] = True
549
+ agent_results["detector"] = detection_result
550
+
551
+ # Extract detection narrative and tool cache ID from current run
552
+ current_detection_narrative = detection_result.get("detection_narrative", "No detection narrative available")
553
+
554
+ # Combine cached narratives from memory with current detection narrative
555
+ combined_narratives = []
556
+
557
+ # Add cached narratives from memory's tool cache IDs (if any)
558
+ if memory_tool_cache_id:
559
+ cached_narrative = self._get_cached_detection_narrative(memory_tool_cache_id)
560
+ if cached_narrative:
561
+ combined_narratives.append(cached_narrative)
562
+
563
+ # Add current detection narratives for ALL tool results
564
+ tool_results = detection_result.get("tool_results", [])
565
+ if tool_results:
566
+ for tool_result in tool_results:
567
+ cache_key = tool_result.get("cache_key")
568
+ tool_arguments = tool_result.get("tool_arguments", {})
569
+
570
+ if cache_key and tool_arguments:
571
+ # Get all possible arguments including defaults from Config
572
+ from deepforest_agent.conf.config import Config
573
+ all_arguments = Config.DEEPFOREST_DEFAULTS.copy()
574
+ all_arguments.update(tool_arguments)
575
+
576
+ # Format tool call info with all arguments
577
+ args_str = ", ".join([f"{k}={v}" for k, v in all_arguments.items()])
578
+
579
+ formatted_current = f"**TOOL CACHE ID:** {cache_key}\nDeepForest tool run with arguments ({args_str}) and got the below narratives:\nDETECTION NARRATIVE:\n{current_detection_narrative}"
580
+ combined_narratives.append(formatted_current)
581
+
582
+ # If no tool results but we have narrative, add it without formatting
583
+ if not tool_results and current_detection_narrative and current_detection_narrative != "No detection narrative available":
584
+ combined_narratives.append(current_detection_narrative)
585
+
586
+ # Combine all narratives
587
+ detection_narrative = "\n\n".join(combined_narratives) if combined_narratives else "No detection narrative available"
588
+
589
+ print(f"Session {session_id} - DeepForest Detection completed with narrative")
590
+
591
+ except Exception as detector_error:
592
+ print(f"Session {session_id} - DeepForest Detection FAILED: {detector_error}")
593
+ detection_result = None
594
+ detection_narrative = f"DeepForest detection failed: {str(detector_error)}"
595
+ else:
596
+ skip_reasons = []
597
+ if not resolution_suitable:
598
+ skip_reasons.append("insufficient resolution")
599
+ if not image_quality_good:
600
+ skip_reasons.append("poor image quality")
601
+
602
+ print(f"Session {session_id} - Skipping DeepForest detection: {', '.join(skip_reasons)}")
603
+ execution_summary["deepforest_executed"] = False
604
+ execution_summary["deepforest_skip_reason"] = ", ".join(skip_reasons)
605
+ detection_narrative = f"DeepForest detection was skipped due to: {', '.join(skip_reasons)}"
606
+
607
+ yield {"stage": "ecology", "message": "Synthesizing ecological insights from all contexts...", "type": "progress"}
608
+
609
+ if session_state_manager.is_cancelled(session_id):
610
+ raise Exception("Processing cancelled by user")
611
+
612
+ print(f"\nSTEP 4: Ecology Analysis with Comprehensive Context (Session {session_id})")
613
+ self._log_gpu_memory(session_id, "before", "Ecology Analysis")
614
+ ecology_start = time.perf_counter()
615
+
616
+ # Prepare comprehensive context for ecology agent
617
+ comprehensive_context = self._prepare_comprehensive_context(
618
+ memory_context=memory_context,
619
+ visual_context=visual_context,
620
+ detection_narrative=detection_narrative,
621
+ tool_cache_id=memory_tool_cache_id
622
+ )
623
+
624
+ final_response = ""
625
+ try:
626
+ for token_result in self.ecology_agent.synthesize_analysis_streaming(
627
+ user_message=user_message,
628
+ memory_context=comprehensive_context,
629
+ cached_json=None,
630
+ current_json=None,
631
+ session_id=session_id
632
+ ):
633
+ if session_state_manager.is_cancelled(session_id):
634
+ raise Exception("Processing cancelled by user")
635
+
636
+ final_response += token_result["token"]
637
+
638
+ yield {
639
+ "stage": "ecology_streaming",
640
+ "message": final_response,
641
+ "type": "streaming",
642
+ "is_complete": token_result["is_complete"]
643
+ }
644
+
645
+ if token_result["is_complete"]:
646
+ break
647
+
648
+ except Exception as ecology_error:
649
+ print(f"Session {session_id} - Ecology streaming error: {ecology_error}")
650
+ if not final_response:
651
+ final_response = f"Ecology analysis failed: {str(ecology_error)}"
652
+
653
+ finally:
654
+ ecology_time = time.perf_counter() - ecology_start
655
+ self._log_gpu_memory(session_id, "after", "Ecology Analysis")
656
+ self._aggressive_gpu_cleanup(session_id, "after_ecology_analysis")
657
+ execution_summary["timings"]["ecology_agent"] = ecology_time
658
+ execution_summary["agents_executed"].append("ecology")
659
+ execution_summary["execution_order"].append("ecology")
660
+ agent_results["ecology"] = {"final_response": final_response}
661
+
662
+ # Store context data for memory agent's next turn
663
+ current_turn = len(session_state_manager.get(session_id, "conversation_history", [])) // 2 + 1
664
+ all_tool_cache_ids = []
665
+ if memory_tool_cache_id:
666
+ all_tool_cache_ids.extend([id.strip() for id in memory_tool_cache_id.split(",")])
667
+
668
+ # Add all current tool cache IDs
669
+ tool_results = detection_result.get("tool_results", []) if detection_result else []
670
+ for tool_result in tool_results:
671
+ cache_key = tool_result.get("cache_key")
672
+ if cache_key:
673
+ all_tool_cache_ids.append(cache_key)
674
+
675
+ combined_tool_cache_id = ", ".join(all_tool_cache_ids) if all_tool_cache_ids else None
676
+ self.memory_agent.store_turn_context(
677
+ session_id=session_id,
678
+ turn_number=current_turn,
679
+ visual_context=visual_context,
680
+ detection_narrative=detection_narrative,
681
+ tool_cache_id=combined_tool_cache_id
682
+ )
683
+
684
+ # Final result
685
+ total_time = time.perf_counter() - start_time
686
+ execution_summary["timings"]["total"] = total_time
687
+ execution_summary["status"] = "completed_narrative_flow"
688
+
689
+ detection_data_monitor = self._format_detection_data_for_monitor(
690
+ detection_narrative=detection_narrative,
691
+ detections_list=detection_result.get("detections_list", []) if detection_result else None
692
+ )
693
+
694
+ print(f"Session {session_id} - NARRATIVE FLOW WORKFLOW COMPLETED")
695
+
696
+ yield {
697
+ "stage": "complete",
698
+ "message": final_response,
699
+ "type": "final",
700
+ "detection_data": detection_data_monitor,
701
+ "agent_results": agent_results,
702
+ "execution_summary": execution_summary,
703
+ "execution_time": total_time,
704
+ "status": "success"
705
+ }
706
+
707
+ def _prepare_comprehensive_context(
708
+ self,
709
+ memory_context: str,
710
+ visual_context: str,
711
+ detection_narrative: str,
712
+ tool_cache_id: Optional[str]
713
+ ) -> str:
714
+ """
715
+ Prepare comprehensive context combining all data sources with better formatting.
716
+
717
+ Args:
718
+ memory_context: Context from memory agent
719
+ visual_context: Visual analysis context
720
+ detection_narrative: R-tree based detection narrative
721
+ tool_cache_id: Tool cache reference if available
722
+
723
+ Returns:
724
+ Combined context string for ecology agent
725
+ """
726
+ context_parts = []
727
+
728
+ # Memory context section
729
+ if memory_context and memory_context != "No memory context available":
730
+ context_parts.append("--- START OF MEMORY CONTEXT ---")
731
+ context_parts.append(memory_context)
732
+ context_parts.append("--- END OF MEMORY CONTEXT ---")
733
+ context_parts.append("")
734
+
735
+ # Tool cache reference
736
+ if tool_cache_id:
737
+ context_parts.append(f"**TOOL CACHE ID:** {tool_cache_id}")
738
+ context_parts.append("")
739
+
740
+ # Detection narrative section
741
+ if detection_narrative and detection_narrative not in ["No detection analysis available", ""]:
742
+ context_parts.append("--- START OF DETECTION ANALYSIS ---")
743
+ context_parts.append(detection_narrative)
744
+ context_parts.append("--- END OF DETECTION ANALYSIS ---")
745
+ context_parts.append("")
746
+
747
+ # Visual context section
748
+ if visual_context and visual_context != "No visual analysis available":
749
+ context_parts.append("--- START OF VISUAL ANALYSIS ---")
750
+ context_parts.append(visual_context)
751
+ context_parts.append("There may be information that are not clear or accurate in this visual analysis. So make sure to mention that this analysis is provided by a visual analysis agent and it may not be very accurate as there is no confidence score associated with it. You can only provide this analysis seperately in a different section and inform the user that you are not very confident about this analysis.")
752
+ context_parts.append("--- END OF VISUAL ANALYSIS ---")
753
+ context_parts.append("")
754
+
755
+ # If we have very little context, provide a meaningful message
756
+ if not context_parts or len("".join(context_parts)) < 50:
757
+ return "No comprehensive context available for this query. Please provide more information or try a different approach."
758
+
759
+ result_context = "\n".join(context_parts)
760
+
761
+ print(f"Prepared comprehensive context ({len(result_context)} characters)")
762
+ print(f"Context preview: {result_context[:200]}...")
763
+
764
+ return result_context
765
+
766
+ def _create_fallback_response(
767
+ self,
768
+ user_message: str,
769
+ agent_results: Dict[str, Any],
770
+ error: str,
771
+ session_id: str
772
+ ) -> str:
773
+ """Create a fallback response when the orchestrator encounters errors."""
774
+ response_parts = []
775
+ response_parts.append(f"I encountered some processing issues but can provide analysis based on available data:")
776
+ response_parts.append("")
777
+
778
+ memory_result = agent_results.get("memory", {})
779
+ if memory_result and memory_result.get("relevant_context"):
780
+ response_parts.append(f"**Memory Context**: {memory_result['relevant_context']}")
781
+ response_parts.append("")
782
+
783
+ visual_result = agent_results.get("visual_analysis", {})
784
+ if visual_result and visual_result.get("visual_analysis"):
785
+ response_parts.append(f"**Visual Analysis**: {visual_result['visual_analysis']}")
786
+ response_parts.append("")
787
+
788
+ detector_result = agent_results.get("detector", {})
789
+ if detector_result and detector_result.get("detection_narrative"):
790
+ response_parts.append(f"**Detection Results**: {detector_result['detection_narrative']}")
791
+ response_parts.append("")
792
+
793
+ response_parts.append(f"Note: Workflow was interrupted ({error}). Please try your query again for full results.")
794
+
795
+ return "\n".join(response_parts)
src/deepforest_agent/agents/visual_analysis_agent.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional
2
+ from PIL import Image
3
+ import json
4
+ import re
5
+ import time
6
+ import torch
7
+ import gc
8
+
9
+ from deepforest_agent.models.qwen_vl_3b_instruct import QwenVL3BModelManager
10
+ from deepforest_agent.utils.image_utils import encode_pil_image_to_base64_url, determine_patch_size, get_image_dimensions_fast
11
+ from deepforest_agent.utils.state_manager import session_state_manager
12
+ from deepforest_agent.conf.config import Config
13
+ from deepforest_agent.utils.parsing_utils import (
14
+ parse_image_quality_for_deepforest,
15
+ parse_deepforest_objects_present,
16
+ parse_visual_analysis,
17
+ parse_additional_objects_json
18
+ )
19
+ from deepforest_agent.prompts.prompt_templates import create_full_image_quality_analysis_prompt, create_individual_tile_analysis_prompt
20
+ from deepforest_agent.utils.logging_utils import multi_agent_logger
21
+ from deepforest_agent.utils.tile_manager import tile_image_for_analysis
22
+
23
+
24
+ class VisualAnalysisAgent:
25
+ """
26
+ Visual analysis agent responsible for analyzing images with unified full/tiled approach.
27
+ Uses Qwen VL model for multimodal understanding.
28
+ """
29
+
30
+ def __init__(self):
31
+ """Initialize the Visual Analysis Agent."""
32
+ self.agent_config = Config.AGENT_CONFIGS["visual_analysis"]
33
+ self.model_manager = QwenVL3BModelManager(Config.AGENT_MODELS["visual_analysis"])
34
+
35
+ def analyze_full_image(self, user_message: str, session_id: str) -> Dict[str, Any]:
36
+ """
37
+ Analyze full image with automatic fallback to tiling on OOM.
38
+
39
+ Args:
40
+ user_message: User's query
41
+ session_id: Session identifier
42
+
43
+ Returns:
44
+ Dict with unified structure for both full and tiled analysis
45
+ """
46
+ if not session_state_manager.session_exists(session_id):
47
+ return {
48
+ "image_quality_for_deepforest": "No",
49
+ "deepforest_objects_present": [],
50
+ "additional_objects": [],
51
+ "visual_analysis": f"Session {session_id} not found.",
52
+ "status": "error",
53
+ "analysis_type": "error"
54
+ }
55
+
56
+ image = session_state_manager.get(session_id, "current_image")
57
+ if image is None:
58
+ return {
59
+ "image_quality_for_deepforest": "No",
60
+ "deepforest_objects_present": [],
61
+ "additional_objects": [],
62
+ "visual_analysis": f"No image available in session {session_id}.",
63
+ "status": "error",
64
+ "analysis_type": "error"
65
+ }
66
+
67
+ # Try full image analysis first
68
+ try:
69
+ print(f"Session {session_id} - Attempting full image analysis")
70
+ result = self._analyze_single_image(image, user_message, session_id, is_full_image=True)
71
+
72
+ if result["status"] == "success":
73
+ multi_agent_logger.log_agent_execution(
74
+ session_id=session_id,
75
+ agent_name="visual_analysis",
76
+ agent_input=f"Full image analysis for: {user_message}",
77
+ agent_output=result["visual_analysis"],
78
+ execution_time=0.0
79
+ )
80
+ return result
81
+
82
+ except Exception as e:
83
+ print(f"Session {session_id} - Full image analysis failed (likely OOM): {e}")
84
+ return self._analyze_with_tiling(user_message, session_id, str(e))
85
+
86
+ return self._analyze_with_tiling(user_message, session_id, "Full image analysis failed")
87
+
88
+ def _analyze_single_image(self, image: Image.Image, user_message: str, session_id: str,
89
+ is_full_image: bool = True, tile_location: str = "") -> Dict[str, Any]:
90
+ """
91
+ Analyze a single image (full image or tile) with unified structure.
92
+
93
+ Args:
94
+ image: PIL Image to analyze
95
+ user_message: User's query
96
+ session_id: Session identifier
97
+ is_full_image: Whether this is full image or tile
98
+ tile_location: Location description for tiles
99
+
100
+ Returns:
101
+ Unified analysis result
102
+ """
103
+ system_prompt = create_full_image_quality_analysis_prompt(user_message)
104
+ image_base64_url = encode_pil_image_to_base64_url(image)
105
+
106
+ messages = [
107
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
108
+ {
109
+ "role": "user",
110
+ "content": [
111
+ {"type": "image", "image": image_base64_url},
112
+ {"type": "text", "text": user_message}
113
+ ]
114
+ }
115
+ ]
116
+
117
+ response = self.model_manager.generate_response(
118
+ messages=messages,
119
+ max_new_tokens=self.agent_config["max_new_tokens"],
120
+ temperature=self.agent_config["temperature"]
121
+ )
122
+
123
+ # Parse structured response
124
+ image_quality = parse_image_quality_for_deepforest(response)
125
+ deepforest_objects = parse_deepforest_objects_present(response)
126
+ additional_objects = parse_additional_objects_json(response)
127
+ raw_visual_analysis = parse_visual_analysis(response)
128
+
129
+ # Format visual analysis with consistent prefix
130
+ if is_full_image:
131
+ width, height = image.size
132
+ visual_analysis = f"Full image analysis of image ({width}x{height}) is done. Here's the analysis: {raw_visual_analysis}"
133
+ analysis_type = "full_image"
134
+ else:
135
+ visual_analysis = f"The visual analysis of tiled image on ({tile_location}) this location is done. Here's the analysis: {raw_visual_analysis}"
136
+ analysis_type = "tiled_image"
137
+
138
+ return {
139
+ "image_quality_for_deepforest": image_quality,
140
+ "deepforest_objects_present": deepforest_objects,
141
+ "additional_objects": additional_objects,
142
+ "visual_analysis": visual_analysis,
143
+ "status": "success",
144
+ "analysis_type": analysis_type,
145
+ "raw_response": response
146
+ }
147
+
148
+ def _analyze_with_tiling(self, user_message: str, session_id: str, error_msg: str) -> Dict[str, Any]:
149
+ """
150
+ Analyze image using tiling approach when full image fails.
151
+
152
+ Args:
153
+ user_message: User's query
154
+ session_id: Session identifier
155
+ error_msg: Original error message
156
+
157
+ Returns:
158
+ Combined analysis from tiled approach with same structure as full image
159
+ """
160
+ print(f"Session {session_id} - Falling back to tiled analysis due to: {error_msg}")
161
+
162
+ image = session_state_manager.get(session_id, "current_image")
163
+ image_file_path = session_state_manager.get(session_id, "image_file_path")
164
+
165
+ if not image:
166
+ return {
167
+ "image_quality_for_deepforest": "No",
168
+ "deepforest_objects_present": [],
169
+ "additional_objects": [],
170
+ "visual_analysis": "No image available for tiled analysis.",
171
+ "status": "error",
172
+ "analysis_type": "error"
173
+ }
174
+
175
+ # Determine appropriate patch size
176
+ if image_file_path:
177
+ patch_size = determine_patch_size(image_file_path, image.size)
178
+ else:
179
+ max_dim = max(image.size)
180
+ if max_dim >= 5000:
181
+ patch_size = 1500 if max_dim <= 7500 else 2000
182
+ else:
183
+ patch_size = 1000
184
+
185
+ print(f"Session {session_id} - Using patch size {patch_size} for tiled analysis")
186
+
187
+ try:
188
+ tiles, tile_metadata = tile_image_for_analysis(
189
+ image=image,
190
+ patch_size=patch_size,
191
+ patch_overlap=Config.DEEPFOREST_DEFAULTS["patch_overlap"],
192
+ image_file_path=image_file_path
193
+ )
194
+
195
+ print(f"Session {session_id} - Created {len(tiles)} tiles for analysis")
196
+
197
+ # Analyze all tiles and combine results
198
+ all_visual_analyses = []
199
+ all_additional_objects = []
200
+ tile_results = []
201
+
202
+ for i, (tile, metadata) in enumerate(zip(tiles, tile_metadata)):
203
+ try:
204
+ tile_coords = metadata.get("window_coords", {})
205
+ location_desc = f"x:{tile_coords.get('x', 0)}-{tile_coords.get('x', 0) + tile_coords.get('width', 0)}, y:{tile_coords.get('y', 0)}-{tile_coords.get('y', 0) + tile_coords.get('height', 0)}"
206
+
207
+ # Analyze individual tile
208
+ tile_result = self._analyze_single_image(
209
+ image=tile,
210
+ user_message=user_message,
211
+ session_id=session_id,
212
+ is_full_image=False,
213
+ tile_location=location_desc
214
+ )
215
+
216
+ if tile_result["status"] == "success":
217
+ all_visual_analyses.append(tile_result["visual_analysis"])
218
+ all_additional_objects.extend(tile_result["additional_objects"])
219
+
220
+ # Store tile result for potential reuse
221
+ tile_results.append({
222
+ "tile_id": i,
223
+ "location": location_desc,
224
+ "coordinates": tile_coords,
225
+ "visual_analysis": tile_result["visual_analysis"],
226
+ "additional_objects": tile_result["additional_objects"]
227
+ })
228
+
229
+ # Log individual tile analysis
230
+ multi_agent_logger.log_agent_execution(
231
+ session_id=session_id,
232
+ agent_name=f"visual_tile_{i}",
233
+ agent_input=f"Tile {i+1} analysis: {user_message}",
234
+ agent_output=tile_result["visual_analysis"],
235
+ execution_time=0.0
236
+ )
237
+
238
+ print(f"Session {session_id} - Analyzed tile {i+1}/{len(tiles)}")
239
+
240
+ # Memory cleanup
241
+ del tile
242
+ gc.collect()
243
+ if torch.cuda.is_available():
244
+ torch.cuda.empty_cache()
245
+
246
+ except Exception as tile_error:
247
+ print(f"Session {session_id} - Tile {i} analysis failed: {tile_error}")
248
+ continue
249
+
250
+ if all_visual_analyses:
251
+ # Store tile results for potential reuse
252
+ session_state_manager.set(session_id, "tile_analysis_results", tile_results)
253
+ session_state_manager.set(session_id, "tiled_patch_size", patch_size)
254
+
255
+ # Combine all tile analyses
256
+ combined_visual_analysis = " ".join(all_visual_analyses)
257
+
258
+ return {
259
+ "image_quality_for_deepforest": "Yes",
260
+ "deepforest_objects_present": ["tree", "bird", "livestock"],
261
+ "additional_objects": all_additional_objects,
262
+ "visual_analysis": combined_visual_analysis,
263
+ "status": "tiled_success",
264
+ "analysis_type": "tiled_combined",
265
+ "tile_count": len(tiles),
266
+ "successful_tiles": len(all_visual_analyses),
267
+ "patch_size_used": patch_size
268
+ }
269
+
270
+ except Exception as tiling_error:
271
+ print(f"Session {session_id} - Tiled analysis also failed: {tiling_error}")
272
+
273
+ # Final fallback - resolution-based assessment
274
+ resolution_result = session_state_manager.get(session_id, "resolution_result")
275
+ if resolution_result and resolution_result.get("is_suitable"):
276
+ width, height = image.size
277
+ return {
278
+ "image_quality_for_deepforest": "Yes",
279
+ "deepforest_objects_present": ["tree", "bird", "livestock"],
280
+ "additional_objects": [],
281
+ "visual_analysis": f"Full image analysis of image ({width}x{height}) is done. Here's the analysis: Large image analyzed using resolution-based assessment. Original error: {error_msg}",
282
+ "status": "resolution_fallback",
283
+ "analysis_type": "resolution_based"
284
+ }
285
+
286
+ # Complete failure
287
+ width, height = image.size
288
+ return {
289
+ "image_quality_for_deepforest": "No",
290
+ "deepforest_objects_present": [],
291
+ "additional_objects": [],
292
+ "visual_analysis": f"Full image analysis of image ({width}x{height}) failed. Analysis could not be completed due to: {error_msg}",
293
+ "status": "error",
294
+ "analysis_type": "failed"
295
+ }
296
+
297
+ def get_tile_analysis_results(self, session_id: str) -> List[Dict[str, Any]]:
298
+ """
299
+ Get stored tile analysis results for reuse.
300
+
301
+ Args:
302
+ session_id: Session identifier
303
+
304
+ Returns:
305
+ List of tile analysis results or empty list
306
+ """
307
+ return session_state_manager.get(session_id, "tile_analysis_results", [])
src/deepforest_agent/conf/__init__.py ADDED
File without changes
src/deepforest_agent/conf/config.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ class Config:
4
+ """
5
+ Configuration class defining DeepForest model paths, visualization colors, and agent models.
6
+ """
7
+
8
+ DEEPFOREST_MODELS = {
9
+ "bird": "weecology/deepforest-bird",
10
+ "tree": "weecology/deepforest-tree",
11
+ "livestock": "weecology/deepforest-livestock"
12
+ }
13
+
14
+ DEEPFOREST_DEFAULTS = {
15
+ "patch_size": 400,
16
+ "patch_overlap": 0.05,
17
+ "iou_threshold": 0.15,
18
+ "thresh": 0.55,
19
+ "alive_dead_trees": False
20
+ }
21
+
22
+ COLORS = {
23
+ "bird": (0, 0, 255), # Red (BGR)
24
+ "tree": (0, 255, 0), # Green (BGR)
25
+ "livestock": (255, 0, 0), # Blue (BGR)
26
+ "alive_tree": (255, 255, 0), # Cyan (BGR)
27
+ "dead_tree": (0, 165, 255) # Orange (BGR)
28
+ }
29
+
30
+ AGENT_MODELS = {
31
+ "memory": "HuggingFaceTB/SmolLM3-3B",
32
+ "deepforest_detector": "HuggingFaceTB/SmolLM3-3B",
33
+ "visual_analysis": "Qwen/Qwen2.5-VL-3B-Instruct",
34
+ "ecology_analysis": "meta-llama/Llama-3.2-3B-Instruct"
35
+ }
36
+
37
+ # Agent-specific generation parameters
38
+ AGENT_CONFIGS = {
39
+ "memory": {
40
+ "max_new_tokens": 16000,
41
+ "temperature": 0.6,
42
+ "top_p": 0.95
43
+ },
44
+ "deepforest_detector": {
45
+ "max_new_tokens": 16000,
46
+ "temperature": 0.6,
47
+ "top_p": 0.95
48
+ },
49
+ "visual_analysis": {
50
+ "max_new_tokens": 5000,
51
+ "temperature": 0.1
52
+ },
53
+ "ecology_analysis": {
54
+ "max_new_tokens": 16000,
55
+ "temperature": 0.6,
56
+ "top_p": 0.95
57
+ }
58
+ }
59
+
60
+ NO_ALBUMENTATIONS = os.getenv("NO_ALBUMENTATIONS", "")
src/deepforest_agent/models/__init__.py ADDED
File without changes
src/deepforest_agent/models/llama32_3b_instruct.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from typing import Tuple, Dict, Any, Optional, List, Generator
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from transformers.generation.streamers import TextIteratorStreamer
6
+ from threading import Thread
7
+
8
+ from deepforest_agent.conf.config import Config
9
+
10
+
11
+ class Llama32ModelManager:
12
+ """
13
+ Manages Llama-3.2-3B-Instruct model instances for text generation tasks.
14
+
15
+ Attributes:
16
+ model_id (str): HuggingFace model identifier
17
+ load_count (int): Number of times model has been loaded
18
+ """
19
+
20
+ def __init__(self, model_id: str = Config.AGENT_MODELS["ecology_analysis"]):
21
+ """
22
+ Initialize the Llama-3.2-3B model manager.
23
+
24
+ Args:
25
+ model_id (str, optional): HuggingFace model identifier.
26
+ Defaults to "meta-llama/Llama-3.2-3B-Instruct".
27
+ """
28
+ self.model_id = model_id
29
+ self.load_count = 0
30
+
31
+ def generate_response(
32
+ self,
33
+ messages: List[Dict[str, str]],
34
+ max_new_tokens: int = Config.AGENT_CONFIGS["ecology_analysis"]["max_new_tokens"],
35
+ temperature: float = Config.AGENT_CONFIGS["ecology_analysis"]["temperature"],
36
+ top_p: float = Config.AGENT_CONFIGS["ecology_analysis"]["top_p"],
37
+ tools: Optional[List[Dict[str, Any]]] = None
38
+ ) -> str:
39
+ """
40
+ Generate text response using Llama-3.2-3B-Instruct.
41
+
42
+ Args:
43
+ messages: List of message dictionaries with 'role' and 'content'
44
+ max_new_tokens: Maximum tokens to generate
45
+ temperature: Sampling temperature
46
+ top_p: Top-p sampling
47
+ tools (Optional[List[Dict[str, Any]]]): List of tools (not used for Llama)
48
+
49
+ Returns:
50
+ str: Generated response text
51
+
52
+ Raises:
53
+ Exception: If generation fails due to model issues, memory, or other errors
54
+ """
55
+ print(f"Loading Llama-3.2-3B for inference #{self.load_count + 1}")
56
+
57
+ model, tokenizer = self._load_model()
58
+ self.load_count += 1
59
+
60
+ try:
61
+ # Llama uses standard chat template without xml_tools
62
+ text = tokenizer.apply_chat_template(
63
+ messages,
64
+ tokenize=False,
65
+ add_generation_prompt=True
66
+ )
67
+
68
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
69
+
70
+ generated_ids = model.generate(
71
+ model_inputs.input_ids,
72
+ max_new_tokens=max_new_tokens,
73
+ temperature=temperature,
74
+ top_p=top_p,
75
+ do_sample=True,
76
+ pad_token_id=tokenizer.eos_token_id
77
+ )
78
+
79
+ generated_ids = [
80
+ output_ids[len(input_ids):]
81
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
82
+ ]
83
+
84
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
85
+ return response
86
+
87
+ except Exception as e:
88
+ print(f"Error during Llama-3.2-3B text generation: {e}")
89
+ raise e
90
+
91
+ finally:
92
+ print(f"Releasing Llama-3.2-3B GPU memory after inference")
93
+ if 'model' in locals():
94
+ if hasattr(model, 'cpu'):
95
+ model.cpu()
96
+ del model
97
+ if 'tokenizer' in locals():
98
+ del tokenizer
99
+ if 'model_inputs' in locals():
100
+ del model_inputs
101
+ if 'generated_ids' in locals():
102
+ del generated_ids
103
+
104
+ # Multiple garbage collection passes
105
+ for _ in range(3):
106
+ gc.collect()
107
+
108
+ if torch.cuda.is_available():
109
+ torch.cuda.empty_cache()
110
+ torch.cuda.ipc_collect()
111
+ torch.cuda.synchronize()
112
+ try:
113
+ torch.cuda.memory._record_memory_history(enabled=None)
114
+ except:
115
+ pass
116
+ print(f"GPU memory after aggressive cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")
117
+
118
+ def generate_response_streaming(
119
+ self,
120
+ messages: List[Dict[str, str]],
121
+ max_new_tokens: int = Config.AGENT_CONFIGS["ecology_analysis"]["max_new_tokens"],
122
+ temperature: float = Config.AGENT_CONFIGS["ecology_analysis"]["temperature"],
123
+ top_p: float = Config.AGENT_CONFIGS["ecology_analysis"]["top_p"],
124
+ ) -> Generator[Dict[str, Any], None, None]:
125
+ """
126
+ Generate text response with streaming (token by token).
127
+
128
+ Args:
129
+ messages: List of message dictionaries with 'role' and 'content'
130
+ max_new_tokens: Maximum tokens to generate
131
+ temperature: Sampling temperature
132
+ top_p: Top-p sampling
133
+
134
+ Yields:
135
+ Dict[str, Any]: Dictionary containing:
136
+ - token: The generated token/text chunk
137
+ - is_complete: Whether generation is finished
138
+
139
+ Raises:
140
+ Exception: If generation fails due to model issues, memory, or other errors
141
+ """
142
+ print(f"Loading Llama-3.2-3B for streaming inference #{self.load_count + 1}")
143
+
144
+ model, tokenizer = self._load_model()
145
+ self.load_count += 1
146
+
147
+ try:
148
+ text = tokenizer.apply_chat_template(
149
+ messages,
150
+ tokenize=False,
151
+ add_generation_prompt=True
152
+ )
153
+
154
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
155
+
156
+ streamer = TextIteratorStreamer(
157
+ tokenizer,
158
+ timeout=60.0,
159
+ skip_prompt=True,
160
+ skip_special_tokens=True
161
+ )
162
+
163
+ generation_kwargs = {
164
+ "input_ids": model_inputs.input_ids,
165
+ "max_new_tokens": max_new_tokens,
166
+ "temperature": temperature,
167
+ "top_p": top_p,
168
+ "do_sample": True,
169
+ "pad_token_id": tokenizer.eos_token_id,
170
+ "streamer": streamer
171
+ }
172
+
173
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
174
+ thread.start()
175
+
176
+ for new_text in streamer:
177
+ yield {"token": new_text, "is_complete": False}
178
+
179
+ thread.join()
180
+ yield {"token": "", "is_complete": True}
181
+
182
+ except Exception as e:
183
+ print(f"Error during Llama-3.2-3B streaming generation: {e}")
184
+ yield {"token": f"[Error: {str(e)}]", "is_complete": True}
185
+
186
+ finally:
187
+ print(f"Releasing Llama-3.2-3B GPU memory after inference")
188
+ if 'model' in locals():
189
+ if hasattr(model, 'cpu'):
190
+ model.cpu()
191
+ del model
192
+ if 'tokenizer' in locals():
193
+ del tokenizer
194
+ if 'model_inputs' in locals():
195
+ del model_inputs
196
+ if 'generated_ids' in locals():
197
+ del generated_ids
198
+
199
+ # Multiple garbage collection passes
200
+ for _ in range(3):
201
+ gc.collect()
202
+
203
+ if torch.cuda.is_available():
204
+ torch.cuda.empty_cache()
205
+ torch.cuda.ipc_collect()
206
+ torch.cuda.synchronize()
207
+ try:
208
+ torch.cuda.memory._record_memory_history(enabled=None)
209
+ except:
210
+ pass
211
+ print(f"GPU memory after aggressive cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")
212
+
213
+ def _load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
214
+ """
215
+ Private method for model and tokenizer loading.
216
+
217
+ Returns:
218
+ Tuple[AutoModelForCausalLM, AutoTokenizer]: Loaded model and tokenizer
219
+
220
+ Raises:
221
+ Exception: If model loading fails due to network, memory, or other issues
222
+ """
223
+ try:
224
+ tokenizer = AutoTokenizer.from_pretrained(
225
+ self.model_id,
226
+ trust_remote_code=True
227
+ )
228
+
229
+ # Llama models may need specific configurations
230
+ model = AutoModelForCausalLM.from_pretrained(
231
+ self.model_id,
232
+ torch_dtype="auto",
233
+ device_map="auto",
234
+ trust_remote_code=True,
235
+ low_cpu_mem_usage=True
236
+ )
237
+
238
+ return model, tokenizer
239
+
240
+ except Exception as e:
241
+ print(f"Error loading Llama-3.2-3B model: {e}")
242
+ raise e
src/deepforest_agent/models/qwen_vl_3b_instruct.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from typing import Tuple, Dict, Any, Optional, List, Union
3
+ import torch
4
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
5
+ from PIL import Image
6
+ from qwen_vl_utils import process_vision_info
7
+
8
+ from deepforest_agent.conf.config import Config
9
+
10
+
11
+ class QwenVL3BModelManager:
12
+ """Manages Qwen2.5-VL-3B model instances for visual analysis tasks.
13
+
14
+ Attributes:
15
+ model_id (str): HuggingFace model identifier
16
+ load_count (int): Number of times model has been loaded
17
+ """
18
+
19
+ def __init__(self, model_id: str = Config.AGENT_MODELS["visual_analysis"]):
20
+ """
21
+ Initialize the Qwen2.5-VL-3B model manager.
22
+
23
+ Args:
24
+ model_id (str, optional): HuggingFace model identifier.
25
+ Defaults to "Qwen/Qwen2.5-VL-3B-Instruct".
26
+ """
27
+ self.model_id = model_id
28
+ self.load_count = 0
29
+
30
+ def _load_model(self) -> Tuple[Qwen2_5_VLForConditionalGeneration, AutoProcessor]:
31
+ """
32
+ Private method for model loading implementation.
33
+
34
+ Returns:
35
+ Tuple[Qwen2_5_VLForConditionalGeneration, AutoProcessor]:
36
+ Loaded model and processor instances
37
+
38
+ Raises:
39
+ Exception: If model or processor loading fails
40
+ """
41
+ try:
42
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
43
+ self.model_id,
44
+ torch_dtype="auto",
45
+ device_map="auto",
46
+ trust_remote_code=True
47
+ )
48
+
49
+ processor = AutoProcessor.from_pretrained(
50
+ self.model_id,
51
+ use_fast=True
52
+ )
53
+
54
+ return model, processor
55
+
56
+ except Exception as e:
57
+ print(f"Error loading Qwen VL model: {e}")
58
+ raise e
59
+
60
+ def generate_response(
61
+ self,
62
+ messages: List[Dict[str, Any]],
63
+ max_new_tokens: int = Config.AGENT_CONFIGS["visual_analysis"]["max_new_tokens"],
64
+ temperature: float = Config.AGENT_CONFIGS["visual_analysis"]["temperature"]
65
+ ) -> str:
66
+ """
67
+ Generate multimodal response.
68
+
69
+ Args:
70
+ messages: List of messages with text and images
71
+ max_new_tokens: Maximum tokens to generate
72
+ temperature: Sampling temperature
73
+
74
+ Returns:
75
+ str: Generated response text based on the input messages
76
+
77
+ Raises:
78
+ Exception: If text generation fails for any reason
79
+ """
80
+ print(f"Loading Qwen VL for inference #{self.load_count + 1}")
81
+
82
+ model, processor = self._load_model()
83
+ self.load_count += 1
84
+
85
+ try:
86
+ # Process vision info using qwen_vl_utils
87
+ text = processor.apply_chat_template(
88
+ messages, tokenize=False, add_generation_prompt=True
89
+ )
90
+
91
+ # Use process_vision_info for proper image handling
92
+ image_inputs, video_inputs = process_vision_info(messages)
93
+
94
+ inputs = processor(
95
+ text=[text],
96
+ images=image_inputs,
97
+ videos=video_inputs,
98
+ padding=True,
99
+ return_tensors="pt",
100
+ )
101
+ inputs = inputs.to(model.device)
102
+
103
+ generated_ids = model.generate(
104
+ **inputs,
105
+ max_new_tokens=max_new_tokens,
106
+ temperature=temperature,
107
+ do_sample=True if temperature > 0 else False
108
+ )
109
+
110
+ generated_ids_trimmed = [
111
+ out_ids[len(in_ids):]
112
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
113
+ ]
114
+
115
+ response = processor.batch_decode(
116
+ generated_ids_trimmed,
117
+ skip_special_tokens=True,
118
+ clean_up_tokenization_spaces=False
119
+ )[0]
120
+
121
+ return response
122
+
123
+ except Exception as e:
124
+ print(f"Error during Qwen VL generation: {e}")
125
+ raise e
126
+
127
+ finally:
128
+ print(f"Releasing Qwen VL GPU memory after inference")
129
+ if 'model' in locals():
130
+ if hasattr(model, 'cpu'):
131
+ model.cpu()
132
+ del model
133
+ if 'processor' in locals():
134
+ del processor
135
+ if 'inputs' in locals():
136
+ del inputs
137
+ if 'generated_ids' in locals():
138
+ del generated_ids
139
+
140
+ # Multiple garbage collection passes
141
+ for _ in range(3):
142
+ gc.collect()
143
+
144
+ if torch.cuda.is_available():
145
+ torch.cuda.empty_cache()
146
+ torch.cuda.ipc_collect()
147
+ torch.cuda.synchronize()
148
+ try:
149
+ torch.cuda.memory._record_memory_history(enabled=None)
150
+ except:
151
+ pass
152
+ print(f"GPU memory after VLM cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")
src/deepforest_agent/models/smollm3_3b.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from typing import Tuple, Dict, Any, Optional, List, Generator
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from transformers.generation.streamers import TextIteratorStreamer
6
+ from threading import Thread
7
+
8
+ from deepforest_agent.conf.config import Config
9
+
10
+ class SmolLM3ModelManager:
11
+ """
12
+ Manages SmolLM3-3B model instances
13
+
14
+ Attributes:
15
+ model_id (str): HuggingFace model identifier
16
+ load_count (int): Number of times model has been loaded
17
+ """
18
+
19
+ def __init__(self, model_id: str = Config.AGENT_MODELS["deepforest_detector"]):
20
+ """
21
+ Initialize the SmolLM3 model manager.
22
+
23
+ Args:
24
+ model_id (str, optional): HuggingFace model identifier.
25
+ Defaults to "HuggingFaceTB/SmolLM3-3B".
26
+ """
27
+ self.model_id = model_id
28
+ self.load_count = 0
29
+
30
+ def generate_response(
31
+ self,
32
+ messages: List[Dict[str, str]],
33
+ max_new_tokens: int = Config.AGENT_CONFIGS["deepforest_detector"]["max_new_tokens"],
34
+ temperature: float = Config.AGENT_CONFIGS["deepforest_detector"]["temperature"],
35
+ top_p: float = Config.AGENT_CONFIGS["deepforest_detector"]["top_p"],
36
+ tools: Optional[List[Dict[str, Any]]] = None
37
+ ) -> str:
38
+ """
39
+ Generate text response
40
+
41
+ Args:
42
+ messages: List of message dictionaries with 'role' and 'content'
43
+ max_new_tokens: Maximum tokens to generate
44
+ temperature: Sampling temperature
45
+ top_p: Top-p sampling
46
+ tools (Optional[List[Dict[str, Any]]]): List of tools
47
+
48
+ Raises:
49
+ Exception: If generation fails due to model issues, memory, or other errors
50
+ """
51
+ print(f"Loading SmolLM3 for inference #{self.load_count + 1}")
52
+
53
+ model, tokenizer = self._load_model()
54
+ self.load_count += 1
55
+
56
+ try:
57
+ if tools:
58
+ text = tokenizer.apply_chat_template(
59
+ messages,
60
+ xml_tools=tools,
61
+ tokenize=False,
62
+ add_generation_prompt=True
63
+ )
64
+ else:
65
+ text = tokenizer.apply_chat_template(
66
+ messages,
67
+ tokenize=False,
68
+ add_generation_prompt=True
69
+ )
70
+
71
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
72
+
73
+ generated_ids = model.generate(
74
+ model_inputs.input_ids,
75
+ max_new_tokens=max_new_tokens,
76
+ temperature=temperature,
77
+ top_p=top_p,
78
+ do_sample=True,
79
+ pad_token_id=tokenizer.eos_token_id
80
+ )
81
+
82
+ generated_ids = [
83
+ output_ids[len(input_ids):]
84
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
85
+ ]
86
+
87
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
88
+ return response
89
+
90
+ except Exception as e:
91
+ print(f"Error during SmolLM3 text generation: {e}")
92
+ raise e
93
+
94
+ finally:
95
+ print(f"Releasing SmolLM3 GPU memory after inference")
96
+ if 'model' in locals():
97
+ if hasattr(model, 'cpu'):
98
+ model.cpu()
99
+ del model
100
+ if 'tokenizer' in locals():
101
+ del tokenizer
102
+ if 'model_inputs' in locals():
103
+ del model_inputs
104
+ if 'generated_ids' in locals():
105
+ del generated_ids
106
+
107
+ # Multiple garbage collection passes
108
+ for _ in range(3):
109
+ gc.collect()
110
+
111
+ if torch.cuda.is_available():
112
+ torch.cuda.empty_cache()
113
+ torch.cuda.ipc_collect()
114
+ torch.cuda.synchronize()
115
+ try:
116
+ torch.cuda.memory._record_memory_history(enabled=None)
117
+ except:
118
+ pass
119
+ print(f"GPU memory after aggressive cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")
120
+
121
+ def generate_response_streaming(
122
+ self,
123
+ messages: List[Dict[str, str]],
124
+ max_new_tokens: int = Config.AGENT_CONFIGS["deepforest_detector"]["max_new_tokens"],
125
+ temperature: float = Config.AGENT_CONFIGS["deepforest_detector"]["temperature"],
126
+ top_p: float = Config.AGENT_CONFIGS["deepforest_detector"]["top_p"]
127
+ ) -> Generator[Dict[str, Any], None, None]:
128
+ """
129
+ Generate text response with streaming (token by token)
130
+
131
+ Args:
132
+ messages: List of message dictionaries with 'role' and 'content'
133
+ max_new_tokens: Maximum tokens to generate
134
+ temperature: Sampling temperature
135
+ top_p: Top-p sampling
136
+
137
+ Yields:
138
+ Dict[str, Any]: Dictionary containing:
139
+ - token: The generated token/text chunk
140
+ - is_complete: Whether generation is finished
141
+
142
+ Raises:
143
+ Exception: If generation fails due to model issues, memory, or other errors
144
+ """
145
+ print(f"Loading SmolLM3 for streaming inference #{self.load_count + 1}")
146
+
147
+ model, tokenizer = self._load_model()
148
+ self.load_count += 1
149
+
150
+ try:
151
+ text = tokenizer.apply_chat_template(
152
+ messages,
153
+ tokenize=False,
154
+ add_generation_prompt=True
155
+ )
156
+
157
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
158
+
159
+ streamer = TextIteratorStreamer(
160
+ tokenizer,
161
+ timeout=60.0,
162
+ skip_prompt=True,
163
+ skip_special_tokens=True
164
+ )
165
+
166
+ generation_kwargs = {
167
+ "input_ids": model_inputs.input_ids,
168
+ "max_new_tokens": max_new_tokens,
169
+ "temperature": temperature,
170
+ "top_p": top_p,
171
+ "do_sample": True,
172
+ "pad_token_id": tokenizer.eos_token_id,
173
+ "streamer": streamer
174
+ }
175
+
176
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
177
+ thread.start()
178
+
179
+ for new_text in streamer:
180
+ yield {"token": new_text, "is_complete": False}
181
+
182
+ thread.join()
183
+ yield {"token": "", "is_complete": True}
184
+
185
+ except Exception as e:
186
+ print(f"Error during SmolLM3 streaming generation: {e}")
187
+ yield {"token": f"[Error: {str(e)}]", "is_complete": True}
188
+
189
+ finally:
190
+ print(f"Releasing SmolLM3 GPU memory after inference")
191
+ if 'model' in locals():
192
+ if hasattr(model, 'cpu'):
193
+ model.cpu()
194
+ del model
195
+ if 'tokenizer' in locals():
196
+ del tokenizer
197
+ if 'model_inputs' in locals():
198
+ del model_inputs
199
+ if 'generated_ids' in locals():
200
+ del generated_ids
201
+
202
+ # Multiple garbage collection passes
203
+ for _ in range(3):
204
+ gc.collect()
205
+
206
+ if torch.cuda.is_available():
207
+ torch.cuda.empty_cache()
208
+ torch.cuda.ipc_collect()
209
+ torch.cuda.synchronize()
210
+ try:
211
+ torch.cuda.memory._record_memory_history(enabled=None)
212
+ except:
213
+ pass
214
+ print(f"GPU memory after aggressive cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")
215
+
216
+ def _load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
217
+ """
218
+ Private method for model and tokenizer loading.
219
+
220
+ Returns:
221
+ Tuple[AutoModelForCausalLM, AutoTokenizer]: Loaded model and tokenizer
222
+
223
+ Raises:
224
+ Exception: If model loading fails due to network, memory, or other issues
225
+ """
226
+ try:
227
+ tokenizer = AutoTokenizer.from_pretrained(
228
+ self.model_id,
229
+ trust_remote_code=True
230
+ )
231
+
232
+ model = AutoModelForCausalLM.from_pretrained(
233
+ self.model_id,
234
+ torch_dtype="auto",
235
+ device_map="auto",
236
+ trust_remote_code=True,
237
+ low_cpu_mem_usage=True
238
+ )
239
+
240
+ return model, tokenizer
241
+
242
+ except Exception as e:
243
+ print(f"Error loading SmolLM3 model: {e}")
244
+ raise e
src/deepforest_agent/prompts/__init__.py ADDED
File without changes
src/deepforest_agent/prompts/prompt_templates.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, List, Any
2
+ import json
3
+
4
+ from deepforest_agent.conf.config import Config
5
+
6
+ def get_deepforest_tool_schema() -> Dict[str, Any]:
7
+ """
8
+ Get the DeepForest tool schema for structured tool calling.
9
+
10
+ Returns:
11
+ Dict[str, Any]: Tool schema for run_deepforest_object_detection
12
+ """
13
+ deepforest_tool_schema = {
14
+ "name": "run_deepforest_object_detection",
15
+ "description": "Performs object detection on ecological images using DeepForest models to detect birds, trees, livestock, and assess tree health. Use this tool for any queries related to ecological objects, wildlife detection, forest analysis, or tree health assessment.",
16
+ "parameters": {
17
+ "type": "object",
18
+ "properties": {
19
+ "model_names": {
20
+ "type": "array",
21
+ "items": {"type": "string", "enum": ["tree", "bird", "livestock"]},
22
+ "description": "List of models to use for detection. Select based on user query: 'tree' for vegetation or forest, 'bird' for avian species, 'livestock' for farm animals. Default: ['tree', 'bird', 'livestock']. Always include 'tree' when alive_dead_trees is true.",
23
+ "default": ["tree", "bird", "livestock"]
24
+ },
25
+ "patch_size": {
26
+ "type": "integer",
27
+ "description": f"Window size in pixels (default {Config.DEEPFOREST_DEFAULTS['patch_size']}) The size for the crops used to cut the input image/raster into smaller pieces.",
28
+ "default": Config.DEEPFOREST_DEFAULTS["patch_size"]
29
+ },
30
+ "patch_overlap": {
31
+ "type": "number",
32
+ "description": f"The horizontal and vertical overlap among patches (must be between 0-1) (default {Config.DEEPFOREST_DEFAULTS['patch_overlap']})",
33
+ "default": Config.DEEPFOREST_DEFAULTS["patch_overlap"]
34
+ },
35
+ "iou_threshold": {
36
+ "type": "number",
37
+ "description": f"Minimum IoU overlap among predictions between windows to be suppressed (default {Config.DEEPFOREST_DEFAULTS['iou_threshold']})",
38
+ "default": Config.DEEPFOREST_DEFAULTS["iou_threshold"]
39
+ },
40
+ "thresh": {
41
+ "type": "number",
42
+ "description": f"Score threshold used to filter bboxes after soft-NMS is performed (default {Config.DEEPFOREST_DEFAULTS['thresh']})",
43
+ "default": Config.DEEPFOREST_DEFAULTS["thresh"]
44
+ },
45
+ "alive_dead_trees": {
46
+ "type": "boolean",
47
+ "description": f"Enable tree health classification to distinguish between alive and dead trees. Required for forest health analysis. When true, 'tree' must be included in model_names. (default {Config.DEEPFOREST_DEFAULTS['alive_dead_trees']})",
48
+ "default": Config.DEEPFOREST_DEFAULTS["alive_dead_trees"]
49
+ }
50
+ },
51
+ "required": ["model_names"]
52
+ }
53
+ }
54
+ return deepforest_tool_schema
55
+
56
+ def format_memory_prompt(conversation_history: List[Dict[str, Any]], latest_message: str, conversation_context: str) -> str:
57
+ """
58
+ Format the memory analysis prompt for new conversation history format.
59
+
60
+ Args:
61
+ conversation_history: Filtered conversation history
62
+ latest_message: Current user message
63
+ conversation_context: Formatted conversation context with turn structure
64
+
65
+ Returns:
66
+ Formatted prompt for memory analysis
67
+ """
68
+ prompt = f"""You are a conversation memory manager for an ecological data analytics assistant. Your role is to analyze previous conversation turns and determine if you can answer the user's query.
69
+
70
+ The user is using DeepForest Agent which can analyze ecological images for objects like trees, birds, livestock, and assess tree health. The user may ask questions about the image content, object counts, spatial distributions, or ecological patterns. The user may also ask follow-up questions based on previous answers. That's why you have access to the previous conversation turns so that you can determine if the answer is already available or if you can provide context for the agents.
71
+
72
+ You should not make up any information that is not present in the previous conversation turns. You should only use the previous conversation turns to answer the user's query. You can't also just make up wrong information with the previous conversation turns. Make sure you are addressing the user query when you are using the previous conversation turns.
73
+
74
+ When you provide RELEVANT_CONTEXT, you should provide with analysis and a comprehensive details about every data that you are providing. Do not just give quick and direct answers. The ecology agent will use this context to provide the final answer to the user. So, make sure you are providing all the relevant context that can help the ecology agent to provide the best possible answer to the user. Your tone should be very professional and analytical. You cannot make any assumptions or guesses.
75
+
76
+ You have access to previous conversation turns. Your task is to determine if the current user query can be answered using information from previous turns, and provide the tool cache ID if detection data needs to be retrieved.
77
+
78
+ Here is the Conversation History:
79
+ {conversation_context}
80
+
81
+ Latest user query: {latest_message}
82
+
83
+ Your response format:
84
+
85
+ **ANSWER_PRESENT:** [YES or NO]
86
+ [YES if the latest user query can be answered fully using information from previous conversation turns and if the latest query is exactly similar to any of the previous queries. Otherwise NO]
87
+
88
+ **TOOL_CACHE_ID:**
89
+ [Analyze tool call information and provide relevant Tool cache IDs from the previous turns that can answer the latest user query. If multiple turns are relevant, provide multiple Tool cache IDs separated by commas.]
90
+
91
+ **RELEVANT_CONTEXT:**
92
+ [Provide a comprehensive analysis using data from previous turns including visual analysis, detection narratives, and ecology responses that answer the user's query. Include specific turn references. If no previous context is relevant, state "No relevant context from previous conversations."]
93
+
94
+ /no_think"""
95
+
96
+ return prompt
97
+
98
+ def create_full_image_quality_analysis_prompt(user_message: str) -> str:
99
+ """
100
+ Create system prompt for full image quality assessment.
101
+
102
+ Args:
103
+ user_message: User's query
104
+
105
+ Returns:
106
+ System prompt for full image quality analysis
107
+ """
108
+ return f"""You are a computer vision expert. Your task is to analyze the ecological image with your image understanding ability. You will provide a comprehensive analysis of this ecological image.
109
+
110
+ The user is using DeepForest Agent which can analyze ecological images for objects like trees, birds, livestock, and assess tree health. An ecological image is provided for you already to analyze. The user may ask questions about the image content, object counts, spatial distributions, or ecological patterns. To answer the user's query, ecological image quality is very important. Otherwise, the DeepForest object detection will not work properly. That's why determining if the image is an ecological aerial/drone image with good quality is very important. You also have to analyze the ecological image completely and provide a comprehensive summary of what's in this ecological image and what's happening. The user likely wants to know about the objects present in this ecological image. It's going to help with the ecological analysis. User likes spatial details and specific location information.
111
+
112
+ So, make sure you are providing all the important details about this ecological image. Incorporate species identification, behavior observations, environmental conditions, and habitat characteristics if possible.
113
+
114
+ You should not make up any information that is not present in the image. You should only use the image to answer the user's query. You can't also just make up wrong information with the image. Do not miss any important details. If possible, zoom in on the image to see more details. Give specific location details when you explain what's in this image. But do not make up false location or explanation. Be strictly based on what you see in this image. Making up false information is not acceptable. Do not make assumptions or guesses. Analyze the image thoroughly before making any claims.
115
+
116
+ Your tone should be very professional and expert in visual analysis. User likes insightful and detailed analysis. So, don't worry about the length of your response. Make it as long as necessary to cover all important aspects of this image. The response must be related to the user query. User's query: "{user_message}". Try to answer the user query with your visual analysis. Follow the structure below in your response:
117
+
118
+ **IMAGE_QUALITY_FOR_DEEPFOREST:** [YES or NO]
119
+ YES if this is a good quality aerial/drone image with clear ecological objects (trees, birds, livestock) that would be suitable for automated DeepForest object detection. NO if image quality is poor, too close-up, wrong angle, blurry, or not an ecological aerial/drone image.
120
+
121
+ **DEEPFOREST_OBJECTS_PRESENT:** []
122
+ List the objects from ["bird", "tree", "livestock"] that are clearly visible in the image. Example: ["bird", "tree"].
123
+
124
+ **ADDITIONAL_OBJECTS:** [JSON array]
125
+ Any objects present in this image with rough coordinates that are not bird, tree or livestock. Do not include bird, tree or livestock coordinates here. Also do not make up any false objects. Only include objects that are clearly visible in the image and necessary according to the user query.
126
+
127
+ **VISUAL_ANALYSIS:**
128
+ [In this section, you will provide the comprehensive visual analysis of the image. You should start with a brief summary containing spatial analysis of what's in the image. Then, give a brief summary if it's an ecological aerial/drone image or not. Then, you should analyze the image completely and provide a comprehensive summary of what's in this image and what's happening. If possible, zoom in on the specific locations of the image to see more details. Answer what's present in the image mentioning specific objects and their counts if possible. But do not make up false counts or objects. Make sure to incorporate species identification, behavior observations, environmental conditions, and habitat characteristics according to the image. Answer the user query "{user_message}" with what you see in the image in detail with proper reasoning, insights, bounding box coordinates and evidence. It can be as long as necessary to cover all important aspects of the image. Do not hallucinate or guess this part and you must provide bounding box coordinates for all the objects you are mentioning in this section. You must mention that this analysis is provided by a visual analysis agent and it may not be very accurate as there is no confidence score associated with it.]
129
+ """
130
+
131
+ def create_individual_tile_analysis_prompt(user_message: str) -> str:
132
+ """
133
+ Create system prompt for individual tile analysis.
134
+
135
+ Args:
136
+ user_message: User's query
137
+
138
+ Returns:
139
+ System prompt for tile-by-tile analysis
140
+ """
141
+ return f"""You are a computer vision expert. Your task is to analyze the given tiled image of an ecological image with your image understanding ability. You will provide a comprehensive analysis of this tile section of the image.
142
+
143
+ The user is using DeepForest Agent which can analyze ecological images for objects like trees, birds, livestock, and assess tree health. A tiled image is provided for you already to analyze. The user may ask questions about the image content, object counts, spatial distributions, or ecological patterns. To answer the user's query, ecological image quality is very important. Otherwise, the DeepForest object detection will not work properly. That's why determining if the tiled image is an ecological aerial/drone image with good quality is very important. You also have to analyze this tile section of the image completely and provide a comprehensive summary of what's in this tile and what's happening. The user likely wants to know about the objects present in this tile section of the image. It's going to help with the ecological analysis. User likes spatial details and specific location information, which is easily missed in a large image.
144
+
145
+ So, make sure you are providing all the important details about this tile section of the image. Incorporate species identification, behavior observations, environmental conditions, and habitat characteristics if possible.
146
+
147
+ You should not make up any information that is not present in the tiled image. You should only use the tiled image to answer the user's query. You can't also just make up wrong information with the image. Do not miss any important details. If possible, zoom in on the tiled image to see more details. Give specific location details when you explain what's in this tile. But do not make up false location or explanation. Be strictly based on what you see in this tile section of the image. Making up false information is not acceptable. Do not make assumptions or guesses. Analyze the image thoroughly before making any claims.
148
+
149
+ Your tone should be very professional and expert in visual analysis. User likes insightful and detailed analysis. So, don't worry about the length of your response. Make it as long as necessary to cover all important aspects of this tile section of the image. The response must be related to the user query. User's query: "{user_message}". Try to answer the user query with your visual analysis. Follow the structure below in your response:
150
+
151
+ **IMAGE_QUALITY_FOR_DEEPFOREST:** [YES or NO]
152
+ YES if this is a good quality aerial/drone tiled image with clear ecological objects (trees, birds, livestock) that would be suitable for automated DeepForest object detection. NO if image quality is poor, too close-up, wrong angle, blurry, or not an ecological aerial/drone image.
153
+
154
+ **DEEPFOREST_OBJECTS_PRESENT:** []
155
+ List the objects from ["bird", "tree", "livestock"] that are clearly visible in the image. Example: ["bird", "tree"].
156
+
157
+ **ADDITIONAL_OBJECTS:** [JSON array]
158
+ Any objects present in this tile section of the image with rough coordinates that are not bird, tree or livestock. Do not include bird, tree or livestock coordinates here. Also do not make up any false objects. Only include objects that are clearly visible in the image and necessary according to the user query.
159
+
160
+ **VISUAL_ANALYSIS:**
161
+ [In this section, you will provide the comprehensive visual analysis of this tile section of the image. You should start with a brief summary containing spatial analysis of what's in this tile section of the image. Then, give a brief summary if it's an ecological aerial/drone image or not. Then, you should analyze the tile completely and provide a comprehensive summary of what's in this tile and what's happening. If possible, zoom in on the specific locations of the tile to see more details. Answer what's present in the tile mentioning specific objects and their counts if possible. But do not make up false counts or objects. Make sure to incorporate species identification, behavior observations, environmental conditions, and habitat characteristics according to the tiled image. Answer the user query "{user_message}" with what you see in the image in detail with proper reasoning, insights, bounding box coordinates and evidence. It can be as long as necessary to cover all important aspects of the tiled image. Do not hallucinate or guess this part and you must provide bounding box coordinates for all the objects you are mentioning in this section. You must mention that this analysis is provided by a visual analysis agent and it may not be very accurate as there is no confidence score associated with it.]
162
+ """
163
+
164
+ def create_detector_system_prompt_with_reasoning(user_message: str, memory_context: str, visual_objects: List[str]) -> str:
165
+ """
166
+ Create the system prompt for the detector agent.
167
+
168
+ Args:
169
+ user_message (str): The original user question
170
+ memory_context (str): Context from memory agent
171
+ visual_objects (List[str]): Objects detected by visual agent
172
+
173
+ Returns:
174
+ System prompt for enhanced tool calling with all context included
175
+ """
176
+
177
+ return """You are a smart DeepForest Tool Calling Agent with reasoning capabilities. You will receive:
178
+
179
+ 1. **User Query**: {user_message}
180
+ 2. **Memory Context**: {memory_context}
181
+ 3. **Objects detected by visual analysis**: {visual_objects}
182
+
183
+ Your task is to call the "run_deepforest_object_detection" tool with intelligent parameter selection based on user query. You can always assume the image is provided. The image will be passed later during tool execution. So, right now based on available data and user query make the right choice. You may need to provide multiple tool calls if necessary according to User Query, and Memory Context.
184
+
185
+ REASONING PROCESS:
186
+
187
+ **STEP 1: PARAMETERS UNDERSTANDING**
188
+ You have to understand the query thoroughly to choose appropriate parameters. Remember these are the only parameters that are available. So, use your knowledge to utilize these parameters based on query.
189
+ - model_names (list): Choose models from this ["tree", "bird", "livestock"] list based on what user wants to detect. If alive_dead_trees is true for tree health or dead/alive trees make sure to add "tree" to the list along with other requested models.
190
+ - patch_size (int): Window size in pixels (default 400) The size for the crops used to cut the input image/raster into smaller pieces.
191
+ - patch_overlap (float): The horizontal and vertical overlap among patches (must be between 0-1) (default 0.05)
192
+ - iou_threshold (float): Minimum IoU overlap among predictions between windows to be suppressed (default 0.15)
193
+ - thresh (float): Score threshold used to filter bboxes after soft-NMS is performed (default 0.55)
194
+ - alive_dead_trees (bool): Whether to classify trees as alive/dead, needed for forest or tree health (default false). If you select this as true make sure to include "tree" to model_names list. If user wants to know about tree health, forest health, dead trees or alive trees, you must set this parameter to true.
195
+
196
+ **STEP 2: MEMORY CONTEXT INTEGRATION**
197
+ - Use memory context to clarify unclear queries
198
+ - If user query is vague, use conversation history to understand intent
199
+ - Select parameters based on intention from memory context
200
+
201
+ **STEP 3: VISUAL OBJECT FILTERING**
202
+ - The visual objects are: {visual_objects}
203
+ - After deciding on the tool arguments from Memory context and user query, in model_names validate the models if it's present in the visual objects. The models that are not present in the visual objects should be removed.
204
+
205
+ **STEP 4: PARAMETER REASONING WITH QUERY**
206
+ - Based on the user query and your parameter understanding, choose the parameters wisely. Think if you can use available model_names or other parameters to address the user query better.
207
+
208
+ **CRITICAL: ACCURATE REASONING ONLY**
209
+ - Base your reasoning only on the provided user query, memory context, and visual objects
210
+ - Do not assume capabilities or parameters not explicitly mentioned
211
+ - If visual objects list is empty or unclear, acknowledge this limitation
212
+ - Do not make up technical details about DeepForest that aren't in the parameter descriptions
213
+
214
+ Your response format:
215
+ **REASONING:** [Explain your visual filtering, memory integration, and parameter choices based only on provided information]
216
+
217
+ Then provide the tool calls using the schema.
218
+
219
+ Always provide clear reasoning for your parameter choices before making the tool call. Your reasoning helps users understand why you chose specific detection models and parameters for their query./no_think"""
220
+
221
+ def create_ecology_synthesis_prompt(
222
+ user_message: str,
223
+ comprehensive_context: str,
224
+ cached_json: Optional[Dict[str, Any]] = None,
225
+ current_json: Optional[Dict[str, Any]] = None
226
+ ) -> str:
227
+ """
228
+ Create system prompt for ecology agent with new context format.
229
+
230
+ Args:
231
+ user_message: User's original query
232
+ comprehensive_context: Comprehensive context from memory + visual + detection narrative
233
+ cached_json (Optional[Dict[str, Any]]): A dictionary of previously
234
+ cached JSON data, if available. Defaults to None.
235
+ current_json (Optional[Dict[str, Any]]): A dictionary of new JSON data
236
+ from the current analysis step. Defaults to None.
237
+
238
+ Returns:
239
+ System prompt for ecological synthesis
240
+ """
241
+ prompt = f"""You are a Geospatial Image Analysis and Interpretation Assistant. Your primary task is to interpret and reason about complex image data from multiple data sources to answer the user query. You must synthesize information from multiple data sources, including memory context(If there is anything relevant), visual analysis, and DeepForest Detection Summary to construct your answers. Your main task is to act as a bridge between the data and the user's understanding, translating technical information into clear, descriptive language and providing proper reasoning to support your findings.
242
+
243
+ The user is using DeepForest Agent which can analyze ecological images for objects like trees, birds, livestock, and assess tree health. The user may ask questions about the image content, object counts, spatial distributions, or ecological patterns. They're trying to understand the content of an image, specifically regarding the distribution of ecological objects like trees, birds, or other wildlife. The user is asking the agent to act as a helpful guide to understand the complex DeepForest analysis data and the visual analysis data. The user has provided a query: {user_message}.
244
+
245
+ Based on the provided context, you must synthesize all available information to provide a comprehensive answer to the user's query.
246
+
247
+ Context Data:
248
+ {comprehensive_context}
249
+
250
+ Your tone should be professional, helpful, and highly informative. Avoid being overly robotic or technical. Use simple language that a non-expert can understand easily. You must be empathetic and nonjudgmental, recognizing that the user may not be familiar with the technical details of geospatial analysis. Focus on ecological insights that directly answer the user's query.
251
+
252
+ Under no circumstances should you invent or hallucinate information that is not present in the multiple data sources. All your statements must be directly supported by the data you have been given. If the data is insufficient to answer the query, you must state that clearly and explain why. You must also not falsify the multiple data sources. If you are unsure about any information, it is better to acknowledge the uncertainty than to provide potentially incorrect information. Analyze the multiple data sources thoroughly before making any claims. Never hallucinate detection coordinates, object labels, object counts, visual analysis, or confidence scores. Never mention cache keys, or technical metadata in your response. Do not mix visual analysis with the detection analysis. And you must inform the user that you are not confident about the visual analysis as there is no confidence score associated with it but you are confident about the DeepForest detection data as it has confidence scores associated with it. So if there is any conflict between visual analysis and DeepForest detection data, you should always trust the DeepForest detection data more than the visual analysis. Always provide detection analysis with proper confidence scores ranges and detailed reasoning. It can have multiple paragraphs and sections if necessary.
253
+
254
+ The response can be as long as necessary to cover all important aspects of the user's query. Starting paragraph should be the "Direct Answer" that immediately addresses the user query with proper reasoning from the detection analysis, memory context (if there's any), and visual analysis. Your response should be based on the available "DETECTION ANALYSIS", which you have to provide a detailed breakdown with proper reasoning that will address the user query: {user_message}. If multiple detection analysis exists for multiple tool calls, you can provide a comprehensive comparison in "Result Comparison". Then you can mention some relevant information from the visual analysis to address the user query but remember you are not very confident about it. Then, you must also provide "Spatial Distribution and Ecological Patterns", and Translate detection results into "Ecological Interpretation from DeepForest Data". All of these sections should be a comprehensive and insightful answer that leverages all available data. Use markdown headings (##) to create distinct sections if the response is lengthy. Make sure to incorporate the multiple data sources to create these sections without hallucinating. Bold important keywords to make them stand out. Seperate into paragraphs for better readability. Use bullet points or numbered lists where appropriate to organize information clearly. Conclude with a clear and concise summary of your findings.
255
+ """
256
+
257
+ return prompt
src/deepforest_agent/tools/__init__.py ADDED
File without changes
src/deepforest_agent/tools/deepforest_tool.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import tempfile
4
+ from typing import List, Optional, Tuple, Dict, Any
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import pandas as pd
9
+ from PIL import Image
10
+ from shapely.geometry import shape
11
+
12
+ from deepforest import main
13
+ from deepforest.model import CropModel
14
+ from deepforest_agent.conf.config import Config
15
+ from deepforest_agent.utils.image_utils import convert_rgb_to_bgr, convert_bgr_to_rgb, load_image_as_np_array, create_temp_image_file, cleanup_temp_file
16
+
17
+
18
+ class DeepForestPredictor:
19
+ """Predictor class for DeepForest object detection models."""
20
+
21
+ def __init__(self):
22
+ """Initialize the DeepForest predictor."""
23
+ pass
24
+
25
+ def _generate_detection_summary(self, predictions_df: pd.DataFrame,
26
+ alive_dead_trees: bool = False) -> str:
27
+ """
28
+ Generate summary of detection results.
29
+
30
+ Args:
31
+ predictions_df: DataFrame containing detection results
32
+ alive_dead_trees: Whether alive/dead tree classification was used
33
+
34
+ Returns:
35
+ DeepForest Detection Summary String
36
+ """
37
+ if predictions_df.empty:
38
+ return "No objects detected by DeepForest with the requested models."
39
+
40
+ detection_summary_parts = []
41
+ counts = predictions_df['label'].value_counts()
42
+
43
+ if 'classification_label' in predictions_df.columns:
44
+ non_tree_df = predictions_df[predictions_df['label'] != 'tree']
45
+ if not non_tree_df.empty:
46
+ non_tree_counts = non_tree_df['label'].value_counts()
47
+ for label, count in non_tree_counts.items():
48
+ label_str = str(label).replace('_', ' ')
49
+ if count == 1:
50
+ detection_summary_parts.append(f"{count} {label_str}")
51
+ else:
52
+ detection_summary_parts.append(f"{count} {label_str}s")
53
+
54
+ tree_df = predictions_df[predictions_df['label'] == 'tree']
55
+ if not tree_df.empty:
56
+ total_trees = len(tree_df)
57
+ classification_counts = tree_df['classification_label'].value_counts()
58
+
59
+ classification_parts = []
60
+ for class_label, count in classification_counts.items():
61
+ class_str = str(class_label).replace('_', ' ')
62
+ classification_parts.append(f"{count} are classified as {class_str}")
63
+
64
+ if total_trees == 1:
65
+ detection_summary_parts.append(f"from {total_trees} tree, {' and '.join(classification_parts)}")
66
+ else:
67
+ detection_summary_parts.append(f"from {total_trees} trees, {' and '.join(classification_parts)}")
68
+ else:
69
+ for label, count in counts.items():
70
+ label_str = str(label).replace('_', ' ')
71
+ if count == 1:
72
+ detection_summary_parts.append(f"{count} {label_str}")
73
+ else:
74
+ detection_summary_parts.append(f"{count} {label_str}s")
75
+
76
+ detection_summary = f"DeepForest detected: {', '.join(detection_summary_parts)}."
77
+
78
+ return detection_summary
79
+
80
+ @staticmethod
81
+ def _plot_boxes(image_array: np.ndarray, predictions: pd.DataFrame,
82
+ colors: dict, thickness: int = 2) -> np.ndarray:
83
+ """
84
+ Plot bounding boxes on image.
85
+
86
+ Args:
87
+ image_array: Input image as numpy array
88
+ predictions: DataFrame with detection results
89
+ colors: Color mapping for different labels
90
+ thickness: Line thickness for bounding boxes
91
+
92
+ Returns:
93
+ Image array with drawn bounding boxes
94
+ """
95
+ image = image_array.copy()
96
+ image = convert_rgb_to_bgr(image)
97
+
98
+ for _, row in predictions.iterrows():
99
+ xmin, ymin = int(row['xmin']), int(row['ymin'])
100
+ xmax, ymax = int(row['xmax']), int(row['ymax'])
101
+
102
+ if 'classification_label' in row and pd.notna(row['classification_label']):
103
+ label = str(row['classification_label'])
104
+ else:
105
+ label = str(row['label'])
106
+ color = colors.get(label.lower(), (200, 200, 200))
107
+
108
+ cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, thickness)
109
+
110
+ text_x = xmin
111
+ text_y = ymin - 10 if ymin - 10 > 10 else ymin + 15
112
+ cv2.putText(image, label, (text_x, text_y),
113
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, thickness)
114
+
115
+ image = convert_bgr_to_rgb(image)
116
+
117
+ return image
118
+
119
+ def predict_objects(
120
+ self,
121
+ image_data_array: Optional[np.ndarray] = None,
122
+ image_file_path: Optional[str] = None,
123
+ model_names: Optional[List[str]] = None,
124
+ patch_size: int = Config.DEEPFOREST_DEFAULTS["patch_size"],
125
+ patch_overlap: float = Config.DEEPFOREST_DEFAULTS["patch_overlap"],
126
+ iou_threshold: float = Config.DEEPFOREST_DEFAULTS["iou_threshold"],
127
+ thresh: float = Config.DEEPFOREST_DEFAULTS["thresh"],
128
+ alive_dead_trees: bool = Config.DEEPFOREST_DEFAULTS["alive_dead_trees"]
129
+ ) -> Tuple[str, Optional[np.ndarray], List[Dict[str, Any]]]:
130
+ """
131
+ Predict objects using DeepForest models with predict_tile method of DeepForest models
132
+
133
+ Args:
134
+ image_data_array: Input image as numpy array (optional if image_file_path not provided)
135
+ image_file_path: Path to image file
136
+ model_names: List of model names to use for prediction
137
+ patch_size: Size of patches for tiled prediction
138
+ patch_overlap: Patch overlap among windows
139
+ iou_threshold: Minimum IoU overlap among predictions between windows to be suppressed
140
+ thresh: Score threshold used to filter bboxes after soft-NMS is performed
141
+ alive_dead_trees: Whether to classify trees as alive/dead
142
+
143
+ Returns:
144
+ Tuple containing:
145
+ - detection_summary: Human-readable summary of detections
146
+ - annotated_image_array: Image with bounding boxes drawn
147
+ - detections_list: List of detection data
148
+ """
149
+
150
+ if model_names is None:
151
+ model_names = ["tree", "bird", "livestock"]
152
+
153
+ if image_file_path is None and image_data_array is None:
154
+ raise ValueError("Either image_data_array or image_file_path must be provided")
155
+
156
+ temp_file_path = None
157
+ use_provided_path = image_file_path is not None
158
+
159
+ if not use_provided_path:
160
+ if image_data_array is not None:
161
+ temp_file_path = create_temp_image_file(image_data_array, suffix=".png")
162
+ working_file_path = temp_file_path
163
+ working_array = image_data_array
164
+ else:
165
+ raise ValueError("image_data_array cannot be None when use_provided_path is False")
166
+ else:
167
+ working_file_path = image_file_path
168
+ working_array = load_image_as_np_array(image_file_path)
169
+
170
+ all_predictions_df = pd.DataFrame({
171
+ "xmin": pd.Series(dtype=int),
172
+ "ymin": pd.Series(dtype=int),
173
+ "xmax": pd.Series(dtype=int),
174
+ "ymax": pd.Series(dtype=int),
175
+ "score": pd.Series(dtype=float),
176
+ "label": pd.Series(dtype=str),
177
+ "model_type": pd.Series(dtype=str)
178
+ })
179
+
180
+ model_instances = {}
181
+ for model_name_key in model_names:
182
+ model_path = Config.DEEPFOREST_MODELS.get(model_name_key)
183
+ if model_path is None:
184
+ print(f"Warning: Model '{model_name_key}' not found in "
185
+ f"Config.DEEPFOREST_MODELS. Skipping.")
186
+ continue
187
+
188
+ try:
189
+ model = main.deepforest()
190
+ model.load_model(model_name=model_path)
191
+ model_instances[model_name_key] = model
192
+ except Exception as e:
193
+ print(f"Error loading DeepForest model '{model_name_key}' "
194
+ f"from path '{model_path}': {e}. Skipping this model.")
195
+ continue
196
+
197
+ temp_file_path = None
198
+
199
+ # Process each model
200
+ for model_type, model in model_instances.items():
201
+ current_predictions = pd.DataFrame()
202
+ try:
203
+ if model_type == "tree" and alive_dead_trees:
204
+ crop_model_instance = CropModel(num_classes=2)
205
+ current_predictions = model.predict_tile(
206
+ raster_path=working_file_path,
207
+ patch_size=patch_size,
208
+ patch_overlap=patch_overlap,
209
+ crop_model=crop_model_instance,
210
+ iou_threshold=iou_threshold,
211
+ thresh=thresh
212
+ )
213
+ else:
214
+ current_predictions = model.predict_tile(
215
+ raster_path=working_file_path,
216
+ patch_size=patch_size,
217
+ patch_overlap=patch_overlap,
218
+ iou_threshold=iou_threshold,
219
+ thresh=thresh
220
+ )
221
+
222
+ if not current_predictions.empty:
223
+ current_predictions['model_type'] = model_type
224
+
225
+ if 'label' in current_predictions.columns:
226
+ current_predictions['label'] = (
227
+ current_predictions['label'].apply(
228
+ lambda x: str(x).lower()
229
+ )
230
+ )
231
+
232
+ # Handle alive/dead tree classification results
233
+ if (alive_dead_trees and 'cropmodel_label' in
234
+ current_predictions.columns and model_type == "tree"):
235
+ current_predictions['classification_label'] = (
236
+ current_predictions.apply(
237
+ lambda row: (
238
+ 'alive_tree' if row['cropmodel_label'] == 0
239
+ else 'dead_tree' if row['cropmodel_label'] == 1
240
+ else row['label']
241
+ ),
242
+ axis=1
243
+ )
244
+ )
245
+ if 'cropmodel_score' in current_predictions.columns:
246
+ current_predictions['classification_score'] = current_predictions['cropmodel_score']
247
+ current_predictions = current_predictions.drop(columns=['cropmodel_score'], errors='ignore')
248
+
249
+ current_predictions = current_predictions.drop(
250
+ columns=['cropmodel_label'],
251
+ errors='ignore'
252
+ )
253
+
254
+ all_predictions_df = pd.concat(
255
+ [all_predictions_df, current_predictions],
256
+ ignore_index=True
257
+ )
258
+
259
+ except Exception as e:
260
+ print(f"Error during DeepForest prediction for model "
261
+ f"'{model_type}': {e}")
262
+ if temp_file_path:
263
+ cleanup_temp_file(temp_file_path)
264
+
265
+ # Generate detection summary
266
+ detection_summary = self._generate_detection_summary(
267
+ all_predictions_df, alive_dead_trees
268
+ )
269
+
270
+ # Create annotated image with bounding boxes
271
+ annotated_image_array = None
272
+ if working_array.ndim == 2:
273
+ annotated_image_array = cv2.cvtColor(
274
+ working_array, cv2.COLOR_GRAY2RGB
275
+ )
276
+ elif (working_array.ndim == 3 and
277
+ working_array.shape[2] == 4):
278
+ annotated_image_array = cv2.cvtColor(
279
+ working_array, cv2.COLOR_RGBA2RGB
280
+ )
281
+ else:
282
+ annotated_image_array = working_array.copy()
283
+
284
+ if annotated_image_array.dtype != np.uint8:
285
+ annotated_image_array = annotated_image_array.astype(np.uint8)
286
+
287
+ annotated_image_array = self._plot_boxes(
288
+ annotated_image_array, all_predictions_df, Config.COLORS
289
+ )
290
+
291
+ output_df = all_predictions_df.copy()
292
+
293
+ essential_columns = ['xmin', 'ymin', 'xmax', 'ymax', 'score', 'label']
294
+ if 'classification_label' in output_df.columns:
295
+ essential_columns.append('classification_label')
296
+ if 'classification_score' in output_df.columns:
297
+ essential_columns.append('classification_score')
298
+
299
+ output_df = output_df[
300
+ [col for col in essential_columns if col in output_df.columns]
301
+ ]
302
+ detections_list = []
303
+ if not output_df.empty:
304
+ for _, row in output_df.iterrows():
305
+ record = {
306
+ "xmin": int(row['xmin']),
307
+ "ymin": int(row['ymin']),
308
+ "xmax": int(row['xmax']),
309
+ "ymax": int(row['ymax']),
310
+ "score": float(row['score']),
311
+ "label": str(row['label'])
312
+ }
313
+ if 'classification_label' in row:
314
+ record["classification_label"] = str(row['classification_label'])
315
+ if 'classification_score' in row:
316
+ try:
317
+ record["classification_score"] = float(row['classification_score'])
318
+ except (ValueError, TypeError):
319
+ pass
320
+
321
+ detections_list.append(record)
322
+
323
+ return detection_summary, annotated_image_array, detections_list
src/deepforest_agent/tools/tool_handler.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from json import JSONDecoder
3
+ import re
4
+ from typing import Dict, Any, Optional, Union, List
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ from deepforest_agent.tools.deepforest_tool import DeepForestPredictor
9
+ from deepforest_agent.utils.state_manager import session_state_manager
10
+ from deepforest_agent.utils.image_utils import validate_image_path
11
+ from deepforest_agent.conf.config import Config
12
+
13
+ deepforest_predictor = DeepForestPredictor()
14
+
15
+ def run_deepforest_object_detection(
16
+ session_id: str,
17
+ model_names: List[str] = ["tree", "bird", "livestock"],
18
+ patch_size: int = Config.DEEPFOREST_DEFAULTS["patch_size"],
19
+ patch_overlap: float = Config.DEEPFOREST_DEFAULTS["patch_overlap"],
20
+ iou_threshold: float = Config.DEEPFOREST_DEFAULTS["iou_threshold"],
21
+ thresh: float = Config.DEEPFOREST_DEFAULTS["thresh"],
22
+ alive_dead_trees: bool = Config.DEEPFOREST_DEFAULTS["alive_dead_trees"]
23
+ ) -> Dict[str, Any]:
24
+ """
25
+ Run DeepForest object detection on the globally stored image.
26
+
27
+ Args:
28
+ session_id (str): Unique session identifier for this user
29
+ model_names: List of model names to use ("tree", "bird", "livestock")
30
+ patch_size: Patch size for each window in pixels (not geographic units). The size for the crops used to cut the input image/raster into smaller pieces.
31
+ patch_overlap: Patch overlap among windows. The horizontal and vertical overlap among patches (must be between 0-1).
32
+ iou_threshold: Minimum IoU overlap among predictions between windows to be suppressed.
33
+ thresh: Score threshold used to filter bboxes after soft-NMS is performed.
34
+ alive_dead_trees: Whether to classify trees as alive/dead
35
+
36
+ Returns:
37
+ Dictionary with detection_summary and detections_list
38
+ """
39
+ # Validate session exists
40
+ if not session_state_manager.session_exists(session_id):
41
+ return {
42
+ "detection_summary": f"Session {session_id} not found.",
43
+ "detections_list": [],
44
+ "status": "error"
45
+ }
46
+
47
+ image_file_path = session_state_manager.get(session_id, "image_file_path")
48
+ current_image = session_state_manager.get(session_id, "current_image")
49
+
50
+ if image_file_path is None and current_image is None:
51
+ return {
52
+ "detection_summary": f"No image available for detection in session {session_id}.",
53
+ "detections_list": [],
54
+ "status": "error"
55
+ }
56
+
57
+ if image_file_path and not validate_image_path(image_file_path):
58
+ print(f"Warning: Invalid image file path {image_file_path}, falling back to PIL image")
59
+ image_file_path = None
60
+
61
+ try:
62
+ if image_file_path:
63
+ print(f"DeepForest: Processing image from file path: {image_file_path}")
64
+ detection_summary, annotated_image, detections_list = deepforest_predictor.predict_objects(
65
+ image_file_path=image_file_path,
66
+ model_names=model_names,
67
+ patch_size=patch_size,
68
+ patch_overlap=patch_overlap,
69
+ iou_threshold=iou_threshold,
70
+ thresh=thresh,
71
+ alive_dead_trees=alive_dead_trees
72
+ )
73
+ else:
74
+ print(f"DeepForest: Processing PIL image (size: {current_image.size})")
75
+ image_array = np.array(current_image)
76
+ detection_summary, annotated_image, detections_list = deepforest_predictor.predict_objects(
77
+ image_data_array=image_array,
78
+ model_names=model_names,
79
+ patch_size=patch_size,
80
+ patch_overlap=patch_overlap,
81
+ iou_threshold=iou_threshold,
82
+ thresh=thresh,
83
+ alive_dead_trees=alive_dead_trees
84
+ )
85
+
86
+ if annotated_image is not None:
87
+ session_state_manager.set(session_id, "annotated_image", Image.fromarray(annotated_image))
88
+
89
+ result = {
90
+ "detection_summary": detection_summary,
91
+ "detections_list": detections_list,
92
+ "total_detections": len(detections_list),
93
+ "status": "success"
94
+ }
95
+
96
+ return result
97
+
98
+ except Exception as e:
99
+ error_msg = f"Error during image detection in session {session_id}: {str(e)}"
100
+ print(f"DeepForest Detection Error: {error_msg}")
101
+ return {
102
+ "detection_summary": error_msg,
103
+ "detections_list": [],
104
+ "total_detections": 0,
105
+ "status": "error"
106
+ }
107
+
108
+ def extract_all_tool_calls(text: str) -> List[Dict[str, Any]]:
109
+ """
110
+ Extract all tool call information from model output text.
111
+
112
+ Args:
113
+ text: The model's output text that may contain multiple tool calls
114
+
115
+ Returns:
116
+ List of dictionaries with tool call info (empty list if none found)
117
+ """
118
+ tool_calls = []
119
+
120
+ # Method 1: Wrapped in XML
121
+ xml_pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
122
+ xml_matches = re.findall(xml_pattern, text, re.DOTALL)
123
+
124
+ for match in xml_matches:
125
+ try:
126
+ result = json.loads(match.strip())
127
+ if isinstance(result, dict) and "name" in result and "arguments" in result:
128
+ print(f"Found valid XML tool call: {result}")
129
+ tool_calls.append(result)
130
+ except json.JSONDecodeError as e:
131
+ print(f"Failed to parse XML tool call JSON: {e}")
132
+ continue
133
+
134
+ # Method 2: If no XML format found, try raw JSON format
135
+ if not tool_calls:
136
+ decoder = JSONDecoder()
137
+ brace_start = 0
138
+
139
+ while True:
140
+ match = text.find('{', brace_start)
141
+ if match == -1:
142
+ break
143
+ try:
144
+ result, index = decoder.raw_decode(text[match:])
145
+ if isinstance(result, dict) and "name" in result and "arguments" in result:
146
+ print(f"Found valid raw JSON tool call: {result}")
147
+ tool_calls.append(result)
148
+ brace_start = match + index
149
+ else:
150
+ brace_start = match + 1
151
+ except ValueError:
152
+ brace_start = match + 1
153
+
154
+ print(f"Total tool calls extracted: {len(tool_calls)}")
155
+ return tool_calls
156
+
157
+ def handle_tool_call(tool_name: str, tool_arguments: Dict[str, Any], session_id: str) -> Union[str, Dict[str, Any]]:
158
+ """
159
+ Handle tool call execution from tool name and arguments.
160
+
161
+ Args:
162
+ tool_name (str): The name of the tool to be executed.
163
+ tool_arguments (Dict[str, Any]): A dictionary of arguments for the tool.
164
+ session_id: Unique session identifier for this user
165
+
166
+ Returns:
167
+ Either error message (str) or tool execution result (dict)
168
+ """
169
+ print(f"Tool Call Detected:")
170
+ print(f"Tool Name: {tool_name}")
171
+ print(f"Arguments: {tool_arguments}")
172
+
173
+ if tool_name == "run_deepforest_object_detection":
174
+ try:
175
+ result = run_deepforest_object_detection(session_id=session_id, **tool_arguments)
176
+
177
+ return result
178
+
179
+ except Exception as e:
180
+ error_msg = f"Error executing {tool_name} in session {session_id}: {str(e)}"
181
+ print(f"Tool Execution Failed: {error_msg}")
182
+
183
+ return error_msg
184
+ else:
185
+ error_msg = f"Unknown tool: {tool_name}"
186
+ print(f"Unknown Tool: {error_msg}")
187
+
188
+ return error_msg
src/deepforest_agent/utils/__init__.py ADDED
File without changes
src/deepforest_agent/utils/cache_utils.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import time
4
+ import tempfile
5
+ import os
6
+ from typing import Dict, Any, Optional, List
7
+ from PIL import Image
8
+ import pickle
9
+ import gzip
10
+ import base64
11
+
12
+ from deepforest_agent.conf.config import Config
13
+ from deepforest_agent.utils.image_utils import convert_pil_image_to_bytes
14
+
15
+ class ToolCallCache:
16
+ """
17
+ Cache utility with data handling and efficient image storage.
18
+ """
19
+
20
+ def __init__(self, cache_dir: Optional[str] = None):
21
+ """
22
+ Initialize the tool call cache with data handling.
23
+
24
+ Args:
25
+ cache_dir: Directory to store cached images. If None, uses system temp directory.
26
+ """
27
+ self.cache_data = {}
28
+
29
+ if cache_dir is None:
30
+ self.cache_dir = os.path.join(tempfile.gettempdir(), "deepforest_cache")
31
+ else:
32
+ self.cache_dir = cache_dir
33
+
34
+ os.makedirs(self.cache_dir, exist_ok=True)
35
+ print(f"Cache directory: {self.cache_dir}")
36
+
37
+ def _normalize_arguments(self, arguments: Dict[str, Any]) -> str:
38
+ """
39
+ Normalize tool arguments to create a consistent cache key.
40
+
41
+ Args:
42
+ arguments: Tool arguments to normalize
43
+
44
+ Returns:
45
+ Normalized JSON string of arguments sorted by key
46
+ """
47
+ normalized_args = Config.DEEPFOREST_DEFAULTS.copy()
48
+ normalized_args.update(arguments)
49
+ if "model_names" in arguments:
50
+ normalized_args["model_names"] = arguments["model_names"]
51
+
52
+ print(f"Cache normalization: {arguments} -> {normalized_args}")
53
+ return json.dumps(normalized_args, sort_keys=True, separators=(',', ':'))
54
+
55
+ def _create_cache_key(self, tool_name: str, arguments: Dict[str, Any]) -> str:
56
+ """
57
+ Create a unique cache key from tool name and arguments.
58
+
59
+ Args:
60
+ tool_name: Name of the tool being called
61
+ arguments: Arguments passed to the tool
62
+
63
+ Returns:
64
+ MD5 hash that uniquely identifies this tool call
65
+ """
66
+ cache_input = f"{tool_name}:{self._normalize_arguments(arguments)}"
67
+ return hashlib.md5(cache_input.encode('utf-8')).hexdigest()
68
+
69
+ def _store_image(self, image: Image.Image, cache_key: str) -> str:
70
+ """
71
+ Store PIL Image while preserving original characteristics.
72
+
73
+ Args:
74
+ image: PIL Image to store
75
+ cache_key: Unique identifier for this cache entry
76
+
77
+ Returns:
78
+ File path where the image was stored
79
+ """
80
+ if image is None:
81
+ return None
82
+
83
+ image_filename = f"cached_image_{cache_key}.pkl.gz"
84
+ image_path = os.path.join(self.cache_dir, image_filename)
85
+
86
+ try:
87
+ # Pickle for exact PIL Image preservation, compressed with gzip
88
+ with gzip.open(image_path, 'wb') as f:
89
+ pickle.dump(image, f, protocol=pickle.HIGHEST_PROTOCOL)
90
+
91
+ file_size_mb = os.path.getsize(image_path) / (1024 * 1024)
92
+ print(f"Image cached to {image_path} ({file_size_mb:.2f} MB)")
93
+
94
+ return image_path
95
+
96
+ except Exception as e:
97
+ print(f"Error storing image efficiently: {e}")
98
+ return self._fallback_image_storage(image)
99
+
100
+ def _load_image(self, image_path: str) -> Optional[Image.Image]:
101
+ """
102
+ Load PIL Image from storage.
103
+
104
+ Args:
105
+ image_path: File path where image was stored
106
+
107
+ Returns:
108
+ Reconstructed PIL Image, or None if loading fails
109
+ """
110
+ if not image_path or not os.path.exists(image_path):
111
+ return None
112
+
113
+ try:
114
+ with gzip.open(image_path, 'rb') as f:
115
+ image = pickle.load(f)
116
+
117
+ print(f"Image loaded from cache: {image_path}")
118
+ return image
119
+
120
+ except Exception as e:
121
+ print(f"Error loading cached image: {e}")
122
+ return None
123
+
124
+ def _fallback_image_storage(self, image: Image.Image) -> str:
125
+ """
126
+ Fallback method for image storage when storage fails.
127
+
128
+ Args:
129
+ image: PIL Image to store
130
+
131
+ Returns:
132
+ Base64 encoded string of the image
133
+ """
134
+ img_bytes = convert_pil_image_to_bytes(image)
135
+
136
+ return base64.b64encode(img_bytes).decode('utf-8')
137
+
138
+ def get_cached_result(self, tool_name: str, arguments: Dict[str, Any]) -> Optional[Dict[str, Any]]:
139
+ """
140
+ Retrieve cached result with data handling.
141
+
142
+ Args:
143
+ tool_name: Name of the tool being called
144
+ arguments: Arguments for the tool call
145
+
146
+ Returns:
147
+ Dictionary containing all cached data or None if not found
148
+ """
149
+ cache_key = self._create_cache_key(tool_name, arguments)
150
+
151
+ if cache_key not in self.cache_data:
152
+ print(f"Cache MISS: No cached result for {tool_name} with key {cache_key}")
153
+ return None
154
+
155
+ cached_entry = self.cache_data[cache_key]
156
+ cached_result = {}
157
+
158
+ if "detection_summary" in cached_entry["result"]:
159
+ cached_result["detection_summary"] = cached_entry["result"]["detection_summary"]
160
+ print(f"Cache: Retrieved detection_summary: {cached_result['detection_summary']}")
161
+
162
+ if "detections_list" in cached_entry["result"]:
163
+ cached_result["detections_list"] = cached_entry["result"]["detections_list"]
164
+ print(f"Cache: Retrieved {len(cached_result['detections_list'])} detections")
165
+
166
+ if "total_detections" in cached_entry["result"]:
167
+ cached_result["total_detections"] = cached_entry["result"]["total_detections"]
168
+
169
+ if "status" in cached_entry["result"]:
170
+ cached_result["status"] = cached_entry["result"]["status"]
171
+
172
+ if "annotated_image_path" in cached_entry["result"]:
173
+ cached_result["annotated_image"] = self._load_image(
174
+ cached_entry["result"]["annotated_image_path"]
175
+ )
176
+ if cached_result["annotated_image"]:
177
+ print(f"Cache: Retrieved annotated image ({cached_result['annotated_image'].size})")
178
+
179
+ cached_result["cache_info"] = {
180
+ "cached_at": cached_entry["timestamp"],
181
+ "cache_hit": True,
182
+ "cache_key": cache_key,
183
+ "tool_name": tool_name,
184
+ "arguments": arguments
185
+ }
186
+
187
+ print(f"Successfully retrieved all data for {tool_name}")
188
+ return cached_result
189
+
190
+ def store_result(self, tool_name: str, arguments: Dict[str, Any], result: Dict[str, Any]) -> str:
191
+ """
192
+ Store tool call result with data handling.
193
+
194
+ Args:
195
+ tool_name: Name of the tool that was executed
196
+ arguments: Arguments that were passed to the tool
197
+ result: Result dictionary containing:
198
+ - detection_summary (str): Text summary of what was detected
199
+ - detections_list (List): List of detection objects
200
+ - total_detections (int): Count of detections
201
+ - status (str): Success/error status
202
+ - annotated_image (PIL.Image, optional): Image with annotations
203
+
204
+ Returns:
205
+ Cache key that was used to store this result
206
+ """
207
+ cache_key = self._create_cache_key(tool_name, arguments)
208
+
209
+ storable_result = {}
210
+
211
+ if "detection_summary" in result:
212
+ storable_result["detection_summary"] = result["detection_summary"]
213
+ print(f"Detection_summary = {result['detection_summary']}")
214
+ else:
215
+ print("No detection_summary found in result to cache")
216
+
217
+ if "detections_list" in result:
218
+ storable_result["detections_list"] = result["detections_list"]
219
+ print(f"Detections_list with {len(result['detections_list'])} items")
220
+ else:
221
+ print("No detections_list found in result to cache")
222
+ storable_result["detections_list"] = []
223
+
224
+ if "total_detections" in result:
225
+ storable_result["total_detections"] = result["total_detections"]
226
+ else:
227
+ storable_result["total_detections"] = len(storable_result["detections_list"])
228
+
229
+ if "status" in result:
230
+ storable_result["status"] = result["status"]
231
+ else:
232
+ storable_result["status"] = "unknown"
233
+
234
+ if "annotated_image" in result and result["annotated_image"] is not None:
235
+ image_path = self._store_image(result["annotated_image"], cache_key)
236
+ if image_path:
237
+ storable_result["annotated_image_path"] = image_path
238
+ print(f"Annotated_image stored efficiently")
239
+ else:
240
+ print("No annotated_image to store")
241
+
242
+ self.cache_data[cache_key] = {
243
+ "tool_name": tool_name,
244
+ "arguments": arguments.copy(),
245
+ "result": storable_result,
246
+ "timestamp": time.time(),
247
+ "cache_key": cache_key
248
+ }
249
+
250
+ print(f"Successfully cached all data for {tool_name} with key {cache_key}")
251
+ return cache_key
252
+
253
+ def get_cache_stats(self) -> Dict[str, Any]:
254
+ """
255
+ Get detailed statistics about cached data.
256
+
257
+ Returns:
258
+ Dictionary with comprehensive cache statistics
259
+ """
260
+ total_images = 0
261
+ total_detections = 0
262
+ cache_size_mb = 0
263
+
264
+ for entry in self.cache_data.values():
265
+ result = entry["result"]
266
+
267
+ if "annotated_image_path" in result:
268
+ total_images += 1
269
+ # Calculate file size if image exists
270
+ if os.path.exists(result["annotated_image_path"]):
271
+ cache_size_mb += os.path.getsize(result["annotated_image_path"]) / (1024 * 1024)
272
+
273
+ # Count total detections across all cached results
274
+ total_detections += result.get("total_detections", 0)
275
+
276
+ return {
277
+ "total_entries": len(self.cache_data),
278
+ "total_images_cached": total_images,
279
+ "total_detections_cached": total_detections,
280
+ "cache_size_mb": round(cache_size_mb, 2),
281
+ "cache_directory": self.cache_dir,
282
+ "tools_cached": set(entry["tool_name"] for entry in self.cache_data.values())
283
+ }
284
+
285
+ def cleanup_cache_files(self):
286
+ """
287
+ Clean up cached image files from disk.
288
+
289
+ Returns:
290
+ The total number of files that were successfully removed.
291
+ """
292
+ files_removed = 0
293
+ for entry in self.cache_data.values():
294
+ if "annotated_image_path" in entry["result"]:
295
+ image_path = entry["result"]["annotated_image_path"]
296
+ if os.path.exists(image_path):
297
+ try:
298
+ os.remove(image_path)
299
+ files_removed += 1
300
+ except Exception as e:
301
+ print(f"Error removing cached image {image_path}: {e}")
302
+
303
+ print(f"Cleaned up {files_removed} cached image files")
304
+ return files_removed
305
+
306
+ tool_call_cache = ToolCallCache()
src/deepforest_agent/utils/detection_narrative_generator.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Dict, Any
3
+ from collections import Counter, defaultdict
4
+
5
+ from deepforest_agent.utils.rtree_spatial_utils import DetectionSpatialAnalyzer
6
+
7
+
8
+ class DetectionNarrativeGenerator:
9
+ """
10
+ Generates natural language narratives from DeepForest detection results with proper classification handling.
11
+ """
12
+
13
+ def __init__(self, image_width: int, image_height: int):
14
+ """
15
+ Initialize narrative generator with image dimensions.
16
+
17
+ Args:
18
+ image_width: Width of the image in pixels
19
+ image_height: Height of the image in pixels
20
+ """
21
+ self.image_width = image_width
22
+ self.image_height = image_height
23
+ self.spatial_analyzer = DetectionSpatialAnalyzer(image_width, image_height)
24
+
25
+ def generate_comprehensive_narrative(self, detections_list: List[Dict[str, Any]]) -> str:
26
+ """
27
+ Generate comprehensive detection narrative using spatial analysis with proper classification handling.
28
+
29
+ Args:
30
+ detections_list: List of detection dictionaries from DeepForest
31
+
32
+ Returns:
33
+ Natural language narrative describing all aspects of detections
34
+ """
35
+ if not detections_list:
36
+ return "No objects were detected by DeepForest in this image."
37
+
38
+ # Add detections to spatial analyzer
39
+ self.spatial_analyzer.add_detections(detections_list)
40
+
41
+ # Get comprehensive statistics
42
+ stats = self.spatial_analyzer.get_detection_statistics()
43
+ grid_analysis = self.spatial_analyzer.get_grid_analysis()
44
+
45
+ narrative_parts = []
46
+
47
+ # 1. Overall Summary with proper classification handling
48
+ narrative_parts.append(self._generate_overall_summary(detections_list))
49
+
50
+ # 2. Confidence Analysis
51
+ narrative_parts.append(self._generate_confidence_analysis(detections_list))
52
+
53
+ # 3. Spatial Distribution Analysis
54
+ narrative_parts.append(self._generate_spatial_distribution_narrative(grid_analysis, detections_list))
55
+
56
+ # 4. Spatial Relationships Analysis using R-tree indexing
57
+ narrative_parts.append(self._generate_spatial_relationships_narrative(detections_list))
58
+
59
+ # 5. Object Coverage Analysis
60
+ narrative_parts.append(self._generate_coverage_analysis(detections_list))
61
+
62
+ return "\n\n".join(narrative_parts)
63
+
64
+ def _generate_overall_summary(self, detections_list: List[Dict[str, Any]]) -> str:
65
+ """
66
+ Generate overall detection summary with proper classification handling.
67
+
68
+ Args:
69
+ detections_list (List[Dict[str, Any]]): List of all detection results
70
+
71
+ Returns:
72
+ str: Formatted summary section including:
73
+ - Total detection count and average confidence
74
+ - Base object counts (birds, trees, livestock)
75
+ - Tree classification breakdown (alive trees, dead trees)
76
+ """
77
+ total_count = len(detections_list)
78
+
79
+ # Calculate overall confidence
80
+ scores = [s for s in (d.get("score") for d in detections_list) if s is not None and np.isfinite(s)]
81
+ overall_confidence = float(np.mean(scores)) if scores else 0.0
82
+
83
+ # Proper object counting with classification handling
84
+ base_label_counts = {} # bird, tree, livestock
85
+ classification_counts = {} # alive_tree, dead_tree
86
+
87
+ for detection in detections_list:
88
+ base_label = detection.get('label', 'unknown')
89
+ base_label_counts[base_label] = base_label_counts.get(base_label, 0) + 1
90
+
91
+ # Handle tree classifications
92
+ if base_label == 'tree':
93
+ classification_label = detection.get('classification_label')
94
+ classification_score = detection.get('classification_score')
95
+
96
+ # Only count valid classifications (not NaN or None)
97
+ if (classification_label and
98
+ classification_score is not None and
99
+ str(classification_label).lower() != 'nan' and
100
+ str(classification_score).lower() != 'nan'):
101
+
102
+ classification_counts[classification_label] = classification_counts.get(classification_label, 0) + 1
103
+
104
+ summary = f"**Overall Detection Summary**\n"
105
+ summary += f"In the whole image, {total_count} objects were detected with an average confidence of {overall_confidence:.3f}.\n\n"
106
+
107
+ # Object breakdown with proper classification display
108
+ object_parts = []
109
+ for label, count in base_label_counts.items():
110
+ label_name = label.replace('_', ' ')
111
+
112
+ if label == 'tree' and classification_counts:
113
+ # Special handling for trees with classifications
114
+ total_trees = count
115
+ classified_trees = sum(classification_counts.values())
116
+
117
+ if classified_trees > 0:
118
+ tree_part = f"{total_trees} trees are detected"
119
+ classification_parts = []
120
+ for class_label, class_count in classification_counts.items():
121
+ class_name = class_label.replace('_', ' ')
122
+ classification_parts.append(f"{class_count} {class_name}s")
123
+
124
+ tree_part += f". These {total_trees} trees are classified as {' and '.join(classification_parts)}"
125
+ object_parts.append(tree_part)
126
+ else:
127
+ object_parts.append(f"{count} {label_name}{'s' if count != 1 else ''}")
128
+ else:
129
+ object_parts.append(f"{count} {label_name}{'s' if count != 1 else ''}")
130
+
131
+ summary += "Whole image Object breakdown: " + ", ".join(object_parts) + "."
132
+
133
+ return summary
134
+
135
+ def _generate_confidence_analysis(self, detections_list: List[Dict[str, Any]]) -> str:
136
+ """
137
+ Generate confidence-based analysis with proper classification handling.
138
+
139
+ Args:
140
+ detections_list (List[Dict[str, Any]]): List of all detection results
141
+
142
+ Returns:
143
+ str: Formatted confidence analysis section including:
144
+ - Object counts per confidence range
145
+ - Base object type breakdown within each range
146
+ - Tree classification details (alive/dead) within each range
147
+ """
148
+ # Group by confidence ranges
149
+ confidence_groups = {
150
+ "Detections with High Confidence Score (0.7-1.0)": [],
151
+ "Detections with Medium Confidence Score (0.3-0.7)": [],
152
+ "Detections with Low Confidence Score (0.0-0.3)": []
153
+ }
154
+
155
+ for detection in detections_list:
156
+ score = detection.get('score', 0.0)
157
+ if score >= 0.7:
158
+ confidence_groups["Detections with High Confidence Score (0.7-1.0)"].append(detection)
159
+ elif score >= 0.3:
160
+ confidence_groups["Detections with Medium Confidence Score (0.3-0.7)"].append(detection)
161
+ else:
162
+ confidence_groups["Detections with Low Confidence Score (0.0-0.3)"].append(detection)
163
+
164
+ narrative = f"**Whole image Confidence Score Analysis**\n"
165
+
166
+ for conf_range, detections in confidence_groups.items():
167
+ if not detections:
168
+ narrative += f"{conf_range}: No objects detected\n"
169
+ continue
170
+
171
+ count = len(detections)
172
+ narrative += f"{conf_range}: {count} objects detected in the whole image\n"
173
+
174
+ # Count by base labels and classifications
175
+ base_counts = {}
176
+ class_counts = {}
177
+
178
+ for detection in detections:
179
+ base_label = detection.get('label', 'unknown')
180
+ base_counts[base_label] = base_counts.get(base_label, 0) + 1
181
+
182
+ if base_label == 'tree':
183
+ classification_label = detection.get('classification_label')
184
+ if (classification_label and
185
+ str(classification_label).lower() != 'nan'):
186
+ class_counts[classification_label] = class_counts.get(classification_label, 0) + 1
187
+
188
+ # Display breakdown
189
+ breakdown_parts = []
190
+ for label, label_count in base_counts.items():
191
+ if label == 'tree' and class_counts:
192
+ tree_part = f"{label_count} trees"
193
+ class_parts = []
194
+ for class_label, class_count in class_counts.items():
195
+ class_name = class_label.replace('_', ' ')
196
+ class_parts.append(f"{class_count} {class_name}s")
197
+ if class_parts:
198
+ tree_part += f" ({', '.join(class_parts)})"
199
+ breakdown_parts.append(tree_part)
200
+ else:
201
+ label_name = label.replace('_', ' ')
202
+ breakdown_parts.append(f"{label_count} {label_name}{'s' if label_count != 1 else ''}")
203
+
204
+ narrative += f" - {', '.join(breakdown_parts)}\n"
205
+
206
+ return narrative
207
+
208
+ def _generate_spatial_distribution_narrative(self, grid_analysis: Dict[str, Dict[str, Any]], detections_list: List[Dict[str, Any]]) -> str:
209
+ """
210
+ Generate spatial distribution narrative using 9-grid analysis
211
+
212
+ Args:
213
+ grid_analysis (Dict[str, Dict[str, Any]]): Pre-computed grid analysis from spatial_analyzer
214
+ containing detection counts and confidence analysis for each grid section
215
+ detections_list (List[Dict[str, Any]]): Original detection list for additional processing
216
+
217
+ Returns:
218
+ str: Formatted spatial distribution section including:
219
+ - Grid-by-grid object analysis with confidence breakdowns
220
+ - Tree classification details within each grid section
221
+ - Density pattern identification (dense vs sparse regions)
222
+ """
223
+ narrative = f"**Spatial Distribution Analysis**\n"
224
+ narrative += f"The image is divided into nine grid sections for spatial analysis:\n\n"
225
+
226
+ # Grid-by-grid analysis
227
+ for grid_name, grid_data in grid_analysis.items():
228
+ total_dets = grid_data['total_detections']
229
+ conf_analysis = grid_data['confidence_analysis']
230
+
231
+ if total_dets == 0:
232
+ narrative += f"{grid_name}: No objects detected\n"
233
+ continue
234
+
235
+ narrative += f"{grid_name}: {total_dets} objects detected\n"
236
+
237
+ # Per confidence category analysis
238
+ for conf_category, conf_data in conf_analysis.items():
239
+ if conf_data['count'] > 0:
240
+ # Count base labels and classifications for this grid/confidence
241
+ grid_detections = [d for d in detections_list
242
+ if self._detection_in_grid(d, grid_data['bounds'])]
243
+
244
+ conf_range = self._get_confidence_range(conf_category)
245
+ conf_detections = [d for d in grid_detections
246
+ if conf_range[0] <= d.get('score', 0) < conf_range[1] or
247
+ (conf_range[1] == 1.0 and d.get('score', 0) == 1.0)]
248
+
249
+ base_counts, class_counts = self._count_labels_with_classification(conf_detections)
250
+
251
+ # Display object breakdown
252
+ object_desc = []
253
+ for label, count in base_counts.items():
254
+ if label == 'tree' and label in class_counts:
255
+ tree_desc = f"{count} trees"
256
+ if class_counts[label]:
257
+ class_parts = []
258
+ for class_label, class_count in class_counts[label].items():
259
+ class_name = class_label.replace('_', ' ')
260
+ class_parts.append(f"{class_count} {class_name}s")
261
+ tree_desc += f" ({', '.join(class_parts)})"
262
+ object_desc.append(tree_desc)
263
+ else:
264
+ label_name = label.replace('_', ' ')
265
+ object_desc.append(f"{count} {label_name}{'s' if count != 1 else ''}")
266
+
267
+ # Simple description
268
+ narrative += f" - {conf_category}: {', '.join(object_desc)}\n"
269
+
270
+ narrative += "\n"
271
+
272
+ # Overall density patterns
273
+ grid_counts = {name: data['total_detections'] for name, data in grid_analysis.items()}
274
+ avg_count = sum(grid_counts.values()) / len(grid_counts) if grid_counts else 0
275
+
276
+ dense_regions = [name for name, count in grid_counts.items() if count > avg_count]
277
+ sparse_regions = [name for name, count in grid_counts.items() if count < avg_count]
278
+
279
+ if dense_regions or sparse_regions:
280
+ narrative += "**Density Patterns:**\n"
281
+ if dense_regions:
282
+ narrative += f"Dense regions: {', '.join(dense_regions)}\n"
283
+ if sparse_regions:
284
+ narrative += f"Sparse regions: {', '.join(sparse_regions)}\n"
285
+
286
+ return narrative
287
+
288
+ def _generate_coverage_analysis(self, detections_list: List[Dict[str, Any]]) -> str:
289
+ """
290
+ Generate object coverage analysis broken down by object type.
291
+
292
+ Args:
293
+ detections_list (List[Dict[str, Any]]): List of all detection results
294
+
295
+ Returns:
296
+ str: Formatted coverage analysis including:
297
+ - Percentage coverage for each object type (birds, trees, livestock)
298
+ - Tree classification coverage breakdown (alive trees vs dead trees)
299
+ - Total area calculations relative to full image
300
+ """
301
+ narrative = f"**Object Coverage Analysis**\n"
302
+
303
+ total_image_area = self.image_width * self.image_height
304
+
305
+ # Calculate coverage by object type
306
+ base_coverage = {}
307
+ classification_coverage = {}
308
+
309
+ for detection in detections_list:
310
+ width = detection.get('xmax', 0) - detection.get('xmin', 0)
311
+ height = detection.get('ymax', 0) - detection.get('ymin', 0)
312
+ area = width * height
313
+
314
+ base_label = detection.get('label', 'unknown')
315
+ base_coverage[base_label] = base_coverage.get(base_label, 0) + area
316
+
317
+ # Handle tree classifications
318
+ if base_label == 'tree':
319
+ classification_label = detection.get('classification_label')
320
+ if (classification_label and
321
+ str(classification_label).lower() != 'nan'):
322
+ classification_coverage[classification_label] = classification_coverage.get(classification_label, 0) + area
323
+
324
+ # Display coverage percentages
325
+ coverage_parts = []
326
+ for label, area in base_coverage.items():
327
+ coverage_percent = (area / total_image_area) * 100
328
+
329
+ if label == 'tree' and classification_coverage:
330
+ # Show tree breakdown
331
+ tree_coverage = f"{label}s: {coverage_percent:.2f}%"
332
+
333
+ class_parts = []
334
+ for class_label, class_area in classification_coverage.items():
335
+ class_percent = (class_area / total_image_area) * 100
336
+ class_name = class_label.replace('_', ' ')
337
+ class_parts.append(f"{class_name}s: {class_percent:.2f}%")
338
+
339
+ if class_parts:
340
+ tree_coverage += f" ({', '.join(class_parts)})"
341
+ coverage_parts.append(tree_coverage)
342
+ else:
343
+ label_name = label.replace('_', ' ')
344
+ coverage_parts.append(f"{label_name}s: {coverage_percent:.2f}%")
345
+
346
+ narrative += ", ".join(coverage_parts) + " of the total image area."
347
+
348
+ return narrative
349
+
350
+ def _generate_spatial_relationships_narrative(self, detections_list: List[Dict[str, Any]]) -> str:
351
+ """
352
+ Generate spatial relationships narrative using R-tree indexing.
353
+
354
+ Args:
355
+ detections_list (List[Dict[str, Any]]): List of all detection results
356
+
357
+ Returns:
358
+ str: Formatted spatial relationships section including:
359
+ - Count of high-confidence objects analyzed
360
+ - R-tree based intersection and proximity analysis
361
+ - Natural language descriptions of object relationships
362
+ - Confidence threshold information (>= 0.3)
363
+ """
364
+ spatial_relationships = self.spatial_analyzer.analyze_spatial_relationships_with_indexing(confidence_threshold=0.3)
365
+
366
+ if not spatial_relationships:
367
+ return "**Spatial Relationships Analysis (Confidence ≥ 0.3)**\nNo objects with sufficient confidence found for spatial relationship analysis."
368
+
369
+ narrative = f"**Spatial Relationships Analysis in the whole image (Confidence ≥ 0.3)**\n"
370
+
371
+ # Generate narrative using the spatial analyzer
372
+ spatial_narrative = self.spatial_analyzer.generate_spatial_narrative(confidence_threshold=0.3)
373
+ narrative += spatial_narrative
374
+
375
+ return narrative
376
+
377
+ def _detection_in_grid(self, detection: Dict[str, Any], grid_bounds: Dict[str, float]) -> bool:
378
+ """
379
+ Check if detection overlaps with grid bounds.
380
+
381
+ Args:
382
+ detection (Dict[str, Any]): Detection dictionary with 'xmin', 'ymin', 'xmax', 'ymax' keys
383
+ grid_bounds (Dict[str, float]): Grid section bounds with 'x_min', 'y_min', 'x_max', 'y_max' keys
384
+
385
+ Returns:
386
+ bool: True if detection bounding box overlaps with grid bounds, False otherwise
387
+ """
388
+ det_xmin = detection.get('xmin', 0)
389
+ det_ymin = detection.get('ymin', 0)
390
+ det_xmax = detection.get('xmax', 0)
391
+ det_ymax = detection.get('ymax', 0)
392
+
393
+ return not (det_xmax <= grid_bounds['x_min'] or det_xmin >= grid_bounds['x_max'] or
394
+ det_ymax <= grid_bounds['y_min'] or det_ymin >= grid_bounds['y_max'])
395
+
396
+ def _get_confidence_range(self, conf_category: str) -> tuple:
397
+ """
398
+ Get confidence range tuple from category string.
399
+
400
+ Args:
401
+ conf_category (str): Category name containing "High", "Medium", or other confidence indicator
402
+
403
+ Returns:
404
+ tuple: (min_confidence, max_confidence) as floats
405
+ - High: (0.7, 1.0)
406
+ - Medium: (0.3, 0.7)
407
+ - Low/Other: (0.0, 0.3)
408
+ """
409
+ if "High" in conf_category:
410
+ return (0.7, 1.0)
411
+ elif "Medium" in conf_category:
412
+ return (0.3, 0.7)
413
+ else:
414
+ return (0.0, 0.3)
415
+
416
+ def _count_labels_with_classification(self, detections: List[Dict[str, Any]]) -> tuple:
417
+ """
418
+ Count base labels and classifications separately.
419
+
420
+ Args:
421
+ detections (List[Dict[str, Any]]): List of detection dictionaries
422
+
423
+ Returns:
424
+ tuple: (base_counts, class_counts) where:
425
+ - base_counts (Dict[str, int]): Count of each base object type
426
+ - class_counts (Dict[str, Dict[str, int]]): Nested count structure for
427
+ tree classifications under 'tree' key
428
+ """
429
+ base_counts = {}
430
+ class_counts = {}
431
+
432
+ for detection in detections:
433
+ base_label = detection.get('label', 'unknown')
434
+ base_counts[base_label] = base_counts.get(base_label, 0) + 1
435
+
436
+ if base_label == 'tree':
437
+ classification_label = detection.get('classification_label')
438
+ if (classification_label and
439
+ str(classification_label).lower() != 'nan'):
440
+
441
+ if base_label not in class_counts:
442
+ class_counts[base_label] = {}
443
+ class_counts[base_label][classification_label] = class_counts[base_label].get(classification_label, 0) + 1
444
+
445
+ return base_counts, class_counts
src/deepforest_agent/utils/image_utils.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ from typing import Dict, Any, List, Literal, Optional, Tuple
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+ import tempfile
9
+ import rasterio
10
+
11
+ from deepforest_agent.conf.config import Config
12
+
13
+ def load_image_as_np_array(image_path: str) -> np.ndarray:
14
+ """
15
+ Load an image from a file path as a NumPy array.
16
+
17
+ Args:
18
+ image_path: Path to the image file
19
+
20
+ Returns:
21
+ RGB image as numpy array, or None if not found
22
+
23
+ Raises:
24
+ FileNotFoundError: If image file is not found at any expected path
25
+ """
26
+ if not os.path.exists(image_path):
27
+ raise FileNotFoundError(
28
+ f"Image not found at any expected path: {image_path}"
29
+ )
30
+
31
+ img = Image.open(image_path)
32
+ if img.mode != 'RGB':
33
+ img = img.convert('RGB')
34
+ return np.array(img)
35
+
36
+
37
+ def load_pil_image_from_path(image_path: str) -> Optional[Image.Image]:
38
+ """
39
+ Load PIL Image from file path.
40
+
41
+ Args:
42
+ image_path: Path to the image file
43
+
44
+ Returns:
45
+ PIL Image object, or None if loading fails
46
+
47
+ Raises:
48
+ FileNotFoundError: If image file is not found
49
+ Exception: If image cannot be loaded or converted
50
+ """
51
+ if not os.path.exists(image_path):
52
+ raise FileNotFoundError(f"Image not found at path: {image_path}")
53
+
54
+ try:
55
+ img = Image.open(image_path)
56
+ if img.mode != 'RGB':
57
+ img = img.convert('RGB')
58
+ return img
59
+ except Exception as e:
60
+ print(f"Error loading PIL image from {image_path}: {e}")
61
+ return None
62
+
63
+
64
+ def create_temp_image_file(image_array: np.ndarray, suffix: str = ".png") -> str:
65
+ """
66
+ Create a temporary image file from numpy array.
67
+
68
+ Args:
69
+ image_array: Image as numpy array
70
+ suffix: File extension (default: ".png")
71
+
72
+ Returns:
73
+ Path to temporary file
74
+
75
+ Raises:
76
+ Exception: If temporary file creation fails
77
+ """
78
+ try:
79
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_file:
80
+ temp_file_path = tmp_file.name
81
+
82
+ pil_image = Image.fromarray(image_array)
83
+ pil_image.save(temp_file_path, format='PNG')
84
+
85
+ print(f"Created temporary image file: {temp_file_path}")
86
+ return temp_file_path
87
+
88
+ except Exception as e:
89
+ print(f"Error creating temporary image file: {e}")
90
+ raise e
91
+
92
+
93
+ def cleanup_temp_file(file_path: str) -> bool:
94
+ """
95
+ Clean up temporary file.
96
+
97
+ Args:
98
+ file_path: Path to file to remove
99
+
100
+ Returns:
101
+ True if successful, False otherwise
102
+ """
103
+ if file_path and os.path.exists(file_path):
104
+ try:
105
+ os.remove(file_path)
106
+ print(f"Cleaned up temporary file: {file_path}")
107
+ return True
108
+ except OSError as e:
109
+ print(f"Error cleaning up temporary file {file_path}: {e}")
110
+ return False
111
+ return False
112
+
113
+
114
+ def validate_image_path(image_path: str) -> bool:
115
+ """
116
+ Validate if image path exists and is a valid image file.
117
+
118
+ Args:
119
+ image_path: Path to validate
120
+
121
+ Returns:
122
+ True if valid image path, False otherwise
123
+ """
124
+ if not image_path or not os.path.exists(image_path):
125
+ return False
126
+
127
+ try:
128
+ with Image.open(image_path) as img:
129
+ img.verify()
130
+ return True
131
+ except Exception:
132
+ return False
133
+
134
+
135
+ def get_image_info(image_path: str) -> Optional[Dict[str, Any]]:
136
+ """
137
+ Get basic information about an image file.
138
+
139
+ Args:
140
+ image_path: Path to image file
141
+
142
+ Returns:
143
+ Dictionary with image info or None if error
144
+ """
145
+ try:
146
+ with Image.open(image_path) as img:
147
+ return {
148
+ "size": img.size,
149
+ "mode": img.mode,
150
+ "format": img.format,
151
+ "file_size_bytes": os.path.getsize(image_path)
152
+ }
153
+ except Exception as e:
154
+ print(f"Error getting image info for {image_path}: {e}")
155
+ return None
156
+
157
+
158
+ def encode_image_to_base64_url(image_array: np.ndarray, format: str = 'PNG',
159
+ quality: int = 80) -> Optional[str]:
160
+ """
161
+ Encode a NumPy image array to a base64 data URL.
162
+
163
+ Args:
164
+ image_array: Image as numpy array
165
+ format: Output format ('PNG' or 'JPEG')
166
+ quality: JPEG quality (only used for JPEG format)
167
+
168
+ Returns:
169
+ Base64 encoded data URL string, or None if encoding fails
170
+ """
171
+ if image_array is None:
172
+ return None
173
+
174
+ try:
175
+ pil_image = Image.fromarray(image_array)
176
+ if pil_image.mode == 'RGBA':
177
+ background = Image.new("RGB", pil_image.size, (255, 255, 255))
178
+ background.paste(pil_image, mask=pil_image.split()[3])
179
+ pil_image = background
180
+ elif pil_image.mode != 'RGB':
181
+ pil_image = pil_image.convert('RGB')
182
+
183
+ byte_arr = io.BytesIO()
184
+ if format.lower() == 'jpeg':
185
+ pil_image.save(byte_arr, format='JPEG', quality=quality)
186
+ elif format.lower() == 'png':
187
+ pil_image.save(byte_arr, format='PNG')
188
+ else:
189
+ raise ValueError(f"Unsupported format: {format}. Choose 'jpeg' or 'png'.")
190
+
191
+ encoded_string = base64.b64encode(byte_arr.getvalue()).decode('utf-8')
192
+ return f"data:image/{format.lower()};base64,{encoded_string}"
193
+ except Exception as e:
194
+ print(f"Error encoding image to base64: {e}")
195
+ return None
196
+
197
+
198
+ def convert_pil_image_to_bytes(image: Image.Image) -> bytes:
199
+ """
200
+ Convert a PIL Image to bytes in PNG format.
201
+
202
+ Args:
203
+ image: PIL Image object
204
+
205
+ Returns:
206
+ Image bytes in PNG format
207
+ """
208
+ img_byte_arr = io.BytesIO()
209
+
210
+ if image.mode != 'RGB':
211
+ image = image.convert('RGB')
212
+ image.save(img_byte_arr, format='PNG')
213
+ img_bytes = img_byte_arr.getvalue()
214
+
215
+ return img_bytes
216
+
217
+
218
+ def encode_pil_image_to_base64_url(image: Image.Image) -> str:
219
+ """
220
+ Encode a PIL Image directly to a base64 data URL.
221
+
222
+ Args:
223
+ image: PIL Image object
224
+
225
+ Returns:
226
+ Base64 encoded PNG data URL string
227
+ """
228
+ img_bytes = convert_pil_image_to_bytes(image)
229
+ img_str = base64.b64encode(img_bytes).decode()
230
+ data_url = f"data:image/png;base64,{img_str}"
231
+ return data_url
232
+
233
+
234
+ def decode_base64_to_pil_image(base64_data: str) -> Image.Image:
235
+ """
236
+ Decode base64 data to a PIL Image.
237
+
238
+ Handles both data URL format and raw base64 strings.
239
+
240
+ Args:
241
+ base64_data: Base64 encoded image data, either as data URL
242
+ (...) or raw base64 string
243
+
244
+ Returns:
245
+ PIL Image object
246
+
247
+ Raises:
248
+ ValueError: If base64 data is invalid or cannot be decoded
249
+ """
250
+ try:
251
+ if base64_data.startswith('data:image'):
252
+ # Extract base64 part after the comma
253
+ base64_string = base64_data.split(',')[1]
254
+ else:
255
+ # Raw base64 data
256
+ base64_string = base64_data
257
+
258
+ image_bytes = base64.b64decode(base64_string)
259
+ pil_image = Image.open(io.BytesIO(image_bytes))
260
+
261
+ return pil_image
262
+
263
+ except Exception as e:
264
+ raise ValueError(f"Failed to decode base64 data to PIL Image: {e}")
265
+
266
+
267
+ def decode_base64_url_to_np_array(image_url: str) -> Optional[np.ndarray]:
268
+ """
269
+ Decode a base64 data URL to a NumPy array.
270
+
271
+ Args:
272
+ image_url: Base64 data URL (...)
273
+
274
+ Returns:
275
+ RGB image as numpy array, or None if decoding fails
276
+ """
277
+ if not image_url.startswith('data:image'):
278
+ print(f"Invalid data URL format: {image_url[:50]}...")
279
+ return None
280
+
281
+ try:
282
+ pil_image = decode_base64_to_pil_image(image_url)
283
+
284
+ if pil_image.mode != 'RGB':
285
+ pil_image = pil_image.convert('RGB')
286
+
287
+ return np.array(pil_image)
288
+
289
+ except ValueError as e:
290
+ print(f"Error extracting image from data URL: {e}")
291
+ return None
292
+ except Exception as e:
293
+ print(f"Unexpected error processing image URL: {e}")
294
+ return None
295
+
296
+
297
+ def convert_rgb_to_bgr(image_array: np.ndarray) -> np.ndarray:
298
+ """
299
+ Convert an RGB NumPy image array to BGR format.
300
+
301
+ Args:
302
+ image_array: RGB image as numpy array
303
+
304
+ Returns:
305
+ BGR image as numpy array
306
+ """
307
+ if (image_array.ndim == 3 and image_array.shape[2] == 3 and
308
+ image_array.dtype == np.uint8):
309
+ return cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
310
+ return image_array
311
+
312
+
313
+ def convert_bgr_to_rgb(image_array: np.ndarray) -> np.ndarray:
314
+ """
315
+ Convert a BGR NumPy image array to RGB format.
316
+
317
+ Args:
318
+ image_array: BGR image as numpy array
319
+
320
+ Returns:
321
+ RGB image as numpy array
322
+ """
323
+ if (image_array.ndim == 3 and image_array.shape[2] == 3 and
324
+ image_array.dtype == np.uint8):
325
+ return cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)
326
+ return image_array
327
+
328
+ def check_image_resolution_for_deepforest(image_path: str, max_resolution_cm: float = 10.0) -> Dict[str, Any]:
329
+ """
330
+ Resolution check for DeepForest suitability.
331
+
332
+ For GeoTIFF files: Check if pixel resolution is <= 10cm
333
+ For other formats: Allow processing with warning
334
+
335
+ Args:
336
+ image_path: Path to the image file
337
+ max_resolution_cm: Maximum required resolution in cm/pixel (default: 10.0)
338
+
339
+ Returns:
340
+ Dict containing:
341
+ - is_suitable: bool - Whether resolution is suitable for DeepForest
342
+ - resolution_cm: float or None - Actual resolution in cm/pixel
343
+ - resolution_info: str - Resolution info
344
+ - is_georeferenced: bool - Whether image is a GeoTIFF
345
+ - warning: str or None - Warning message if any
346
+ """
347
+ try:
348
+ with rasterio.open(image_path) as src:
349
+ if src.crs is None:
350
+ return _non_geotiff_result(image_path, "No coordinate system found")
351
+ if src.crs.is_geographic:
352
+ return _non_geotiff_result(image_path, "Geographic coordinates detected")
353
+ transform = src.transform
354
+ if transform.is_identity:
355
+ return _non_geotiff_result(image_path, "No spatial transformation found")
356
+
357
+ # Calculate pixel size
358
+ pixel_width = abs(transform.a)
359
+ pixel_height = abs(transform.e)
360
+ pixel_size = max(pixel_width, pixel_height)
361
+
362
+ # Convert to centimeters based on CRS units
363
+ crs_units = src.crs.to_dict().get('units', '').lower()
364
+
365
+ if crs_units in ['m', 'metre', 'meter']:
366
+ resolution_cm = pixel_size * 100
367
+ elif 'foot' in crs_units or crs_units == 'ft':
368
+ resolution_cm = pixel_size * 30.48
369
+ else:
370
+ return {
371
+ "is_suitable": True,
372
+ "resolution_cm": None,
373
+ "resolution_info": f"Unknown units '{crs_units}' - proceeding optimistically",
374
+ "is_georeferenced": True,
375
+ "warning": f"Cannot determine pixel size units: {crs_units}"
376
+ }
377
+
378
+ is_suitable = resolution_cm <= max_resolution_cm
379
+
380
+ return {
381
+ "is_suitable": is_suitable,
382
+ "resolution_cm": resolution_cm,
383
+ "resolution_info": f"{resolution_cm:.1f} cm/pixel ({'suitable' if is_suitable else 'insufficient'} for DeepForest)",
384
+ "is_georeferenced": True,
385
+ "warning": None if is_suitable else f"Resolution {resolution_cm:.1f} cm/pixel exceeds {max_resolution_cm} cm/pixel threshold"
386
+ }
387
+
388
+ except rasterio.RasterioIOError:
389
+ return _non_geotiff_result(image_path, "Not a GeoTIFF file")
390
+ except Exception as e:
391
+ return _non_geotiff_result(image_path, f"Error reading file: {str(e)}")
392
+
393
+
394
+ def _non_geotiff_result(image_path: str, reason: str) -> Dict[str, Any]:
395
+ """
396
+ Helper function for non-GeoTIFF images to allow processing with warning.
397
+
398
+ Args:
399
+ image_path: Path to the image file
400
+ reason: Reason why it's not treated as GeoTIFF
401
+
402
+ Returns:
403
+ Dict with suitable=True but warning about using GeoTIFF
404
+ """
405
+ file_ext = os.path.splitext(image_path)[1].lower()
406
+
407
+ return {
408
+ "is_suitable": True,
409
+ "resolution_cm": None,
410
+ "resolution_info": f"Non-geospatial image ({file_ext}) - proceeding without resolution check",
411
+ "is_georeferenced": False,
412
+ "warning": f"For optimal DeepForest results, use GeoTIFF images with ≤10 cm/pixel resolution. Current: {reason.lower()}"
413
+ }
414
+
415
+ def determine_patch_size(image_file_path: str, image_dimensions: Optional[Tuple[int, int]] = None) -> int:
416
+ """
417
+ Determine patch size based on image file type and dimensions for OOM fallback strategy.
418
+
419
+ Args:
420
+ image_file_path: Path to the image file
421
+ image_dimensions: Optional tuple of (width, height) if known
422
+
423
+ Returns:
424
+ int: Patch size optimized for image type and size
425
+ """
426
+ # Get image dimensions if not provided
427
+ if image_dimensions is None:
428
+ try:
429
+ with Image.open(image_file_path) as img:
430
+ width, height = img.size
431
+ except Exception:
432
+ return Config.DEEPFOREST_DEFAULTS["patch_size"]
433
+ else:
434
+ width, height = image_dimensions
435
+
436
+ # Determine maximum dimension
437
+ max_dimension = max(width, height)
438
+
439
+ # For large dimensions, use larger patch sizes to handle OOM
440
+ if max_dimension > 7500:
441
+ return 2000
442
+ else:
443
+ return 1500
444
+
445
+ def get_image_dimensions_fast(image_path: str) -> Optional[Tuple[int, int]]:
446
+ """
447
+ Get image dimensions quickly without loading full image into memory.
448
+
449
+ Args:
450
+ image_path: Path to image file
451
+
452
+ Returns:
453
+ Tuple of (width, height) or None if cannot determine
454
+ """
455
+ try:
456
+ # Try with PIL first
457
+ with Image.open(image_path) as img:
458
+ return img.size
459
+ except Exception:
460
+ try:
461
+ # Fallback to rasterio for GeoTIFF files
462
+ with rasterio.open(image_path) as src:
463
+ return (src.width, src.height)
464
+ except Exception:
465
+ return None
src/deepforest_agent/utils/logging_utils.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from datetime import datetime, timezone
4
+ from typing import Dict, Any, Optional, List
5
+ from pathlib import Path
6
+ import threading
7
+ import json as json_module
8
+
9
+
10
+ class MultiAgentLogger:
11
+ """
12
+ Logging system for conversation-style logs.
13
+ """
14
+
15
+ def __init__(self, logs_dir: str = "logs"):
16
+ """
17
+ Initialize the multi-agent logger.
18
+
19
+ Args:
20
+ logs_dir: Directory to store log files
21
+ """
22
+ self.logs_dir = Path(logs_dir)
23
+ self.logs_dir.mkdir(exist_ok=True)
24
+ self._lock = threading.Lock()
25
+
26
+ print(f"Logging initialized. Logs directory: {self.logs_dir.absolute()}")
27
+
28
+ def _get_log_file_path(self, session_id: str) -> Path:
29
+ """
30
+ Get the log file path for a specific session.
31
+
32
+ Args:
33
+ session_id: Unique session identifier
34
+
35
+ Returns:
36
+ Path object for the session's log file
37
+ """
38
+ date_str = datetime.now().strftime("%Y%m%d")
39
+ filename = f"session_{session_id}_{date_str}.log"
40
+ return self.logs_dir / filename
41
+
42
+ def _write_log_entry(self, session_id: str, agent_name: str, content: str) -> None:
43
+ """
44
+ Write a log entry to the session's log file.
45
+
46
+ Args:
47
+ session_id: Session identifier
48
+ agent_name: Current agent in the process
49
+ content: Current agent response
50
+ """
51
+ with self._lock:
52
+ log_file_path = self._get_log_file_path(session_id)
53
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
54
+
55
+ try:
56
+ with open(log_file_path, 'a', encoding='utf-8') as f:
57
+ if agent_name == "SESSION_START":
58
+ f.write(f"=== SESSION {session_id} STARTED ===\n\n")
59
+ elif agent_name == "SESSION_EVENT":
60
+ f.write(f"{timestamp} - {content}\n\n")
61
+ else:
62
+ f.write(f"{timestamp} - {agent_name}: {content}\n\n")
63
+ f.flush()
64
+ except Exception as e:
65
+ print(f"Error writing to log file {log_file_path}: {e}")
66
+
67
+ def log_session_event(self, session_id: str, event_type: str, details: Optional[Dict[str, Any]] = None) -> None:
68
+ """
69
+ Log session lifecycle events (creation, image upload, clearing, etc.).
70
+
71
+ Args:
72
+ session_id: Session identifier
73
+ event_type: Type of session event
74
+ details: Additional event details
75
+ """
76
+ if event_type == "session_created":
77
+ self._write_log_entry(session_id, "SESSION_START", "")
78
+ if details:
79
+ image_size = details.get("image_size", "unknown")
80
+ image_mode = details.get("image_mode", "unknown")
81
+ self._write_log_entry(session_id, "SESSION_EVENT", f"Image uploaded: {image_size}, mode: {image_mode}")
82
+ else:
83
+ self._write_log_entry(session_id, "SESSION_EVENT", "Image uploaded: unknown")
84
+ elif event_type == "conversation_cleared":
85
+ self._write_log_entry(session_id, "SESSION_EVENT", "Conversation cleared")
86
+ elif event_type == "multi_agent_workflow_started":
87
+ self._write_log_entry(session_id, "SESSION_EVENT", "Multi-agent workflow started")
88
+
89
+ def log_user_query(self, session_id: str, user_message: str, message_context: Optional[Dict[str, Any]] = None) -> None:
90
+ """
91
+ Log user queries and context.
92
+
93
+ Args:
94
+ session_id: Session identifier
95
+ user_message: User's input message
96
+ message_context: Additional context (conversation length, etc.)
97
+ """
98
+ self._write_log_entry(session_id, "USER", user_message)
99
+
100
+ def log_agent_execution(
101
+ self,
102
+ session_id: str,
103
+ agent_name: str,
104
+ agent_input: str,
105
+ agent_output: str,
106
+ execution_time: float,
107
+ additional_data: Optional[Dict[str, Any]] = None
108
+ ) -> None:
109
+ """
110
+ Log individual agent execution details.
111
+
112
+ Args:
113
+ session_id: Session identifier
114
+ agent_name: Name of the agent (memory, detector, visual, ecology)
115
+ agent_input: Input provided to the agent
116
+ agent_output: Output generated by the agent
117
+ execution_time: Time taken for agent execution in seconds
118
+ additional_data: Agent-specific additional data
119
+ """
120
+
121
+ if agent_name == "memory":
122
+ formatted_name = "Memory Agent"
123
+ elif agent_name == "detector":
124
+ formatted_name = "DeepForest Detector Agent"
125
+ elif agent_name == "visual":
126
+ formatted_name = "Visual Agent"
127
+ elif agent_name == "ecology":
128
+ formatted_name = "Ecology Agent"
129
+ else:
130
+ formatted_name = agent_name.title()
131
+
132
+ formatted_name_with_time = f"{formatted_name} ({execution_time:.2f}s)"
133
+
134
+ content = agent_output
135
+ self._write_log_entry(session_id, formatted_name_with_time, content)
136
+
137
+ def log_tool_call(
138
+ self,
139
+ session_id: str,
140
+ tool_name: str,
141
+ tool_arguments: Dict[str, Any],
142
+ tool_result: Dict[str, Any],
143
+ execution_time: float,
144
+ cache_hit: bool,
145
+ reasoning: Optional[str] = None
146
+ ) -> None:
147
+ """
148
+ Log tool calls, their results, and cache information.
149
+
150
+ Args:
151
+ session_id: Session identifier
152
+ tool_name: Name of the tool that was called
153
+ tool_arguments: Arguments passed to the tool
154
+ tool_result: Result returned by the tool
155
+ execution_time: Time taken for tool execution
156
+ cache_hit: Whether this was served from cache
157
+ reasoning: AI's reasoning for this tool call
158
+ """
159
+ if cache_hit:
160
+ status = "Cache Hit (0.00s)"
161
+ else:
162
+ status = f"Cache Miss - Executed DeepForest detection ({execution_time:.2f}s)"
163
+
164
+ content = f"{status}\n"
165
+ content += f"Detection Summary: {tool_result.get('detection_summary', 'No summary')}\n"
166
+
167
+ detections = tool_result.get('detections_list', [])
168
+ if detections:
169
+ content += f"Detection Data: {detections}"
170
+
171
+ self._write_log_entry(session_id, "DeepForest Function execution", content)
172
+
173
+ def log_error(self, session_id: str, error_type: str, error_message: str, context: Optional[Dict[str, Any]] = None) -> None:
174
+ """
175
+ Log errors in simple format.
176
+
177
+ Args:
178
+ session_id: Session identifier
179
+ error_type: Type/category of error
180
+ error_message: Error message
181
+ context: Additional context about where the error occurred
182
+ """
183
+ self._write_log_entry(session_id, "ERROR", f"{error_type}: {error_message}")
184
+
185
+ def log_resolution_check(
186
+ self,
187
+ session_id: str,
188
+ image_file_path: str,
189
+ resolution_result: Dict[str, Any],
190
+ execution_time: float
191
+ ) -> None:
192
+ """
193
+ Log image resolution check results.
194
+
195
+ Args:
196
+ session_id: Session identifier
197
+ image_file_path: Path to the image that was checked
198
+ resolution_result: Results from simplified resolution check
199
+ execution_time: Time taken for resolution check
200
+ """
201
+ is_suitable = resolution_result.get("is_suitable", True)
202
+ resolution_info = resolution_result.get("resolution_info", "No resolution info")
203
+ is_georeferenced = resolution_result.get("is_georeferenced", False)
204
+ resolution_cm = resolution_result.get("resolution_cm")
205
+ warning = resolution_result.get("warning")
206
+
207
+ content = f"Image Resolution Check ({execution_time:.3f}s)\n"
208
+ content += f"File: {image_file_path}\n"
209
+ content += f"Result: {'Suitable' if is_suitable else 'Insufficient'} for DeepForest\n"
210
+ content += f"Details: {resolution_info}\n"
211
+ content += f"Type: {'GeoTIFF' if is_georeferenced else 'Regular image'}\n"
212
+
213
+ if resolution_cm is not None:
214
+ content += f"Resolution: {resolution_cm:.2f} cm/pixel\n"
215
+
216
+ if warning:
217
+ content += f"Warning: {warning}\n"
218
+
219
+ if not is_suitable:
220
+ content += "Impact: DeepForest detection will be skipped due to insufficient resolution"
221
+ elif warning:
222
+ content += "Impact: DeepForest detection will proceed with noted warning"
223
+ else:
224
+ content += "Impact: Resolution suitable for DeepForest detection"
225
+
226
+ self._write_log_entry(session_id, "Resolution Check", content)
227
+
228
+ def log_deepforest_skip(
229
+ self,
230
+ session_id: str,
231
+ skip_reasons: List[str],
232
+ resolution_result: Optional[Dict[str, Any]] = None,
233
+ visual_result: Optional[Dict[str, Any]] = None
234
+ ) -> None:
235
+ """
236
+ Log when DeepForest detection is skipped and why.
237
+
238
+ Args:
239
+ session_id: Session identifier
240
+ skip_reasons: List of reasons why DeepForest was skipped
241
+ resolution_result: Resolution check results (optional)
242
+ visual_result: Visual analysis results (optional)
243
+ """
244
+ content = "DeepForest Detection Skipped\n"
245
+ content += f"Reasons: {', '.join(skip_reasons)}\n"
246
+
247
+ # Add detailed reason breakdown
248
+ if "insufficient resolution" in ' '.join(skip_reasons).lower():
249
+ if resolution_result:
250
+ resolution_info = resolution_result.get("resolution_info", "No details")
251
+ content += f"Resolution Details: {resolution_info}\n"
252
+
253
+ if "poor image quality" in ' '.join(skip_reasons).lower():
254
+ if visual_result:
255
+ quality_assessment = visual_result.get("image_quality_for_deepforest", "Unknown")
256
+ content += f"Visual Quality Assessment: {quality_assessment}\n"
257
+
258
+ content += "Impact: Analysis will rely on visual analysis only"
259
+
260
+ self._write_log_entry(session_id, "DeepForest Skip Decision", content)
261
+
262
+ def log_tile_analysis(self, session_id: str, tile_id: int, result: Dict[str, Any], execution_time: float) -> None:
263
+ """
264
+ Log individual tile analysis results.
265
+
266
+ Args:
267
+ session_id: Session identifier
268
+ tile_id: Tile identifier
269
+ result: Tile analysis result
270
+ execution_time: Time taken for tile analysis
271
+ """
272
+ content = f"Tile {tile_id} Analysis ({execution_time:.2f}s)\n"
273
+
274
+ coordinates = result.get('coordinates', {})
275
+ content += f"Coordinates: x={coordinates.get('x', 0)}, y={coordinates.get('y', 0)}, "
276
+ content += f"width={coordinates.get('width', 0)}, height={coordinates.get('height', 0)}\n"
277
+
278
+ additional_objects = result.get('additional_objects', [])
279
+ if additional_objects:
280
+ content += f"Additional Objects: {len(additional_objects)} objects detected\n"
281
+ for obj in additional_objects:
282
+ label = obj.get('label', 'unknown')
283
+ bbox = obj.get('bbox', 'no coordinates')
284
+ content += f" - {label} at {bbox}\n"
285
+ else:
286
+ content += f"Additional Objects: None detected\n"
287
+
288
+ visual_analysis = result.get('visual_analysis', '')
289
+ if visual_analysis:
290
+ content += f"Visual Analysis: {visual_analysis}\n"
291
+
292
+ assigned_detections = result.get('assigned_detections', [])
293
+ content += f"Assigned DeepForest Detections: {len(assigned_detections)}\n"
294
+
295
+ if 'error' in result:
296
+ content += f"Error: {result['error']}\n"
297
+
298
+ self._write_log_entry(session_id, f"Tile {tile_id} Analysis", content)
299
+
300
+ def log_spatial_relationships(
301
+ self,
302
+ session_id: str,
303
+ spatial_relationships: List[Dict[str, Any]],
304
+ execution_time: float
305
+ ) -> None:
306
+ """Log spatial relationships analysis results.
307
+
308
+ Args:
309
+ session_id: The unique identifier for the current session.
310
+ spatial_relationships: A list of dictionaries, where each
311
+ dictionary contains details about an object's spatial
312
+ relationships, including its grid region and intersecting
313
+ objects.
314
+ execution_time: The time taken to perform the spatial
315
+ relationships analysis, in seconds.
316
+ """
317
+ relationships_count = len(spatial_relationships)
318
+ content = f"Spatial Relationships Analysis ({execution_time:.3f}s)\n"
319
+ content += f"Analyzed {relationships_count} objects with confidence ≥ 0.3\n"
320
+
321
+ # Group by regions
322
+ by_region = {}
323
+ for rel in spatial_relationships:
324
+ region = rel['grid_region']
325
+ by_region[region] = by_region.get(region, 0) + 1
326
+
327
+ content += f"Distribution by region: {dict(by_region)}\n"
328
+ content += f"Objects with neighbors: {sum(1 for r in spatial_relationships if r['intersecting_objects'])}\n"
329
+
330
+ self._write_log_entry(session_id, "Spatial Relationships Analysis", content)
331
+
332
+ def log_detection_narrative(
333
+ self,
334
+ session_id: str,
335
+ detection_narrative: str,
336
+ detections_count: int,
337
+ execution_time: float
338
+ ) -> None:
339
+ """Log detection narrative generation.
340
+
341
+ Args:
342
+ session_id: The unique identifier for the current session.
343
+ detection_narrative: The string containing the generated narrative.
344
+ detections_count: The total number of detections used to
345
+ generate the narrative.
346
+ execution_time: The time taken for narrative generation, in seconds.
347
+ """
348
+ narrative_length = len(detection_narrative)
349
+ content = f"Detection Narrative Generation ({execution_time:.3f}s)\n"
350
+ content += f"Generated narrative for {detections_count} detections\n"
351
+ content += f"Narrative length: {narrative_length} characters\n"
352
+ content += f"Narrative content:\n{detection_narrative}"
353
+
354
+ self._write_log_entry(session_id, "Detection Narrative", content)
355
+
356
+ def log_visual_analysis_unified(
357
+ self,
358
+ session_id: str,
359
+ analysis_type: str,
360
+ visual_analysis: str,
361
+ additional_objects_count: int,
362
+ execution_time: float
363
+ ) -> None:
364
+ """Log unified visual analysis results.
365
+
366
+ Args:
367
+ session_id: The unique identifier for the current session.
368
+ analysis_type: A string specifying the type of visual analysis
369
+ performed (e.g., 'segmentation', 'classification').
370
+ visual_analysis: The string containing the final analysis result.
371
+ additional_objects_count: The number of objects detected beyond
372
+ the initial set.
373
+ execution_time: The time taken for the visual analysis, in seconds.
374
+ """
375
+ content = f"Visual Analysis - {analysis_type} ({execution_time:.3f}s)\n"
376
+ content += f"Additional objects detected: {additional_objects_count}\n"
377
+ content += f"Analysis: {visual_analysis}"
378
+
379
+ self._write_log_entry(session_id, f"Visual Analysis ({analysis_type})", content)
380
+
381
+ def get_session_log_summary(self, session_id: str) -> Dict[str, Any]:
382
+ """
383
+ Get a summary of all logged events for a session.
384
+
385
+ Args:
386
+ session_id: Session identifier
387
+
388
+ Returns:
389
+ Dictionary containing session log summary
390
+ """
391
+ log_file_path = self._get_log_file_path(session_id)
392
+
393
+ if not log_file_path.exists():
394
+ return {"error": f"No log file found for session {session_id}"}
395
+
396
+ try:
397
+ with open(log_file_path, 'r', encoding='utf-8') as f:
398
+ content = f.read()
399
+
400
+ return {
401
+ "session_id": session_id,
402
+ "log_file": str(log_file_path),
403
+ "content_preview": content
404
+ }
405
+ except Exception as e:
406
+ return {"error": f"Error reading log file: {str(e)}"}
407
+
408
+ def get_all_session_logs(self) -> List[str]:
409
+ """
410
+ Get a list of all session IDs that have log files.
411
+
412
+ Returns:
413
+ List of session IDs with existing log files
414
+ """
415
+ session_ids = []
416
+
417
+ for log_file in self.logs_dir.glob("session_*.log"):
418
+ filename = log_file.stem
419
+ parts = filename.split("_")
420
+ if len(parts) >= 2:
421
+ session_id = parts[1]
422
+ session_ids.append(session_id)
423
+
424
+ return sorted(set(session_ids))
425
+
426
+ def cleanup_old_logs(self, days_to_keep: int = 7) -> int:
427
+ """
428
+ Clean up log files older than specified days.
429
+
430
+ Args:
431
+ days_to_keep: Number of days of logs to retain
432
+
433
+ Returns:
434
+ Number of log files deleted
435
+ """
436
+ cutoff_time = time.time() - (days_to_keep * 24 * 60 * 60)
437
+ deleted_count = 0
438
+
439
+ for log_file in self.logs_dir.glob("session_*.log"):
440
+ if log_file.stat().st_mtime < cutoff_time:
441
+ try:
442
+ log_file.unlink()
443
+ deleted_count += 1
444
+ except Exception as e:
445
+ print(f"Error deleting old log file {log_file}: {e}")
446
+
447
+ return deleted_count
448
+
449
+ multi_agent_logger = MultiAgentLogger()
src/deepforest_agent/utils/parsing_utils.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import Dict, List, Any, Optional
4
+
5
+
6
+ def parse_image_quality_for_deepforest(response: str) -> str:
7
+ """
8
+ Parse IMAGE_QUALITY_FOR_DEEPFOREST from response.
9
+
10
+ Args:
11
+ response: Model response text
12
+
13
+ Returns:
14
+ "Yes" or "No"
15
+ """
16
+ quality_match = re.search(r'(?:\*\*)?IMAGE_QUALITY_FOR_DEEPFOREST[:\*\s]+\[?(YES|NO|Yes|No|yes|no)\]?', response, re.IGNORECASE)
17
+ if quality_match:
18
+ quality_value = quality_match.group(1).upper()
19
+ return "Yes" if quality_value == "YES" else "No"
20
+ return "No"
21
+
22
+ def parse_deepforest_objects_present(response: str) -> List[str]:
23
+ """
24
+ Parse DEEPFOREST_OBJECTS_PRESENT from response.
25
+
26
+ Args:
27
+ response: Model response text
28
+
29
+ Returns:
30
+ List of objects present
31
+ """
32
+ objects_match = re.search(r'(?:\*\*)?DEEPFOREST_OBJECTS_PRESENT[:\*\s]+(\[.*?\])', response, re.DOTALL)
33
+ if objects_match:
34
+ try:
35
+ objects_str = objects_match.group(1)
36
+ objects_str = re.sub(r'[`\'"]', '"', objects_str)
37
+ objects_list = json.loads(objects_str)
38
+
39
+ allowed_objects = ["bird", "tree", "livestock"]
40
+ validated_objects = [obj for obj in objects_list if obj in allowed_objects]
41
+ return validated_objects
42
+ except json.JSONDecodeError:
43
+ objects_str = objects_match.group(1)
44
+ manual_objects = re.findall(r'"(bird|tree|livestock)"', objects_str)
45
+ return list(set(manual_objects))
46
+ return []
47
+
48
+
49
+ def parse_additional_objects_json(response: str) -> List[Dict[str, Any]]:
50
+ """
51
+ Parse ADDITIONAL_OBJECTS_JSON from response.
52
+
53
+ Args:
54
+ response: Model response text
55
+
56
+ Returns:
57
+ List of additional objects with coordinates
58
+ """
59
+ additional_match = re.search(r'(?:\*\*)?ADDITIONAL_OBJECTS_JSON[:\*\s]+(.*?)(?=\n(?:\*\*)?(?:VISUAL_ANALYSIS|IMAGE_QUALITY|DEEPFOREST_OBJECTS)|$)', response, re.DOTALL)
60
+ if additional_match:
61
+ try:
62
+ additional_str = additional_match.group(1).strip()
63
+ if additional_str.startswith('```json'):
64
+ additional_str = additional_str[7:]
65
+ if additional_str.startswith('```'):
66
+ additional_str = additional_str[3:]
67
+ if additional_str.endswith('```'):
68
+ additional_str = additional_str[:-3]
69
+
70
+ additional_str = additional_str.strip()
71
+
72
+ if additional_str.startswith('[') and additional_str.endswith(']'):
73
+ additional_objects = json.loads(additional_str)
74
+ if isinstance(additional_objects, list):
75
+ return additional_objects
76
+ else:
77
+ additional_objects = []
78
+ for line in additional_str.split('\n'):
79
+ line = line.strip().rstrip(',')
80
+ if line and line.startswith('{') and line.endswith('}'):
81
+ try:
82
+ obj = json.loads(line)
83
+ additional_objects.append(obj)
84
+ except json.JSONDecodeError:
85
+ continue
86
+ return additional_objects
87
+
88
+ except Exception as e:
89
+ print(f"Error parsing additional objects JSON: {e}")
90
+ return []
91
+
92
+
93
+ def parse_visual_analysis(response: str) -> str:
94
+ """
95
+ Parse VISUAL_ANALYSIS from response.
96
+
97
+ Args:
98
+ response: Model response text
99
+
100
+ Returns:
101
+ Visual analysis text
102
+ """
103
+ analysis_match = re.search(r'(?:\*\*)?VISUAL_ANALYSIS[:\*\s]+(.*?)(?=\n(?:\*\*)?(?:IMAGE_QUALITY|DEEPFOREST_OBJECTS|ADDITIONAL_OBJECTS)|$)', response, re.IGNORECASE | re.DOTALL)
104
+ if analysis_match:
105
+ return analysis_match.group(1).strip()
106
+ else:
107
+ fallback_match = re.search(r'(?:\*\*)?VISUAL_ANALYSIS[:\*\s]+(.*)', response, re.IGNORECASE | re.DOTALL)
108
+ if fallback_match:
109
+ return fallback_match.group(1).strip()
110
+ return response
111
+
112
+
113
+ def parse_deepforest_agent_response_with_reasoning(response: str) -> Dict[str, Any]:
114
+ """
115
+ Parse DeepForest detector agent response with reasoning.
116
+
117
+ Args:
118
+ response: Model response text
119
+
120
+ Returns:
121
+ Dictionary with reasoning and tool calls
122
+ """
123
+ from deepforest_agent.tools.tool_handler import extract_all_tool_calls
124
+
125
+ try:
126
+ tool_calls = extract_all_tool_calls(response)
127
+
128
+ if not tool_calls:
129
+ return {"error": "No valid tool calls found in response"}
130
+
131
+ reasoning_text = ""
132
+ first_json_match = re.search(r'\{[^}]*"name"[^}]*"arguments"[^}]*\}', response)
133
+
134
+ if first_json_match:
135
+ reasoning_text = response[:first_json_match.start()].strip()
136
+ reasoning_text = re.sub(r'^(REASONING:|Reasoning:|Analysis:|\*\*REASONING:\*\*)', '', reasoning_text).strip()
137
+
138
+ if not reasoning_text:
139
+ reasoning_text = "Tool calls generated based on analysis"
140
+
141
+ return {
142
+ "reasoning": reasoning_text,
143
+ "tool_calls": tool_calls
144
+ }
145
+
146
+ except Exception as e:
147
+ return {"error": f"Unexpected error parsing response: {str(e)}"}
148
+
149
+ def parse_memory_agent_response(response: str) -> Dict[str, Any]:
150
+ """
151
+ Parse memory agent structured response format with new TOOL_CACHE_ID field.
152
+
153
+ Args:
154
+ response: Model response text
155
+
156
+ Returns:
157
+ Dictionary with answer_present, direct_answer, tool_cache_id, and relevant_context
158
+ """
159
+ try:
160
+ # Parse ANSWER_PRESENT
161
+ answer_present_match = re.search(r'(?:\*\*)?ANSWER_PRESENT:(?:\*\*)?\s*\[?(YES|NO)\]?', response, re.IGNORECASE)
162
+ answer_present = False
163
+ if answer_present_match:
164
+ answer_present = answer_present_match.group(1).upper() == "YES"
165
+
166
+ # Parse TOOL_CACHE_ID
167
+ tool_cache_id_match = re.search(r'(?:\*\*)?TOOL_CACHE_ID:(?:\*\*)?\s*(.*?)(?=\n(?:\*\*)?(?:RELEVANT_CONTEXT|$))', response, re.IGNORECASE | re.DOTALL)
168
+ tool_cache_id = None
169
+
170
+ if tool_cache_id_match:
171
+ tool_cache_id_text = tool_cache_id_match.group(1).strip()
172
+
173
+ # Extract all cache IDs using multiple patterns
174
+ cache_ids = []
175
+
176
+ # Pattern 1: IDs within brackets [id1, id2, ...]
177
+ bracket_pattern = r'\[([^\[\]]*)\]'
178
+ bracket_matches = re.findall(bracket_pattern, tool_cache_id_text)
179
+ for bracket_content in bracket_matches:
180
+ if bracket_content.strip(): # Skip empty brackets
181
+ # Extract hex IDs from bracket content
182
+ hex_ids = re.findall(r'([a-fA-F0-9]{8,})', bracket_content)
183
+ cache_ids.extend(hex_ids)
184
+
185
+ # Pattern 2: Direct hex IDs (not in brackets)
186
+ # Remove bracketed content first, then find remaining hex IDs
187
+ text_without_brackets = re.sub(r'\[[^\[\]]*\]', '', tool_cache_id_text)
188
+ direct_hex_ids = re.findall(r'([a-fA-F0-9]{8,})', text_without_brackets)
189
+ cache_ids.extend(direct_hex_ids)
190
+
191
+ # Pattern 3: Standalone hex IDs on separate lines (check the whole response)
192
+ standalone_pattern = r'^([a-fA-F0-9]{8,})$'
193
+ standalone_matches = re.findall(standalone_pattern, response, re.MULTILINE)
194
+ cache_ids.extend(standalone_matches)
195
+
196
+ # Remove duplicates while preserving order
197
+ seen = set()
198
+ unique_cache_ids = []
199
+ for cache_id in cache_ids:
200
+ if cache_id not in seen:
201
+ seen.add(cache_id)
202
+ unique_cache_ids.append(cache_id)
203
+
204
+ if unique_cache_ids:
205
+ tool_cache_id = ", ".join(unique_cache_ids) if len(unique_cache_ids) > 1 else unique_cache_ids[0]
206
+ elif tool_cache_id_text and tool_cache_id_text.lower() not in ["", "empty", "none", "no tool cache id"]:
207
+ tool_cache_id = tool_cache_id_text
208
+
209
+ # Parse RELEVANT_CONTEXT
210
+ context_match = re.search(
211
+ r'(?:\*\*)?RELEVANT_CONTEXT:(?:\*\*)?\s*(.*?)(?=\n\*\*[A-Z_]+:|\Z)',
212
+ response,
213
+ re.IGNORECASE | re.DOTALL
214
+ )
215
+
216
+ relevant_context = ""
217
+ if context_match:
218
+ relevant_context = context_match.group(1).strip()
219
+ elif not answer_present:
220
+ relevant_context = response
221
+
222
+ return {
223
+ "answer_present": answer_present,
224
+ "direct_answer": "YES" if answer_present else "NO",
225
+ "tool_cache_id": tool_cache_id,
226
+ "relevant_context": relevant_context,
227
+ "raw_response": response
228
+ }
229
+
230
+ except Exception as e:
231
+ print(f"Error parsing memory response: {e}")
232
+ return {
233
+ "answer_present": False,
234
+ "direct_answer": "NO",
235
+ "tool_cache_id": None,
236
+ "relevant_context": response,
237
+ "raw_response": response
238
+ }
src/deepforest_agent/utils/rtree_spatial_utils.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Dict, Any, Tuple, Optional
3
+ from rtree import index
4
+ import pandas as pd
5
+
6
+
7
+ class DetectionSpatialAnalyzer:
8
+ """
9
+ Spatial analyzer using R-tree for DeepForest detection results.
10
+ """
11
+
12
+ def __init__(self, image_width: int, image_height: int):
13
+ """
14
+ Initialize spatial analyzer with image dimensions.
15
+
16
+ Args:
17
+ image_width: Width of the image in pixels
18
+ image_height: Height of the image in pixels
19
+ """
20
+ self.image_width = image_width
21
+ self.image_height = image_height
22
+ self.spatial_index = index.Index()
23
+ self.detections = []
24
+
25
+ def add_detections(self, detections_list: List[Dict[str, Any]]) -> None:
26
+ """
27
+ Add detections to R-tree spatial index.
28
+
29
+ Args:
30
+ detections_list: List of detection dictionaries with coordinates
31
+ """
32
+ for i, detection in enumerate(detections_list):
33
+ xmin = detection.get('xmin', 0)
34
+ ymin = detection.get('ymin', 0)
35
+ xmax = detection.get('xmax', 0)
36
+ ymax = detection.get('ymax', 0)
37
+
38
+ # Validate box ordering - swap if necessary
39
+ if xmin > xmax:
40
+ xmin, xmax = xmax, xmin
41
+ if ymin > ymax:
42
+ ymin, ymax = ymax, ymin
43
+
44
+ # Clamp to image bounds
45
+ xmin = max(0, min(xmin, self.image_width))
46
+ ymin = max(0, min(ymin, self.image_height))
47
+ xmax = max(0, min(xmax, self.image_width))
48
+ ymax = max(0, min(ymax, self.image_height))
49
+
50
+ # Skip invalid boxes (zero area after validation)
51
+ if xmin >= xmax or ymin >= ymax:
52
+ continue
53
+
54
+ # Add to R-tree index
55
+ self.spatial_index.insert(i, (xmin, ymin, xmax, ymax))
56
+
57
+ # Store detection with spatial info
58
+ detection_copy = detection.copy()
59
+ detection_copy['detection_id'] = i
60
+ detection_copy['centroid_x'] = (xmin + xmax) / 2
61
+ detection_copy['centroid_y'] = (ymin + ymax) / 2
62
+ detection_copy['area'] = (xmax - xmin) * (ymax - ymin)
63
+ self.detections.append(detection_copy)
64
+
65
+ def get_grid_analysis(self) -> Dict[str, Dict[str, Any]]:
66
+ """
67
+ Analyze detections using 3x3 grid system.
68
+
69
+ Returns:
70
+ Dictionary with analysis for each grid cell
71
+ """
72
+ grid_width = self.image_width / 3
73
+ grid_height = self.image_height / 3
74
+
75
+ grid_names = {
76
+ (0, 0): "Top-Left (Northwest)", (1, 0): "Top-Center (North)", (2, 0): "Top-Right (Northeast)",
77
+ (0, 1): "Middle-Left (West)", (1, 1): "Center", (2, 1): "Middle-Right (East)",
78
+ (0, 2): "Bottom-Left (Southwest)", (1, 2): "Bottom-Center (South)", (2, 2): "Bottom-Right (Southeast)"
79
+ }
80
+
81
+ grid_analysis = {}
82
+
83
+ for (grid_x, grid_y), grid_name in grid_names.items():
84
+ # Define grid bounds
85
+ x_min = grid_x * grid_width
86
+ y_min = grid_y * grid_height
87
+ x_max = (grid_x + 1) * grid_width
88
+ y_max = (grid_y + 1) * grid_height
89
+
90
+ # Query R-tree for intersecting detections
91
+ intersecting_ids = list(self.spatial_index.intersection((x_min, y_min, x_max, y_max)))
92
+ grid_detections = [self.detections[i] for i in intersecting_ids]
93
+
94
+ # Analyze by confidence categories
95
+ confidence_analysis = self._analyze_confidence_categories(grid_detections)
96
+
97
+ grid_analysis[grid_name] = {
98
+ "total_detections": len(grid_detections),
99
+ "confidence_analysis": confidence_analysis,
100
+ "bounds": {"x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max}
101
+ }
102
+
103
+ return grid_analysis
104
+
105
+ def _analyze_confidence_categories(self, detections: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
106
+ """
107
+ Analyze detections by confidence categories.
108
+
109
+ Args:
110
+ detections: List of detection dictionaries
111
+
112
+ Returns:
113
+ Analysis by confidence categories (Low, Medium, High)
114
+ """
115
+ categories = {
116
+ "Detections with Low Confidence Score (0.0-0.3)": {"detections": [], "range": (0.0, 0.3)},
117
+ "Detections with Medium Confidence Score (0.3-0.7)": {"detections": [], "range": (0.3, 0.7)},
118
+ "Detections with High Confidence Score (0.7-1.0)": {"detections": [], "range": (0.7, 1.0)}
119
+ }
120
+
121
+ for detection in detections:
122
+ score = detection.get('score', 0.0)
123
+ if score < 0.3:
124
+ categories["Detections with Low Confidence Score (0.0-0.3)"]["detections"].append(detection)
125
+ elif score < 0.7:
126
+ categories["Detections with Medium Confidence Score (0.3-0.7)"]["detections"].append(detection)
127
+ else:
128
+ categories["Detections with High Confidence Score (0.7-1.0)"]["detections"].append(detection)
129
+
130
+ # Calculate statistics for each category
131
+ analysis = {}
132
+ for category_name, category_data in categories.items():
133
+ cat_detections = category_data["detections"]
134
+ if cat_detections:
135
+ areas = [d['area'] for d in cat_detections]
136
+ analysis[category_name] = {
137
+ "count": len(cat_detections),
138
+ "avg_area": np.mean(areas),
139
+ "min_area": np.min(areas),
140
+ "max_area": np.max(areas),
141
+ "total_area_covered": np.sum(areas),
142
+ "labels": [d.get('label', 'unknown') for d in cat_detections]
143
+ }
144
+ else:
145
+ analysis[category_name] = {
146
+ "count": 0,
147
+ "avg_area": 0,
148
+ "min_area": 0,
149
+ "max_area": 0,
150
+ "total_area_covered": 0,
151
+ "labels": []
152
+ }
153
+
154
+ return analysis
155
+
156
+ def analyze_spatial_relationships_with_indexing(self, confidence_threshold: float = 0.3) -> List[Dict[str, Any]]:
157
+ """
158
+ Analyze spatial relationships using R-tree indexing for confidence >= 0.3 detections.
159
+
160
+ Args:
161
+ confidence_threshold: Minimum confidence score (default: 0.3)
162
+
163
+ Returns:
164
+ List of spatial relationship dictionaries with intersection and nearest neighbor data
165
+ """
166
+ # Filter detections by confidence threshold
167
+ high_confidence_detections = [
168
+ d for d in self.detections
169
+ if d.get('score', 0.0) >= confidence_threshold
170
+ ]
171
+
172
+ if not high_confidence_detections:
173
+ return []
174
+
175
+ relationships = []
176
+
177
+ for detection in high_confidence_detections:
178
+ # Get bounding box coordinates directly
179
+ xmin = detection.get('xmin', 0)
180
+ ymin = detection.get('ymin', 0)
181
+ xmax = detection.get('xmax', 0)
182
+ ymax = detection.get('ymax', 0)
183
+ detection_id = detection.get('detection_id', 0)
184
+
185
+ # Get object label (handle classification labels for trees)
186
+ if 'classification_label' in detection and detection['classification_label'] and str(detection['classification_label']).lower() != 'nan':
187
+ object_label = detection['classification_label']
188
+ else:
189
+ object_label = detection.get('label', 'unknown')
190
+
191
+ # Find intersecting objects using spatial index
192
+ intersecting_ids = list(self.spatial_index.intersection((xmin, ymin, xmax, ymax)))
193
+
194
+ # Remove self from intersections
195
+ intersecting_ids = [idx for idx in intersecting_ids if idx != detection_id]
196
+
197
+ # Get details of intersecting objects
198
+ intersecting_objects = []
199
+ for idx in intersecting_ids:
200
+ if idx < len(self.detections):
201
+ intersecting_detection = self.detections[idx]
202
+ if intersecting_detection.get('score', 0.0) >= confidence_threshold:
203
+ if 'classification_label' in intersecting_detection and intersecting_detection['classification_label'] and str(intersecting_detection['classification_label']).lower() != 'nan':
204
+ intersecting_label = intersecting_detection['classification_label']
205
+ else:
206
+ intersecting_label = intersecting_detection.get('label', 'unknown')
207
+ intersecting_objects.append(intersecting_label)
208
+
209
+ # Find nearest neighbor using spatial index
210
+ nearest_ids = list(self.spatial_index.nearest((xmin, ymin, xmax, ymax), 2)) # 2 to get self + nearest
211
+ nearest_neighbor = None
212
+
213
+ for idx in nearest_ids:
214
+ if idx != detection_id and idx < len(self.detections):
215
+ nearest_detection = self.detections[idx]
216
+ if nearest_detection.get('score', 0.0) >= confidence_threshold:
217
+ if 'classification_label' in nearest_detection and nearest_detection['classification_label'] and str(nearest_detection['classification_label']).lower() != 'nan':
218
+ nearest_label = nearest_detection['classification_label']
219
+ else:
220
+ nearest_label = nearest_detection.get('label', 'unknown')
221
+ nearest_neighbor = nearest_label
222
+ break
223
+
224
+ # Determine grid region
225
+ grid_region = self._determine_grid_region(detection)
226
+
227
+ # Count intersecting objects by type
228
+ object_counts = {}
229
+ for obj_label in intersecting_objects:
230
+ object_counts[obj_label] = object_counts.get(obj_label, 0) + 1
231
+
232
+ relationships.append({
233
+ 'object_type': object_label,
234
+ 'object_location': f"({ymin}, {xmin})",
235
+ 'grid_region': grid_region,
236
+ 'intersecting_objects': object_counts,
237
+ 'nearest_neighbor': nearest_neighbor,
238
+ 'confidence_score': detection.get('score', 0.0),
239
+ 'total_intersections': len(intersecting_objects)
240
+ })
241
+
242
+ return relationships
243
+
244
+ def _determine_grid_region(self, detection: Dict[str, Any]) -> str:
245
+ """
246
+ Determine which grid region a detection belongs to based on its centroid.
247
+
248
+ Args:
249
+ detection: Detection dictionary with coordinates
250
+
251
+ Returns:
252
+ Grid region name (e.g., "northern", "northwest", etc.)
253
+ """
254
+ centroid_x = detection.get('centroid_x', 0)
255
+ centroid_y = detection.get('centroid_y', 0)
256
+
257
+ grid_width = self.image_width / 3
258
+ grid_height = self.image_height / 3
259
+
260
+ # Determine grid position
261
+ grid_x = int(centroid_x // grid_width)
262
+ grid_y = int(centroid_y // grid_height)
263
+
264
+ # Ensure within bounds
265
+ grid_x = min(2, max(0, grid_x))
266
+ grid_y = min(2, max(0, grid_y))
267
+
268
+ grid_names = {
269
+ (0, 0): "northwestern", (1, 0): "northern", (2, 0): "northeastern",
270
+ (0, 1): "western", (1, 1): "central", (2, 1): "eastern",
271
+ (0, 2): "southwestern", (1, 2): "southern", (2, 2): "southeastern"
272
+ }
273
+
274
+ return grid_names.get((grid_x, grid_y), "central")
275
+
276
+ def generate_spatial_narrative(self, confidence_threshold: float = 0.3) -> str:
277
+ """
278
+ Generate narrative description of spatial relationships using R-tree analysis.
279
+
280
+ Args:
281
+ confidence_threshold: Minimum confidence score for analysis (default: 0.3)
282
+
283
+ Returns:
284
+ Natural language narrative of spatial relationships
285
+ """
286
+ relationships = self.analyze_spatial_relationships_with_indexing(confidence_threshold)
287
+
288
+ if not relationships:
289
+ return f"No objects with confidence score >= {confidence_threshold} found for spatial relationship analysis."
290
+
291
+ narrative_parts = []
292
+
293
+ # Process each relationship and only include different object types
294
+ for rel in relationships:
295
+ object_type = rel['object_type']
296
+ confidence_score = rel['confidence_score']
297
+ grid_region = rel['grid_region']
298
+ object_location = rel['object_location']
299
+
300
+ # Only process intersecting objects that are DIFFERENT from the main object
301
+ different_intersecting = {}
302
+ for intersecting_type, count in rel['intersecting_objects'].items():
303
+ if intersecting_type != object_type: # Only different object types
304
+ different_intersecting[intersecting_type] = count
305
+
306
+ # Generate narrative for intersecting different objects
307
+ if different_intersecting:
308
+ intersecting_parts = []
309
+ for obj_label, count in different_intersecting.items():
310
+ if count == 1:
311
+ intersecting_parts.append(f"{count} {obj_label.replace('_', ' ')}")
312
+ else:
313
+ intersecting_parts.append(f"{count} {obj_label.replace('_', ' ')}s")
314
+
315
+ intersecting_desc = ", ".join(intersecting_parts)
316
+
317
+ narrative_parts.append(
318
+ f"I am about {confidence_score*100:.1f}% confident that, in {grid_region} region, "
319
+ f"{intersecting_desc} found overlapping around the {object_type.replace('_', ' ')} "
320
+ f"object at location (top, left) = {object_location}.\n"
321
+ )
322
+
323
+ # Only add nearest neighbor information if it's a DIFFERENT object type
324
+ if rel['nearest_neighbor'] and rel['nearest_neighbor'] != object_type:
325
+ narrative_parts.append(
326
+ f"I am about {confidence_score*100:.1f}% confident that, in {grid_region} region, "
327
+ f"around the {object_type.replace('_', ' ')} at location (top, left) = {object_location} "
328
+ f"the nearest neighbor is a {rel['nearest_neighbor'].replace('_', ' ')}.\n"
329
+ )
330
+
331
+ if narrative_parts:
332
+ # Remove duplicates while preserving order
333
+ unique_narratives = []
334
+ seen = set()
335
+ for part in narrative_parts:
336
+ if part not in seen:
337
+ unique_narratives.append(part)
338
+ seen.add(part)
339
+
340
+ return " ".join(unique_narratives)
341
+ else:
342
+ return f"Spatial analysis completed for {len(relationships)} objects with confidence >= {confidence_threshold}, but no significant spatial relationships between different object types detected."
343
+
344
+ def get_detection_statistics(self) -> Dict[str, Any]:
345
+ """
346
+ Get comprehensive detection statistics.
347
+
348
+ Returns:
349
+ Dictionary with overall statistics
350
+ """
351
+ if not self.detections:
352
+ return {"total_count": 0}
353
+
354
+ # Basic counts and confidence
355
+ total_count = len(self.detections)
356
+ scores = [d.get('score', 0.0) for d in self.detections]
357
+ overall_confidence = np.mean(scores)
358
+
359
+ # Size statistics
360
+ areas = [d['area'] for d in self.detections]
361
+ avg_area = np.mean(areas)
362
+ min_area = np.min(areas)
363
+ max_area = np.max(areas)
364
+ total_area = np.sum(areas)
365
+
366
+ # Label distribution
367
+ labels = [d.get('label', 'unknown') for d in self.detections]
368
+ # Handle classification labels for trees
369
+ classified_labels = []
370
+ for d in self.detections:
371
+ if 'classification_label' in d and d['classification_label'] and str(d['classification_label']).lower() != 'nan':
372
+ classified_labels.append(d['classification_label'])
373
+ else:
374
+ classified_labels.append(d.get('label', 'unknown'))
375
+
376
+ from collections import Counter
377
+ label_counts = Counter(classified_labels)
378
+
379
+ return {
380
+ "total_count": total_count,
381
+ "overall_confidence": overall_confidence,
382
+ "size_stats": {
383
+ "avg_area": avg_area,
384
+ "min_area": min_area,
385
+ "max_area": max_area,
386
+ "total_area_covered": total_area
387
+ },
388
+ "label_distribution": dict(label_counts),
389
+ "confidence_distribution": {
390
+ "low_count": len([s for s in scores if s < 0.3]),
391
+ "medium_count": len([s for s in scores if 0.3 <= s < 0.7]),
392
+ "high_count": len([s for s in scores if s >= 0.7])
393
+ }
394
+ }
src/deepforest_agent/utils/state_manager.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import uuid
3
+ import time
4
+ from typing import Optional, Any, Dict, List
5
+
6
+ from deepforest_agent.utils.cache_utils import tool_call_cache
7
+
8
+
9
+ class SessionStateManager:
10
+ """
11
+ Session-based state manager with thread ID for the DeepForest Agent.
12
+
13
+ This class manages state for multiple concurrent users with each user
14
+ having their own session containing current image, conversation
15
+ history, and session information.
16
+
17
+ Attributes:
18
+ _lock (threading.Lock): Thread synchronization lock
19
+ _sessions (Dict[str, Dict[str, Any]]): Dictionary mapping session_ids to session state
20
+ _cleanup_interval (int): Time in seconds after which inactive sessions are cleaned up
21
+ """
22
+
23
+ def __init__(self, cleanup_interval: int = 3600) -> None:
24
+ """
25
+ Initialize the session state manager.
26
+
27
+ Args:
28
+ cleanup_interval (int): Time in seconds after which inactive sessions
29
+ are eligible for cleanup (default: 1 hour)
30
+ """
31
+ self._lock = threading.Lock()
32
+ self._sessions = {}
33
+ self._cleanup_interval = cleanup_interval
34
+
35
+ def create_session(self, image: Any = None) -> str:
36
+ """
37
+ Create a new session with initial image.
38
+
39
+ Args:
40
+ image (Any, optional): Initial image for the session
41
+
42
+ Returns:
43
+ str: Unique session ID
44
+ """
45
+ session_id = str(uuid.uuid4())[:12]
46
+
47
+ with self._lock:
48
+ self._sessions[session_id] = {
49
+ "current_image": image,
50
+ "conversation_history": [],
51
+ "annotated_image": None,
52
+ "thread_id": session_id,
53
+ "first_message": True,
54
+ "created_at": time.time(),
55
+ "last_accessed": time.time(),
56
+ "is_cancelled": False,
57
+ "is_processing": False,
58
+ "tool_call_history": [],
59
+ "visual_analysis_history": []
60
+ }
61
+
62
+ return session_id
63
+
64
+ def get_session_state(self, session_id: str) -> Dict[str, Any]:
65
+ """
66
+ Get complete state for a specific session.
67
+
68
+ Args:
69
+ session_id (str): The session ID to retrieve
70
+
71
+ Returns:
72
+ Dict[str, Any]: Copy of session state dictionary
73
+
74
+ Raises:
75
+ KeyError: If session_id doesn't exist
76
+ """
77
+ with self._lock:
78
+ if session_id not in self._sessions:
79
+ raise KeyError(f"Session {session_id} not found")
80
+
81
+ self._sessions[session_id]["last_accessed"] = time.time()
82
+
83
+ # Return a copy to prevent external modification
84
+ return self._sessions[session_id].copy()
85
+
86
+ def get(self, session_id: str, key: str, default: Any = None) -> Any:
87
+ """
88
+ Get a value from session state.
89
+
90
+ Args:
91
+ session_id (str): The session ID
92
+ key (str): The state key to retrieve
93
+ default (Any, optional): Default value if key not found.
94
+
95
+ Returns:
96
+ Any: The value associated with the key, or default if not found
97
+
98
+ Raises:
99
+ KeyError: If session_id doesn't exist
100
+ """
101
+ with self._lock:
102
+ if session_id not in self._sessions:
103
+ raise KeyError(f"Session {session_id} not found")
104
+
105
+ self._sessions[session_id]["last_accessed"] = time.time()
106
+
107
+ return self._sessions[session_id].get(key, default)
108
+
109
+ def set(self, session_id: str, key: str, value: Any) -> None:
110
+ """
111
+ Set a value in session state.
112
+
113
+ Args:
114
+ session_id (str): The session ID
115
+ key (str): The state key to set
116
+ value (Any): The value to store
117
+
118
+ Raises:
119
+ KeyError: If session_id doesn't exist
120
+ """
121
+ with self._lock:
122
+ if session_id not in self._sessions:
123
+ raise KeyError(f"Session {session_id} not found")
124
+
125
+ self._sessions[session_id][key] = value
126
+ self._sessions[session_id]["last_accessed"] = time.time()
127
+
128
+ def update(self, session_id: str, updates: Dict[str, Any]) -> None:
129
+ """
130
+ Update multiple values in session state.
131
+
132
+ Args:
133
+ session_id (str): The session ID
134
+ updates (Dict[str, Any]): Dictionary of key-value pairs to update
135
+
136
+ Raises:
137
+ KeyError: If session_id doesn't exist
138
+ """
139
+ with self._lock:
140
+ if session_id not in self._sessions:
141
+ raise KeyError(f"Session {session_id} not found")
142
+
143
+ self._sessions[session_id].update(updates)
144
+ self._sessions[session_id]["last_accessed"] = time.time()
145
+
146
+ def set_processing_state(self, session_id: str, is_processing: bool) -> None:
147
+ """
148
+ Set processing state for a session.
149
+
150
+ Args:
151
+ session_id (str): The session ID
152
+ is_processing (bool): Whether processing is active
153
+ """
154
+ with self._lock:
155
+ if session_id in self._sessions:
156
+ self._sessions[session_id]["is_processing"] = is_processing
157
+ self._sessions[session_id]["last_accessed"] = time.time()
158
+
159
+ def cancel_session(self, session_id: str) -> None:
160
+ """
161
+ Cancel processing for a session.
162
+
163
+ Args:
164
+ session_id (str): The session ID to cancel
165
+ """
166
+ with self._lock:
167
+ if session_id in self._sessions:
168
+ self._sessions[session_id]["is_cancelled"] = True
169
+ self._sessions[session_id]["is_processing"] = False
170
+ self._sessions[session_id]["last_accessed"] = time.time()
171
+
172
+ def is_cancelled(self, session_id: str) -> bool:
173
+ """
174
+ Check if session is cancelled.
175
+
176
+ Args:
177
+ session_id (str): The session ID to check
178
+
179
+ Returns:
180
+ bool: True if cancelled
181
+ """
182
+ with self._lock:
183
+ if session_id not in self._sessions:
184
+ return True
185
+ return self._sessions[session_id].get("is_cancelled", False)
186
+
187
+ def reset_cancellation(self, session_id: str) -> None:
188
+ """
189
+ Reset cancellation flag for a session.
190
+
191
+ Args:
192
+ session_id (str): The session ID to reset
193
+ """
194
+ with self._lock:
195
+ if session_id in self._sessions:
196
+ self._sessions[session_id]["is_cancelled"] = False
197
+ self._sessions[session_id]["last_accessed"] = time.time()
198
+
199
+ def add_tool_call_to_history(self, session_id: str, tool_name: str, arguments: Dict[str, Any], cache_key: str) -> None:
200
+ """
201
+ Add a tool call to the session's tool call history.
202
+
203
+ Args:
204
+ session_id (str): The session ID
205
+ tool_name (str): Name of the tool that was called
206
+ arguments (Dict[str, Any]): Arguments passed to the tool
207
+ cache_key (str): Cache key used for this tool call
208
+
209
+ Raises:
210
+ KeyError: If session_id doesn't exist
211
+ """
212
+ with self._lock:
213
+ if session_id not in self._sessions:
214
+ raise KeyError(f"Session {session_id} not found")
215
+
216
+ tool_call_entry = {
217
+ "tool_name": tool_name,
218
+ "arguments": arguments.copy(),
219
+ "cache_key": cache_key,
220
+ "timestamp": time.time(),
221
+ "call_number": len(self._sessions[session_id]["tool_call_history"]) + 1
222
+ }
223
+
224
+ self._sessions[session_id]["tool_call_history"].append(tool_call_entry)
225
+ self._sessions[session_id]["last_accessed"] = time.time()
226
+
227
+ def get_tool_call_history(self, session_id: str) -> List[Dict[str, Any]]:
228
+ """
229
+ Get the tool call history for a specific session.
230
+
231
+ Args:
232
+ session_id (str): The session ID
233
+
234
+ Returns:
235
+ List[Dict[str, Any]]: List of tool calls made in this session
236
+
237
+ Raises:
238
+ KeyError: If session_id doesn't exist
239
+ """
240
+ with self._lock:
241
+ if session_id not in self._sessions:
242
+ raise KeyError(f"Session {session_id} not found")
243
+
244
+ self._sessions[session_id]["last_accessed"] = time.time()
245
+ return self._sessions[session_id]["tool_call_history"].copy()
246
+
247
+ def add_visual_analysis_to_history(self, session_id: str, visual_analysis: str, additional_objects: Optional[List[Dict[str, Any]]] = None) -> None:
248
+ """
249
+ Add a visual analysis response to the session's history.
250
+
251
+ Args:
252
+ session_id (str): The session ID
253
+ visual_analysis (str): Visual analysis text from visual agent
254
+ additional_objects (Optional[List[Dict[str, Any]]]): Additional objects detected by visual agent
255
+
256
+ Raises:
257
+ KeyError: If session_id doesn't exist
258
+ """
259
+ with self._lock:
260
+ if session_id not in self._sessions:
261
+ raise KeyError(f"Session {session_id} not found")
262
+
263
+ visual_entry = {
264
+ "visual_analysis": visual_analysis,
265
+ "additional_objects": additional_objects or [],
266
+ "timestamp": time.time(),
267
+ "turn_number": len(self._sessions[session_id]["visual_analysis_history"]) + 1
268
+ }
269
+
270
+ self._sessions[session_id]["visual_analysis_history"].append(visual_entry)
271
+ self._sessions[session_id]["last_accessed"] = time.time()
272
+
273
+ def get_visual_analysis_history(self, session_id: str) -> List[Dict[str, Any]]:
274
+ """
275
+ Get all visual analysis responses from previous turns.
276
+
277
+ Args:
278
+ session_id (str): The session ID
279
+
280
+ Returns:
281
+ List[Dict[str, Any]]: List of visual analysis entries with text and additional objects
282
+
283
+ Raises:
284
+ KeyError: If session_id doesn't exist
285
+ """
286
+ with self._lock:
287
+ if session_id not in self._sessions:
288
+ raise KeyError(f"Session {session_id} not found")
289
+
290
+ self._sessions[session_id]["last_accessed"] = time.time()
291
+
292
+ return self._sessions[session_id]["visual_analysis_history"].copy()
293
+
294
+ def get_formatted_tool_call_history(self, session_id: str) -> str:
295
+ """
296
+ Get formatted tool call history for memory agent context.
297
+
298
+ Args:
299
+ session_id (str): The session ID
300
+
301
+ Returns:
302
+ str: Formatted tool call history string
303
+
304
+ Raises:
305
+ KeyError: If session_id doesn't exist
306
+ """
307
+ try:
308
+ tool_calls = self.get_tool_call_history(session_id)
309
+ if not tool_calls:
310
+ return "No previous tool calls in this session."
311
+
312
+ formatted_history = []
313
+ for tool_call in tool_calls:
314
+ call_info = f"Tool Call #{tool_call.get('call_number', 'N/A')}: "
315
+ call_info += f"{tool_call.get('tool_name', 'unknown')} "
316
+ call_info += f"with args {tool_call.get('arguments', {})}"
317
+ formatted_history.append(call_info)
318
+
319
+ return "\n".join(formatted_history)
320
+ except KeyError:
321
+ return f"Session {session_id} not found - no tool call history available."
322
+
323
+ def store_conversation_turn_context(
324
+ self,
325
+ session_id: str,
326
+ turn_number: int,
327
+ user_query: str,
328
+ visual_context: str,
329
+ detection_narrative: str,
330
+ tool_cache_id: Optional[str],
331
+ ecology_response: str
332
+ ) -> None:
333
+ """
334
+ Store complete turn context for memory agent.
335
+
336
+ Args:
337
+ session_id (str): The session ID
338
+ turn_number (int): Sequential number of this conversation turn (1-indexed)
339
+ user_query (str): The original user question of the current turn
340
+ visual_context (str): Complete visual analysis output from the visual agent
341
+ detection_narrative (str): Comprehensive spatial analysis narrative generated
342
+ from DeepForest detection results
343
+ tool_cache_id (Optional[str]): Cache identifier for DeepForest tool execution
344
+ results
345
+ ecology_response (str): Final synthesized ecological analysis response
346
+ """
347
+ turn_data = {
348
+ "user_query": user_query,
349
+ "visual_context": visual_context,
350
+ "detection_narrative": detection_narrative,
351
+ "tool_cache_id": tool_cache_id or "No tool cache ID",
352
+ "ecology_response": ecology_response,
353
+ "timestamp": time.time()
354
+ }
355
+
356
+ self.set(session_id, f"conversation_turn_{turn_number}", turn_data)
357
+
358
+ # Update turn counter
359
+ current_turns = self.get(session_id, "total_turns", 0)
360
+ self.set(session_id, "total_turns", max(current_turns, turn_number))
361
+
362
+ def get_cache_stats_for_session(self, session_id: str) -> Dict[str, Any]:
363
+ """
364
+ Get cache statistics specific to this session.
365
+
366
+ Args:
367
+ session_id (str): The session ID
368
+
369
+ Returns:
370
+ Dict[str, Any]: Cache statistics for this session
371
+
372
+ Raises:
373
+ KeyError: If session_id doesn't exist
374
+ """
375
+ with self._lock:
376
+ if session_id not in self._sessions:
377
+ raise KeyError(f"Session {session_id} not found")
378
+
379
+ session_tool_calls = self._sessions[session_id]["tool_call_history"]
380
+
381
+ return {
382
+ "session_id": session_id,
383
+ "total_tool_calls": len(session_tool_calls),
384
+ "tool_calls": session_tool_calls,
385
+ "global_cache_stats": tool_call_cache.get_cache_stats()
386
+ }
387
+
388
+ def clear_session_cache_data(self, session_id: str) -> None:
389
+ """
390
+ Clear tool call history for a specific session.
391
+
392
+ Note: This only clears the session's record of tool calls,
393
+ not the global cache itself.
394
+
395
+ Args:
396
+ session_id (str): The session ID
397
+
398
+ Raises:
399
+ KeyError: If session_id doesn't exist
400
+ """
401
+ with self._lock:
402
+ if session_id not in self._sessions:
403
+ raise KeyError(f"Session {session_id} not found")
404
+
405
+ self._sessions[session_id]["tool_call_history"] = []
406
+ self._sessions[session_id]["last_accessed"] = time.time()
407
+
408
+ def clear_conversation(self, session_id: str) -> None:
409
+ """
410
+ Clear conversation-specific state for a session.
411
+
412
+ current_image and thread_id are preserved so that users can
413
+ start a new conversation without re-uploading the image.
414
+
415
+ Args:
416
+ session_id (str): The session ID to clear
417
+
418
+ Raises:
419
+ KeyError: If session_id doesn't exist
420
+ """
421
+ with self._lock:
422
+ if session_id not in self._sessions:
423
+ raise KeyError(f"Session {session_id} not found")
424
+
425
+ self._sessions[session_id].update({
426
+ "conversation_history": [],
427
+ "annotated_image": None,
428
+ "first_message": True,
429
+ "last_accessed": time.time(),
430
+ "is_cancelled": True,
431
+ "is_processing": False,
432
+ "tool_call_history": [],
433
+ "visual_analysis_history": []
434
+ })
435
+
436
+ def reset_for_new_image(self, session_id: str, image: Any) -> None:
437
+ """
438
+ Reset session state for new image upload.
439
+
440
+ Args:
441
+ session_id (str): The session ID
442
+ image (Any): The new image object (typically PIL Image)
443
+
444
+ Raises:
445
+ KeyError: If session_id doesn't exist
446
+ """
447
+ with self._lock:
448
+ if session_id not in self._sessions:
449
+ raise KeyError(f"Session {session_id} not found")
450
+
451
+ self._sessions[session_id].update({
452
+ "current_image": image,
453
+ "conversation_history": [],
454
+ "annotated_image": None,
455
+ "first_message": True,
456
+ "last_accessed": time.time(),
457
+ "tool_call_history": [],
458
+ "visual_analysis_history": []
459
+ })
460
+
461
+ def add_to_conversation(self, session_id: str, message: Dict[str, Any]) -> None:
462
+ """
463
+ Add a message to conversation history for a specific session.
464
+
465
+ Args:
466
+ session_id (str): The session ID
467
+ message (Dict[str, Any]): Message dictionary with role and content
468
+
469
+ Raises:
470
+ KeyError: If session_id doesn't exist
471
+ """
472
+ with self._lock:
473
+ if session_id not in self._sessions:
474
+ raise KeyError(f"Session {session_id} not found")
475
+
476
+ self._sessions[session_id]["conversation_history"].append(message)
477
+ self._sessions[session_id]["last_accessed"] = time.time()
478
+
479
+ def get_conversation_length(self, session_id: str) -> int:
480
+ """
481
+ Get the length of conversation history for a session.
482
+
483
+ Args:
484
+ session_id (str): The session ID
485
+
486
+ Returns:
487
+ int: Number of messages in conversation history
488
+
489
+ Raises:
490
+ KeyError: If session_id doesn't exist
491
+ """
492
+ with self._lock:
493
+ if session_id not in self._sessions:
494
+ raise KeyError(f"Session {session_id} not found")
495
+
496
+ self._sessions[session_id]["last_accessed"] = time.time()
497
+ return len(self._sessions[session_id]["conversation_history"])
498
+
499
+ def session_exists(self, session_id: str) -> bool:
500
+ """
501
+ Check if a session exists.
502
+
503
+ Args:
504
+ session_id (str): The session ID to check
505
+
506
+ Returns:
507
+ bool: True if session exists, False otherwise
508
+ """
509
+ with self._lock:
510
+ return session_id in self._sessions
511
+
512
+ def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
513
+ """
514
+ Get information about all active sessions.
515
+
516
+ Returns:
517
+ Dict[str, Dict[str, Any]]: Dictionary mapping session_ids to session info
518
+ """
519
+ with self._lock:
520
+ session_info = {}
521
+ for session_id, session_data in self._sessions.items():
522
+ session_info[session_id] = {
523
+ "thread_id": session_data.get("thread_id"),
524
+ "created_at": session_data.get("created_at"),
525
+ "last_accessed": session_data.get("last_accessed"),
526
+ "conversation_length": len(session_data.get("conversation_history", [])),
527
+ "has_image": session_data.get("current_image") is not None,
528
+ "has_annotated_image": session_data.get("annotated_image") is not None,
529
+ "tool_calls_count": len(session_data.get("tool_call_history", []))
530
+ }
531
+ return session_info
532
+
533
+ def cleanup_inactive_sessions(self) -> int:
534
+ """
535
+ Remove sessions that haven't been accessed recently.
536
+
537
+ Returns:
538
+ int: Number of sessions cleaned up
539
+ """
540
+ current_time = time.time()
541
+ cleaned_count = 0
542
+
543
+ with self._lock:
544
+ inactive_sessions = []
545
+ for session_id, session_data in self._sessions.items():
546
+ last_accessed = session_data.get("last_accessed", 0)
547
+ if current_time - last_accessed > self._cleanup_interval:
548
+ inactive_sessions.append(session_id)
549
+
550
+ for session_id in inactive_sessions:
551
+ del self._sessions[session_id]
552
+ cleaned_count += 1
553
+
554
+ return cleaned_count
555
+
556
+ def delete_session(self, session_id: str) -> bool:
557
+ """
558
+ Manually delete a specific session.
559
+
560
+ Args:
561
+ session_id (str): The session ID to delete
562
+
563
+ Returns:
564
+ bool: True if session was deleted, False if it didn't exist
565
+ """
566
+ with self._lock:
567
+ if session_id in self._sessions:
568
+ del self._sessions[session_id]
569
+ return True
570
+ return False
571
+
572
+
573
+ # Global session manager instance
574
+ session_state_manager = SessionStateManager()
src/deepforest_agent/utils/tile_manager.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Tuple, Dict, List, Any, Optional
3
+ from PIL import Image
4
+ import rasterio as rio
5
+ from rasterio.windows import Window
6
+ from deepforest import preprocess
7
+ try:
8
+ import slidingwindow
9
+ SLIDINGWINDOW_AVAILABLE = True
10
+ except ImportError:
11
+ SLIDINGWINDOW_AVAILABLE = False
12
+ print("Warning: slidingwindow not available, falling back to deepforest preprocess")
13
+
14
+ from deepforest_agent.conf.config import Config
15
+
16
+
17
+ def tile_image_for_analysis(
18
+ image: Image.Image,
19
+ patch_size: int = Config.DEEPFOREST_DEFAULTS["patch_size"],
20
+ patch_overlap: float = Config.DEEPFOREST_DEFAULTS["patch_overlap"],
21
+ image_file_path: Optional[str] = None,
22
+ ) -> Tuple[List[Image.Image], List[Dict[str, Any]]]:
23
+ """
24
+ Tile am Image for visual analysis.
25
+
26
+ Args:
27
+ image (Image.Image): PIL Image to tile
28
+ patch_size (int): Size of each tile in pixels (default: 400)
29
+ patch_overlap (float): Overlap between tiles as fraction 0-1 (default: 0.05)
30
+ image_file_path (Optional[str]): Path to raster file for memory-efficient dimension reading
31
+
32
+ Returns:
33
+ Tuple containing:
34
+ - List[Image.Image]: List of PIL Image tiles
35
+ - List[Dict[str, Any]]: List of tile metadata with coordinates
36
+
37
+ Raises:
38
+ ValueError: If patch_overlap > 1 or image is too small for patch_size
39
+ Exception: If tiling process fails
40
+ """
41
+ try:
42
+ # Use slidingwindow for all image types if available
43
+ if SLIDINGWINDOW_AVAILABLE:
44
+ height = width = None
45
+ method = "unknown"
46
+
47
+ if image_file_path:
48
+ try:
49
+ # Get raster shape without keeping file open
50
+ with rio.open(image_file_path) as src:
51
+ height = src.shape[0]
52
+ width = src.shape[1]
53
+ method = "slidingwindow_raster"
54
+ print(f"Using raster dimensions: {width}x{height} from file path")
55
+ except Exception as raster_error:
56
+ print(f"Raster reading failed: {raster_error}, using PIL image dimensions")
57
+ height = width = None
58
+
59
+ # If raster reading failed or no file path, get dimensions from PIL image
60
+ if height is None or width is None:
61
+ width, height = image.size
62
+ method = "slidingwindow_pil"
63
+ print(f"Using PIL dimensions: {width}x{height} from image object")
64
+
65
+ try:
66
+ # Generate windows using slidingwindow for any image type
67
+ windows = slidingwindow.generateForSize(
68
+ height=height,
69
+ width=width,
70
+ dimOrder=slidingwindow.DimOrder.ChannelHeightWidth,
71
+ maxWindowSize=patch_size,
72
+ overlapPercent=patch_overlap
73
+ )
74
+
75
+ print(f"Generated {len(windows)} tiles using slidingwindow with method: {method}")
76
+
77
+ tiles = []
78
+ tile_metadata = []
79
+
80
+ for i, window in enumerate(windows):
81
+ x = window.x
82
+ y = window.y
83
+ w = window.w
84
+ h = window.h
85
+
86
+ # Extract actual image data for this tile
87
+ if method == "slidingwindow_raster" and image_file_path:
88
+ try:
89
+ with rio.open(image_file_path) as src:
90
+ window_data = src.read(window=Window(x, y, w, h))
91
+ if window_data.ndim == 3:
92
+ window_data = window_data.transpose(1, 2, 0)
93
+
94
+ if window_data.dtype != np.uint8:
95
+ if window_data.max() <= 1.0:
96
+ window_data = (window_data * 255).astype(np.uint8)
97
+ else:
98
+ window_data = window_data.astype(np.uint8)
99
+
100
+ tile_pil = Image.fromarray(window_data)
101
+ print(f"Tile {i}: Read raster data {window_data.shape} -> PIL {tile_pil.size}")
102
+
103
+ except Exception as raster_read_error:
104
+ print(f"Failed to read raster tile {i}: {raster_read_error}")
105
+ tile_pil = image.crop((x, y, x + w, y + h))
106
+ print(f"Tile {i}: Fallback PIL crop -> {tile_pil.size}")
107
+ else:
108
+ tile_pil = image.crop((x, y, x + w, y + h))
109
+ print(f"Tile {i}: PIL crop ({x},{y},{x+w},{y+h}) -> {tile_pil.size}")
110
+
111
+ tiles.append(tile_pil)
112
+
113
+ # Create tile metadata with tile info
114
+ metadata = {
115
+ "tile_index": i,
116
+ "window_coords": {
117
+ "x": x,
118
+ "y": y,
119
+ "width": w,
120
+ "height": h
121
+ },
122
+ "tile_size": tile_pil.size,
123
+ "original_image_size": (width, height),
124
+ "method": method,
125
+ "actual_crop_bounds": (x, y, x + w, y + h)
126
+ }
127
+ tile_metadata.append(metadata)
128
+
129
+ print(f"Successfully created {len(tiles)} tiles using slidingwindow method")
130
+ return tiles, tile_metadata
131
+
132
+ except Exception as slidingwindow_error:
133
+ print(f"Slidingwindow method failed: {slidingwindow_error}, falling back to deepforest preprocess")
134
+
135
+ # Fallback to deepforest preprocess method only if slidingwindow failed
136
+ print(f"Using PIL-based tiling for image with size {image.size}")
137
+
138
+ numpy_image = np.array(image)
139
+
140
+ if numpy_image.shape[2] == 4:
141
+ numpy_image = numpy_image[:, :, :3]
142
+ elif numpy_image.shape[2] != 3:
143
+ raise ValueError(f"Image must have 3 channels (RGB), got {numpy_image.shape[2]}")
144
+
145
+ numpy_image = numpy_image.transpose(2, 0, 1)
146
+ numpy_image = numpy_image / 255.0
147
+ numpy_image = numpy_image.astype(np.float32)
148
+
149
+ print(f"Tiling image with shape {numpy_image.shape} using patch_size={patch_size}, patch_overlap={patch_overlap}")
150
+
151
+ windows = preprocess.compute_windows(numpy_image, patch_size, patch_overlap)
152
+
153
+ print(f"Generated {len(windows)} tiles for analysis using deepforest preprocess")
154
+
155
+ tiles = []
156
+ tile_metadata = []
157
+
158
+ for i, window in enumerate(windows):
159
+ tile_array = numpy_image[window.indices()]
160
+ tile_array = tile_array.transpose(1, 2, 0)
161
+ if tile_array.dtype != np.uint8:
162
+ tile_array = (tile_array * 255).astype(np.uint8) if tile_array.max() <= 1.0 else tile_array.astype(np.uint8)
163
+
164
+ tile_pil = Image.fromarray(tile_array)
165
+ tiles.append(tile_pil)
166
+
167
+ x, y, w, h = window.getRect()
168
+ print(f"DeepForest tile {i}: array shape {tile_array.shape} -> PIL {tile_pil.size}")
169
+
170
+ # Create tile metadata
171
+ metadata = {
172
+ "tile_index": i,
173
+ "window_coords": {
174
+ "x": x,
175
+ "y": y,
176
+ "width": w,
177
+ "height": h
178
+ },
179
+ "tile_size": tile_pil.size,
180
+ "original_image_size": image.size,
181
+ "method": "deepforest_preprocess"
182
+ }
183
+ tile_metadata.append(metadata)
184
+
185
+ if not tiles:
186
+ raise Exception("No tiles were created - check image dimensions and parameters")
187
+
188
+ # Check for empty or invalid tiles
189
+ valid_tiles = []
190
+ valid_metadata = []
191
+ for i, tile in enumerate(tiles):
192
+ if tile.size[0] > 0 and tile.size[1] > 0:
193
+ valid_tiles.append(tile)
194
+ valid_metadata.append(tile_metadata[i])
195
+ else:
196
+ print(f"Warning: Tile {i} has invalid size {tile.size}, skipping")
197
+
198
+ if not valid_tiles:
199
+ raise Exception("No valid tiles were created")
200
+
201
+ if len(valid_tiles) != len(tiles):
202
+ print(f"Filtered {len(tiles)} -> {len(valid_tiles)} valid tiles")
203
+ tiles = valid_tiles
204
+ tile_metadata = valid_metadata
205
+
206
+ print(f"Successfully created {len(tiles)} tiles for multi-image analysis using fallback method")
207
+ return tiles, tile_metadata
208
+
209
+ except Exception as e:
210
+ print(f"Error during image tiling: {e}")
211
+ raise e
tests/test_deepforest_tool.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from matplotlib import pyplot as plt
4
+
5
+ from deepforest_agent.conf.config import Config
6
+ from deepforest_agent.tools.deepforest_tool import DeepForestPredictor
7
+ from deepforest_agent.utils.image_utils import load_image_as_np_array
8
+
9
+ TEST_IMAGE_PATH_SMALL = "data/AWPE Pigeon Lake 2020 DJI_0005.JPG"
10
+ TEST_IMAGE_PATH_LARGE = "data/OSBS_029.tif"
11
+
12
+ deepforest_predictor = DeepForestPredictor()
13
+
14
+
15
+ def display_image_for_test(image_array: np.ndarray, title: str = "Test Image"):
16
+ """
17
+ Display an image using matplotlib for visual inspection during testing.
18
+
19
+ Args:
20
+ image_array: Image as numpy array
21
+ title: Title for the plot
22
+ """
23
+ plt.imshow(image_array)
24
+ plt.axis('off')
25
+ plt.title(title)
26
+ plt.show()
27
+
28
+
29
+ def test_deepforest_predict_objects_basic_detection_bird():
30
+ """Test basic bird detection with default parameters on a small image."""
31
+ image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
32
+ if image_array is None:
33
+ return
34
+
35
+ summary, annotated_image, detections_list = (
36
+ deepforest_predictor.predict_objects(
37
+ image_data_array=image_array,
38
+ model_names=["bird"]
39
+ )
40
+ )
41
+
42
+ assert "DeepForest detected" in summary or "No objects detected" in summary
43
+ assert ("bird" in summary or "No objects detected" in summary)
44
+ assert annotated_image is not None
45
+ assert isinstance(annotated_image, np.ndarray)
46
+ assert annotated_image.shape[:2] == image_array.shape[:2]
47
+
48
+ assert isinstance(detections_list, list)
49
+ if detections_list:
50
+ bird_labels_found = any(
51
+ detection["label"] == "bird" for detection in detections_list if 'label' in detection
52
+ )
53
+ assert bird_labels_found
54
+
55
+ display_image_for_test(annotated_image, "Bird Detection Test")
56
+
57
+
58
+ def test_deepforest_predict_objects_basic_detection_tree():
59
+ """Test basic tree detection with default parameters on a small image."""
60
+ image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
61
+ if image_array is None:
62
+ return
63
+
64
+ summary, annotated_image, detections_list = (
65
+ deepforest_predictor.predict_objects(
66
+ image_data_array=image_array,
67
+ model_names=["tree"]
68
+ )
69
+ )
70
+
71
+ assert "DeepForest detected" in summary or "No objects detected" in summary
72
+ assert "tree" in summary or "No objects detected" in summary
73
+ assert annotated_image is not None
74
+ assert isinstance(annotated_image, np.ndarray)
75
+ assert annotated_image.shape[:2] == image_array.shape[:2]
76
+
77
+ assert isinstance(detections_list, list)
78
+ if detections_list:
79
+ tree_labels_found = any(
80
+ detection["label"] == "tree" for detection in detections_list if 'label' in detection
81
+ )
82
+ assert tree_labels_found
83
+
84
+ display_image_for_test(annotated_image, "Tree Detection Test")
85
+
86
+
87
+ def test_deepforest_predict_objects_multiple_models():
88
+ """Test detection using multiple models simultaneously."""
89
+ image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
90
+ if image_array is None:
91
+ return
92
+
93
+ summary, annotated_image, detections_list = (
94
+ deepforest_predictor.predict_objects(
95
+ image_data_array=image_array,
96
+ model_names=["bird", "tree", "livestock"]
97
+ )
98
+ )
99
+
100
+ assert "DeepForest detected" in summary or "No objects detected" in summary
101
+ assert annotated_image is not None
102
+ assert isinstance(annotated_image, np.ndarray)
103
+ assert annotated_image.shape[:2] == image_array.shape[:2]
104
+
105
+ assert isinstance(detections_list, list)
106
+ if detections_list:
107
+ labels = {detection['label'] for detection in detections_list if 'label' in detection}
108
+ assert "bird" in labels or "tree" in labels or "livestock" in labels
109
+
110
+ display_image_for_test(annotated_image, "Multiple Models Test")
111
+
112
+
113
+ def test_deepforest_predict_objects_large_image_processing():
114
+ """Test processing of large images using tiled prediction."""
115
+ summary, annotated_image, detections_list = (
116
+ deepforest_predictor.predict_objects(
117
+ image_file_path=TEST_IMAGE_PATH_LARGE,
118
+ model_names=["tree"],
119
+ patch_size=Config.DEEPFOREST_DEFAULTS["patch_size"],
120
+ patch_overlap=Config.DEEPFOREST_DEFAULTS["patch_overlap"],
121
+ iou_threshold=Config.DEEPFOREST_DEFAULTS["iou_threshold"],
122
+ thresh=Config.DEEPFOREST_DEFAULTS["thresh"]
123
+ )
124
+ )
125
+
126
+ assert "DeepForest detected" in summary or "No objects detected" in summary
127
+ assert annotated_image is not None
128
+ assert isinstance(annotated_image, np.ndarray)
129
+
130
+ assert isinstance(detections_list, list)
131
+ if detections_list:
132
+ assert any(detection['label'] == 'tree' for detection in detections_list if 'label' in detection)
133
+
134
+ display_image_for_test(annotated_image, "Large Image Processing Test")
135
+
136
+
137
+ def test_deepforest_predict_objects_custom_patch_size():
138
+ """Test detection with custom patch size parameter."""
139
+ image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
140
+ if image_array is None:
141
+ return
142
+
143
+ summary, annotated_image, detections_list = (
144
+ deepforest_predictor.predict_objects(
145
+ image_data_array=image_array,
146
+ model_names=["tree"],
147
+ patch_size=800,
148
+ patch_overlap=Config.DEEPFOREST_DEFAULTS["patch_overlap"],
149
+ iou_threshold=Config.DEEPFOREST_DEFAULTS["iou_threshold"],
150
+ thresh=Config.DEEPFOREST_DEFAULTS["thresh"]
151
+ )
152
+ )
153
+
154
+ assert "DeepForest detected" in summary or "No objects detected" in summary
155
+ assert annotated_image is not None
156
+ assert isinstance(annotated_image, np.ndarray)
157
+ assert annotated_image.shape[:2] == image_array.shape[:2]
158
+
159
+ assert isinstance(detections_list, list)
160
+ if detections_list:
161
+ assert any(detection['label'] == 'tree' for detection in detections_list if 'label' in detection)
162
+
163
+ display_image_for_test(annotated_image, "Custom Patch Size Test")
164
+
165
+
166
+ def test_deepforest_predict_objects_multiple_custom_parameters():
167
+ """Test detection with multiple custom parameters."""
168
+ image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
169
+ if image_array is None:
170
+ return
171
+
172
+ summary, annotated_image, detections_list = (
173
+ deepforest_predictor.predict_objects(
174
+ image_data_array=image_array,
175
+ model_names=["tree"],
176
+ patch_size=600,
177
+ patch_overlap=0.1,
178
+ iou_threshold=0.3,
179
+ thresh=0.3
180
+ )
181
+ )
182
+
183
+ assert "DeepForest detected" in summary or "No objects detected" in summary
184
+ assert annotated_image is not None
185
+ assert isinstance(annotated_image, np.ndarray)
186
+ assert annotated_image.shape[:2] == image_array.shape[:2]
187
+
188
+ assert isinstance(detections_list, list)
189
+ if detections_list:
190
+ assert any(detection['label'] == 'tree' for detection in detections_list if 'label' in detection)
191
+
192
+ display_image_for_test(annotated_image, "Multiple Custom Parameters Test")
193
+
194
+
195
+ def test_deepforest_predict_objects_alive_dead_trees():
196
+ """Test alive/dead tree classification detection."""
197
+ summary, annotated_image, detections_list = (
198
+ deepforest_predictor.predict_objects(
199
+ image_file_path=TEST_IMAGE_PATH_LARGE,
200
+ model_names=["tree"],
201
+ alive_dead_trees=True
202
+ )
203
+ )
204
+
205
+ assert "DeepForest detected" in summary or "No objects detected" in summary
206
+ assert annotated_image is not None
207
+ assert isinstance(annotated_image, np.ndarray)
208
+
209
+ print(summary)
210
+
211
+ assert isinstance(detections_list, list)
212
+ if detections_list:
213
+ tree_detections = [d for d in detections_list if d.get('label') == 'tree']
214
+ assert len(tree_detections) > 0, "Expected at least one tree detection"
215
+
216
+ # Check for classification_label field in tree detections
217
+ classification_labels = {d.get('classification_label') for d in tree_detections
218
+ if 'classification_label' in d}
219
+ assert ('alive_tree' in classification_labels or 'dead_tree' in classification_labels), \
220
+ f"Expected alive_tree or dead_tree in classification labels, got: {classification_labels}"
221
+
222
+ # Check that summary mentions classification results
223
+ assert (("alive" in summary and "tree" in summary) or
224
+ ("dead" in summary and "tree" in summary) or
225
+ ("No objects detected" in summary)), \
226
+ f"Summary should mention alive/dead classification: {summary}"
227
+
228
+ display_image_for_test(annotated_image, "Alive/Dead Tree Detection Test")
229
+
230
+
231
+ def test_deepforest_predict_objects_no_detections():
232
+ """Test the function gracefully handles cases with no detections."""
233
+ blank_image = np.zeros((100, 100, 3), dtype=np.uint8)
234
+
235
+ summary, annotated_image, detections_list = (
236
+ deepforest_predictor.predict_objects(
237
+ image_data_array=blank_image,
238
+ model_names=["tree"],
239
+ thresh=1.0
240
+ )
241
+ )
242
+
243
+ assert "No objects detected by DeepForest" in summary
244
+ assert annotated_image is not None
245
+ assert isinstance(annotated_image, np.ndarray)
246
+ assert annotated_image.shape[:2] == blank_image.shape[:2]
247
+
248
+ assert isinstance(detections_list, list)
249
+ assert len(detections_list) == 0
250
+
251
+ display_image_for_test(annotated_image, "No Detections Test")
252
+
253
+
254
+ def test_deepforest_predict_objects_custom_thresholds():
255
+ """Test detection with custom threshold parameters."""
256
+ image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
257
+ if image_array is None:
258
+ return
259
+
260
+ summary, annotated_image, detections_list = (
261
+ deepforest_predictor.predict_objects(
262
+ image_data_array=image_array,
263
+ model_names=["tree"],
264
+ thresh=0.9,
265
+ iou_threshold=0.5
266
+ )
267
+ )
268
+
269
+ assert ("DeepForest detected" in summary or
270
+ "No objects detected" in summary)
271
+ assert annotated_image is not None
272
+ assert isinstance(annotated_image, np.ndarray)
273
+ assert annotated_image.shape[:2] == image_array.shape[:2]
274
+
275
+ assert isinstance(detections_list, list)
276
+ if detections_list:
277
+ assert any(detection['label'] == 'tree' for detection in detections_list if 'label' in detection)
278
+
279
+ display_image_for_test(annotated_image, "Custom Thresholds Test")
280
+
281
+
282
+ def test_deepforest_predict_objects_unsupported_model_name():
283
+ """Test behavior with an unsupported model name."""
284
+ image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
285
+ if image_array is None:
286
+ return
287
+
288
+ summary, annotated_image, detections_list = (
289
+ deepforest_predictor.predict_objects(
290
+ image_data_array=image_array,
291
+ model_names=["tree", "nonexistent_model"]
292
+ )
293
+ )
294
+
295
+ assert ("DeepForest detected" in summary or
296
+ "No objects detected" in summary)
297
+ assert annotated_image is not None
298
+ assert isinstance(annotated_image, np.ndarray)
299
+ assert annotated_image.shape[:2] == image_array.shape[:2]
300
+
301
+ assert isinstance(detections_list, list)
302
+ if detections_list:
303
+ labels = {detection['label'] for detection in detections_list if 'label' in detection}
304
+ assert 'tree' in labels
305
+ assert 'nonexistent_model' not in labels
306
+
307
+ display_image_for_test(annotated_image, "Unsupported Model Test")
308
+
309
+
310
+ def test_plot_boxes_basic():
311
+ """Test _plot_boxes with some sample bounding box data."""
312
+ img = np.zeros((100, 100, 3), dtype=np.uint8) + 255
313
+ predictions = pd.DataFrame([
314
+ {'xmin': 10, 'ymin': 10, 'xmax': 30, 'ymax': 30,
315
+ 'label': 'bird', 'score': 0.9},
316
+ {'xmin': 50, 'ymin': 50, 'xmax': 70, 'ymax': 70,
317
+ 'label': 'tree', 'score': 0.8}
318
+ ])
319
+
320
+ annotated_img = DeepForestPredictor._plot_boxes(
321
+ img, predictions, Config.COLORS
322
+ )
323
+ assert annotated_img.shape == img.shape
324
+ assert not np.array_equal(annotated_img, img)
325
+
326
+ display_image_for_test(annotated_img, "Plot Boxes Basic Test")
327
+
328
+
329
+ def test_plot_boxes_empty_predictions():
330
+ """Test _plot_boxes with empty predictions DataFrame."""
331
+ img = np.zeros((100, 100, 3), dtype=np.uint8) + 255
332
+
333
+ predictions = pd.DataFrame({
334
+ "xmin": pd.Series(dtype=float),
335
+ "ymin": pd.Series(dtype=float),
336
+ "xmax": pd.Series(dtype=float),
337
+ "ymax": pd.Series(dtype=float),
338
+ "label": pd.Series(dtype=str),
339
+ "score": pd.Series(dtype=float)
340
+ })
341
+
342
+ annotated_img = DeepForestPredictor._plot_boxes(
343
+ img, predictions, Config.COLORS
344
+ )
345
+ assert np.array_equal(annotated_img, img)
346
+
347
+ display_image_for_test(annotated_img, "Empty Predictions Test")
348
+
349
+
350
+ def test_deepforest_predict_objects_default_parameters():
351
+ """Test that default parameters work correctly with tiled prediction."""
352
+ image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
353
+ if image_array is None:
354
+ return
355
+
356
+ summary, annotated_image, detections_list = (
357
+ deepforest_predictor.predict_objects(
358
+ image_data_array=image_array,
359
+ model_names=["tree"]
360
+ )
361
+ )
362
+
363
+ assert ("DeepForest detected" in summary or "No objects detected" in summary)
364
+ assert annotated_image is not None
365
+ assert isinstance(annotated_image, np.ndarray)
366
+ assert annotated_image.shape[:2] == image_array.shape[:2]
367
+
368
+ assert isinstance(detections_list, list)
369
+
370
+ print("Default parameters test completed successfully")
371
+ display_image_for_test(annotated_image, "Default Parameters Test")
372
+
373
+
374
+ def test_generate_detection_summary():
375
+ """Test the _generate_detection_summary method directly."""
376
+ # Test with empty DataFrame
377
+ empty_df = pd.DataFrame()
378
+ summary = deepforest_predictor._generate_detection_summary(empty_df)
379
+ assert "No objects detected" in summary
380
+
381
+ # Test with basic detections
382
+ predictions_df = pd.DataFrame([
383
+ {'label': 'tree', 'score': 0.9},
384
+ {'label': 'tree', 'score': 0.8},
385
+ {'label': 'bird', 'score': 0.7}
386
+ ])
387
+ summary = deepforest_predictor._generate_detection_summary(predictions_df)
388
+ assert "DeepForest detected" in summary
389
+ assert "2 trees" in summary
390
+ assert "1 bird" in summary
391
+
392
+ print("Detection summary tests completed successfully")
393
+
394
+
395
+ def test_detections_list_structure():
396
+ """Test that detections_list has the correct structure."""
397
+ image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
398
+ if image_array is None:
399
+ return
400
+
401
+ summary, annotated_image, detections_list = (
402
+ deepforest_predictor.predict_objects(
403
+ image_data_array=image_array,
404
+ model_names=["tree"]
405
+ )
406
+ )
407
+
408
+ assert isinstance(detections_list, list)
409
+
410
+ if detections_list:
411
+ for detection in detections_list:
412
+ assert isinstance(detection, dict)
413
+ assert 'xmin' in detection
414
+ assert 'ymin' in detection
415
+ assert 'xmax' in detection
416
+ assert 'ymax' in detection
417
+ assert 'score' in detection
418
+ assert 'label' in detection
419
+
420
+ assert isinstance(detection['xmin'], int)
421
+ assert isinstance(detection['ymin'], int)
422
+ assert isinstance(detection['xmax'], int)
423
+ assert isinstance(detection['ymax'], int)
424
+ assert isinstance(detection['score'], float)
425
+ assert isinstance(detection['label'], str)
426
+
427
+ print("Detections list structure test completed successfully")
428
+
429
+
430
+ def test_error_handling_invalid_model():
431
+ """Test error handling when all models are invalid."""
432
+ image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
433
+ if image_array is None:
434
+ return
435
+
436
+ summary, annotated_image, detections_list = (
437
+ deepforest_predictor.predict_objects(
438
+ image_data_array=image_array,
439
+ model_names=["invalid_model_1", "invalid_model_2"]
440
+ )
441
+ )
442
+
443
+ assert "No objects detected" in summary
444
+ assert annotated_image is not None
445
+ assert isinstance(annotated_image, np.ndarray)
446
+ assert isinstance(detections_list, list)
447
+ assert len(detections_list) == 0
448
+
449
+ print("Error handling test completed successfully")
450
+
451
+
452
+ def test_input_validation():
453
+ """Test input validation for the predict_objects method."""
454
+ # Test with neither image_data_array nor image_file_path provided
455
+ try:
456
+ deepforest_predictor.predict_objects(
457
+ image_data_array=None,
458
+ image_file_path=None,
459
+ model_names=["tree"]
460
+ )
461
+ assert False, "Should have raised ValueError"
462
+ except ValueError as e:
463
+ assert "Either image_data_array or image_file_path must be provided" in str(e)
464
+
465
+ print("Input validation test completed successfully")