Spaces:
No application file
No application file
Add Initial implementation of the deepforest-agent
#1
by
SamiaHaque
- opened
- LICENSE +21 -0
- README.md +100 -12
- app.py +501 -0
- pyproject.toml +66 -0
- requirements.txt +43 -0
- src/__init__.py +0 -0
- src/deepforest_agent/__init__.py +0 -0
- src/deepforest_agent/agents/__init__.py +0 -0
- src/deepforest_agent/agents/deepforest_detector_agent.py +403 -0
- src/deepforest_agent/agents/ecology_analysis_agent.py +92 -0
- src/deepforest_agent/agents/memory_agent.py +238 -0
- src/deepforest_agent/agents/orchestrator.py +795 -0
- src/deepforest_agent/agents/visual_analysis_agent.py +307 -0
- src/deepforest_agent/conf/__init__.py +0 -0
- src/deepforest_agent/conf/config.py +60 -0
- src/deepforest_agent/models/__init__.py +0 -0
- src/deepforest_agent/models/llama32_3b_instruct.py +242 -0
- src/deepforest_agent/models/qwen_vl_3b_instruct.py +152 -0
- src/deepforest_agent/models/smollm3_3b.py +244 -0
- src/deepforest_agent/prompts/__init__.py +0 -0
- src/deepforest_agent/prompts/prompt_templates.py +257 -0
- src/deepforest_agent/tools/__init__.py +0 -0
- src/deepforest_agent/tools/deepforest_tool.py +323 -0
- src/deepforest_agent/tools/tool_handler.py +188 -0
- src/deepforest_agent/utils/__init__.py +0 -0
- src/deepforest_agent/utils/cache_utils.py +306 -0
- src/deepforest_agent/utils/detection_narrative_generator.py +445 -0
- src/deepforest_agent/utils/image_utils.py +465 -0
- src/deepforest_agent/utils/logging_utils.py +449 -0
- src/deepforest_agent/utils/parsing_utils.py +238 -0
- src/deepforest_agent/utils/rtree_spatial_utils.py +394 -0
- src/deepforest_agent/utils/state_manager.py +574 -0
- src/deepforest_agent/utils/tile_manager.py +211 -0
- tests/test_deepforest_tool.py +465 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 DeepForest Agent
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,100 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DeepForest Multi-Agent System
|
2 |
+
|
3 |
+
The DeepForest Multi-Agent System provides ecological image analysis by orchestrating multiple AI agents that work together to understand ecological images. Simply upload an image of a forest, wildlife habitat, or ecological scene, and ask questions in natural language.
|
4 |
+
|
5 |
+
## Installation
|
6 |
+
|
7 |
+
### 1. Clone the repository
|
8 |
+
|
9 |
+
```bash
|
10 |
+
git clone https://github.com/weecology/deepforest-agent.git
|
11 |
+
cd deepforest-agent
|
12 |
+
```
|
13 |
+
|
14 |
+
### 2. Create and activate a Conda environment
|
15 |
+
|
16 |
+
```bash
|
17 |
+
conda create -n deepforest_agent python=3.12.11
|
18 |
+
conda activate deepforest_agent
|
19 |
+
```
|
20 |
+
|
21 |
+
### 3. Install dependencies
|
22 |
+
|
23 |
+
```bash
|
24 |
+
pip install -r requirements.txt
|
25 |
+
pip install -e .
|
26 |
+
```
|
27 |
+
|
28 |
+
### 4. Configure the HuggingFace Token
|
29 |
+
Create a `.env` file in the root directory of the deepforest-agent project and add your HuggingFace token like below:
|
30 |
+
|
31 |
+
```bash
|
32 |
+
HF_TOKEN="your_huggingface_token_here"
|
33 |
+
```
|
34 |
+
|
35 |
+
You can obtain your token from [HuggingFace Access Token](https://huggingface.co/settings/tokens). Make sure the Token type is "Write".
|
36 |
+
|
37 |
+
## Usage
|
38 |
+
|
39 |
+
The DeepForest Agent runs through a Gradio web interface. To start the interface, execute:
|
40 |
+
|
41 |
+
```bash
|
42 |
+
python app.py
|
43 |
+
```
|
44 |
+
|
45 |
+
A link like http://127.0.0.1:7860 will appear in the terminal. Open it in your browser to interact with the agent. A public Gradio link may also be provided if available.
|
46 |
+
|
47 |
+
**Sample Recording of Running the System:** [Drive Link](https://drive.google.com/file/d/1gNMn-xJd48Ld3TZU4oiYvTbiWaiLsc8G/view?usp=sharing)
|
48 |
+
|
49 |
+
|
50 |
+
### How to Use
|
51 |
+
|
52 |
+
1. Upload an ecological image (aerial/drone photography works best)
|
53 |
+
2. Ask questions about wildlife, forest health, or ecological patterns. For example:
|
54 |
+
- How many trees are detected, and how many of them are alive vs dead?
|
55 |
+
- How many birds are around each dead tree?
|
56 |
+
- What objects are in the northwest region of the image?
|
57 |
+
- Do any birds overlap with livestock in this image?
|
58 |
+
- What percentage of the image is covered by trees vs birds vs livestock?
|
59 |
+
3. Get comprehensive analysis combining computer vision and ecological insights. The gallery shows the annotated image with objects and the detection monitor presents the summary of DeepForest detection.
|
60 |
+
|
61 |
+
|
62 |
+
## Features
|
63 |
+
|
64 |
+
- **Multi-Species Detection**: Automatically detects trees, birds, and livestock using specialized DeepForest models
|
65 |
+
- **Tree Health Assessment**: Identifies alive and dead trees using DeepForest Tree Detector whenever user asks.
|
66 |
+
- **Visual Analysis**: Dual analysis of original and annotated images using Qwen2.5-VL-3B-Instruct model
|
67 |
+
- **Memory Context**: Maintains conversation history for contextual understanding across multiple queries
|
68 |
+
- **Tiling Image for Visual Agent:** Larger images are tiled and processed individually for the visual agent.
|
69 |
+
- **R-Tree Spatial Indexing:** Stores DeepForest Results in an R-Tree spatial index structure and use spatial queries to retrieve relevant information and present it to the user.
|
70 |
+
- **Ecological Insights**: Synthesizes detection data with visual analysis and memory context for comprehensive ecological understanding
|
71 |
+
- **Streaming Responses**: Real-time updates as each agent processes your query
|
72 |
+
|
73 |
+
|
74 |
+
## Requirements
|
75 |
+
|
76 |
+
### Hardware Requirements
|
77 |
+
- **GPU**: GPU with at least 24GB VRAM (recommended for optimal performance). The system is optimized for GPU execution. Running on CPU will take significantly longer processing times
|
78 |
+
- **Storage**: At least 35GB free space for model downloads
|
79 |
+
|
80 |
+
### API Requirements
|
81 |
+
- **HuggingFace Token**: Required for model access.
|
82 |
+
|
83 |
+
|
84 |
+
## Image Processing Times
|
85 |
+
|
86 |
+
- **Standard Images**: Most ecological images process within 30 seconds on GPU
|
87 |
+
- **Large GeoTIFF Files**: Larger geospatial images may require significant time for complete analysis
|
88 |
+
|
89 |
+
|
90 |
+
## Models Used
|
91 |
+
|
92 |
+
- **SmolLM3-3B**: For Memory Agent to get context, and for Detector Agent to call the tool with appropriate parameters
|
93 |
+
- **Qwen2.5-VL-3B-Instruct**: Used in Visual agent for multimodal image-text understanding
|
94 |
+
- **Llama-3.2-3B-Instruct**: For Ecology agents for text understanding and generation
|
95 |
+
- **DeepForest Models**: For tree, bird, and livestock detection. Also used for alive/dead tree classification.
|
96 |
+
|
97 |
+
|
98 |
+
## Multi-Agent Workflow
|
99 |
+
|
100 |
+
[](https://mermaid.live/edit#pako:eNqlV9ty4jgQ_RWVt2pexmQI9_hhtwiQK4SLgVyceVBMO7giLEq2kzAk_75tSTYiM9k8LA8UpruPuk93H9tby-cLsBzrUdD1kky79xHBT9ubxSDIbM04XcTkfEUf4Scplf4mx15HAE2AuBDHIY_IN3IehUlIWfgL_0zQ9FNhHEv_jkJyIUKccQpio80dae56Q-EvIU4ETbhwyCBlSVhqP0KUkGsungLGXzJUkegw9d2VwT1vACsuNkT6O8RdcdYfVEvVY-3ck24nW-12RmPSDQX4CWlH8QuIf95N0JPM--0W4jdy6vVeMSV0nHLOSIdijuS8q2FPJeyZN4FEhPAMyr4gXUgQOyPligosKHzOuTiTEedez-eMPxYJ9xldUVI9qGDKeWbuJkqQkDDWoecy9M47CSPKyATiNY9iIC9hsiS6rg6PEnjdZ0gVc8XfyIU3D-MUY9sIsEHg_PTxC0SleX9H14U86nJ7kjKmer6LcVPfx44HKdsn7XJHWt9zw-iRwcfQDl-tGRRzcakzIyUyHA5ITwgu3sjAO6GMPVD_ySHTkEHpmMZIaQ6iYwcyw6t8BtVBmXusCBnRxF8SF0dRB1zJgKEncXBAe9gpGYATuabYI2D5QA6l68jDdB_CCPNHEqQnco5Tmacwkm59k4O-_GuM8xBzlsoB6CzBfyIBFzgUsD7hAkccOQwT-hCyMMnPHMvIyVYVMsYuoY2cco6VX3WJAahiGeyzPym67HruU7g2TyumUZ_lyrOmXp8_Euk7ARrjLGnzVJpnelhKw4htdi1EXpc_fz9Ytn3u_XYolv3ZSs7lMdfeKUSQ0Z8vGJItO6iSwjnS_tfS_2a7c1PLts_DTZ4ODpVa1rMweSO3v63ofi9vdqOoogZhjBW127j-4KeY3X_wqZWyrWTx2Jukkek-QF1lsUMSAfDjIRSLHwz1IE64_5QLpFbIzo6MnYK46WpFcbe_4fpwscDlTyBPu6O1s-u5SHUxoCSMDLnSvl0llbdmzrdq0kfepJRlR9w1zQQchXwBr0g97kaSrsn3Pwka0blyke-DWojxOF8c5w9VfC_OmACjmSlehuu8nrFeg8mOiEwzBCwhirMzPxdW9T2T8a7rjUS21R_D9_VxMtHeei30Xky9fTXFnLVu7v74Ko-pXqLZTug_aO6e4rvIPl3tZn2m6pjPvfwm8EvJkM4g63DCyf6dIN8rvVjXnkLd3Smm_Aki8rBRP_K10nt1o0domoqoKDSTravsh2bEvG3rxblRU3XrzdYL82lAPgDIoY2eQcSy1biLOPYFwq0av7s7rxvGa0Y3xfx-R7oiniEslLTHxuAMUBV2U6e-7-4UlPlfnGxQs9skCBlz_oLDoB6AaelqS1CFelA3Lb3cEqCtbFoucrRWUIeWaZl_GoN7oU0-1MA3TdhnjVcOKsGhabrLgw6DFhyZdfmMxnEXArKSTVGPSObhNj5EYYezy6NWOb8svYSLZOlU1q8fYJ7lcJswqroCpubToP4lzEIL_v_OJ1Z9LhbGJK-AgqNDaFS_ggK1fHu1SaYLnHL5qNFqfYXDjUfTvakpcI78SvPhs9IMNNKzT83GmaYL-9Lu2wP7yh7aI7MtptPcvrbbbfv42O50bNT0PdpNx9HIHo9t1LgPfJo-s5k9R8DrPaJMh67tuvZ0at_c2LitJg2WjW8K4cJyAspisK0ViBXNrq1tBnBvoWyt4N5y8GcEKQaxe-s-ese4NY3uOF9ZTiJSjBQ8fVzmF6lUkW5I8TVkVYALfGkA0eFplFhOrXwoMSxna71azmGzfFArN8qNZr1Rb7RqRw3b2lhOqVKttA6alWat1qgf1hvVZu3dtn7JcxsH5VqzUa81Kq1mvdpoVGwLFpmmDNQrkHwTev8XjGAuiw)
|
app.py
ADDED
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
# This allows imports to work when app.py is in root but modules are in src/
|
9 |
+
current_dir = Path(__file__).parent.absolute()
|
10 |
+
src_dir = current_dir / "src"
|
11 |
+
|
12 |
+
if not src_dir.exists():
|
13 |
+
raise RuntimeError(f"Source directory not found: {src_dir}")
|
14 |
+
|
15 |
+
# Add to Python path if not already there
|
16 |
+
if str(src_dir) not in sys.path:
|
17 |
+
sys.path.insert(0, str(src_dir))
|
18 |
+
|
19 |
+
print(f"App running from: {current_dir}")
|
20 |
+
print(f"Source directory: {src_dir}")
|
21 |
+
print(f"Python path includes src: {str(src_dir) in sys.path}")
|
22 |
+
|
23 |
+
from deepforest_agent.agents.orchestrator import AgentOrchestrator
|
24 |
+
from deepforest_agent.utils.state_manager import session_state_manager
|
25 |
+
from deepforest_agent.utils.image_utils import (
|
26 |
+
encode_pil_image_to_base64_url,
|
27 |
+
load_pil_image_from_path,
|
28 |
+
get_image_info,
|
29 |
+
validate_image_path
|
30 |
+
)
|
31 |
+
from deepforest_agent.utils.logging_utils import multi_agent_logger
|
32 |
+
|
33 |
+
|
34 |
+
def upload_image(image_path):
|
35 |
+
"""
|
36 |
+
Handle image upload and initialize a new session for the multi-agent workflow.
|
37 |
+
|
38 |
+
This function is triggered when a user uploads an image. It creates a new
|
39 |
+
session with isolated state and updates the UI to show the chat interface
|
40 |
+
and monitoring components.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
image_path (str or None): The file path to uploaded image from Gradio
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
tuple: A tuple containing 9 Gradio component updates:
|
47 |
+
- gr.Chatbot: Chat interface (visible/hidden)
|
48 |
+
- image: Uploaded image state
|
49 |
+
- str: Upload status message
|
50 |
+
- gr.Textbox: Message input field (visible/hidden)
|
51 |
+
- gr.Button: Send button (visible/hidden)
|
52 |
+
- gr.Button: Clear button (visible/hidden)
|
53 |
+
- gr.Gallery: Generated images gallery (visible/hidden)
|
54 |
+
- str: Monitor text with session information
|
55 |
+
- str: Session ID for this user
|
56 |
+
"""
|
57 |
+
if image_path is None:
|
58 |
+
return (
|
59 |
+
gr.Chatbot(visible=False),
|
60 |
+
None, # uploaded_image_state
|
61 |
+
"No image uploaded",
|
62 |
+
gr.Textbox(visible=False),
|
63 |
+
gr.Button(visible=False), # send_btn
|
64 |
+
gr.Button(visible=False), # clear_btn
|
65 |
+
gr.Gallery(visible=False),
|
66 |
+
"No image uploaded",
|
67 |
+
None # session_id
|
68 |
+
)
|
69 |
+
|
70 |
+
if not validate_image_path(image_path):
|
71 |
+
return (
|
72 |
+
gr.Chatbot(visible=False),
|
73 |
+
None,
|
74 |
+
"Invalid image file or path not accessible",
|
75 |
+
gr.Textbox(visible=False),
|
76 |
+
gr.Button(visible=False),
|
77 |
+
gr.Button(visible=False),
|
78 |
+
gr.Gallery(visible=False),
|
79 |
+
"Invalid image file for analysis.",
|
80 |
+
None
|
81 |
+
)
|
82 |
+
|
83 |
+
try:
|
84 |
+
pil_image = load_pil_image_from_path(image_path)
|
85 |
+
if pil_image is None:
|
86 |
+
raise Exception("Failed to load image")
|
87 |
+
image_info = get_image_info(image_path)
|
88 |
+
except Exception as e:
|
89 |
+
return (
|
90 |
+
gr.Chatbot(visible=False),
|
91 |
+
None,
|
92 |
+
f"Error loading image: {str(e)}",
|
93 |
+
gr.Textbox(visible=False),
|
94 |
+
gr.Button(visible=False),
|
95 |
+
gr.Button(visible=False),
|
96 |
+
gr.Gallery(visible=False),
|
97 |
+
"Error loading image for analysis.",
|
98 |
+
None
|
99 |
+
)
|
100 |
+
|
101 |
+
# Create new session for this user
|
102 |
+
session_id = session_state_manager.create_session(pil_image)
|
103 |
+
session_state_manager.set(session_id, "image_file_path", image_path)
|
104 |
+
|
105 |
+
detection_monitor = ""
|
106 |
+
|
107 |
+
multi_agent_logger.log_session_event(
|
108 |
+
session_id=session_id,
|
109 |
+
event_type="session_created",
|
110 |
+
details={
|
111 |
+
"image_size": image_info.get("size") if image_info else pil_image.size,
|
112 |
+
"image_mode": image_info.get("mode") if image_info else pil_image.mode,
|
113 |
+
"image_path": image_path,
|
114 |
+
"file_size_bytes": image_info.get("file_size_bytes") if image_info else "unknown"
|
115 |
+
}
|
116 |
+
)
|
117 |
+
|
118 |
+
return (
|
119 |
+
gr.Chatbot(visible=True, value=[]),
|
120 |
+
pil_image,
|
121 |
+
f"Image uploaded successfully! Size: {pil_image.size}",
|
122 |
+
gr.Textbox(visible=True),
|
123 |
+
gr.Button(visible=True), # send_btn
|
124 |
+
gr.Button(visible=True), # clear_btn
|
125 |
+
gr.Gallery(visible=True, value=[]),
|
126 |
+
detection_monitor,
|
127 |
+
session_id # Return session ID
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
def process_message_streaming(user_message, chatbot_history, generated_images, detection_monitor, session_id):
|
132 |
+
"""
|
133 |
+
Process user message through the multi-agent workflow with streaming updates.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
user_message (str): The user's input message
|
137 |
+
chatbot_history (list): Current chat history for display
|
138 |
+
generated_images (list): List of annotated images in PIL Image objects
|
139 |
+
detection_monitor (str): Current detection data monitoring text
|
140 |
+
session_id (str): Unique session identifier for this user
|
141 |
+
|
142 |
+
Yields:
|
143 |
+
tuple: A tuple containing 6 updated components:
|
144 |
+
- chatbot_history: Updated conversation history
|
145 |
+
- msg_input_clear: Empty string to clear message input field
|
146 |
+
- generated_images: Updated list of annotated images
|
147 |
+
- detection_monitor: Updated detection data monitor
|
148 |
+
- send_btn: Button component with interactive state
|
149 |
+
- msg_input: Input field component with interactive state
|
150 |
+
"""
|
151 |
+
if not user_message.strip():
|
152 |
+
yield chatbot_history, "", generated_images, detection_monitor, gr.Button(interactive=True), gr.Textbox(interactive=True)
|
153 |
+
return
|
154 |
+
|
155 |
+
# Check if session exists
|
156 |
+
if session_id is None or not session_state_manager.session_exists(session_id):
|
157 |
+
error_msg = "Session expired or invalid. Please upload an image to start a new session."
|
158 |
+
chatbot_history.append({"role": "user", "content": user_message})
|
159 |
+
chatbot_history.append({"role": "assistant", "content": error_msg})
|
160 |
+
yield chatbot_history, "", generated_images, detection_monitor, gr.Button(interactive=True), gr.Textbox(interactive=True)
|
161 |
+
return
|
162 |
+
|
163 |
+
# Check if image is available in session
|
164 |
+
current_image = session_state_manager.get(session_id, "current_image")
|
165 |
+
if current_image is None:
|
166 |
+
error_msg = "No image found in your session. Please upload an image first."
|
167 |
+
chatbot_history.append({"role": "user", "content": user_message})
|
168 |
+
chatbot_history.append({"role": "assistant", "content": error_msg})
|
169 |
+
yield chatbot_history, "", generated_images, detection_monitor, gr.Button(interactive=True), gr.Textbox(interactive=True)
|
170 |
+
return
|
171 |
+
|
172 |
+
total_execution_start = time.perf_counter()
|
173 |
+
|
174 |
+
multi_agent_logger.log_user_query(
|
175 |
+
session_id=session_id,
|
176 |
+
user_message=user_message
|
177 |
+
)
|
178 |
+
|
179 |
+
try:
|
180 |
+
if session_state_manager.get(session_id, "first_message", True):
|
181 |
+
image_base64_url = encode_pil_image_to_base64_url(current_image)
|
182 |
+
user_msg = {
|
183 |
+
"role": "user",
|
184 |
+
"content": [
|
185 |
+
{"type": "image", "image": image_base64_url},
|
186 |
+
{"type": "text", "text": user_message}
|
187 |
+
]
|
188 |
+
}
|
189 |
+
session_state_manager.set(session_id, "first_message", False)
|
190 |
+
else:
|
191 |
+
user_msg = {
|
192 |
+
"role": "user",
|
193 |
+
"content": [
|
194 |
+
{"type": "text", "text": user_message}
|
195 |
+
]
|
196 |
+
}
|
197 |
+
|
198 |
+
session_state_manager.add_to_conversation(session_id, user_msg)
|
199 |
+
chatbot_history.append({"role": "user", "content": user_message})
|
200 |
+
|
201 |
+
chatbot_history.append({"role": "assistant", "content": "Starting analysis..."})
|
202 |
+
|
203 |
+
yield chatbot_history, "", generated_images, detection_monitor, gr.Button(interactive=False), gr.Textbox(interactive=False)
|
204 |
+
|
205 |
+
conversation_history = session_state_manager.get(session_id, "conversation_history", [])
|
206 |
+
|
207 |
+
print(f"Session {session_id} - User message: {user_message}")
|
208 |
+
|
209 |
+
orchestrator = AgentOrchestrator()
|
210 |
+
|
211 |
+
start_time = time.perf_counter()
|
212 |
+
|
213 |
+
try:
|
214 |
+
# Process with streaming updates
|
215 |
+
final_result = None
|
216 |
+
|
217 |
+
for result in orchestrator.process_user_message_streaming(
|
218 |
+
user_message=user_message,
|
219 |
+
conversation_history=conversation_history,
|
220 |
+
session_id=session_id
|
221 |
+
):
|
222 |
+
if result["type"] == "progress":
|
223 |
+
chatbot_history[-1] = {"role": "assistant", "content": result["message"]}
|
224 |
+
|
225 |
+
yield chatbot_history, "", generated_images, detection_monitor, gr.Button(interactive=False), gr.Textbox(interactive=False)
|
226 |
+
|
227 |
+
elif result["type"] == "memory_direct":
|
228 |
+
final_response = result["message"]
|
229 |
+
chatbot_history[-1] = {"role": "assistant", "content": final_response}
|
230 |
+
|
231 |
+
updated_detection_monitor = result.get("detection_data", "")
|
232 |
+
|
233 |
+
final_result = result
|
234 |
+
|
235 |
+
yield chatbot_history, "", generated_images, updated_detection_monitor, gr.Button(interactive=True), gr.Textbox(interactive=True)
|
236 |
+
break
|
237 |
+
|
238 |
+
elif result["type"] == "streaming":
|
239 |
+
# Update the last message with streaming response
|
240 |
+
chatbot_history[-1] = {"role": "assistant", "content": result["message"]}
|
241 |
+
|
242 |
+
yield chatbot_history, "", generated_images, detection_monitor, gr.Button(interactive=False), gr.Textbox(interactive=False)
|
243 |
+
|
244 |
+
if result.get("is_complete", False):
|
245 |
+
final_response = result["message"]
|
246 |
+
|
247 |
+
elif result["type"] == "final":
|
248 |
+
final_response = result["message"]
|
249 |
+
chatbot_history[-1] = {"role": "assistant", "content": final_response}
|
250 |
+
|
251 |
+
final_result = result
|
252 |
+
break
|
253 |
+
|
254 |
+
if final_result:
|
255 |
+
total_execution_time = time.perf_counter() - total_execution_start
|
256 |
+
|
257 |
+
execution_summary = final_result.get("execution_summary", {})
|
258 |
+
agent_results = final_result.get("agent_results", {})
|
259 |
+
execution_time = final_result.get("execution_time", 0)
|
260 |
+
|
261 |
+
assistant_msg = {
|
262 |
+
"role": "assistant",
|
263 |
+
"content": [{"type": "text", "text": final_response}]
|
264 |
+
}
|
265 |
+
session_state_manager.add_to_conversation(session_id, assistant_msg)
|
266 |
+
|
267 |
+
multi_agent_logger.log_agent_execution(
|
268 |
+
session_id=session_id,
|
269 |
+
agent_name="ecology",
|
270 |
+
agent_input="Final synthesis of all agent outputs",
|
271 |
+
agent_output=final_response,
|
272 |
+
execution_time=total_execution_time
|
273 |
+
)
|
274 |
+
|
275 |
+
annotated_image = session_state_manager.get(session_id, "annotated_image")
|
276 |
+
if annotated_image:
|
277 |
+
generated_images.append(annotated_image)
|
278 |
+
|
279 |
+
updated_detection_monitor = final_result.get("detection_data", "")
|
280 |
+
|
281 |
+
yield chatbot_history, "", generated_images, updated_detection_monitor, gr.Button(interactive=True), gr.Textbox(interactive=True)
|
282 |
+
|
283 |
+
finally:
|
284 |
+
orchestrator.cleanup_all_agents()
|
285 |
+
|
286 |
+
except Exception as e:
|
287 |
+
total_execution_time = time.perf_counter() - total_execution_start
|
288 |
+
error_msg = f"Workflow error: {str(e)}"
|
289 |
+
print(f"MAIN APP ERROR (Session {session_id}): {error_msg}")
|
290 |
+
|
291 |
+
multi_agent_logger.log_error(
|
292 |
+
session_id=session_id,
|
293 |
+
error_type="app_workflow_error",
|
294 |
+
error_message=f"Workflow failed after {total_execution_time:.2f}s: {str(e)}"
|
295 |
+
)
|
296 |
+
|
297 |
+
if chatbot_history and chatbot_history[-1]["role"] == "assistant":
|
298 |
+
chatbot_history[-1] = {"role": "assistant", "content": error_msg}
|
299 |
+
else:
|
300 |
+
chatbot_history.append({"role": "assistant", "content": error_msg})
|
301 |
+
|
302 |
+
error_detection_monitor = "ERROR: Workflow failed - no detection data available"
|
303 |
+
|
304 |
+
yield chatbot_history, "", generated_images, error_detection_monitor, gr.Button(interactive=True), gr.Textbox(interactive=True)
|
305 |
+
|
306 |
+
def clear_chat(session_id):
|
307 |
+
"""
|
308 |
+
Clear chat history and cancel any ongoing processing for the session.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
session_id (str): The session identifier to clear. Must correspond to
|
312 |
+
an existing active session.
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
tuple: A tuple containing 5 updated components:
|
316 |
+
- chatbot_history: Empty list clearing chat display
|
317 |
+
- generated_images: Empty list clearing image gallery
|
318 |
+
- monitor_message: Status message indicating successful clear
|
319 |
+
operation and session ID
|
320 |
+
- send_btn: Re-enabled send button component
|
321 |
+
- msg_input: Re-enabled message input component
|
322 |
+
|
323 |
+
"""
|
324 |
+
if session_id and session_state_manager.session_exists(session_id):
|
325 |
+
session_state_manager.cancel_session(session_id)
|
326 |
+
session_state_manager.clear_conversation(session_id)
|
327 |
+
|
328 |
+
multi_agent_logger.log_session_event(
|
329 |
+
session_id=session_id,
|
330 |
+
event_type="conversation_cleared"
|
331 |
+
)
|
332 |
+
|
333 |
+
return (
|
334 |
+
[], # chatbot
|
335 |
+
[], # generated_images
|
336 |
+
"",
|
337 |
+
gr.Button(interactive=True), # Re-enable send button
|
338 |
+
gr.Textbox(interactive=True) # Re-enable message input
|
339 |
+
)
|
340 |
+
else:
|
341 |
+
return (
|
342 |
+
[], # chatbot
|
343 |
+
[], # generated_images
|
344 |
+
"",
|
345 |
+
gr.Button(interactive=True), # Re-enable send button
|
346 |
+
gr.Textbox(interactive=True) # Re-enable message input
|
347 |
+
)
|
348 |
+
|
349 |
+
|
350 |
+
def create_interface():
|
351 |
+
"""
|
352 |
+
Create and configure the complete Gradio web interface with streaming support.
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
gr.Blocks: Complete Gradio application interface
|
356 |
+
"""
|
357 |
+
|
358 |
+
with gr.Blocks(
|
359 |
+
title="DeepForest Multi-Agent System",
|
360 |
+
theme=gr.themes.Default(
|
361 |
+
spacing_size=gr.themes.sizes.spacing_sm,
|
362 |
+
radius_size=gr.themes.sizes.radius_none,
|
363 |
+
primary_hue=gr.themes.colors.emerald,
|
364 |
+
secondary_hue=gr.themes.colors.lime
|
365 |
+
)
|
366 |
+
) as app:
|
367 |
+
|
368 |
+
# Gradio State variables
|
369 |
+
uploaded_image_state = gr.State(None)
|
370 |
+
generated_images_state = gr.State([])
|
371 |
+
session_id_state = gr.State(None)
|
372 |
+
|
373 |
+
gr.Markdown("# DeepForest Multi-Agent System")
|
374 |
+
gr.Markdown("*DeepForest with SmolLM3-3B + Qwen-VL-3B-Instruct + Llama 3.2-3B-Instruct*")
|
375 |
+
|
376 |
+
with gr.Row():
|
377 |
+
# Left column
|
378 |
+
with gr.Column(scale=1):
|
379 |
+
image_upload = gr.Image(
|
380 |
+
type="filepath",
|
381 |
+
label="Upload Ecological Image",
|
382 |
+
height=300
|
383 |
+
)
|
384 |
+
upload_status = gr.Textbox(
|
385 |
+
label="Upload Status",
|
386 |
+
value="Upload an image to begin analysis",
|
387 |
+
interactive=False
|
388 |
+
)
|
389 |
+
|
390 |
+
# Right column
|
391 |
+
with gr.Column(scale=2):
|
392 |
+
chatbot = gr.Chatbot(
|
393 |
+
label="Multi-Agent Ecological Analysis",
|
394 |
+
height=400,
|
395 |
+
visible=False,
|
396 |
+
show_copy_button=True,
|
397 |
+
type='messages'
|
398 |
+
)
|
399 |
+
|
400 |
+
with gr.Row():
|
401 |
+
msg_input = gr.Textbox(
|
402 |
+
placeholder="Ask about wildlife, forest health, ecological patterns...",
|
403 |
+
scale=4,
|
404 |
+
visible=False
|
405 |
+
)
|
406 |
+
send_btn = gr.Button("Analyze", scale=1, visible=False, variant="primary")
|
407 |
+
clear_btn = gr.Button("Clear", scale=1, visible=False)
|
408 |
+
|
409 |
+
with gr.Row():
|
410 |
+
generated_images_display = gr.Gallery(
|
411 |
+
label="Annotated Images after DeepForest Detection",
|
412 |
+
columns=2,
|
413 |
+
height=400,
|
414 |
+
visible=False,
|
415 |
+
show_label=True
|
416 |
+
)
|
417 |
+
|
418 |
+
with gr.Row():
|
419 |
+
with gr.Column():
|
420 |
+
gr.Markdown("### Detection Data Monitor")
|
421 |
+
|
422 |
+
detection_data_monitor = gr.Textbox(
|
423 |
+
label="Detection Data Monitor",
|
424 |
+
value="Upload an image and ask a question to see detection data",
|
425 |
+
interactive=False,
|
426 |
+
show_copy_button=True
|
427 |
+
)
|
428 |
+
|
429 |
+
with gr.Row(visible=False) as example_row:
|
430 |
+
gr.Markdown("""
|
431 |
+
**Multi-agent test questions:**
|
432 |
+
- How many trees are detected, and how many of them are alive vs dead?
|
433 |
+
- How many birds are around each dead tree?
|
434 |
+
- What objects are in the northwest region of the image?
|
435 |
+
- Do any birds overlap with livestock in this image?
|
436 |
+
- What percentage of the image is covered by trees vs birds vs livestock?
|
437 |
+
""")
|
438 |
+
|
439 |
+
# Image upload
|
440 |
+
image_upload.change(
|
441 |
+
fn=upload_image,
|
442 |
+
inputs=[image_upload],
|
443 |
+
outputs=[
|
444 |
+
chatbot,
|
445 |
+
uploaded_image_state,
|
446 |
+
upload_status,
|
447 |
+
msg_input,
|
448 |
+
send_btn,
|
449 |
+
clear_btn,
|
450 |
+
generated_images_display,
|
451 |
+
detection_data_monitor,
|
452 |
+
session_id_state
|
453 |
+
]
|
454 |
+
).then(
|
455 |
+
fn=lambda: gr.Row(visible=True),
|
456 |
+
outputs=[example_row]
|
457 |
+
)
|
458 |
+
|
459 |
+
# Send button with streaming
|
460 |
+
send_btn.click(
|
461 |
+
fn=process_message_streaming,
|
462 |
+
inputs=[msg_input, chatbot, generated_images_state, detection_data_monitor, session_id_state],
|
463 |
+
outputs=[chatbot, msg_input, generated_images_state, detection_data_monitor, send_btn, msg_input]
|
464 |
+
).then(
|
465 |
+
fn=lambda images: images,
|
466 |
+
inputs=[generated_images_state],
|
467 |
+
outputs=[generated_images_display]
|
468 |
+
)
|
469 |
+
|
470 |
+
# Enter key with streaming
|
471 |
+
msg_input.submit(
|
472 |
+
fn=process_message_streaming,
|
473 |
+
inputs=[msg_input, chatbot, generated_images_state, detection_data_monitor, session_id_state],
|
474 |
+
outputs=[chatbot, msg_input, generated_images_state, detection_data_monitor, send_btn, msg_input]
|
475 |
+
).then(
|
476 |
+
fn=lambda images: images,
|
477 |
+
inputs=[generated_images_state],
|
478 |
+
outputs=[generated_images_display]
|
479 |
+
)
|
480 |
+
|
481 |
+
clear_btn.click(
|
482 |
+
fn=clear_chat,
|
483 |
+
inputs=[session_id_state],
|
484 |
+
outputs=[chatbot, generated_images_state, detection_data_monitor, send_btn, msg_input]
|
485 |
+
).then(
|
486 |
+
fn=lambda: [],
|
487 |
+
outputs=[generated_images_display]
|
488 |
+
)
|
489 |
+
|
490 |
+
return app
|
491 |
+
|
492 |
+
|
493 |
+
app = create_interface()
|
494 |
+
|
495 |
+
if __name__ == "__main__":
|
496 |
+
app.launch(
|
497 |
+
share=True,
|
498 |
+
debug=True,
|
499 |
+
show_error=True,
|
500 |
+
max_threads=3
|
501 |
+
)
|
pyproject.toml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "deepforest_agent"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "AI Agent for DeepForest object detection"
|
5 |
+
authors = [
|
6 |
+
{name = "Your Name", email = "you@example.com"}
|
7 |
+
]
|
8 |
+
requires-python = ">=3.12"
|
9 |
+
readme = "README.md"
|
10 |
+
dependencies = [
|
11 |
+
"accelerate",
|
12 |
+
"albumentations<2.0",
|
13 |
+
"deepforest",
|
14 |
+
"fastapi",
|
15 |
+
"geopandas",
|
16 |
+
"google-genai",
|
17 |
+
"google-generativeai",
|
18 |
+
"gradio",
|
19 |
+
"gradio-image-annotation",
|
20 |
+
"langchain",
|
21 |
+
"langchain-community",
|
22 |
+
"langchain-google-genai",
|
23 |
+
"langchain-huggingface",
|
24 |
+
"langgraph",
|
25 |
+
"matplotlib",
|
26 |
+
"numpy",
|
27 |
+
"rtree",
|
28 |
+
"num2words",
|
29 |
+
"openai",
|
30 |
+
"opencv-python",
|
31 |
+
"outlines",
|
32 |
+
"pandas",
|
33 |
+
"pillow",
|
34 |
+
"scikit-learn",
|
35 |
+
"plotly",
|
36 |
+
"pydantic",
|
37 |
+
"pydantic-settings",
|
38 |
+
"pytest",
|
39 |
+
"pytest-cov",
|
40 |
+
"python-dotenv",
|
41 |
+
"pyyaml",
|
42 |
+
"qwen-vl-utils",
|
43 |
+
"rasterio",
|
44 |
+
"requests",
|
45 |
+
"scikit-image",
|
46 |
+
"seaborn",
|
47 |
+
"shapely",
|
48 |
+
"streamlit",
|
49 |
+
"torch",
|
50 |
+
"torchvision",
|
51 |
+
"tqdm",
|
52 |
+
"transformers",
|
53 |
+
"bitsandbytes",
|
54 |
+
]
|
55 |
+
|
56 |
+
[project.optional-dependencies]
|
57 |
+
dev = [
|
58 |
+
"pre-commit",
|
59 |
+
"pytest",
|
60 |
+
"pytest-profiling",
|
61 |
+
"yapf"
|
62 |
+
]
|
63 |
+
|
64 |
+
[build-system]
|
65 |
+
requires = ["setuptools>=61.0"]
|
66 |
+
build-backend = "setuptools.build_meta"
|
requirements.txt
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate
|
2 |
+
albumentations<2.0
|
3 |
+
deepforest
|
4 |
+
fastapi
|
5 |
+
geopandas
|
6 |
+
google-genai
|
7 |
+
google-generativeai
|
8 |
+
gradio
|
9 |
+
gradio-image-annotation
|
10 |
+
langchain
|
11 |
+
langchain-community
|
12 |
+
langchain-google-genai
|
13 |
+
langchain-huggingface
|
14 |
+
langgraph
|
15 |
+
matplotlib
|
16 |
+
numpy
|
17 |
+
rtree
|
18 |
+
num2words
|
19 |
+
openai
|
20 |
+
opencv-python
|
21 |
+
outlines
|
22 |
+
pandas
|
23 |
+
scikit-learn
|
24 |
+
pillow
|
25 |
+
plotly
|
26 |
+
pydantic
|
27 |
+
pydantic-settings
|
28 |
+
pytest
|
29 |
+
pytest-cov
|
30 |
+
python-dotenv
|
31 |
+
pyyaml
|
32 |
+
qwen-vl-utils
|
33 |
+
rasterio
|
34 |
+
requests
|
35 |
+
scikit-image
|
36 |
+
seaborn
|
37 |
+
shapely
|
38 |
+
streamlit
|
39 |
+
torch
|
40 |
+
torchvision
|
41 |
+
tqdm
|
42 |
+
transformers
|
43 |
+
bitsandbytes
|
src/__init__.py
ADDED
File without changes
|
src/deepforest_agent/__init__.py
ADDED
File without changes
|
src/deepforest_agent/agents/__init__.py
ADDED
File without changes
|
src/deepforest_agent/agents/deepforest_detector_agent.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any, Optional
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
|
6 |
+
from deepforest_agent.utils.cache_utils import tool_call_cache
|
7 |
+
from deepforest_agent.models.smollm3_3b import SmolLM3ModelManager
|
8 |
+
from deepforest_agent.tools.tool_handler import handle_tool_call, extract_all_tool_calls
|
9 |
+
from deepforest_agent.conf.config import Config
|
10 |
+
from deepforest_agent.prompts.prompt_templates import create_detector_system_prompt_with_reasoning, get_deepforest_tool_schema
|
11 |
+
from deepforest_agent.utils.state_manager import session_state_manager
|
12 |
+
from deepforest_agent.utils.logging_utils import multi_agent_logger
|
13 |
+
from deepforest_agent.utils.parsing_utils import parse_deepforest_agent_response_with_reasoning
|
14 |
+
from deepforest_agent.utils.rtree_spatial_utils import DetectionSpatialAnalyzer
|
15 |
+
from deepforest_agent.utils.detection_narrative_generator import DetectionNarrativeGenerator
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class DeepForestDetectorAgent:
|
20 |
+
"""
|
21 |
+
DeepForest detector agent responsible for executing object detection.
|
22 |
+
Uses SmolLM3-3B model for tool calling.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self):
|
26 |
+
"""Initialize the DeepForest Detector Agent."""
|
27 |
+
self.agent_config = Config.AGENT_CONFIGS["deepforest_detector"]
|
28 |
+
self.model_manager = SmolLM3ModelManager(Config.AGENT_MODELS["deepforest_detector"])
|
29 |
+
|
30 |
+
def _filter_models_based_on_visual(self, visual_objects: List[str], original_models: List[str]):
|
31 |
+
"""
|
32 |
+
Filter original model names based on visual agent's detected objects.
|
33 |
+
Remove models that weren't visually detected.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
visual_objects (List[str]): Objects detected by visual agent
|
37 |
+
original_models (List[str]): Original model list from tool call
|
38 |
+
"""
|
39 |
+
pass
|
40 |
+
|
41 |
+
def execute_detection_with_context(
|
42 |
+
self,
|
43 |
+
user_message: str,
|
44 |
+
session_id: str,
|
45 |
+
visual_objects_detected: List[str],
|
46 |
+
memory_context: str
|
47 |
+
) -> Dict[str, Any]:
|
48 |
+
"""
|
49 |
+
Execute DeepForest detection with R-tree spatial analysis and narrative generation.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
user_message (str): User's query
|
53 |
+
session_id (str): Unique session identifier for this user
|
54 |
+
visual_objects_detected (List[str]): Objects detected by visual agent
|
55 |
+
memory_context (str): Context from memory agent
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
Dictionary with detection results, R-tree analysis, and narrative
|
59 |
+
"""
|
60 |
+
# Validate session exists
|
61 |
+
if not session_state_manager.session_exists(session_id):
|
62 |
+
return {
|
63 |
+
"detection_summary": f"Session {session_id} not found.",
|
64 |
+
"detections_list": [],
|
65 |
+
"total_detections": 0,
|
66 |
+
"status": "error",
|
67 |
+
"error": f"Session {session_id} not found",
|
68 |
+
"detection_narrative": "No detection narrative available due to session error."
|
69 |
+
}
|
70 |
+
|
71 |
+
try:
|
72 |
+
tool_generation_start = time.perf_counter()
|
73 |
+
|
74 |
+
system_prompt = create_detector_system_prompt_with_reasoning(
|
75 |
+
user_message, memory_context, visual_objects_detected
|
76 |
+
)
|
77 |
+
|
78 |
+
messages = [
|
79 |
+
{"role": "system", "content": system_prompt},
|
80 |
+
{"role": "user", "content": user_message}
|
81 |
+
]
|
82 |
+
|
83 |
+
deepforest_tool_schema = get_deepforest_tool_schema()
|
84 |
+
|
85 |
+
response = self.model_manager.generate_response(
|
86 |
+
messages=messages,
|
87 |
+
max_new_tokens=self.agent_config["max_new_tokens"],
|
88 |
+
temperature=self.agent_config["temperature"],
|
89 |
+
top_p=self.agent_config["top_p"],
|
90 |
+
tools=[deepforest_tool_schema]
|
91 |
+
)
|
92 |
+
|
93 |
+
tool_generation_time = time.perf_counter() - tool_generation_start
|
94 |
+
|
95 |
+
print(f"Session {session_id} - Detector Raw Response: {response}")
|
96 |
+
|
97 |
+
multi_agent_logger.log_agent_execution(
|
98 |
+
session_id=session_id,
|
99 |
+
agent_name="detector",
|
100 |
+
agent_input=f"User: {user_message}",
|
101 |
+
agent_output=response,
|
102 |
+
execution_time=tool_generation_time
|
103 |
+
)
|
104 |
+
|
105 |
+
parsed_response = self._parse_response_with_reasoning(response)
|
106 |
+
|
107 |
+
if "error" in parsed_response:
|
108 |
+
multi_agent_logger.log_error(
|
109 |
+
session_id=session_id,
|
110 |
+
error_type="tool_call_parsing_error",
|
111 |
+
error_message=parsed_response["error"]
|
112 |
+
)
|
113 |
+
|
114 |
+
return {
|
115 |
+
"detection_summary": f"Tool call parsing failed: {parsed_response['error']}",
|
116 |
+
"detections_list": [],
|
117 |
+
"total_detections": 0,
|
118 |
+
"status": "error",
|
119 |
+
"error": parsed_response["error"],
|
120 |
+
"detection_narrative": "No detection narrative available due to parsing error."
|
121 |
+
}
|
122 |
+
|
123 |
+
reasoning = parsed_response["reasoning"]
|
124 |
+
tool_calls = parsed_response["tool_calls"]
|
125 |
+
|
126 |
+
print(f"Session {session_id} - Reasoning: {reasoning}")
|
127 |
+
print(f"Session {session_id} - Found {len(tool_calls)} tool calls")
|
128 |
+
|
129 |
+
all_results = []
|
130 |
+
combined_detection_summary = []
|
131 |
+
combined_detections_list = []
|
132 |
+
total_detections = 0
|
133 |
+
|
134 |
+
for i, tool_call in enumerate(tool_calls):
|
135 |
+
print(f"Session {session_id} - Executing tool call {i+1}/{len(tool_calls)}")
|
136 |
+
|
137 |
+
tool_name = tool_call["name"]
|
138 |
+
tool_arguments = tool_call["arguments"]
|
139 |
+
|
140 |
+
cached_result = tool_call_cache.get_cached_result(tool_name, tool_arguments)
|
141 |
+
|
142 |
+
if cached_result:
|
143 |
+
print(f"Session {session_id} - Tool call {i+1}: Using cached results")
|
144 |
+
|
145 |
+
if cached_result.get("annotated_image"):
|
146 |
+
session_state_manager.set(session_id, "annotated_image", cached_result["annotated_image"])
|
147 |
+
|
148 |
+
cache_key = cached_result["cache_info"]["cache_key"]
|
149 |
+
session_state_manager.add_tool_call_to_history(
|
150 |
+
session_id, tool_name, tool_arguments, cache_key
|
151 |
+
)
|
152 |
+
|
153 |
+
multi_agent_logger.log_tool_call(
|
154 |
+
session_id=session_id,
|
155 |
+
tool_name=tool_name,
|
156 |
+
tool_arguments=tool_arguments,
|
157 |
+
tool_result=cached_result,
|
158 |
+
execution_time=0.0,
|
159 |
+
cache_hit=True,
|
160 |
+
reasoning=f"Tool call {i+1}: {reasoning}"
|
161 |
+
)
|
162 |
+
|
163 |
+
tool_result = {
|
164 |
+
"tool_call_number": i + 1,
|
165 |
+
"tool_name": tool_name,
|
166 |
+
"tool_arguments": tool_arguments,
|
167 |
+
"cache_key": cache_key,
|
168 |
+
"detection_summary": cached_result["detection_summary"],
|
169 |
+
"detections_list": cached_result.get("detections_list", []),
|
170 |
+
"total_detections": len(cached_result.get("detections_list", [])),
|
171 |
+
"status": "success",
|
172 |
+
"cache_hit": True
|
173 |
+
}
|
174 |
+
|
175 |
+
all_results.append(tool_result)
|
176 |
+
combined_detection_summary.append(cached_result["detection_summary"])
|
177 |
+
combined_detections_list.extend(cached_result.get("detections_list", []))
|
178 |
+
total_detections += len(cached_result.get("detections_list", []))
|
179 |
+
|
180 |
+
else:
|
181 |
+
print(f"Session {session_id} - Tool call {i+1}: Cache MISS, executing tool")
|
182 |
+
|
183 |
+
tool_execution_start = time.perf_counter()
|
184 |
+
execution_result = handle_tool_call(tool_name, tool_arguments, session_id)
|
185 |
+
|
186 |
+
tool_execution_time = time.perf_counter() - tool_execution_start
|
187 |
+
|
188 |
+
if isinstance(execution_result, dict) and "detection_summary" in execution_result:
|
189 |
+
cache_result = {
|
190 |
+
"detection_summary": execution_result["detection_summary"],
|
191 |
+
"detections_list": execution_result.get("detections_list", []),
|
192 |
+
"total_detections": execution_result.get("total_detections", 0),
|
193 |
+
"status": "success"
|
194 |
+
}
|
195 |
+
|
196 |
+
annotated_image = session_state_manager.get(session_id, "annotated_image")
|
197 |
+
if annotated_image:
|
198 |
+
cache_result["annotated_image"] = annotated_image
|
199 |
+
|
200 |
+
cache_key = tool_call_cache.store_result(tool_name, tool_arguments, cache_result)
|
201 |
+
|
202 |
+
session_state_manager.add_tool_call_to_history(
|
203 |
+
session_id, tool_name, tool_arguments, cache_key
|
204 |
+
)
|
205 |
+
|
206 |
+
multi_agent_logger.log_tool_call(
|
207 |
+
session_id=session_id,
|
208 |
+
tool_name=tool_name,
|
209 |
+
tool_arguments=tool_arguments,
|
210 |
+
tool_result=execution_result,
|
211 |
+
execution_time=tool_execution_time,
|
212 |
+
cache_hit=False,
|
213 |
+
reasoning=f"Tool call {i+1}: {reasoning}"
|
214 |
+
)
|
215 |
+
|
216 |
+
tool_result = {
|
217 |
+
"tool_call_number": i + 1,
|
218 |
+
"tool_name": tool_name,
|
219 |
+
"tool_arguments": tool_arguments,
|
220 |
+
"cache_key": cache_key,
|
221 |
+
"detection_summary": execution_result["detection_summary"],
|
222 |
+
"detections_list": execution_result.get("detections_list", []),
|
223 |
+
"total_detections": execution_result.get("total_detections", 0),
|
224 |
+
"status": "success",
|
225 |
+
"cache_hit": False
|
226 |
+
}
|
227 |
+
|
228 |
+
all_results.append(tool_result)
|
229 |
+
combined_detection_summary.append(execution_result["detection_summary"])
|
230 |
+
combined_detections_list.extend(execution_result.get("detections_list", []))
|
231 |
+
total_detections += execution_result.get("total_detections", 0)
|
232 |
+
|
233 |
+
else:
|
234 |
+
error_msg = str(execution_result) if isinstance(execution_result, str) else "Unknown tool execution error"
|
235 |
+
print(f"Session {session_id} - Tool call {i+1} execution failed: {error_msg}")
|
236 |
+
|
237 |
+
multi_agent_logger.log_error(
|
238 |
+
session_id=session_id,
|
239 |
+
error_type="tool_execution_error",
|
240 |
+
error_message=f"Tool call {i+1} execution failed after {tool_execution_time:.2f}s: {error_msg}"
|
241 |
+
)
|
242 |
+
|
243 |
+
tool_result = {
|
244 |
+
"tool_call_number": i + 1,
|
245 |
+
"tool_name": tool_name,
|
246 |
+
"tool_arguments": tool_arguments,
|
247 |
+
"detection_summary": f"Tool call {i+1} failed: {error_msg}",
|
248 |
+
"detections_list": [],
|
249 |
+
"total_detections": 0,
|
250 |
+
"status": "error",
|
251 |
+
"error": error_msg,
|
252 |
+
"cache_hit": False
|
253 |
+
}
|
254 |
+
all_results.append(tool_result)
|
255 |
+
|
256 |
+
final_detection_summary = " | ".join(combined_detection_summary) if combined_detection_summary else "No successful detections"
|
257 |
+
|
258 |
+
# Generate comprehensive R-tree based narrative
|
259 |
+
detection_narrative = self._generate_spatial_narrative(
|
260 |
+
combined_detections_list, session_id
|
261 |
+
)
|
262 |
+
|
263 |
+
# Log the detection narrative
|
264 |
+
multi_agent_logger.log_agent_execution(
|
265 |
+
session_id=session_id,
|
266 |
+
agent_name="detection_narrative",
|
267 |
+
agent_input=f"Detection narrative for {len(combined_detections_list)} detections",
|
268 |
+
agent_output=detection_narrative,
|
269 |
+
execution_time=0.0
|
270 |
+
)
|
271 |
+
|
272 |
+
result = {
|
273 |
+
"detection_summary": final_detection_summary,
|
274 |
+
"detections_list": combined_detections_list,
|
275 |
+
"total_detections": total_detections,
|
276 |
+
"status": "success",
|
277 |
+
"reasoning": reasoning,
|
278 |
+
"visual_objects_input": visual_objects_detected,
|
279 |
+
"tool_calls_executed": len(tool_calls),
|
280 |
+
"tool_results": all_results,
|
281 |
+
"detection_narrative": detection_narrative,
|
282 |
+
"raw_tool_response": response
|
283 |
+
}
|
284 |
+
|
285 |
+
print(f"Session {session_id} - Executed {len(tool_calls)} tool calls successfully")
|
286 |
+
print(f"Session {session_id} - Generated detection narrative ({len(detection_narrative)} characters)")
|
287 |
+
return result
|
288 |
+
|
289 |
+
except Exception as e:
|
290 |
+
error_msg = f"Error in detector agent for session {session_id}: {str(e)}"
|
291 |
+
print(f"Detector Agent Error: {error_msg}")
|
292 |
+
|
293 |
+
multi_agent_logger.log_error(
|
294 |
+
session_id=session_id,
|
295 |
+
error_type="detector_agent_exception",
|
296 |
+
error_message=error_msg
|
297 |
+
)
|
298 |
+
|
299 |
+
return {
|
300 |
+
"detection_summary": f"Detection agent error: {error_msg}",
|
301 |
+
"detections_list": [],
|
302 |
+
"total_detections": 0,
|
303 |
+
"status": "error",
|
304 |
+
"error": error_msg,
|
305 |
+
"visual_objects_input": visual_objects_detected,
|
306 |
+
"detection_narrative": f"Detection narrative generation failed due to error: {error_msg}"
|
307 |
+
}
|
308 |
+
|
309 |
+
def _generate_spatial_narrative(self, detections_list: List[Dict[str, Any]], session_id: str) -> str:
|
310 |
+
"""
|
311 |
+
Generate comprehensive spatial narrative using R-tree analysis.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
detections_list: Combined list of all detections
|
315 |
+
session_id: Session identifier for getting image dimensions
|
316 |
+
|
317 |
+
Returns:
|
318 |
+
Comprehensive detection narrative
|
319 |
+
"""
|
320 |
+
if not detections_list:
|
321 |
+
return "No detections available for spatial narrative generation."
|
322 |
+
|
323 |
+
try:
|
324 |
+
# Get image dimensions
|
325 |
+
current_image = session_state_manager.get(session_id, "current_image")
|
326 |
+
if current_image:
|
327 |
+
image_width, image_height = current_image.size
|
328 |
+
else:
|
329 |
+
# Default dimensions if image not available
|
330 |
+
image_width, image_height = 1920, 1080
|
331 |
+
|
332 |
+
# Generate narrative using DetectionNarrativeGenerator
|
333 |
+
narrative_generator = DetectionNarrativeGenerator(image_width, image_height)
|
334 |
+
comprehensive_narrative = narrative_generator.generate_comprehensive_narrative(detections_list)
|
335 |
+
|
336 |
+
print(f"Session {session_id} - Generated comprehensive spatial narrative")
|
337 |
+
|
338 |
+
return comprehensive_narrative
|
339 |
+
|
340 |
+
except Exception as e:
|
341 |
+
error_msg = f"Error generating spatial narrative: {str(e)}"
|
342 |
+
print(f"Session {session_id} - {error_msg}")
|
343 |
+
|
344 |
+
# Just return the detection summary itself
|
345 |
+
total_count = len(detections_list)
|
346 |
+
label_counts = {}
|
347 |
+
classification_counts = {}
|
348 |
+
|
349 |
+
for detection in detections_list:
|
350 |
+
base_label = detection.get('label', 'unknown')
|
351 |
+
label_counts[base_label] = label_counts.get(base_label, 0) + 1
|
352 |
+
|
353 |
+
# Handle tree classifications
|
354 |
+
if base_label == 'tree':
|
355 |
+
classification_label = detection.get('classification_label')
|
356 |
+
classification_score = detection.get('classification_score')
|
357 |
+
|
358 |
+
# Only count valid classifications (not NaN)
|
359 |
+
if (classification_label and
|
360 |
+
classification_score is not None and
|
361 |
+
str(classification_label).lower() != 'nan' and
|
362 |
+
str(classification_score).lower() != 'nan'):
|
363 |
+
|
364 |
+
classification_counts[classification_label] = classification_counts.get(classification_label, 0) + 1
|
365 |
+
|
366 |
+
# Build simple summary
|
367 |
+
object_parts = []
|
368 |
+
for label, count in label_counts.items():
|
369 |
+
if label == 'tree' and classification_counts:
|
370 |
+
# Special handling for trees with classifications
|
371 |
+
total_trees = count
|
372 |
+
tree_part = f"{total_trees} trees are detected"
|
373 |
+
|
374 |
+
if classification_counts:
|
375 |
+
classification_parts = []
|
376 |
+
for class_label, class_count in classification_counts.items():
|
377 |
+
class_name = class_label.replace('_', ' ')
|
378 |
+
classification_parts.append(f"{class_count} {class_name}s")
|
379 |
+
|
380 |
+
tree_part += f". These {total_trees} trees are classified as {' and '.join(classification_parts)}"
|
381 |
+
|
382 |
+
object_parts.append(tree_part)
|
383 |
+
else:
|
384 |
+
label_name = label.replace('_', ' ')
|
385 |
+
object_parts.append(f"{count} {label_name}{'s' if count != 1 else ''}")
|
386 |
+
|
387 |
+
fallback_summary = f"DeepForest detected {total_count} objects: {', '.join(object_parts)}."
|
388 |
+
|
389 |
+
return fallback_summary
|
390 |
+
|
391 |
+
def _parse_response_with_reasoning(self, response: str) -> Dict[str, Any]:
|
392 |
+
"""
|
393 |
+
Parse model response to extract reasoning and multiple tool calls.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
response (str): Raw response from the model
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
Dictionary containing either:
|
400 |
+
- {"reasoning": str, "tool_call": dict} on success
|
401 |
+
- {"error": str} on parsing failure
|
402 |
+
"""
|
403 |
+
return parse_deepforest_agent_response_with_reasoning(response)
|
src/deepforest_agent/agents/ecology_analysis_agent.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Dict, List, Any, Optional, Generator
|
3 |
+
|
4 |
+
from deepforest_agent.models.llama32_3b_instruct import Llama32ModelManager
|
5 |
+
from deepforest_agent.conf.config import Config
|
6 |
+
from deepforest_agent.prompts.prompt_templates import create_ecology_synthesis_prompt
|
7 |
+
from deepforest_agent.utils.state_manager import session_state_manager
|
8 |
+
|
9 |
+
|
10 |
+
class EcologyAnalysisAgent:
|
11 |
+
"""
|
12 |
+
Ecology analysis agent responsible for combining all data into comprehensive ecological insights.
|
13 |
+
Uses Llama-3.2-3B-Instruct model for detailed structured response generation with analysis.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
"""Initialize the Ecology Analysis Agent."""
|
18 |
+
self.agent_config = Config.AGENT_CONFIGS["ecology_analysis"]
|
19 |
+
self.model_manager = Llama32ModelManager(Config.AGENT_MODELS["ecology_analysis"])
|
20 |
+
|
21 |
+
def synthesize_analysis_streaming(
|
22 |
+
self,
|
23 |
+
user_message: str,
|
24 |
+
memory_context: str,
|
25 |
+
cached_json: Optional[Dict[str, Any]] = None,
|
26 |
+
current_json: Optional[Dict[str, Any]] = None,
|
27 |
+
session_id: Optional[str] = None
|
28 |
+
) -> Generator[Dict[str, Any], None, None]:
|
29 |
+
"""
|
30 |
+
Synthesize all agent outputs with streaming text generation.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
user_message (str): The user's original query for the analysis.
|
34 |
+
memory_context (str): The context and conversation history provided
|
35 |
+
by a memory agent.
|
36 |
+
cached_json (Optional[Dict[str, Any]]): A dictionary of previously
|
37 |
+
cached JSON data, if available. Defaults to None.
|
38 |
+
current_json (Optional[Dict[str, Any]]): A dictionary of new JSON data
|
39 |
+
from the current analysis step. Defaults to None.
|
40 |
+
session_id (Optional[str]): A unique session identifier for tracking
|
41 |
+
and logging. Defaults to None.
|
42 |
+
|
43 |
+
Yields:
|
44 |
+
Dict[str, Any]: Dictionary containing:
|
45 |
+
- token: Generated text token
|
46 |
+
- is_complete: Whether generation is finished
|
47 |
+
"""
|
48 |
+
if session_id and not session_state_manager.session_exists(session_id):
|
49 |
+
yield {
|
50 |
+
"token": f"Session {session_id} not found. Unable to synthesize analysis.",
|
51 |
+
"is_complete": True
|
52 |
+
}
|
53 |
+
return
|
54 |
+
|
55 |
+
try:
|
56 |
+
synthesis_prompt = create_ecology_synthesis_prompt(
|
57 |
+
user_message=user_message,
|
58 |
+
comprehensive_context=memory_context,
|
59 |
+
cached_json=cached_json,
|
60 |
+
current_json=current_json
|
61 |
+
)
|
62 |
+
print(f"Ecology Synthesis Prompt:\n{synthesis_prompt}\n")
|
63 |
+
|
64 |
+
messages = [
|
65 |
+
{"role": "system", "content": synthesis_prompt},
|
66 |
+
{"role": "user", "content": user_message}
|
67 |
+
]
|
68 |
+
|
69 |
+
print(f"Session {session_id} - Ecology Agent: Starting streaming synthesis")
|
70 |
+
|
71 |
+
# Stream the response token by token
|
72 |
+
for token_data in self.model_manager.generate_response_streaming(
|
73 |
+
messages=messages,
|
74 |
+
max_new_tokens=self.agent_config["max_new_tokens"],
|
75 |
+
temperature=self.agent_config["temperature"],
|
76 |
+
top_p=self.agent_config["top_p"]
|
77 |
+
):
|
78 |
+
yield token_data
|
79 |
+
|
80 |
+
if token_data["is_complete"]:
|
81 |
+
print(f"Session {session_id} - Ecology Agent: Streaming synthesis completed")
|
82 |
+
break
|
83 |
+
|
84 |
+
except Exception as e:
|
85 |
+
error_msg = f"Error in ecology synthesis for session {session_id}: {str(e)}"
|
86 |
+
print(f"Ecology Analysis Error: {error_msg}")
|
87 |
+
|
88 |
+
# Yield error_msg response as single token
|
89 |
+
yield {
|
90 |
+
"token": error_msg,
|
91 |
+
"is_complete": True
|
92 |
+
}
|
src/deepforest_agent/agents/memory_agent.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any, Optional
|
2 |
+
import re
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
|
6 |
+
from deepforest_agent.models.smollm3_3b import SmolLM3ModelManager
|
7 |
+
from deepforest_agent.conf.config import Config
|
8 |
+
from deepforest_agent.prompts.prompt_templates import format_memory_prompt
|
9 |
+
from deepforest_agent.utils.state_manager import session_state_manager
|
10 |
+
from deepforest_agent.utils.logging_utils import multi_agent_logger
|
11 |
+
from deepforest_agent.utils.parsing_utils import parse_memory_agent_response
|
12 |
+
from deepforest_agent.utils.cache_utils import tool_call_cache
|
13 |
+
from deepforest_agent.conf.config import Config
|
14 |
+
|
15 |
+
class MemoryAgent:
|
16 |
+
"""
|
17 |
+
Memory agent responsible for analyzing conversation history in new format.
|
18 |
+
Uses SmolLM3-3B model for getting relevant context
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
"""Initialize the Memory Agent with model manager and configuration."""
|
23 |
+
self.agent_config = Config.AGENT_CONFIGS["memory"]
|
24 |
+
self.model_manager = SmolLM3ModelManager(Config.AGENT_MODELS["memory"])
|
25 |
+
|
26 |
+
def _filter_conversation_history(self, conversation_history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
27 |
+
"""
|
28 |
+
Filter conversation history to include user and assistant messages.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
conversation_history: Full conversation history
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
Filtered history with only user/assistant messages
|
35 |
+
"""
|
36 |
+
filtered_history = []
|
37 |
+
|
38 |
+
for message in conversation_history:
|
39 |
+
if message.get("role") in ["user", "assistant"]:
|
40 |
+
content = message.get("content", "")
|
41 |
+
if isinstance(content, list):
|
42 |
+
text_parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
43 |
+
content = " ".join(text_parts)
|
44 |
+
elif isinstance(content, str):
|
45 |
+
content = content
|
46 |
+
else:
|
47 |
+
content = str(content)
|
48 |
+
|
49 |
+
filtered_history.append({
|
50 |
+
"role": message["role"],
|
51 |
+
"content": content
|
52 |
+
})
|
53 |
+
|
54 |
+
return filtered_history
|
55 |
+
|
56 |
+
def _get_conversation_history_context(self, session_id: str) -> str:
|
57 |
+
"""
|
58 |
+
Get formatted conversation history with turn-based structure.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
session_id: Session identifier
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Formatted conversation history with turn structure
|
65 |
+
"""
|
66 |
+
conversation_history = session_state_manager.get(session_id, "conversation_history", [])
|
67 |
+
|
68 |
+
print(f"Session {session_id} - Conversation length: {len(conversation_history)}")
|
69 |
+
|
70 |
+
if not conversation_history:
|
71 |
+
return "No previous conversation history available."
|
72 |
+
|
73 |
+
# Build turn-based history
|
74 |
+
formatted_history = []
|
75 |
+
turn_number = 1
|
76 |
+
|
77 |
+
# Process conversation in pairs (user -> assistant)
|
78 |
+
i = 0
|
79 |
+
while i < len(conversation_history):
|
80 |
+
if i + 1 < len(conversation_history):
|
81 |
+
user_msg = conversation_history[i]
|
82 |
+
assistant_msg = conversation_history[i + 1]
|
83 |
+
|
84 |
+
if user_msg.get("role") == "user" and assistant_msg.get("role") == "assistant":
|
85 |
+
# Extract user query
|
86 |
+
user_content = user_msg.get("content", "")
|
87 |
+
if isinstance(user_content, list):
|
88 |
+
text_parts = [item.get("text", "") for item in user_content if item.get("type") == "text"]
|
89 |
+
user_query = " ".join(text_parts)
|
90 |
+
else:
|
91 |
+
user_query = str(user_content)
|
92 |
+
|
93 |
+
# Get stored context data for this turn
|
94 |
+
visual_context = session_state_manager.get(session_id, f"turn_{turn_number}_visual_context", "No visual analysis available")
|
95 |
+
detection_narrative = session_state_manager.get(session_id, f"turn_{turn_number}_detection_narrative", "No detection narrative available")
|
96 |
+
tool_cache_id = session_state_manager.get(session_id, f"turn_{turn_number}_tool_cache_id", "No tool cache ID")
|
97 |
+
tool_call_info = "No tool call information available"
|
98 |
+
if tool_cache_id:
|
99 |
+
try:
|
100 |
+
if tool_cache_id in tool_call_cache.cache_data:
|
101 |
+
cached_entry = tool_call_cache.cache_data[tool_cache_id]
|
102 |
+
tool_name = cached_entry.get("tool_name", "unknown")
|
103 |
+
stored_arguments = cached_entry.get("arguments", {})
|
104 |
+
|
105 |
+
all_arguments = Config.DEEPFOREST_DEFAULTS.copy()
|
106 |
+
all_arguments.update(stored_arguments)
|
107 |
+
|
108 |
+
# Format tool call info with all arguments
|
109 |
+
args_str = ", ".join([f"{k}={v}" for k, v in all_arguments.items()])
|
110 |
+
tool_call_info = f"Tool: {tool_name} called with arguments: {args_str}"
|
111 |
+
except Exception as e:
|
112 |
+
tool_call_info = f"Error retrieving tool call info: {str(e)}"
|
113 |
+
|
114 |
+
turn_text = f"--- Turn {turn_number}: ---\n"
|
115 |
+
turn_text += f"Turn {turn_number} User query: {user_query}\n"
|
116 |
+
turn_text += f"Turn {turn_number} Visual analysis full image or per tile: {visual_context}\n"
|
117 |
+
turn_text += f"Turn {turn_number} Tool cache ID: {tool_cache_id}\n"
|
118 |
+
turn_text += f"Turn {turn_number} Tool call details: {tool_call_info}\n"
|
119 |
+
turn_text += f"Turn {turn_number} Detection Data Analysis: {detection_narrative}\n"
|
120 |
+
turn_text += f"--- Turn {turn_number} Completed ---\n"
|
121 |
+
|
122 |
+
formatted_history.append(turn_text)
|
123 |
+
turn_number += 1
|
124 |
+
i += 2
|
125 |
+
else:
|
126 |
+
i += 1
|
127 |
+
else:
|
128 |
+
i += 1
|
129 |
+
|
130 |
+
if not formatted_history:
|
131 |
+
return "No complete conversation turns available."
|
132 |
+
|
133 |
+
print(f"Formatted {len(formatted_history)} conversation turns")
|
134 |
+
return "\n\n".join(formatted_history)
|
135 |
+
|
136 |
+
def process_conversation_history_structured(
|
137 |
+
self,
|
138 |
+
conversation_history: List[Dict[str, Any]],
|
139 |
+
latest_message: str,
|
140 |
+
session_id: str
|
141 |
+
) -> Dict[str, Any]:
|
142 |
+
"""
|
143 |
+
Process conversation history and extract relevant context with structured output.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
conversation_history: Full conversation history
|
147 |
+
latest_message: Current user message requiring context analysis
|
148 |
+
session_id: Unique session identifier for this user
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
Dict with structured output including tool_cache_id and relevant context
|
152 |
+
"""
|
153 |
+
if not session_state_manager.session_exists(session_id):
|
154 |
+
return {
|
155 |
+
"answer_present": False,
|
156 |
+
"direct_answer": "NO",
|
157 |
+
"tool_cache_id": None,
|
158 |
+
"relevant_context": f"Session {session_id} not found. Current query: {latest_message}",
|
159 |
+
"raw_response": f"Session {session_id} not found"
|
160 |
+
}
|
161 |
+
|
162 |
+
filtered_history = self._filter_conversation_history(conversation_history)
|
163 |
+
conversation_context = self._get_conversation_history_context(session_id)
|
164 |
+
|
165 |
+
memory_prompt = format_memory_prompt(filtered_history, latest_message, conversation_context)
|
166 |
+
print(f"Memory Agent Prompt:\n{memory_prompt}\n")
|
167 |
+
|
168 |
+
messages = [
|
169 |
+
{"role": "system", "content": memory_prompt},
|
170 |
+
{"role": "user", "content": latest_message}
|
171 |
+
]
|
172 |
+
|
173 |
+
memory_execution_start = time.perf_counter()
|
174 |
+
|
175 |
+
try:
|
176 |
+
response = self.model_manager.generate_response(
|
177 |
+
messages=messages,
|
178 |
+
max_new_tokens=self.agent_config["max_new_tokens"],
|
179 |
+
temperature=self.agent_config["temperature"],
|
180 |
+
top_p=self.agent_config["top_p"]
|
181 |
+
)
|
182 |
+
|
183 |
+
memory_execution_time = time.perf_counter() - memory_execution_start
|
184 |
+
|
185 |
+
print(f"Session {session_id} - Memory Agent: Raw response received")
|
186 |
+
print(f"Raw Response: {response}")
|
187 |
+
|
188 |
+
parsed_result = parse_memory_agent_response(response)
|
189 |
+
|
190 |
+
multi_agent_logger.log_agent_execution(
|
191 |
+
session_id=session_id,
|
192 |
+
agent_name="memory",
|
193 |
+
agent_input=f"Latest message: {latest_message}",
|
194 |
+
agent_output=response,
|
195 |
+
execution_time=memory_execution_time
|
196 |
+
)
|
197 |
+
|
198 |
+
print(f"Session {session_id} - Memory Agent: Analysis completed")
|
199 |
+
print(f"Has Answer: {parsed_result['answer_present']}")
|
200 |
+
|
201 |
+
return parsed_result
|
202 |
+
|
203 |
+
except Exception as e:
|
204 |
+
memory_execution_time = time.perf_counter() - memory_execution_start
|
205 |
+
error_msg = f"Error processing conversation history in session {session_id}: {str(e)}"
|
206 |
+
print(f"Session {session_id} - Memory Agent Error: {e}")
|
207 |
+
|
208 |
+
multi_agent_logger.log_error(
|
209 |
+
session_id=session_id,
|
210 |
+
error_type="memory_agent_error",
|
211 |
+
error_message=f"Memory agent failed after {memory_execution_time:.2f}s: {str(e)}"
|
212 |
+
)
|
213 |
+
|
214 |
+
return {
|
215 |
+
"answer_present": False,
|
216 |
+
"direct_answer": "NO",
|
217 |
+
"tool_cache_id": None,
|
218 |
+
"relevant_context": f"{error_msg}. Current query: {latest_message}",
|
219 |
+
"raw_response": str(e)
|
220 |
+
}
|
221 |
+
|
222 |
+
def store_turn_context(self, session_id: str, turn_number: int, visual_context: str,
|
223 |
+
detection_narrative: str, tool_cache_id: Optional[str]) -> None:
|
224 |
+
"""
|
225 |
+
Store context data for a specific conversation turn.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
session_id: Session identifier
|
229 |
+
turn_number: Turn number in conversation
|
230 |
+
visual_context: Visual analysis context
|
231 |
+
detection_narrative: Detection narrative
|
232 |
+
tool_cache_id: Tool cache identifier
|
233 |
+
"""
|
234 |
+
session_state_manager.set(session_id, f"turn_{turn_number}_visual_context", visual_context)
|
235 |
+
session_state_manager.set(session_id, f"turn_{turn_number}_detection_narrative", detection_narrative)
|
236 |
+
session_state_manager.set(session_id, f"turn_{turn_number}_tool_cache_id", tool_cache_id or "No tool cache ID")
|
237 |
+
|
238 |
+
print(f"Session {session_id} - Stored context for turn {turn_number}")
|
src/deepforest_agent/agents/orchestrator.py
ADDED
@@ -0,0 +1,795 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import gc
|
5 |
+
from typing import Dict, List, Any, Optional, Generator
|
6 |
+
|
7 |
+
from deepforest_agent.agents.memory_agent import MemoryAgent
|
8 |
+
from deepforest_agent.agents.deepforest_detector_agent import DeepForestDetectorAgent
|
9 |
+
from deepforest_agent.agents.visual_analysis_agent import VisualAnalysisAgent
|
10 |
+
from deepforest_agent.agents.ecology_analysis_agent import EcologyAnalysisAgent
|
11 |
+
from deepforest_agent.utils.state_manager import session_state_manager
|
12 |
+
from deepforest_agent.utils.cache_utils import tool_call_cache
|
13 |
+
from deepforest_agent.utils.image_utils import check_image_resolution_for_deepforest
|
14 |
+
from deepforest_agent.utils.logging_utils import multi_agent_logger
|
15 |
+
from deepforest_agent.utils.detection_narrative_generator import DetectionNarrativeGenerator
|
16 |
+
|
17 |
+
|
18 |
+
class AgentOrchestrator:
|
19 |
+
"""
|
20 |
+
Orchestrates the multi-agent workflow with memory context + visual contexts + DeepForest detection context + ecological synthesis.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
"""Initialize the Agent Orchestrator."""
|
25 |
+
self.memory_agent = MemoryAgent()
|
26 |
+
self.detector_agent = DeepForestDetectorAgent()
|
27 |
+
self.visual_agent = VisualAnalysisAgent()
|
28 |
+
self.ecology_agent = EcologyAnalysisAgent()
|
29 |
+
|
30 |
+
self.execution_stats = {
|
31 |
+
"total_runs": 0,
|
32 |
+
"successful_runs": 0,
|
33 |
+
"average_execution_time": 0.0,
|
34 |
+
"memory_direct_answers": 0,
|
35 |
+
"deepforest_skipped": 0
|
36 |
+
}
|
37 |
+
|
38 |
+
def _log_gpu_memory(self, session_id: str, stage: str, agent_name: str):
|
39 |
+
"""
|
40 |
+
Log current GPU memory usage.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
session_id (str): Unique identifier for the user session being processed
|
44 |
+
stage (str): Workflow stage identifier (e.g., "before", "after", "cleanup")
|
45 |
+
agent_name (str): Name of the agent being monitored (e.g., "Visual Analysis",
|
46 |
+
"DeepForest Detection", "Memory Agent")
|
47 |
+
"""
|
48 |
+
if torch.cuda.is_available():
|
49 |
+
allocated_gb = torch.cuda.memory_allocated() / 1024**3
|
50 |
+
cached_gb = torch.cuda.memory_reserved() / 1024**3
|
51 |
+
|
52 |
+
multi_agent_logger.log_agent_execution(
|
53 |
+
session_id=session_id,
|
54 |
+
agent_name=f"gpu_memory_{stage}",
|
55 |
+
agent_input=f"{agent_name} - {stage}",
|
56 |
+
agent_output=f"GPU Memory - Allocated: {allocated_gb:.2f} GB, Cached: {cached_gb:.2f} GB",
|
57 |
+
execution_time=0.0
|
58 |
+
)
|
59 |
+
print(f"Session {session_id} - {agent_name} {stage}: GPU Memory - Allocated: {allocated_gb:.2f} GB, Cached: {cached_gb:.2f} GB")
|
60 |
+
|
61 |
+
def cleanup_all_agents(self):
|
62 |
+
"""Cleanup models to manage memory."""
|
63 |
+
print("Orchestrator cleanup:")
|
64 |
+
gc.collect()
|
65 |
+
if torch.cuda.is_available():
|
66 |
+
torch.cuda.empty_cache()
|
67 |
+
torch.cuda.synchronize()
|
68 |
+
torch.cuda.ipc_collect()
|
69 |
+
print(f"Final GPU memory after orchestrator cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
70 |
+
|
71 |
+
def _aggressive_gpu_cleanup(self, session_id: str, stage: str):
|
72 |
+
"""
|
73 |
+
Perform aggressive GPU memory cleanup.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
session_id (str): Unique identifier for the user session
|
77 |
+
stage (str): Workflow stage identifier for logging context
|
78 |
+
"""
|
79 |
+
if torch.cuda.is_available():
|
80 |
+
for i in range(3):
|
81 |
+
gc.collect()
|
82 |
+
torch.cuda.empty_cache()
|
83 |
+
|
84 |
+
torch.cuda.ipc_collect()
|
85 |
+
torch.cuda.synchronize()
|
86 |
+
|
87 |
+
try:
|
88 |
+
torch.cuda.reset_peak_memory_stats()
|
89 |
+
torch.cuda.reset_accumulated_memory_stats()
|
90 |
+
except:
|
91 |
+
pass
|
92 |
+
|
93 |
+
allocated = torch.cuda.memory_allocated() / 1024**3
|
94 |
+
cached = torch.cuda.memory_reserved() / 1024**3
|
95 |
+
|
96 |
+
print(f"Session {session_id} - {stage} aggressive cleanup: {allocated:.2f} GB allocated, {cached:.2f} GB cached")
|
97 |
+
|
98 |
+
def _format_detection_data_for_monitor(self, detection_narrative: str, detections_list: Optional[List[Dict[str, Any]]] = None) -> str:
|
99 |
+
"""
|
100 |
+
Format detection data for monitor display.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
detection_narrative: Generated detection context from DeepForest Data
|
104 |
+
detections_list: Full DeepForest detection data
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
Formatted detection data for monitor
|
108 |
+
"""
|
109 |
+
monitor_parts = []
|
110 |
+
|
111 |
+
if detections_list:
|
112 |
+
monitor_parts.append("=== DEEPFOREST DETECTIONS ===")
|
113 |
+
monitor_parts.append(json.dumps(detections_list, indent=2))
|
114 |
+
monitor_parts.append("")
|
115 |
+
|
116 |
+
if detection_narrative:
|
117 |
+
monitor_parts.append("=== DETECTION NARRATIVE ===")
|
118 |
+
monitor_parts.append(detection_narrative)
|
119 |
+
|
120 |
+
return "\n".join(monitor_parts) if monitor_parts else "No detection data available"
|
121 |
+
|
122 |
+
def _get_cached_detection_narrative(self, tool_cache_id: str) -> Optional[str]:
|
123 |
+
"""
|
124 |
+
Retrieve detection narrative using tool cache ID from the tool_call_cache.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
tool_cache_id: Tool cache identifier
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Detection context from DeepForest Data if found, None otherwise
|
131 |
+
"""
|
132 |
+
try:
|
133 |
+
print(f"Looking up cached detection narrative for tool_cache_id: {tool_cache_id}")
|
134 |
+
|
135 |
+
# Handle multiple cache IDs
|
136 |
+
cache_ids = [id.strip() for id in tool_cache_id.split(",")] if tool_cache_id else []
|
137 |
+
all_narratives = []
|
138 |
+
|
139 |
+
for cache_id in cache_ids:
|
140 |
+
if cache_id in tool_call_cache.cache_data:
|
141 |
+
cached_entry = tool_call_cache.cache_data[cache_id]
|
142 |
+
cached_result = cached_entry.get("result", {})
|
143 |
+
tool_name = cached_entry.get("tool_name", "unknown")
|
144 |
+
tool_arguments = cached_entry.get("arguments", {})
|
145 |
+
|
146 |
+
# Get all possible arguments including defaults from Config
|
147 |
+
from deepforest_agent.conf.config import Config
|
148 |
+
all_arguments = Config.DEEPFOREST_DEFAULTS.copy()
|
149 |
+
all_arguments.update(tool_arguments)
|
150 |
+
|
151 |
+
# Format tool call info with all arguments
|
152 |
+
args_str = ", ".join([f"{k}={v}" for k, v in all_arguments.items()])
|
153 |
+
|
154 |
+
# Check if we have detections_list to generate narrative from
|
155 |
+
detections_list = cached_result.get("detections_list", [])
|
156 |
+
|
157 |
+
if detections_list:
|
158 |
+
print(f"Found {len(detections_list)} cached detections for cache ID {cache_id}")
|
159 |
+
|
160 |
+
# Get image dimensions for narrative generation
|
161 |
+
try:
|
162 |
+
session_keys = list(session_state_manager._sessions.keys())
|
163 |
+
if session_keys:
|
164 |
+
current_image = session_state_manager.get(session_keys[0], "current_image")
|
165 |
+
if current_image:
|
166 |
+
image_width, image_height = current_image.size
|
167 |
+
else:
|
168 |
+
image_width, image_height = 0, 0
|
169 |
+
else:
|
170 |
+
image_width, image_height = 0, 0
|
171 |
+
except:
|
172 |
+
image_width, image_height = 0, 0
|
173 |
+
|
174 |
+
# Generate fresh narrative from cached detection data
|
175 |
+
narrative_generator = DetectionNarrativeGenerator(image_width, image_height)
|
176 |
+
cached_detection_narrative = narrative_generator.generate_comprehensive_narrative(detections_list)
|
177 |
+
|
178 |
+
# Format with proper tool cache ID structure
|
179 |
+
formatted_narrative = f"**TOOL CACHE ID:** {cache_id}\nDeepForest tool run with arguments ({args_str}) and got the below narratives:\nDETECTION NARRATIVE:\n{cached_detection_narrative}"
|
180 |
+
all_narratives.append(formatted_narrative)
|
181 |
+
else:
|
182 |
+
detection_summary = cached_result.get("detection_summary", "")
|
183 |
+
if detection_summary:
|
184 |
+
formatted_summary = f"**TOOL CACHE ID:** {cache_id}\nDeepForest tool run with arguments ({args_str}) and got the below narratives:\nDETECTION NARRATIVE:\n{detection_summary}"
|
185 |
+
all_narratives.append(formatted_summary)
|
186 |
+
|
187 |
+
if all_narratives:
|
188 |
+
print(f"Generated {len(all_narratives)} cached detection narratives")
|
189 |
+
return "\n\n".join(all_narratives)
|
190 |
+
|
191 |
+
print(f"No cached data found for tool_cache_id(s): {tool_cache_id}")
|
192 |
+
return None
|
193 |
+
|
194 |
+
except Exception as e:
|
195 |
+
print(f"Error retrieving cached detection narrative for {tool_cache_id}: {e}")
|
196 |
+
return None
|
197 |
+
|
198 |
+
def process_user_message_streaming(
|
199 |
+
self,
|
200 |
+
user_message: str,
|
201 |
+
conversation_history: List[Dict[str, Any]],
|
202 |
+
session_id: str
|
203 |
+
) -> Generator[Dict[str, Any], None, None]:
|
204 |
+
"""
|
205 |
+
Orchestrate the multi-agent workflow with memory context and detection narrative flow.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
user_message: Current user message/query to be processed
|
209 |
+
conversation_history: Full conversation history
|
210 |
+
session_id: Unique session identifier for this user's workflow
|
211 |
+
|
212 |
+
Yields:
|
213 |
+
Dict[str, Any]: Progress updates during processing
|
214 |
+
"""
|
215 |
+
start_time = time.perf_counter()
|
216 |
+
self.execution_stats["total_runs"] += 1
|
217 |
+
|
218 |
+
print(f"Session {session_id} - Query: {user_message}")
|
219 |
+
print(f"Session {session_id} - Conversation history length: {len(conversation_history)}")
|
220 |
+
|
221 |
+
agent_results = {}
|
222 |
+
execution_summary = {
|
223 |
+
"agents_executed": [],
|
224 |
+
"execution_order": [],
|
225 |
+
"timings": {},
|
226 |
+
"status": "in_progress",
|
227 |
+
"session_id": session_id,
|
228 |
+
"workflow_type": "memory_narrative_flow",
|
229 |
+
"memory_provided_direct_answer": False,
|
230 |
+
"deepforest_executed": False
|
231 |
+
}
|
232 |
+
|
233 |
+
memory_context = ""
|
234 |
+
visual_context = ""
|
235 |
+
detection_narrative = ""
|
236 |
+
memory_tool_cache_id = None
|
237 |
+
current_tool_cache_id = None
|
238 |
+
|
239 |
+
try:
|
240 |
+
if not session_state_manager.session_exists(session_id):
|
241 |
+
raise ValueError(f"Session {session_id} not found")
|
242 |
+
|
243 |
+
session_state_manager.set_processing_state(session_id, True)
|
244 |
+
session_state_manager.reset_cancellation(session_id)
|
245 |
+
|
246 |
+
yield {
|
247 |
+
"stage": "memory",
|
248 |
+
"message": "Analyzing conversation memory and context...",
|
249 |
+
"type": "progress"
|
250 |
+
}
|
251 |
+
|
252 |
+
if session_state_manager.is_cancelled(session_id):
|
253 |
+
raise Exception("Processing cancelled by user")
|
254 |
+
|
255 |
+
print(f"\nSTEP 1: Memory Agent Processing (Session {session_id})")
|
256 |
+
self._log_gpu_memory(session_id, "before", "Memory Agent")
|
257 |
+
memory_start = time.perf_counter()
|
258 |
+
|
259 |
+
memory_result = self.memory_agent.process_conversation_history_structured(
|
260 |
+
conversation_history=conversation_history,
|
261 |
+
latest_message=user_message,
|
262 |
+
session_id=session_id
|
263 |
+
)
|
264 |
+
|
265 |
+
memory_time = time.perf_counter() - memory_start
|
266 |
+
self._log_gpu_memory(session_id, "after", "Memory Agent")
|
267 |
+
self._aggressive_gpu_cleanup(session_id, "after_memory_agent")
|
268 |
+
execution_summary["timings"]["memory_agent"] = memory_time
|
269 |
+
execution_summary["agents_executed"].append("memory")
|
270 |
+
execution_summary["execution_order"].append("memory")
|
271 |
+
agent_results["memory"] = memory_result
|
272 |
+
|
273 |
+
# Extract memory context and tool cache ID
|
274 |
+
memory_context = memory_result.get("relevant_context", "No memory context available")
|
275 |
+
tool_cache_id = memory_result.get("tool_cache_id")
|
276 |
+
|
277 |
+
print(f"Session {session_id} - Memory Agent: Completed in {memory_time:.2f}s")
|
278 |
+
print(f"Session {session_id} - Memory Has Answer: {memory_result['answer_present']}")
|
279 |
+
print(f"Session {session_id} - Tool Cache ID: {tool_cache_id}")
|
280 |
+
|
281 |
+
if memory_result["answer_present"]:
|
282 |
+
print(f"Session {session_id} - Memory has direct answer - using cached data for synthesis")
|
283 |
+
|
284 |
+
self.execution_stats["memory_direct_answers"] += 1
|
285 |
+
execution_summary["memory_provided_direct_answer"] = True
|
286 |
+
|
287 |
+
# Get cached detection narrative if available
|
288 |
+
cached_detection_narrative = ""
|
289 |
+
if tool_cache_id:
|
290 |
+
cached_detection_narrative = self._get_cached_detection_narrative(tool_cache_id) or ""
|
291 |
+
|
292 |
+
yield {
|
293 |
+
"stage": "ecology",
|
294 |
+
"message": "Using memory context and cached detection narrative for synthesis...",
|
295 |
+
"type": "progress"
|
296 |
+
}
|
297 |
+
|
298 |
+
if session_state_manager.is_cancelled(session_id):
|
299 |
+
raise Exception("Processing cancelled by user")
|
300 |
+
|
301 |
+
print(f"\nSTEP 2 (MEMORY PATH): Ecology Agent with Memory Context (Session {session_id})")
|
302 |
+
self._log_gpu_memory(session_id, "before", "Ecology Agent (Memory Path)")
|
303 |
+
ecology_start = time.perf_counter()
|
304 |
+
|
305 |
+
# Prepare comprehensive context
|
306 |
+
comprehensive_context = self._prepare_comprehensive_context(
|
307 |
+
memory_context=memory_context,
|
308 |
+
visual_context="",
|
309 |
+
detection_narrative=cached_detection_narrative,
|
310 |
+
tool_cache_id=tool_cache_id
|
311 |
+
)
|
312 |
+
|
313 |
+
final_response = ""
|
314 |
+
for token_result in self.ecology_agent.synthesize_analysis_streaming(
|
315 |
+
user_message=user_message,
|
316 |
+
memory_context=comprehensive_context,
|
317 |
+
cached_json=None,
|
318 |
+
current_json=None,
|
319 |
+
session_id=session_id
|
320 |
+
):
|
321 |
+
|
322 |
+
if session_state_manager.is_cancelled(session_id):
|
323 |
+
raise Exception("Processing cancelled by user")
|
324 |
+
|
325 |
+
final_response += token_result["token"]
|
326 |
+
|
327 |
+
yield {
|
328 |
+
"stage": "ecology_streaming",
|
329 |
+
"message": final_response,
|
330 |
+
"type": "streaming",
|
331 |
+
"is_complete": token_result["is_complete"]
|
332 |
+
}
|
333 |
+
|
334 |
+
if token_result["is_complete"]:
|
335 |
+
ecology_time = time.perf_counter() - ecology_start
|
336 |
+
self._log_gpu_memory(session_id, "after", "Ecology Agent (Memory Path)")
|
337 |
+
execution_summary["timings"]["ecology_agent"] = ecology_time
|
338 |
+
execution_summary["agents_executed"].append("ecology")
|
339 |
+
execution_summary["execution_order"].append("ecology")
|
340 |
+
agent_results["ecology"] = {"final_response": final_response}
|
341 |
+
print(f"Session {session_id} - Ecology (Memory Path): Completed in {ecology_time:.2f}s")
|
342 |
+
break
|
343 |
+
|
344 |
+
total_time = time.perf_counter() - start_time
|
345 |
+
execution_summary["timings"]["total"] = total_time
|
346 |
+
execution_summary["status"] = "completed_via_memory"
|
347 |
+
|
348 |
+
detection_data_monitor = self._format_detection_data_for_monitor(
|
349 |
+
detection_narrative=cached_detection_narrative
|
350 |
+
)
|
351 |
+
|
352 |
+
yield {
|
353 |
+
"stage": "complete",
|
354 |
+
"message": final_response,
|
355 |
+
"type": "final",
|
356 |
+
"detection_data": detection_data_monitor,
|
357 |
+
"agent_results": agent_results,
|
358 |
+
"execution_summary": execution_summary,
|
359 |
+
"execution_time": total_time,
|
360 |
+
"status": "success"
|
361 |
+
}
|
362 |
+
return
|
363 |
+
else:
|
364 |
+
for result in self._execute_full_pipeline_with_narrative_flow(
|
365 |
+
user_message=user_message,
|
366 |
+
conversation_history=conversation_history,
|
367 |
+
session_id=session_id,
|
368 |
+
memory_context=memory_context,
|
369 |
+
memory_tool_cache_id=memory_result.get("tool_cache_id"),
|
370 |
+
start_time=start_time
|
371 |
+
):
|
372 |
+
yield result
|
373 |
+
if result["type"] == "final":
|
374 |
+
return
|
375 |
+
|
376 |
+
except Exception as e:
|
377 |
+
error_msg = f"Orchestrator error (Session {session_id}): {str(e)}"
|
378 |
+
print(f"ORCHESTRATOR ERROR: {error_msg}")
|
379 |
+
|
380 |
+
try:
|
381 |
+
self._aggressive_gpu_cleanup(session_id, "emergency")
|
382 |
+
except Exception as cleanup_error:
|
383 |
+
print(f"Emergency cleanup error: {cleanup_error}")
|
384 |
+
|
385 |
+
partial_time = time.perf_counter() - start_time
|
386 |
+
execution_summary["timings"]["total"] = partial_time
|
387 |
+
execution_summary["status"] = "error"
|
388 |
+
execution_summary["error"] = error_msg
|
389 |
+
|
390 |
+
fallback_response = self._create_fallback_response(
|
391 |
+
user_message=user_message,
|
392 |
+
agent_results=agent_results,
|
393 |
+
error=error_msg,
|
394 |
+
session_id=session_id
|
395 |
+
)
|
396 |
+
|
397 |
+
yield {
|
398 |
+
"stage": "error",
|
399 |
+
"message": fallback_response,
|
400 |
+
"type": "final",
|
401 |
+
"detection_data": "Error occurred - no detection data available",
|
402 |
+
"agent_results": agent_results,
|
403 |
+
"execution_summary": execution_summary,
|
404 |
+
"execution_time": partial_time,
|
405 |
+
"status": "error",
|
406 |
+
"error": error_msg
|
407 |
+
}
|
408 |
+
|
409 |
+
finally:
|
410 |
+
session_state_manager.set_processing_state(session_id, False)
|
411 |
+
|
412 |
+
def _execute_full_pipeline_with_narrative_flow(
|
413 |
+
self,
|
414 |
+
user_message: str,
|
415 |
+
conversation_history: List[Dict[str, Any]],
|
416 |
+
session_id: str,
|
417 |
+
memory_context: str,
|
418 |
+
memory_tool_cache_id: Optional[str],
|
419 |
+
start_time: float
|
420 |
+
) -> Generator[Dict[str, Any], None, None]:
|
421 |
+
"""
|
422 |
+
Execute the complete pipeline using memory context, visual contexts, and detection narratives.
|
423 |
+
|
424 |
+
Args:
|
425 |
+
user_message: Current user query
|
426 |
+
conversation_history: Complete conversation context
|
427 |
+
session_id: Unique session identifier
|
428 |
+
memory_context: Context from memory agent
|
429 |
+
memory_tool_cache_id (Optional[str]): Cache identifier from memory agent
|
430 |
+
start_time: Start time for total execution calculation
|
431 |
+
|
432 |
+
Yields:
|
433 |
+
Dict[str, Any]: Progress updates during processing containing:
|
434 |
+
- stage (str): Current workflow stage ("visual_analysis", "detector", etc.)
|
435 |
+
- message (str): Human-readable progress message
|
436 |
+
- type (str): Update type ("progress", "streaming", "final")
|
437 |
+
- Additional stage-specific data (detection_data, agent_results, etc.)
|
438 |
+
"""
|
439 |
+
agent_results = {}
|
440 |
+
execution_summary = {
|
441 |
+
"agents_executed": [],
|
442 |
+
"execution_order": [],
|
443 |
+
"timings": {},
|
444 |
+
"status": "in_progress",
|
445 |
+
"session_id": session_id,
|
446 |
+
"workflow_type": "Full Pipeline with Narrative Flow",
|
447 |
+
"memory_provided_direct_answer": False,
|
448 |
+
"deepforest_executed": False
|
449 |
+
}
|
450 |
+
|
451 |
+
visual_context = ""
|
452 |
+
detection_narrative = ""
|
453 |
+
|
454 |
+
yield {"stage": "visual_analysis", "message": "Analyzing image with unified full/tiled approach...", "type": "progress"}
|
455 |
+
|
456 |
+
if session_state_manager.is_cancelled(session_id):
|
457 |
+
raise Exception("Processing cancelled by user")
|
458 |
+
|
459 |
+
print(f"\nSTEP 1: Visual Analysis (Session {session_id})")
|
460 |
+
self._log_gpu_memory(session_id, "before", "Visual Analysis")
|
461 |
+
visual_start = time.perf_counter()
|
462 |
+
|
463 |
+
# Unified visual analysis
|
464 |
+
visual_analysis_result = self.visual_agent.analyze_full_image(
|
465 |
+
user_message=user_message,
|
466 |
+
session_id=session_id
|
467 |
+
)
|
468 |
+
|
469 |
+
visual_time = time.perf_counter() - visual_start
|
470 |
+
self._log_gpu_memory(session_id, "after", "Visual Analysis")
|
471 |
+
self._aggressive_gpu_cleanup(session_id, "after_visual_analysis")
|
472 |
+
execution_summary["timings"]["visual_analysis"] = visual_time
|
473 |
+
execution_summary["agents_executed"].append("visual_analysis")
|
474 |
+
execution_summary["execution_order"].append("visual_analysis")
|
475 |
+
agent_results["visual_analysis"] = visual_analysis_result
|
476 |
+
|
477 |
+
# Extract visual context
|
478 |
+
visual_context = visual_analysis_result.get("visual_analysis", "No visual analysis available")
|
479 |
+
|
480 |
+
print(f"Session {session_id} - Visual Analysis: {visual_analysis_result.get('status')}")
|
481 |
+
print(f"Session {session_id} - Analysis Type: {visual_analysis_result.get('analysis_type')}")
|
482 |
+
|
483 |
+
yield {"stage": "resolution_check", "message": "Checking image resolution for DeepForest suitability...", "type": "progress"}
|
484 |
+
|
485 |
+
if session_state_manager.is_cancelled(session_id):
|
486 |
+
raise Exception("Processing cancelled by user")
|
487 |
+
|
488 |
+
print(f"\nSTEP 2: Resolution Check (Session {session_id})")
|
489 |
+
resolution_start = time.perf_counter()
|
490 |
+
|
491 |
+
image_file_path = session_state_manager.get(session_id, "image_file_path")
|
492 |
+
resolution_result = None
|
493 |
+
|
494 |
+
if image_file_path:
|
495 |
+
resolution_result = check_image_resolution_for_deepforest(image_file_path)
|
496 |
+
resolution_time = time.perf_counter() - resolution_start
|
497 |
+
|
498 |
+
multi_agent_logger.log_resolution_check(
|
499 |
+
session_id=session_id,
|
500 |
+
image_file_path=image_file_path,
|
501 |
+
resolution_result=resolution_result,
|
502 |
+
execution_time=resolution_time
|
503 |
+
)
|
504 |
+
else:
|
505 |
+
resolution_result = {
|
506 |
+
"is_suitable": True,
|
507 |
+
"resolution_info": "No file path available for resolution check",
|
508 |
+
"error": None
|
509 |
+
}
|
510 |
+
resolution_time = time.perf_counter() - resolution_start
|
511 |
+
|
512 |
+
execution_summary["timings"]["resolution_check"] = resolution_time
|
513 |
+
execution_summary["agents_executed"].append("resolution_check")
|
514 |
+
execution_summary["execution_order"].append("resolution_check")
|
515 |
+
agent_results["resolution_check"] = resolution_result
|
516 |
+
|
517 |
+
# Determine if DeepForest should run
|
518 |
+
detection_result = None
|
519 |
+
image_quality_good = visual_analysis_result.get("image_quality_for_deepforest", "No").lower() == "yes"
|
520 |
+
resolution_suitable = resolution_result.get("is_suitable", True)
|
521 |
+
|
522 |
+
if resolution_suitable and image_quality_good:
|
523 |
+
yield {"stage": "detector", "message": "Quality and resolution good - executing DeepForest detection with narrative generation...", "type": "progress"}
|
524 |
+
|
525 |
+
if session_state_manager.is_cancelled(session_id):
|
526 |
+
raise Exception("Processing cancelled by user")
|
527 |
+
|
528 |
+
print(f"\nSTEP 3: DeepForest Detection with R-tree and Narrative (Session {session_id})")
|
529 |
+
self._log_gpu_memory(session_id, "before", "DeepForest Detection")
|
530 |
+
detector_start = time.perf_counter()
|
531 |
+
|
532 |
+
visual_objects = visual_analysis_result.get("deepforest_objects_present", [])
|
533 |
+
|
534 |
+
try:
|
535 |
+
detection_result = self.detector_agent.execute_detection_with_context(
|
536 |
+
user_message=user_message,
|
537 |
+
session_id=session_id,
|
538 |
+
visual_objects_detected=visual_objects,
|
539 |
+
memory_context=memory_context
|
540 |
+
)
|
541 |
+
|
542 |
+
detector_time = time.perf_counter() - detector_start
|
543 |
+
self._log_gpu_memory(session_id, "after", "DeepForest Detection")
|
544 |
+
self._aggressive_gpu_cleanup(session_id, "after_deepforest_detection")
|
545 |
+
execution_summary["timings"]["detector_agent"] = detector_time
|
546 |
+
execution_summary["agents_executed"].append("detector")
|
547 |
+
execution_summary["execution_order"].append("detector")
|
548 |
+
execution_summary["deepforest_executed"] = True
|
549 |
+
agent_results["detector"] = detection_result
|
550 |
+
|
551 |
+
# Extract detection narrative and tool cache ID from current run
|
552 |
+
current_detection_narrative = detection_result.get("detection_narrative", "No detection narrative available")
|
553 |
+
|
554 |
+
# Combine cached narratives from memory with current detection narrative
|
555 |
+
combined_narratives = []
|
556 |
+
|
557 |
+
# Add cached narratives from memory's tool cache IDs (if any)
|
558 |
+
if memory_tool_cache_id:
|
559 |
+
cached_narrative = self._get_cached_detection_narrative(memory_tool_cache_id)
|
560 |
+
if cached_narrative:
|
561 |
+
combined_narratives.append(cached_narrative)
|
562 |
+
|
563 |
+
# Add current detection narratives for ALL tool results
|
564 |
+
tool_results = detection_result.get("tool_results", [])
|
565 |
+
if tool_results:
|
566 |
+
for tool_result in tool_results:
|
567 |
+
cache_key = tool_result.get("cache_key")
|
568 |
+
tool_arguments = tool_result.get("tool_arguments", {})
|
569 |
+
|
570 |
+
if cache_key and tool_arguments:
|
571 |
+
# Get all possible arguments including defaults from Config
|
572 |
+
from deepforest_agent.conf.config import Config
|
573 |
+
all_arguments = Config.DEEPFOREST_DEFAULTS.copy()
|
574 |
+
all_arguments.update(tool_arguments)
|
575 |
+
|
576 |
+
# Format tool call info with all arguments
|
577 |
+
args_str = ", ".join([f"{k}={v}" for k, v in all_arguments.items()])
|
578 |
+
|
579 |
+
formatted_current = f"**TOOL CACHE ID:** {cache_key}\nDeepForest tool run with arguments ({args_str}) and got the below narratives:\nDETECTION NARRATIVE:\n{current_detection_narrative}"
|
580 |
+
combined_narratives.append(formatted_current)
|
581 |
+
|
582 |
+
# If no tool results but we have narrative, add it without formatting
|
583 |
+
if not tool_results and current_detection_narrative and current_detection_narrative != "No detection narrative available":
|
584 |
+
combined_narratives.append(current_detection_narrative)
|
585 |
+
|
586 |
+
# Combine all narratives
|
587 |
+
detection_narrative = "\n\n".join(combined_narratives) if combined_narratives else "No detection narrative available"
|
588 |
+
|
589 |
+
print(f"Session {session_id} - DeepForest Detection completed with narrative")
|
590 |
+
|
591 |
+
except Exception as detector_error:
|
592 |
+
print(f"Session {session_id} - DeepForest Detection FAILED: {detector_error}")
|
593 |
+
detection_result = None
|
594 |
+
detection_narrative = f"DeepForest detection failed: {str(detector_error)}"
|
595 |
+
else:
|
596 |
+
skip_reasons = []
|
597 |
+
if not resolution_suitable:
|
598 |
+
skip_reasons.append("insufficient resolution")
|
599 |
+
if not image_quality_good:
|
600 |
+
skip_reasons.append("poor image quality")
|
601 |
+
|
602 |
+
print(f"Session {session_id} - Skipping DeepForest detection: {', '.join(skip_reasons)}")
|
603 |
+
execution_summary["deepforest_executed"] = False
|
604 |
+
execution_summary["deepforest_skip_reason"] = ", ".join(skip_reasons)
|
605 |
+
detection_narrative = f"DeepForest detection was skipped due to: {', '.join(skip_reasons)}"
|
606 |
+
|
607 |
+
yield {"stage": "ecology", "message": "Synthesizing ecological insights from all contexts...", "type": "progress"}
|
608 |
+
|
609 |
+
if session_state_manager.is_cancelled(session_id):
|
610 |
+
raise Exception("Processing cancelled by user")
|
611 |
+
|
612 |
+
print(f"\nSTEP 4: Ecology Analysis with Comprehensive Context (Session {session_id})")
|
613 |
+
self._log_gpu_memory(session_id, "before", "Ecology Analysis")
|
614 |
+
ecology_start = time.perf_counter()
|
615 |
+
|
616 |
+
# Prepare comprehensive context for ecology agent
|
617 |
+
comprehensive_context = self._prepare_comprehensive_context(
|
618 |
+
memory_context=memory_context,
|
619 |
+
visual_context=visual_context,
|
620 |
+
detection_narrative=detection_narrative,
|
621 |
+
tool_cache_id=memory_tool_cache_id
|
622 |
+
)
|
623 |
+
|
624 |
+
final_response = ""
|
625 |
+
try:
|
626 |
+
for token_result in self.ecology_agent.synthesize_analysis_streaming(
|
627 |
+
user_message=user_message,
|
628 |
+
memory_context=comprehensive_context,
|
629 |
+
cached_json=None,
|
630 |
+
current_json=None,
|
631 |
+
session_id=session_id
|
632 |
+
):
|
633 |
+
if session_state_manager.is_cancelled(session_id):
|
634 |
+
raise Exception("Processing cancelled by user")
|
635 |
+
|
636 |
+
final_response += token_result["token"]
|
637 |
+
|
638 |
+
yield {
|
639 |
+
"stage": "ecology_streaming",
|
640 |
+
"message": final_response,
|
641 |
+
"type": "streaming",
|
642 |
+
"is_complete": token_result["is_complete"]
|
643 |
+
}
|
644 |
+
|
645 |
+
if token_result["is_complete"]:
|
646 |
+
break
|
647 |
+
|
648 |
+
except Exception as ecology_error:
|
649 |
+
print(f"Session {session_id} - Ecology streaming error: {ecology_error}")
|
650 |
+
if not final_response:
|
651 |
+
final_response = f"Ecology analysis failed: {str(ecology_error)}"
|
652 |
+
|
653 |
+
finally:
|
654 |
+
ecology_time = time.perf_counter() - ecology_start
|
655 |
+
self._log_gpu_memory(session_id, "after", "Ecology Analysis")
|
656 |
+
self._aggressive_gpu_cleanup(session_id, "after_ecology_analysis")
|
657 |
+
execution_summary["timings"]["ecology_agent"] = ecology_time
|
658 |
+
execution_summary["agents_executed"].append("ecology")
|
659 |
+
execution_summary["execution_order"].append("ecology")
|
660 |
+
agent_results["ecology"] = {"final_response": final_response}
|
661 |
+
|
662 |
+
# Store context data for memory agent's next turn
|
663 |
+
current_turn = len(session_state_manager.get(session_id, "conversation_history", [])) // 2 + 1
|
664 |
+
all_tool_cache_ids = []
|
665 |
+
if memory_tool_cache_id:
|
666 |
+
all_tool_cache_ids.extend([id.strip() for id in memory_tool_cache_id.split(",")])
|
667 |
+
|
668 |
+
# Add all current tool cache IDs
|
669 |
+
tool_results = detection_result.get("tool_results", []) if detection_result else []
|
670 |
+
for tool_result in tool_results:
|
671 |
+
cache_key = tool_result.get("cache_key")
|
672 |
+
if cache_key:
|
673 |
+
all_tool_cache_ids.append(cache_key)
|
674 |
+
|
675 |
+
combined_tool_cache_id = ", ".join(all_tool_cache_ids) if all_tool_cache_ids else None
|
676 |
+
self.memory_agent.store_turn_context(
|
677 |
+
session_id=session_id,
|
678 |
+
turn_number=current_turn,
|
679 |
+
visual_context=visual_context,
|
680 |
+
detection_narrative=detection_narrative,
|
681 |
+
tool_cache_id=combined_tool_cache_id
|
682 |
+
)
|
683 |
+
|
684 |
+
# Final result
|
685 |
+
total_time = time.perf_counter() - start_time
|
686 |
+
execution_summary["timings"]["total"] = total_time
|
687 |
+
execution_summary["status"] = "completed_narrative_flow"
|
688 |
+
|
689 |
+
detection_data_monitor = self._format_detection_data_for_monitor(
|
690 |
+
detection_narrative=detection_narrative,
|
691 |
+
detections_list=detection_result.get("detections_list", []) if detection_result else None
|
692 |
+
)
|
693 |
+
|
694 |
+
print(f"Session {session_id} - NARRATIVE FLOW WORKFLOW COMPLETED")
|
695 |
+
|
696 |
+
yield {
|
697 |
+
"stage": "complete",
|
698 |
+
"message": final_response,
|
699 |
+
"type": "final",
|
700 |
+
"detection_data": detection_data_monitor,
|
701 |
+
"agent_results": agent_results,
|
702 |
+
"execution_summary": execution_summary,
|
703 |
+
"execution_time": total_time,
|
704 |
+
"status": "success"
|
705 |
+
}
|
706 |
+
|
707 |
+
def _prepare_comprehensive_context(
|
708 |
+
self,
|
709 |
+
memory_context: str,
|
710 |
+
visual_context: str,
|
711 |
+
detection_narrative: str,
|
712 |
+
tool_cache_id: Optional[str]
|
713 |
+
) -> str:
|
714 |
+
"""
|
715 |
+
Prepare comprehensive context combining all data sources with better formatting.
|
716 |
+
|
717 |
+
Args:
|
718 |
+
memory_context: Context from memory agent
|
719 |
+
visual_context: Visual analysis context
|
720 |
+
detection_narrative: R-tree based detection narrative
|
721 |
+
tool_cache_id: Tool cache reference if available
|
722 |
+
|
723 |
+
Returns:
|
724 |
+
Combined context string for ecology agent
|
725 |
+
"""
|
726 |
+
context_parts = []
|
727 |
+
|
728 |
+
# Memory context section
|
729 |
+
if memory_context and memory_context != "No memory context available":
|
730 |
+
context_parts.append("--- START OF MEMORY CONTEXT ---")
|
731 |
+
context_parts.append(memory_context)
|
732 |
+
context_parts.append("--- END OF MEMORY CONTEXT ---")
|
733 |
+
context_parts.append("")
|
734 |
+
|
735 |
+
# Tool cache reference
|
736 |
+
if tool_cache_id:
|
737 |
+
context_parts.append(f"**TOOL CACHE ID:** {tool_cache_id}")
|
738 |
+
context_parts.append("")
|
739 |
+
|
740 |
+
# Detection narrative section
|
741 |
+
if detection_narrative and detection_narrative not in ["No detection analysis available", ""]:
|
742 |
+
context_parts.append("--- START OF DETECTION ANALYSIS ---")
|
743 |
+
context_parts.append(detection_narrative)
|
744 |
+
context_parts.append("--- END OF DETECTION ANALYSIS ---")
|
745 |
+
context_parts.append("")
|
746 |
+
|
747 |
+
# Visual context section
|
748 |
+
if visual_context and visual_context != "No visual analysis available":
|
749 |
+
context_parts.append("--- START OF VISUAL ANALYSIS ---")
|
750 |
+
context_parts.append(visual_context)
|
751 |
+
context_parts.append("There may be information that are not clear or accurate in this visual analysis. So make sure to mention that this analysis is provided by a visual analysis agent and it may not be very accurate as there is no confidence score associated with it. You can only provide this analysis seperately in a different section and inform the user that you are not very confident about this analysis.")
|
752 |
+
context_parts.append("--- END OF VISUAL ANALYSIS ---")
|
753 |
+
context_parts.append("")
|
754 |
+
|
755 |
+
# If we have very little context, provide a meaningful message
|
756 |
+
if not context_parts or len("".join(context_parts)) < 50:
|
757 |
+
return "No comprehensive context available for this query. Please provide more information or try a different approach."
|
758 |
+
|
759 |
+
result_context = "\n".join(context_parts)
|
760 |
+
|
761 |
+
print(f"Prepared comprehensive context ({len(result_context)} characters)")
|
762 |
+
print(f"Context preview: {result_context[:200]}...")
|
763 |
+
|
764 |
+
return result_context
|
765 |
+
|
766 |
+
def _create_fallback_response(
|
767 |
+
self,
|
768 |
+
user_message: str,
|
769 |
+
agent_results: Dict[str, Any],
|
770 |
+
error: str,
|
771 |
+
session_id: str
|
772 |
+
) -> str:
|
773 |
+
"""Create a fallback response when the orchestrator encounters errors."""
|
774 |
+
response_parts = []
|
775 |
+
response_parts.append(f"I encountered some processing issues but can provide analysis based on available data:")
|
776 |
+
response_parts.append("")
|
777 |
+
|
778 |
+
memory_result = agent_results.get("memory", {})
|
779 |
+
if memory_result and memory_result.get("relevant_context"):
|
780 |
+
response_parts.append(f"**Memory Context**: {memory_result['relevant_context']}")
|
781 |
+
response_parts.append("")
|
782 |
+
|
783 |
+
visual_result = agent_results.get("visual_analysis", {})
|
784 |
+
if visual_result and visual_result.get("visual_analysis"):
|
785 |
+
response_parts.append(f"**Visual Analysis**: {visual_result['visual_analysis']}")
|
786 |
+
response_parts.append("")
|
787 |
+
|
788 |
+
detector_result = agent_results.get("detector", {})
|
789 |
+
if detector_result and detector_result.get("detection_narrative"):
|
790 |
+
response_parts.append(f"**Detection Results**: {detector_result['detection_narrative']}")
|
791 |
+
response_parts.append("")
|
792 |
+
|
793 |
+
response_parts.append(f"Note: Workflow was interrupted ({error}). Please try your query again for full results.")
|
794 |
+
|
795 |
+
return "\n".join(response_parts)
|
src/deepforest_agent/agents/visual_analysis_agent.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any, Optional
|
2 |
+
from PIL import Image
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
import torch
|
7 |
+
import gc
|
8 |
+
|
9 |
+
from deepforest_agent.models.qwen_vl_3b_instruct import QwenVL3BModelManager
|
10 |
+
from deepforest_agent.utils.image_utils import encode_pil_image_to_base64_url, determine_patch_size, get_image_dimensions_fast
|
11 |
+
from deepforest_agent.utils.state_manager import session_state_manager
|
12 |
+
from deepforest_agent.conf.config import Config
|
13 |
+
from deepforest_agent.utils.parsing_utils import (
|
14 |
+
parse_image_quality_for_deepforest,
|
15 |
+
parse_deepforest_objects_present,
|
16 |
+
parse_visual_analysis,
|
17 |
+
parse_additional_objects_json
|
18 |
+
)
|
19 |
+
from deepforest_agent.prompts.prompt_templates import create_full_image_quality_analysis_prompt, create_individual_tile_analysis_prompt
|
20 |
+
from deepforest_agent.utils.logging_utils import multi_agent_logger
|
21 |
+
from deepforest_agent.utils.tile_manager import tile_image_for_analysis
|
22 |
+
|
23 |
+
|
24 |
+
class VisualAnalysisAgent:
|
25 |
+
"""
|
26 |
+
Visual analysis agent responsible for analyzing images with unified full/tiled approach.
|
27 |
+
Uses Qwen VL model for multimodal understanding.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self):
|
31 |
+
"""Initialize the Visual Analysis Agent."""
|
32 |
+
self.agent_config = Config.AGENT_CONFIGS["visual_analysis"]
|
33 |
+
self.model_manager = QwenVL3BModelManager(Config.AGENT_MODELS["visual_analysis"])
|
34 |
+
|
35 |
+
def analyze_full_image(self, user_message: str, session_id: str) -> Dict[str, Any]:
|
36 |
+
"""
|
37 |
+
Analyze full image with automatic fallback to tiling on OOM.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
user_message: User's query
|
41 |
+
session_id: Session identifier
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Dict with unified structure for both full and tiled analysis
|
45 |
+
"""
|
46 |
+
if not session_state_manager.session_exists(session_id):
|
47 |
+
return {
|
48 |
+
"image_quality_for_deepforest": "No",
|
49 |
+
"deepforest_objects_present": [],
|
50 |
+
"additional_objects": [],
|
51 |
+
"visual_analysis": f"Session {session_id} not found.",
|
52 |
+
"status": "error",
|
53 |
+
"analysis_type": "error"
|
54 |
+
}
|
55 |
+
|
56 |
+
image = session_state_manager.get(session_id, "current_image")
|
57 |
+
if image is None:
|
58 |
+
return {
|
59 |
+
"image_quality_for_deepforest": "No",
|
60 |
+
"deepforest_objects_present": [],
|
61 |
+
"additional_objects": [],
|
62 |
+
"visual_analysis": f"No image available in session {session_id}.",
|
63 |
+
"status": "error",
|
64 |
+
"analysis_type": "error"
|
65 |
+
}
|
66 |
+
|
67 |
+
# Try full image analysis first
|
68 |
+
try:
|
69 |
+
print(f"Session {session_id} - Attempting full image analysis")
|
70 |
+
result = self._analyze_single_image(image, user_message, session_id, is_full_image=True)
|
71 |
+
|
72 |
+
if result["status"] == "success":
|
73 |
+
multi_agent_logger.log_agent_execution(
|
74 |
+
session_id=session_id,
|
75 |
+
agent_name="visual_analysis",
|
76 |
+
agent_input=f"Full image analysis for: {user_message}",
|
77 |
+
agent_output=result["visual_analysis"],
|
78 |
+
execution_time=0.0
|
79 |
+
)
|
80 |
+
return result
|
81 |
+
|
82 |
+
except Exception as e:
|
83 |
+
print(f"Session {session_id} - Full image analysis failed (likely OOM): {e}")
|
84 |
+
return self._analyze_with_tiling(user_message, session_id, str(e))
|
85 |
+
|
86 |
+
return self._analyze_with_tiling(user_message, session_id, "Full image analysis failed")
|
87 |
+
|
88 |
+
def _analyze_single_image(self, image: Image.Image, user_message: str, session_id: str,
|
89 |
+
is_full_image: bool = True, tile_location: str = "") -> Dict[str, Any]:
|
90 |
+
"""
|
91 |
+
Analyze a single image (full image or tile) with unified structure.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
image: PIL Image to analyze
|
95 |
+
user_message: User's query
|
96 |
+
session_id: Session identifier
|
97 |
+
is_full_image: Whether this is full image or tile
|
98 |
+
tile_location: Location description for tiles
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
Unified analysis result
|
102 |
+
"""
|
103 |
+
system_prompt = create_full_image_quality_analysis_prompt(user_message)
|
104 |
+
image_base64_url = encode_pil_image_to_base64_url(image)
|
105 |
+
|
106 |
+
messages = [
|
107 |
+
{"role": "system", "content": [{"type": "text", "text": system_prompt}]},
|
108 |
+
{
|
109 |
+
"role": "user",
|
110 |
+
"content": [
|
111 |
+
{"type": "image", "image": image_base64_url},
|
112 |
+
{"type": "text", "text": user_message}
|
113 |
+
]
|
114 |
+
}
|
115 |
+
]
|
116 |
+
|
117 |
+
response = self.model_manager.generate_response(
|
118 |
+
messages=messages,
|
119 |
+
max_new_tokens=self.agent_config["max_new_tokens"],
|
120 |
+
temperature=self.agent_config["temperature"]
|
121 |
+
)
|
122 |
+
|
123 |
+
# Parse structured response
|
124 |
+
image_quality = parse_image_quality_for_deepforest(response)
|
125 |
+
deepforest_objects = parse_deepforest_objects_present(response)
|
126 |
+
additional_objects = parse_additional_objects_json(response)
|
127 |
+
raw_visual_analysis = parse_visual_analysis(response)
|
128 |
+
|
129 |
+
# Format visual analysis with consistent prefix
|
130 |
+
if is_full_image:
|
131 |
+
width, height = image.size
|
132 |
+
visual_analysis = f"Full image analysis of image ({width}x{height}) is done. Here's the analysis: {raw_visual_analysis}"
|
133 |
+
analysis_type = "full_image"
|
134 |
+
else:
|
135 |
+
visual_analysis = f"The visual analysis of tiled image on ({tile_location}) this location is done. Here's the analysis: {raw_visual_analysis}"
|
136 |
+
analysis_type = "tiled_image"
|
137 |
+
|
138 |
+
return {
|
139 |
+
"image_quality_for_deepforest": image_quality,
|
140 |
+
"deepforest_objects_present": deepforest_objects,
|
141 |
+
"additional_objects": additional_objects,
|
142 |
+
"visual_analysis": visual_analysis,
|
143 |
+
"status": "success",
|
144 |
+
"analysis_type": analysis_type,
|
145 |
+
"raw_response": response
|
146 |
+
}
|
147 |
+
|
148 |
+
def _analyze_with_tiling(self, user_message: str, session_id: str, error_msg: str) -> Dict[str, Any]:
|
149 |
+
"""
|
150 |
+
Analyze image using tiling approach when full image fails.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
user_message: User's query
|
154 |
+
session_id: Session identifier
|
155 |
+
error_msg: Original error message
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
Combined analysis from tiled approach with same structure as full image
|
159 |
+
"""
|
160 |
+
print(f"Session {session_id} - Falling back to tiled analysis due to: {error_msg}")
|
161 |
+
|
162 |
+
image = session_state_manager.get(session_id, "current_image")
|
163 |
+
image_file_path = session_state_manager.get(session_id, "image_file_path")
|
164 |
+
|
165 |
+
if not image:
|
166 |
+
return {
|
167 |
+
"image_quality_for_deepforest": "No",
|
168 |
+
"deepforest_objects_present": [],
|
169 |
+
"additional_objects": [],
|
170 |
+
"visual_analysis": "No image available for tiled analysis.",
|
171 |
+
"status": "error",
|
172 |
+
"analysis_type": "error"
|
173 |
+
}
|
174 |
+
|
175 |
+
# Determine appropriate patch size
|
176 |
+
if image_file_path:
|
177 |
+
patch_size = determine_patch_size(image_file_path, image.size)
|
178 |
+
else:
|
179 |
+
max_dim = max(image.size)
|
180 |
+
if max_dim >= 5000:
|
181 |
+
patch_size = 1500 if max_dim <= 7500 else 2000
|
182 |
+
else:
|
183 |
+
patch_size = 1000
|
184 |
+
|
185 |
+
print(f"Session {session_id} - Using patch size {patch_size} for tiled analysis")
|
186 |
+
|
187 |
+
try:
|
188 |
+
tiles, tile_metadata = tile_image_for_analysis(
|
189 |
+
image=image,
|
190 |
+
patch_size=patch_size,
|
191 |
+
patch_overlap=Config.DEEPFOREST_DEFAULTS["patch_overlap"],
|
192 |
+
image_file_path=image_file_path
|
193 |
+
)
|
194 |
+
|
195 |
+
print(f"Session {session_id} - Created {len(tiles)} tiles for analysis")
|
196 |
+
|
197 |
+
# Analyze all tiles and combine results
|
198 |
+
all_visual_analyses = []
|
199 |
+
all_additional_objects = []
|
200 |
+
tile_results = []
|
201 |
+
|
202 |
+
for i, (tile, metadata) in enumerate(zip(tiles, tile_metadata)):
|
203 |
+
try:
|
204 |
+
tile_coords = metadata.get("window_coords", {})
|
205 |
+
location_desc = f"x:{tile_coords.get('x', 0)}-{tile_coords.get('x', 0) + tile_coords.get('width', 0)}, y:{tile_coords.get('y', 0)}-{tile_coords.get('y', 0) + tile_coords.get('height', 0)}"
|
206 |
+
|
207 |
+
# Analyze individual tile
|
208 |
+
tile_result = self._analyze_single_image(
|
209 |
+
image=tile,
|
210 |
+
user_message=user_message,
|
211 |
+
session_id=session_id,
|
212 |
+
is_full_image=False,
|
213 |
+
tile_location=location_desc
|
214 |
+
)
|
215 |
+
|
216 |
+
if tile_result["status"] == "success":
|
217 |
+
all_visual_analyses.append(tile_result["visual_analysis"])
|
218 |
+
all_additional_objects.extend(tile_result["additional_objects"])
|
219 |
+
|
220 |
+
# Store tile result for potential reuse
|
221 |
+
tile_results.append({
|
222 |
+
"tile_id": i,
|
223 |
+
"location": location_desc,
|
224 |
+
"coordinates": tile_coords,
|
225 |
+
"visual_analysis": tile_result["visual_analysis"],
|
226 |
+
"additional_objects": tile_result["additional_objects"]
|
227 |
+
})
|
228 |
+
|
229 |
+
# Log individual tile analysis
|
230 |
+
multi_agent_logger.log_agent_execution(
|
231 |
+
session_id=session_id,
|
232 |
+
agent_name=f"visual_tile_{i}",
|
233 |
+
agent_input=f"Tile {i+1} analysis: {user_message}",
|
234 |
+
agent_output=tile_result["visual_analysis"],
|
235 |
+
execution_time=0.0
|
236 |
+
)
|
237 |
+
|
238 |
+
print(f"Session {session_id} - Analyzed tile {i+1}/{len(tiles)}")
|
239 |
+
|
240 |
+
# Memory cleanup
|
241 |
+
del tile
|
242 |
+
gc.collect()
|
243 |
+
if torch.cuda.is_available():
|
244 |
+
torch.cuda.empty_cache()
|
245 |
+
|
246 |
+
except Exception as tile_error:
|
247 |
+
print(f"Session {session_id} - Tile {i} analysis failed: {tile_error}")
|
248 |
+
continue
|
249 |
+
|
250 |
+
if all_visual_analyses:
|
251 |
+
# Store tile results for potential reuse
|
252 |
+
session_state_manager.set(session_id, "tile_analysis_results", tile_results)
|
253 |
+
session_state_manager.set(session_id, "tiled_patch_size", patch_size)
|
254 |
+
|
255 |
+
# Combine all tile analyses
|
256 |
+
combined_visual_analysis = " ".join(all_visual_analyses)
|
257 |
+
|
258 |
+
return {
|
259 |
+
"image_quality_for_deepforest": "Yes",
|
260 |
+
"deepforest_objects_present": ["tree", "bird", "livestock"],
|
261 |
+
"additional_objects": all_additional_objects,
|
262 |
+
"visual_analysis": combined_visual_analysis,
|
263 |
+
"status": "tiled_success",
|
264 |
+
"analysis_type": "tiled_combined",
|
265 |
+
"tile_count": len(tiles),
|
266 |
+
"successful_tiles": len(all_visual_analyses),
|
267 |
+
"patch_size_used": patch_size
|
268 |
+
}
|
269 |
+
|
270 |
+
except Exception as tiling_error:
|
271 |
+
print(f"Session {session_id} - Tiled analysis also failed: {tiling_error}")
|
272 |
+
|
273 |
+
# Final fallback - resolution-based assessment
|
274 |
+
resolution_result = session_state_manager.get(session_id, "resolution_result")
|
275 |
+
if resolution_result and resolution_result.get("is_suitable"):
|
276 |
+
width, height = image.size
|
277 |
+
return {
|
278 |
+
"image_quality_for_deepforest": "Yes",
|
279 |
+
"deepforest_objects_present": ["tree", "bird", "livestock"],
|
280 |
+
"additional_objects": [],
|
281 |
+
"visual_analysis": f"Full image analysis of image ({width}x{height}) is done. Here's the analysis: Large image analyzed using resolution-based assessment. Original error: {error_msg}",
|
282 |
+
"status": "resolution_fallback",
|
283 |
+
"analysis_type": "resolution_based"
|
284 |
+
}
|
285 |
+
|
286 |
+
# Complete failure
|
287 |
+
width, height = image.size
|
288 |
+
return {
|
289 |
+
"image_quality_for_deepforest": "No",
|
290 |
+
"deepforest_objects_present": [],
|
291 |
+
"additional_objects": [],
|
292 |
+
"visual_analysis": f"Full image analysis of image ({width}x{height}) failed. Analysis could not be completed due to: {error_msg}",
|
293 |
+
"status": "error",
|
294 |
+
"analysis_type": "failed"
|
295 |
+
}
|
296 |
+
|
297 |
+
def get_tile_analysis_results(self, session_id: str) -> List[Dict[str, Any]]:
|
298 |
+
"""
|
299 |
+
Get stored tile analysis results for reuse.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
session_id: Session identifier
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
List of tile analysis results or empty list
|
306 |
+
"""
|
307 |
+
return session_state_manager.get(session_id, "tile_analysis_results", [])
|
src/deepforest_agent/conf/__init__.py
ADDED
File without changes
|
src/deepforest_agent/conf/config.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
class Config:
|
4 |
+
"""
|
5 |
+
Configuration class defining DeepForest model paths, visualization colors, and agent models.
|
6 |
+
"""
|
7 |
+
|
8 |
+
DEEPFOREST_MODELS = {
|
9 |
+
"bird": "weecology/deepforest-bird",
|
10 |
+
"tree": "weecology/deepforest-tree",
|
11 |
+
"livestock": "weecology/deepforest-livestock"
|
12 |
+
}
|
13 |
+
|
14 |
+
DEEPFOREST_DEFAULTS = {
|
15 |
+
"patch_size": 400,
|
16 |
+
"patch_overlap": 0.05,
|
17 |
+
"iou_threshold": 0.15,
|
18 |
+
"thresh": 0.55,
|
19 |
+
"alive_dead_trees": False
|
20 |
+
}
|
21 |
+
|
22 |
+
COLORS = {
|
23 |
+
"bird": (0, 0, 255), # Red (BGR)
|
24 |
+
"tree": (0, 255, 0), # Green (BGR)
|
25 |
+
"livestock": (255, 0, 0), # Blue (BGR)
|
26 |
+
"alive_tree": (255, 255, 0), # Cyan (BGR)
|
27 |
+
"dead_tree": (0, 165, 255) # Orange (BGR)
|
28 |
+
}
|
29 |
+
|
30 |
+
AGENT_MODELS = {
|
31 |
+
"memory": "HuggingFaceTB/SmolLM3-3B",
|
32 |
+
"deepforest_detector": "HuggingFaceTB/SmolLM3-3B",
|
33 |
+
"visual_analysis": "Qwen/Qwen2.5-VL-3B-Instruct",
|
34 |
+
"ecology_analysis": "meta-llama/Llama-3.2-3B-Instruct"
|
35 |
+
}
|
36 |
+
|
37 |
+
# Agent-specific generation parameters
|
38 |
+
AGENT_CONFIGS = {
|
39 |
+
"memory": {
|
40 |
+
"max_new_tokens": 16000,
|
41 |
+
"temperature": 0.6,
|
42 |
+
"top_p": 0.95
|
43 |
+
},
|
44 |
+
"deepforest_detector": {
|
45 |
+
"max_new_tokens": 16000,
|
46 |
+
"temperature": 0.6,
|
47 |
+
"top_p": 0.95
|
48 |
+
},
|
49 |
+
"visual_analysis": {
|
50 |
+
"max_new_tokens": 5000,
|
51 |
+
"temperature": 0.1
|
52 |
+
},
|
53 |
+
"ecology_analysis": {
|
54 |
+
"max_new_tokens": 16000,
|
55 |
+
"temperature": 0.6,
|
56 |
+
"top_p": 0.95
|
57 |
+
}
|
58 |
+
}
|
59 |
+
|
60 |
+
NO_ALBUMENTATIONS = os.getenv("NO_ALBUMENTATIONS", "")
|
src/deepforest_agent/models/__init__.py
ADDED
File without changes
|
src/deepforest_agent/models/llama32_3b_instruct.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
from typing import Tuple, Dict, Any, Optional, List, Generator
|
3 |
+
import torch
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
+
from transformers.generation.streamers import TextIteratorStreamer
|
6 |
+
from threading import Thread
|
7 |
+
|
8 |
+
from deepforest_agent.conf.config import Config
|
9 |
+
|
10 |
+
|
11 |
+
class Llama32ModelManager:
|
12 |
+
"""
|
13 |
+
Manages Llama-3.2-3B-Instruct model instances for text generation tasks.
|
14 |
+
|
15 |
+
Attributes:
|
16 |
+
model_id (str): HuggingFace model identifier
|
17 |
+
load_count (int): Number of times model has been loaded
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, model_id: str = Config.AGENT_MODELS["ecology_analysis"]):
|
21 |
+
"""
|
22 |
+
Initialize the Llama-3.2-3B model manager.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
model_id (str, optional): HuggingFace model identifier.
|
26 |
+
Defaults to "meta-llama/Llama-3.2-3B-Instruct".
|
27 |
+
"""
|
28 |
+
self.model_id = model_id
|
29 |
+
self.load_count = 0
|
30 |
+
|
31 |
+
def generate_response(
|
32 |
+
self,
|
33 |
+
messages: List[Dict[str, str]],
|
34 |
+
max_new_tokens: int = Config.AGENT_CONFIGS["ecology_analysis"]["max_new_tokens"],
|
35 |
+
temperature: float = Config.AGENT_CONFIGS["ecology_analysis"]["temperature"],
|
36 |
+
top_p: float = Config.AGENT_CONFIGS["ecology_analysis"]["top_p"],
|
37 |
+
tools: Optional[List[Dict[str, Any]]] = None
|
38 |
+
) -> str:
|
39 |
+
"""
|
40 |
+
Generate text response using Llama-3.2-3B-Instruct.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
messages: List of message dictionaries with 'role' and 'content'
|
44 |
+
max_new_tokens: Maximum tokens to generate
|
45 |
+
temperature: Sampling temperature
|
46 |
+
top_p: Top-p sampling
|
47 |
+
tools (Optional[List[Dict[str, Any]]]): List of tools (not used for Llama)
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
str: Generated response text
|
51 |
+
|
52 |
+
Raises:
|
53 |
+
Exception: If generation fails due to model issues, memory, or other errors
|
54 |
+
"""
|
55 |
+
print(f"Loading Llama-3.2-3B for inference #{self.load_count + 1}")
|
56 |
+
|
57 |
+
model, tokenizer = self._load_model()
|
58 |
+
self.load_count += 1
|
59 |
+
|
60 |
+
try:
|
61 |
+
# Llama uses standard chat template without xml_tools
|
62 |
+
text = tokenizer.apply_chat_template(
|
63 |
+
messages,
|
64 |
+
tokenize=False,
|
65 |
+
add_generation_prompt=True
|
66 |
+
)
|
67 |
+
|
68 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
69 |
+
|
70 |
+
generated_ids = model.generate(
|
71 |
+
model_inputs.input_ids,
|
72 |
+
max_new_tokens=max_new_tokens,
|
73 |
+
temperature=temperature,
|
74 |
+
top_p=top_p,
|
75 |
+
do_sample=True,
|
76 |
+
pad_token_id=tokenizer.eos_token_id
|
77 |
+
)
|
78 |
+
|
79 |
+
generated_ids = [
|
80 |
+
output_ids[len(input_ids):]
|
81 |
+
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
82 |
+
]
|
83 |
+
|
84 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
85 |
+
return response
|
86 |
+
|
87 |
+
except Exception as e:
|
88 |
+
print(f"Error during Llama-3.2-3B text generation: {e}")
|
89 |
+
raise e
|
90 |
+
|
91 |
+
finally:
|
92 |
+
print(f"Releasing Llama-3.2-3B GPU memory after inference")
|
93 |
+
if 'model' in locals():
|
94 |
+
if hasattr(model, 'cpu'):
|
95 |
+
model.cpu()
|
96 |
+
del model
|
97 |
+
if 'tokenizer' in locals():
|
98 |
+
del tokenizer
|
99 |
+
if 'model_inputs' in locals():
|
100 |
+
del model_inputs
|
101 |
+
if 'generated_ids' in locals():
|
102 |
+
del generated_ids
|
103 |
+
|
104 |
+
# Multiple garbage collection passes
|
105 |
+
for _ in range(3):
|
106 |
+
gc.collect()
|
107 |
+
|
108 |
+
if torch.cuda.is_available():
|
109 |
+
torch.cuda.empty_cache()
|
110 |
+
torch.cuda.ipc_collect()
|
111 |
+
torch.cuda.synchronize()
|
112 |
+
try:
|
113 |
+
torch.cuda.memory._record_memory_history(enabled=None)
|
114 |
+
except:
|
115 |
+
pass
|
116 |
+
print(f"GPU memory after aggressive cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")
|
117 |
+
|
118 |
+
def generate_response_streaming(
|
119 |
+
self,
|
120 |
+
messages: List[Dict[str, str]],
|
121 |
+
max_new_tokens: int = Config.AGENT_CONFIGS["ecology_analysis"]["max_new_tokens"],
|
122 |
+
temperature: float = Config.AGENT_CONFIGS["ecology_analysis"]["temperature"],
|
123 |
+
top_p: float = Config.AGENT_CONFIGS["ecology_analysis"]["top_p"],
|
124 |
+
) -> Generator[Dict[str, Any], None, None]:
|
125 |
+
"""
|
126 |
+
Generate text response with streaming (token by token).
|
127 |
+
|
128 |
+
Args:
|
129 |
+
messages: List of message dictionaries with 'role' and 'content'
|
130 |
+
max_new_tokens: Maximum tokens to generate
|
131 |
+
temperature: Sampling temperature
|
132 |
+
top_p: Top-p sampling
|
133 |
+
|
134 |
+
Yields:
|
135 |
+
Dict[str, Any]: Dictionary containing:
|
136 |
+
- token: The generated token/text chunk
|
137 |
+
- is_complete: Whether generation is finished
|
138 |
+
|
139 |
+
Raises:
|
140 |
+
Exception: If generation fails due to model issues, memory, or other errors
|
141 |
+
"""
|
142 |
+
print(f"Loading Llama-3.2-3B for streaming inference #{self.load_count + 1}")
|
143 |
+
|
144 |
+
model, tokenizer = self._load_model()
|
145 |
+
self.load_count += 1
|
146 |
+
|
147 |
+
try:
|
148 |
+
text = tokenizer.apply_chat_template(
|
149 |
+
messages,
|
150 |
+
tokenize=False,
|
151 |
+
add_generation_prompt=True
|
152 |
+
)
|
153 |
+
|
154 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
155 |
+
|
156 |
+
streamer = TextIteratorStreamer(
|
157 |
+
tokenizer,
|
158 |
+
timeout=60.0,
|
159 |
+
skip_prompt=True,
|
160 |
+
skip_special_tokens=True
|
161 |
+
)
|
162 |
+
|
163 |
+
generation_kwargs = {
|
164 |
+
"input_ids": model_inputs.input_ids,
|
165 |
+
"max_new_tokens": max_new_tokens,
|
166 |
+
"temperature": temperature,
|
167 |
+
"top_p": top_p,
|
168 |
+
"do_sample": True,
|
169 |
+
"pad_token_id": tokenizer.eos_token_id,
|
170 |
+
"streamer": streamer
|
171 |
+
}
|
172 |
+
|
173 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
174 |
+
thread.start()
|
175 |
+
|
176 |
+
for new_text in streamer:
|
177 |
+
yield {"token": new_text, "is_complete": False}
|
178 |
+
|
179 |
+
thread.join()
|
180 |
+
yield {"token": "", "is_complete": True}
|
181 |
+
|
182 |
+
except Exception as e:
|
183 |
+
print(f"Error during Llama-3.2-3B streaming generation: {e}")
|
184 |
+
yield {"token": f"[Error: {str(e)}]", "is_complete": True}
|
185 |
+
|
186 |
+
finally:
|
187 |
+
print(f"Releasing Llama-3.2-3B GPU memory after inference")
|
188 |
+
if 'model' in locals():
|
189 |
+
if hasattr(model, 'cpu'):
|
190 |
+
model.cpu()
|
191 |
+
del model
|
192 |
+
if 'tokenizer' in locals():
|
193 |
+
del tokenizer
|
194 |
+
if 'model_inputs' in locals():
|
195 |
+
del model_inputs
|
196 |
+
if 'generated_ids' in locals():
|
197 |
+
del generated_ids
|
198 |
+
|
199 |
+
# Multiple garbage collection passes
|
200 |
+
for _ in range(3):
|
201 |
+
gc.collect()
|
202 |
+
|
203 |
+
if torch.cuda.is_available():
|
204 |
+
torch.cuda.empty_cache()
|
205 |
+
torch.cuda.ipc_collect()
|
206 |
+
torch.cuda.synchronize()
|
207 |
+
try:
|
208 |
+
torch.cuda.memory._record_memory_history(enabled=None)
|
209 |
+
except:
|
210 |
+
pass
|
211 |
+
print(f"GPU memory after aggressive cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")
|
212 |
+
|
213 |
+
def _load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
214 |
+
"""
|
215 |
+
Private method for model and tokenizer loading.
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
Tuple[AutoModelForCausalLM, AutoTokenizer]: Loaded model and tokenizer
|
219 |
+
|
220 |
+
Raises:
|
221 |
+
Exception: If model loading fails due to network, memory, or other issues
|
222 |
+
"""
|
223 |
+
try:
|
224 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
225 |
+
self.model_id,
|
226 |
+
trust_remote_code=True
|
227 |
+
)
|
228 |
+
|
229 |
+
# Llama models may need specific configurations
|
230 |
+
model = AutoModelForCausalLM.from_pretrained(
|
231 |
+
self.model_id,
|
232 |
+
torch_dtype="auto",
|
233 |
+
device_map="auto",
|
234 |
+
trust_remote_code=True,
|
235 |
+
low_cpu_mem_usage=True
|
236 |
+
)
|
237 |
+
|
238 |
+
return model, tokenizer
|
239 |
+
|
240 |
+
except Exception as e:
|
241 |
+
print(f"Error loading Llama-3.2-3B model: {e}")
|
242 |
+
raise e
|
src/deepforest_agent/models/qwen_vl_3b_instruct.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
from typing import Tuple, Dict, Any, Optional, List, Union
|
3 |
+
import torch
|
4 |
+
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
5 |
+
from PIL import Image
|
6 |
+
from qwen_vl_utils import process_vision_info
|
7 |
+
|
8 |
+
from deepforest_agent.conf.config import Config
|
9 |
+
|
10 |
+
|
11 |
+
class QwenVL3BModelManager:
|
12 |
+
"""Manages Qwen2.5-VL-3B model instances for visual analysis tasks.
|
13 |
+
|
14 |
+
Attributes:
|
15 |
+
model_id (str): HuggingFace model identifier
|
16 |
+
load_count (int): Number of times model has been loaded
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, model_id: str = Config.AGENT_MODELS["visual_analysis"]):
|
20 |
+
"""
|
21 |
+
Initialize the Qwen2.5-VL-3B model manager.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
model_id (str, optional): HuggingFace model identifier.
|
25 |
+
Defaults to "Qwen/Qwen2.5-VL-3B-Instruct".
|
26 |
+
"""
|
27 |
+
self.model_id = model_id
|
28 |
+
self.load_count = 0
|
29 |
+
|
30 |
+
def _load_model(self) -> Tuple[Qwen2_5_VLForConditionalGeneration, AutoProcessor]:
|
31 |
+
"""
|
32 |
+
Private method for model loading implementation.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Tuple[Qwen2_5_VLForConditionalGeneration, AutoProcessor]:
|
36 |
+
Loaded model and processor instances
|
37 |
+
|
38 |
+
Raises:
|
39 |
+
Exception: If model or processor loading fails
|
40 |
+
"""
|
41 |
+
try:
|
42 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
43 |
+
self.model_id,
|
44 |
+
torch_dtype="auto",
|
45 |
+
device_map="auto",
|
46 |
+
trust_remote_code=True
|
47 |
+
)
|
48 |
+
|
49 |
+
processor = AutoProcessor.from_pretrained(
|
50 |
+
self.model_id,
|
51 |
+
use_fast=True
|
52 |
+
)
|
53 |
+
|
54 |
+
return model, processor
|
55 |
+
|
56 |
+
except Exception as e:
|
57 |
+
print(f"Error loading Qwen VL model: {e}")
|
58 |
+
raise e
|
59 |
+
|
60 |
+
def generate_response(
|
61 |
+
self,
|
62 |
+
messages: List[Dict[str, Any]],
|
63 |
+
max_new_tokens: int = Config.AGENT_CONFIGS["visual_analysis"]["max_new_tokens"],
|
64 |
+
temperature: float = Config.AGENT_CONFIGS["visual_analysis"]["temperature"]
|
65 |
+
) -> str:
|
66 |
+
"""
|
67 |
+
Generate multimodal response.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
messages: List of messages with text and images
|
71 |
+
max_new_tokens: Maximum tokens to generate
|
72 |
+
temperature: Sampling temperature
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
str: Generated response text based on the input messages
|
76 |
+
|
77 |
+
Raises:
|
78 |
+
Exception: If text generation fails for any reason
|
79 |
+
"""
|
80 |
+
print(f"Loading Qwen VL for inference #{self.load_count + 1}")
|
81 |
+
|
82 |
+
model, processor = self._load_model()
|
83 |
+
self.load_count += 1
|
84 |
+
|
85 |
+
try:
|
86 |
+
# Process vision info using qwen_vl_utils
|
87 |
+
text = processor.apply_chat_template(
|
88 |
+
messages, tokenize=False, add_generation_prompt=True
|
89 |
+
)
|
90 |
+
|
91 |
+
# Use process_vision_info for proper image handling
|
92 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
93 |
+
|
94 |
+
inputs = processor(
|
95 |
+
text=[text],
|
96 |
+
images=image_inputs,
|
97 |
+
videos=video_inputs,
|
98 |
+
padding=True,
|
99 |
+
return_tensors="pt",
|
100 |
+
)
|
101 |
+
inputs = inputs.to(model.device)
|
102 |
+
|
103 |
+
generated_ids = model.generate(
|
104 |
+
**inputs,
|
105 |
+
max_new_tokens=max_new_tokens,
|
106 |
+
temperature=temperature,
|
107 |
+
do_sample=True if temperature > 0 else False
|
108 |
+
)
|
109 |
+
|
110 |
+
generated_ids_trimmed = [
|
111 |
+
out_ids[len(in_ids):]
|
112 |
+
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
113 |
+
]
|
114 |
+
|
115 |
+
response = processor.batch_decode(
|
116 |
+
generated_ids_trimmed,
|
117 |
+
skip_special_tokens=True,
|
118 |
+
clean_up_tokenization_spaces=False
|
119 |
+
)[0]
|
120 |
+
|
121 |
+
return response
|
122 |
+
|
123 |
+
except Exception as e:
|
124 |
+
print(f"Error during Qwen VL generation: {e}")
|
125 |
+
raise e
|
126 |
+
|
127 |
+
finally:
|
128 |
+
print(f"Releasing Qwen VL GPU memory after inference")
|
129 |
+
if 'model' in locals():
|
130 |
+
if hasattr(model, 'cpu'):
|
131 |
+
model.cpu()
|
132 |
+
del model
|
133 |
+
if 'processor' in locals():
|
134 |
+
del processor
|
135 |
+
if 'inputs' in locals():
|
136 |
+
del inputs
|
137 |
+
if 'generated_ids' in locals():
|
138 |
+
del generated_ids
|
139 |
+
|
140 |
+
# Multiple garbage collection passes
|
141 |
+
for _ in range(3):
|
142 |
+
gc.collect()
|
143 |
+
|
144 |
+
if torch.cuda.is_available():
|
145 |
+
torch.cuda.empty_cache()
|
146 |
+
torch.cuda.ipc_collect()
|
147 |
+
torch.cuda.synchronize()
|
148 |
+
try:
|
149 |
+
torch.cuda.memory._record_memory_history(enabled=None)
|
150 |
+
except:
|
151 |
+
pass
|
152 |
+
print(f"GPU memory after VLM cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")
|
src/deepforest_agent/models/smollm3_3b.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
from typing import Tuple, Dict, Any, Optional, List, Generator
|
3 |
+
import torch
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
+
from transformers.generation.streamers import TextIteratorStreamer
|
6 |
+
from threading import Thread
|
7 |
+
|
8 |
+
from deepforest_agent.conf.config import Config
|
9 |
+
|
10 |
+
class SmolLM3ModelManager:
|
11 |
+
"""
|
12 |
+
Manages SmolLM3-3B model instances
|
13 |
+
|
14 |
+
Attributes:
|
15 |
+
model_id (str): HuggingFace model identifier
|
16 |
+
load_count (int): Number of times model has been loaded
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, model_id: str = Config.AGENT_MODELS["deepforest_detector"]):
|
20 |
+
"""
|
21 |
+
Initialize the SmolLM3 model manager.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
model_id (str, optional): HuggingFace model identifier.
|
25 |
+
Defaults to "HuggingFaceTB/SmolLM3-3B".
|
26 |
+
"""
|
27 |
+
self.model_id = model_id
|
28 |
+
self.load_count = 0
|
29 |
+
|
30 |
+
def generate_response(
|
31 |
+
self,
|
32 |
+
messages: List[Dict[str, str]],
|
33 |
+
max_new_tokens: int = Config.AGENT_CONFIGS["deepforest_detector"]["max_new_tokens"],
|
34 |
+
temperature: float = Config.AGENT_CONFIGS["deepforest_detector"]["temperature"],
|
35 |
+
top_p: float = Config.AGENT_CONFIGS["deepforest_detector"]["top_p"],
|
36 |
+
tools: Optional[List[Dict[str, Any]]] = None
|
37 |
+
) -> str:
|
38 |
+
"""
|
39 |
+
Generate text response
|
40 |
+
|
41 |
+
Args:
|
42 |
+
messages: List of message dictionaries with 'role' and 'content'
|
43 |
+
max_new_tokens: Maximum tokens to generate
|
44 |
+
temperature: Sampling temperature
|
45 |
+
top_p: Top-p sampling
|
46 |
+
tools (Optional[List[Dict[str, Any]]]): List of tools
|
47 |
+
|
48 |
+
Raises:
|
49 |
+
Exception: If generation fails due to model issues, memory, or other errors
|
50 |
+
"""
|
51 |
+
print(f"Loading SmolLM3 for inference #{self.load_count + 1}")
|
52 |
+
|
53 |
+
model, tokenizer = self._load_model()
|
54 |
+
self.load_count += 1
|
55 |
+
|
56 |
+
try:
|
57 |
+
if tools:
|
58 |
+
text = tokenizer.apply_chat_template(
|
59 |
+
messages,
|
60 |
+
xml_tools=tools,
|
61 |
+
tokenize=False,
|
62 |
+
add_generation_prompt=True
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
text = tokenizer.apply_chat_template(
|
66 |
+
messages,
|
67 |
+
tokenize=False,
|
68 |
+
add_generation_prompt=True
|
69 |
+
)
|
70 |
+
|
71 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
72 |
+
|
73 |
+
generated_ids = model.generate(
|
74 |
+
model_inputs.input_ids,
|
75 |
+
max_new_tokens=max_new_tokens,
|
76 |
+
temperature=temperature,
|
77 |
+
top_p=top_p,
|
78 |
+
do_sample=True,
|
79 |
+
pad_token_id=tokenizer.eos_token_id
|
80 |
+
)
|
81 |
+
|
82 |
+
generated_ids = [
|
83 |
+
output_ids[len(input_ids):]
|
84 |
+
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
85 |
+
]
|
86 |
+
|
87 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
88 |
+
return response
|
89 |
+
|
90 |
+
except Exception as e:
|
91 |
+
print(f"Error during SmolLM3 text generation: {e}")
|
92 |
+
raise e
|
93 |
+
|
94 |
+
finally:
|
95 |
+
print(f"Releasing SmolLM3 GPU memory after inference")
|
96 |
+
if 'model' in locals():
|
97 |
+
if hasattr(model, 'cpu'):
|
98 |
+
model.cpu()
|
99 |
+
del model
|
100 |
+
if 'tokenizer' in locals():
|
101 |
+
del tokenizer
|
102 |
+
if 'model_inputs' in locals():
|
103 |
+
del model_inputs
|
104 |
+
if 'generated_ids' in locals():
|
105 |
+
del generated_ids
|
106 |
+
|
107 |
+
# Multiple garbage collection passes
|
108 |
+
for _ in range(3):
|
109 |
+
gc.collect()
|
110 |
+
|
111 |
+
if torch.cuda.is_available():
|
112 |
+
torch.cuda.empty_cache()
|
113 |
+
torch.cuda.ipc_collect()
|
114 |
+
torch.cuda.synchronize()
|
115 |
+
try:
|
116 |
+
torch.cuda.memory._record_memory_history(enabled=None)
|
117 |
+
except:
|
118 |
+
pass
|
119 |
+
print(f"GPU memory after aggressive cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")
|
120 |
+
|
121 |
+
def generate_response_streaming(
|
122 |
+
self,
|
123 |
+
messages: List[Dict[str, str]],
|
124 |
+
max_new_tokens: int = Config.AGENT_CONFIGS["deepforest_detector"]["max_new_tokens"],
|
125 |
+
temperature: float = Config.AGENT_CONFIGS["deepforest_detector"]["temperature"],
|
126 |
+
top_p: float = Config.AGENT_CONFIGS["deepforest_detector"]["top_p"]
|
127 |
+
) -> Generator[Dict[str, Any], None, None]:
|
128 |
+
"""
|
129 |
+
Generate text response with streaming (token by token)
|
130 |
+
|
131 |
+
Args:
|
132 |
+
messages: List of message dictionaries with 'role' and 'content'
|
133 |
+
max_new_tokens: Maximum tokens to generate
|
134 |
+
temperature: Sampling temperature
|
135 |
+
top_p: Top-p sampling
|
136 |
+
|
137 |
+
Yields:
|
138 |
+
Dict[str, Any]: Dictionary containing:
|
139 |
+
- token: The generated token/text chunk
|
140 |
+
- is_complete: Whether generation is finished
|
141 |
+
|
142 |
+
Raises:
|
143 |
+
Exception: If generation fails due to model issues, memory, or other errors
|
144 |
+
"""
|
145 |
+
print(f"Loading SmolLM3 for streaming inference #{self.load_count + 1}")
|
146 |
+
|
147 |
+
model, tokenizer = self._load_model()
|
148 |
+
self.load_count += 1
|
149 |
+
|
150 |
+
try:
|
151 |
+
text = tokenizer.apply_chat_template(
|
152 |
+
messages,
|
153 |
+
tokenize=False,
|
154 |
+
add_generation_prompt=True
|
155 |
+
)
|
156 |
+
|
157 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
158 |
+
|
159 |
+
streamer = TextIteratorStreamer(
|
160 |
+
tokenizer,
|
161 |
+
timeout=60.0,
|
162 |
+
skip_prompt=True,
|
163 |
+
skip_special_tokens=True
|
164 |
+
)
|
165 |
+
|
166 |
+
generation_kwargs = {
|
167 |
+
"input_ids": model_inputs.input_ids,
|
168 |
+
"max_new_tokens": max_new_tokens,
|
169 |
+
"temperature": temperature,
|
170 |
+
"top_p": top_p,
|
171 |
+
"do_sample": True,
|
172 |
+
"pad_token_id": tokenizer.eos_token_id,
|
173 |
+
"streamer": streamer
|
174 |
+
}
|
175 |
+
|
176 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
177 |
+
thread.start()
|
178 |
+
|
179 |
+
for new_text in streamer:
|
180 |
+
yield {"token": new_text, "is_complete": False}
|
181 |
+
|
182 |
+
thread.join()
|
183 |
+
yield {"token": "", "is_complete": True}
|
184 |
+
|
185 |
+
except Exception as e:
|
186 |
+
print(f"Error during SmolLM3 streaming generation: {e}")
|
187 |
+
yield {"token": f"[Error: {str(e)}]", "is_complete": True}
|
188 |
+
|
189 |
+
finally:
|
190 |
+
print(f"Releasing SmolLM3 GPU memory after inference")
|
191 |
+
if 'model' in locals():
|
192 |
+
if hasattr(model, 'cpu'):
|
193 |
+
model.cpu()
|
194 |
+
del model
|
195 |
+
if 'tokenizer' in locals():
|
196 |
+
del tokenizer
|
197 |
+
if 'model_inputs' in locals():
|
198 |
+
del model_inputs
|
199 |
+
if 'generated_ids' in locals():
|
200 |
+
del generated_ids
|
201 |
+
|
202 |
+
# Multiple garbage collection passes
|
203 |
+
for _ in range(3):
|
204 |
+
gc.collect()
|
205 |
+
|
206 |
+
if torch.cuda.is_available():
|
207 |
+
torch.cuda.empty_cache()
|
208 |
+
torch.cuda.ipc_collect()
|
209 |
+
torch.cuda.synchronize()
|
210 |
+
try:
|
211 |
+
torch.cuda.memory._record_memory_history(enabled=None)
|
212 |
+
except:
|
213 |
+
pass
|
214 |
+
print(f"GPU memory after aggressive cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")
|
215 |
+
|
216 |
+
def _load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
217 |
+
"""
|
218 |
+
Private method for model and tokenizer loading.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
Tuple[AutoModelForCausalLM, AutoTokenizer]: Loaded model and tokenizer
|
222 |
+
|
223 |
+
Raises:
|
224 |
+
Exception: If model loading fails due to network, memory, or other issues
|
225 |
+
"""
|
226 |
+
try:
|
227 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
228 |
+
self.model_id,
|
229 |
+
trust_remote_code=True
|
230 |
+
)
|
231 |
+
|
232 |
+
model = AutoModelForCausalLM.from_pretrained(
|
233 |
+
self.model_id,
|
234 |
+
torch_dtype="auto",
|
235 |
+
device_map="auto",
|
236 |
+
trust_remote_code=True,
|
237 |
+
low_cpu_mem_usage=True
|
238 |
+
)
|
239 |
+
|
240 |
+
return model, tokenizer
|
241 |
+
|
242 |
+
except Exception as e:
|
243 |
+
print(f"Error loading SmolLM3 model: {e}")
|
244 |
+
raise e
|
src/deepforest_agent/prompts/__init__.py
ADDED
File without changes
|
src/deepforest_agent/prompts/prompt_templates.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Dict, List, Any
|
2 |
+
import json
|
3 |
+
|
4 |
+
from deepforest_agent.conf.config import Config
|
5 |
+
|
6 |
+
def get_deepforest_tool_schema() -> Dict[str, Any]:
|
7 |
+
"""
|
8 |
+
Get the DeepForest tool schema for structured tool calling.
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
Dict[str, Any]: Tool schema for run_deepforest_object_detection
|
12 |
+
"""
|
13 |
+
deepforest_tool_schema = {
|
14 |
+
"name": "run_deepforest_object_detection",
|
15 |
+
"description": "Performs object detection on ecological images using DeepForest models to detect birds, trees, livestock, and assess tree health. Use this tool for any queries related to ecological objects, wildlife detection, forest analysis, or tree health assessment.",
|
16 |
+
"parameters": {
|
17 |
+
"type": "object",
|
18 |
+
"properties": {
|
19 |
+
"model_names": {
|
20 |
+
"type": "array",
|
21 |
+
"items": {"type": "string", "enum": ["tree", "bird", "livestock"]},
|
22 |
+
"description": "List of models to use for detection. Select based on user query: 'tree' for vegetation or forest, 'bird' for avian species, 'livestock' for farm animals. Default: ['tree', 'bird', 'livestock']. Always include 'tree' when alive_dead_trees is true.",
|
23 |
+
"default": ["tree", "bird", "livestock"]
|
24 |
+
},
|
25 |
+
"patch_size": {
|
26 |
+
"type": "integer",
|
27 |
+
"description": f"Window size in pixels (default {Config.DEEPFOREST_DEFAULTS['patch_size']}) The size for the crops used to cut the input image/raster into smaller pieces.",
|
28 |
+
"default": Config.DEEPFOREST_DEFAULTS["patch_size"]
|
29 |
+
},
|
30 |
+
"patch_overlap": {
|
31 |
+
"type": "number",
|
32 |
+
"description": f"The horizontal and vertical overlap among patches (must be between 0-1) (default {Config.DEEPFOREST_DEFAULTS['patch_overlap']})",
|
33 |
+
"default": Config.DEEPFOREST_DEFAULTS["patch_overlap"]
|
34 |
+
},
|
35 |
+
"iou_threshold": {
|
36 |
+
"type": "number",
|
37 |
+
"description": f"Minimum IoU overlap among predictions between windows to be suppressed (default {Config.DEEPFOREST_DEFAULTS['iou_threshold']})",
|
38 |
+
"default": Config.DEEPFOREST_DEFAULTS["iou_threshold"]
|
39 |
+
},
|
40 |
+
"thresh": {
|
41 |
+
"type": "number",
|
42 |
+
"description": f"Score threshold used to filter bboxes after soft-NMS is performed (default {Config.DEEPFOREST_DEFAULTS['thresh']})",
|
43 |
+
"default": Config.DEEPFOREST_DEFAULTS["thresh"]
|
44 |
+
},
|
45 |
+
"alive_dead_trees": {
|
46 |
+
"type": "boolean",
|
47 |
+
"description": f"Enable tree health classification to distinguish between alive and dead trees. Required for forest health analysis. When true, 'tree' must be included in model_names. (default {Config.DEEPFOREST_DEFAULTS['alive_dead_trees']})",
|
48 |
+
"default": Config.DEEPFOREST_DEFAULTS["alive_dead_trees"]
|
49 |
+
}
|
50 |
+
},
|
51 |
+
"required": ["model_names"]
|
52 |
+
}
|
53 |
+
}
|
54 |
+
return deepforest_tool_schema
|
55 |
+
|
56 |
+
def format_memory_prompt(conversation_history: List[Dict[str, Any]], latest_message: str, conversation_context: str) -> str:
|
57 |
+
"""
|
58 |
+
Format the memory analysis prompt for new conversation history format.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
conversation_history: Filtered conversation history
|
62 |
+
latest_message: Current user message
|
63 |
+
conversation_context: Formatted conversation context with turn structure
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Formatted prompt for memory analysis
|
67 |
+
"""
|
68 |
+
prompt = f"""You are a conversation memory manager for an ecological data analytics assistant. Your role is to analyze previous conversation turns and determine if you can answer the user's query.
|
69 |
+
|
70 |
+
The user is using DeepForest Agent which can analyze ecological images for objects like trees, birds, livestock, and assess tree health. The user may ask questions about the image content, object counts, spatial distributions, or ecological patterns. The user may also ask follow-up questions based on previous answers. That's why you have access to the previous conversation turns so that you can determine if the answer is already available or if you can provide context for the agents.
|
71 |
+
|
72 |
+
You should not make up any information that is not present in the previous conversation turns. You should only use the previous conversation turns to answer the user's query. You can't also just make up wrong information with the previous conversation turns. Make sure you are addressing the user query when you are using the previous conversation turns.
|
73 |
+
|
74 |
+
When you provide RELEVANT_CONTEXT, you should provide with analysis and a comprehensive details about every data that you are providing. Do not just give quick and direct answers. The ecology agent will use this context to provide the final answer to the user. So, make sure you are providing all the relevant context that can help the ecology agent to provide the best possible answer to the user. Your tone should be very professional and analytical. You cannot make any assumptions or guesses.
|
75 |
+
|
76 |
+
You have access to previous conversation turns. Your task is to determine if the current user query can be answered using information from previous turns, and provide the tool cache ID if detection data needs to be retrieved.
|
77 |
+
|
78 |
+
Here is the Conversation History:
|
79 |
+
{conversation_context}
|
80 |
+
|
81 |
+
Latest user query: {latest_message}
|
82 |
+
|
83 |
+
Your response format:
|
84 |
+
|
85 |
+
**ANSWER_PRESENT:** [YES or NO]
|
86 |
+
[YES if the latest user query can be answered fully using information from previous conversation turns and if the latest query is exactly similar to any of the previous queries. Otherwise NO]
|
87 |
+
|
88 |
+
**TOOL_CACHE_ID:**
|
89 |
+
[Analyze tool call information and provide relevant Tool cache IDs from the previous turns that can answer the latest user query. If multiple turns are relevant, provide multiple Tool cache IDs separated by commas.]
|
90 |
+
|
91 |
+
**RELEVANT_CONTEXT:**
|
92 |
+
[Provide a comprehensive analysis using data from previous turns including visual analysis, detection narratives, and ecology responses that answer the user's query. Include specific turn references. If no previous context is relevant, state "No relevant context from previous conversations."]
|
93 |
+
|
94 |
+
/no_think"""
|
95 |
+
|
96 |
+
return prompt
|
97 |
+
|
98 |
+
def create_full_image_quality_analysis_prompt(user_message: str) -> str:
|
99 |
+
"""
|
100 |
+
Create system prompt for full image quality assessment.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
user_message: User's query
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
System prompt for full image quality analysis
|
107 |
+
"""
|
108 |
+
return f"""You are a computer vision expert. Your task is to analyze the ecological image with your image understanding ability. You will provide a comprehensive analysis of this ecological image.
|
109 |
+
|
110 |
+
The user is using DeepForest Agent which can analyze ecological images for objects like trees, birds, livestock, and assess tree health. An ecological image is provided for you already to analyze. The user may ask questions about the image content, object counts, spatial distributions, or ecological patterns. To answer the user's query, ecological image quality is very important. Otherwise, the DeepForest object detection will not work properly. That's why determining if the image is an ecological aerial/drone image with good quality is very important. You also have to analyze the ecological image completely and provide a comprehensive summary of what's in this ecological image and what's happening. The user likely wants to know about the objects present in this ecological image. It's going to help with the ecological analysis. User likes spatial details and specific location information.
|
111 |
+
|
112 |
+
So, make sure you are providing all the important details about this ecological image. Incorporate species identification, behavior observations, environmental conditions, and habitat characteristics if possible.
|
113 |
+
|
114 |
+
You should not make up any information that is not present in the image. You should only use the image to answer the user's query. You can't also just make up wrong information with the image. Do not miss any important details. If possible, zoom in on the image to see more details. Give specific location details when you explain what's in this image. But do not make up false location or explanation. Be strictly based on what you see in this image. Making up false information is not acceptable. Do not make assumptions or guesses. Analyze the image thoroughly before making any claims.
|
115 |
+
|
116 |
+
Your tone should be very professional and expert in visual analysis. User likes insightful and detailed analysis. So, don't worry about the length of your response. Make it as long as necessary to cover all important aspects of this image. The response must be related to the user query. User's query: "{user_message}". Try to answer the user query with your visual analysis. Follow the structure below in your response:
|
117 |
+
|
118 |
+
**IMAGE_QUALITY_FOR_DEEPFOREST:** [YES or NO]
|
119 |
+
YES if this is a good quality aerial/drone image with clear ecological objects (trees, birds, livestock) that would be suitable for automated DeepForest object detection. NO if image quality is poor, too close-up, wrong angle, blurry, or not an ecological aerial/drone image.
|
120 |
+
|
121 |
+
**DEEPFOREST_OBJECTS_PRESENT:** []
|
122 |
+
List the objects from ["bird", "tree", "livestock"] that are clearly visible in the image. Example: ["bird", "tree"].
|
123 |
+
|
124 |
+
**ADDITIONAL_OBJECTS:** [JSON array]
|
125 |
+
Any objects present in this image with rough coordinates that are not bird, tree or livestock. Do not include bird, tree or livestock coordinates here. Also do not make up any false objects. Only include objects that are clearly visible in the image and necessary according to the user query.
|
126 |
+
|
127 |
+
**VISUAL_ANALYSIS:**
|
128 |
+
[In this section, you will provide the comprehensive visual analysis of the image. You should start with a brief summary containing spatial analysis of what's in the image. Then, give a brief summary if it's an ecological aerial/drone image or not. Then, you should analyze the image completely and provide a comprehensive summary of what's in this image and what's happening. If possible, zoom in on the specific locations of the image to see more details. Answer what's present in the image mentioning specific objects and their counts if possible. But do not make up false counts or objects. Make sure to incorporate species identification, behavior observations, environmental conditions, and habitat characteristics according to the image. Answer the user query "{user_message}" with what you see in the image in detail with proper reasoning, insights, bounding box coordinates and evidence. It can be as long as necessary to cover all important aspects of the image. Do not hallucinate or guess this part and you must provide bounding box coordinates for all the objects you are mentioning in this section. You must mention that this analysis is provided by a visual analysis agent and it may not be very accurate as there is no confidence score associated with it.]
|
129 |
+
"""
|
130 |
+
|
131 |
+
def create_individual_tile_analysis_prompt(user_message: str) -> str:
|
132 |
+
"""
|
133 |
+
Create system prompt for individual tile analysis.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
user_message: User's query
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
System prompt for tile-by-tile analysis
|
140 |
+
"""
|
141 |
+
return f"""You are a computer vision expert. Your task is to analyze the given tiled image of an ecological image with your image understanding ability. You will provide a comprehensive analysis of this tile section of the image.
|
142 |
+
|
143 |
+
The user is using DeepForest Agent which can analyze ecological images for objects like trees, birds, livestock, and assess tree health. A tiled image is provided for you already to analyze. The user may ask questions about the image content, object counts, spatial distributions, or ecological patterns. To answer the user's query, ecological image quality is very important. Otherwise, the DeepForest object detection will not work properly. That's why determining if the tiled image is an ecological aerial/drone image with good quality is very important. You also have to analyze this tile section of the image completely and provide a comprehensive summary of what's in this tile and what's happening. The user likely wants to know about the objects present in this tile section of the image. It's going to help with the ecological analysis. User likes spatial details and specific location information, which is easily missed in a large image.
|
144 |
+
|
145 |
+
So, make sure you are providing all the important details about this tile section of the image. Incorporate species identification, behavior observations, environmental conditions, and habitat characteristics if possible.
|
146 |
+
|
147 |
+
You should not make up any information that is not present in the tiled image. You should only use the tiled image to answer the user's query. You can't also just make up wrong information with the image. Do not miss any important details. If possible, zoom in on the tiled image to see more details. Give specific location details when you explain what's in this tile. But do not make up false location or explanation. Be strictly based on what you see in this tile section of the image. Making up false information is not acceptable. Do not make assumptions or guesses. Analyze the image thoroughly before making any claims.
|
148 |
+
|
149 |
+
Your tone should be very professional and expert in visual analysis. User likes insightful and detailed analysis. So, don't worry about the length of your response. Make it as long as necessary to cover all important aspects of this tile section of the image. The response must be related to the user query. User's query: "{user_message}". Try to answer the user query with your visual analysis. Follow the structure below in your response:
|
150 |
+
|
151 |
+
**IMAGE_QUALITY_FOR_DEEPFOREST:** [YES or NO]
|
152 |
+
YES if this is a good quality aerial/drone tiled image with clear ecological objects (trees, birds, livestock) that would be suitable for automated DeepForest object detection. NO if image quality is poor, too close-up, wrong angle, blurry, or not an ecological aerial/drone image.
|
153 |
+
|
154 |
+
**DEEPFOREST_OBJECTS_PRESENT:** []
|
155 |
+
List the objects from ["bird", "tree", "livestock"] that are clearly visible in the image. Example: ["bird", "tree"].
|
156 |
+
|
157 |
+
**ADDITIONAL_OBJECTS:** [JSON array]
|
158 |
+
Any objects present in this tile section of the image with rough coordinates that are not bird, tree or livestock. Do not include bird, tree or livestock coordinates here. Also do not make up any false objects. Only include objects that are clearly visible in the image and necessary according to the user query.
|
159 |
+
|
160 |
+
**VISUAL_ANALYSIS:**
|
161 |
+
[In this section, you will provide the comprehensive visual analysis of this tile section of the image. You should start with a brief summary containing spatial analysis of what's in this tile section of the image. Then, give a brief summary if it's an ecological aerial/drone image or not. Then, you should analyze the tile completely and provide a comprehensive summary of what's in this tile and what's happening. If possible, zoom in on the specific locations of the tile to see more details. Answer what's present in the tile mentioning specific objects and their counts if possible. But do not make up false counts or objects. Make sure to incorporate species identification, behavior observations, environmental conditions, and habitat characteristics according to the tiled image. Answer the user query "{user_message}" with what you see in the image in detail with proper reasoning, insights, bounding box coordinates and evidence. It can be as long as necessary to cover all important aspects of the tiled image. Do not hallucinate or guess this part and you must provide bounding box coordinates for all the objects you are mentioning in this section. You must mention that this analysis is provided by a visual analysis agent and it may not be very accurate as there is no confidence score associated with it.]
|
162 |
+
"""
|
163 |
+
|
164 |
+
def create_detector_system_prompt_with_reasoning(user_message: str, memory_context: str, visual_objects: List[str]) -> str:
|
165 |
+
"""
|
166 |
+
Create the system prompt for the detector agent.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
user_message (str): The original user question
|
170 |
+
memory_context (str): Context from memory agent
|
171 |
+
visual_objects (List[str]): Objects detected by visual agent
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
System prompt for enhanced tool calling with all context included
|
175 |
+
"""
|
176 |
+
|
177 |
+
return """You are a smart DeepForest Tool Calling Agent with reasoning capabilities. You will receive:
|
178 |
+
|
179 |
+
1. **User Query**: {user_message}
|
180 |
+
2. **Memory Context**: {memory_context}
|
181 |
+
3. **Objects detected by visual analysis**: {visual_objects}
|
182 |
+
|
183 |
+
Your task is to call the "run_deepforest_object_detection" tool with intelligent parameter selection based on user query. You can always assume the image is provided. The image will be passed later during tool execution. So, right now based on available data and user query make the right choice. You may need to provide multiple tool calls if necessary according to User Query, and Memory Context.
|
184 |
+
|
185 |
+
REASONING PROCESS:
|
186 |
+
|
187 |
+
**STEP 1: PARAMETERS UNDERSTANDING**
|
188 |
+
You have to understand the query thoroughly to choose appropriate parameters. Remember these are the only parameters that are available. So, use your knowledge to utilize these parameters based on query.
|
189 |
+
- model_names (list): Choose models from this ["tree", "bird", "livestock"] list based on what user wants to detect. If alive_dead_trees is true for tree health or dead/alive trees make sure to add "tree" to the list along with other requested models.
|
190 |
+
- patch_size (int): Window size in pixels (default 400) The size for the crops used to cut the input image/raster into smaller pieces.
|
191 |
+
- patch_overlap (float): The horizontal and vertical overlap among patches (must be between 0-1) (default 0.05)
|
192 |
+
- iou_threshold (float): Minimum IoU overlap among predictions between windows to be suppressed (default 0.15)
|
193 |
+
- thresh (float): Score threshold used to filter bboxes after soft-NMS is performed (default 0.55)
|
194 |
+
- alive_dead_trees (bool): Whether to classify trees as alive/dead, needed for forest or tree health (default false). If you select this as true make sure to include "tree" to model_names list. If user wants to know about tree health, forest health, dead trees or alive trees, you must set this parameter to true.
|
195 |
+
|
196 |
+
**STEP 2: MEMORY CONTEXT INTEGRATION**
|
197 |
+
- Use memory context to clarify unclear queries
|
198 |
+
- If user query is vague, use conversation history to understand intent
|
199 |
+
- Select parameters based on intention from memory context
|
200 |
+
|
201 |
+
**STEP 3: VISUAL OBJECT FILTERING**
|
202 |
+
- The visual objects are: {visual_objects}
|
203 |
+
- After deciding on the tool arguments from Memory context and user query, in model_names validate the models if it's present in the visual objects. The models that are not present in the visual objects should be removed.
|
204 |
+
|
205 |
+
**STEP 4: PARAMETER REASONING WITH QUERY**
|
206 |
+
- Based on the user query and your parameter understanding, choose the parameters wisely. Think if you can use available model_names or other parameters to address the user query better.
|
207 |
+
|
208 |
+
**CRITICAL: ACCURATE REASONING ONLY**
|
209 |
+
- Base your reasoning only on the provided user query, memory context, and visual objects
|
210 |
+
- Do not assume capabilities or parameters not explicitly mentioned
|
211 |
+
- If visual objects list is empty or unclear, acknowledge this limitation
|
212 |
+
- Do not make up technical details about DeepForest that aren't in the parameter descriptions
|
213 |
+
|
214 |
+
Your response format:
|
215 |
+
**REASONING:** [Explain your visual filtering, memory integration, and parameter choices based only on provided information]
|
216 |
+
|
217 |
+
Then provide the tool calls using the schema.
|
218 |
+
|
219 |
+
Always provide clear reasoning for your parameter choices before making the tool call. Your reasoning helps users understand why you chose specific detection models and parameters for their query./no_think"""
|
220 |
+
|
221 |
+
def create_ecology_synthesis_prompt(
|
222 |
+
user_message: str,
|
223 |
+
comprehensive_context: str,
|
224 |
+
cached_json: Optional[Dict[str, Any]] = None,
|
225 |
+
current_json: Optional[Dict[str, Any]] = None
|
226 |
+
) -> str:
|
227 |
+
"""
|
228 |
+
Create system prompt for ecology agent with new context format.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
user_message: User's original query
|
232 |
+
comprehensive_context: Comprehensive context from memory + visual + detection narrative
|
233 |
+
cached_json (Optional[Dict[str, Any]]): A dictionary of previously
|
234 |
+
cached JSON data, if available. Defaults to None.
|
235 |
+
current_json (Optional[Dict[str, Any]]): A dictionary of new JSON data
|
236 |
+
from the current analysis step. Defaults to None.
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
System prompt for ecological synthesis
|
240 |
+
"""
|
241 |
+
prompt = f"""You are a Geospatial Image Analysis and Interpretation Assistant. Your primary task is to interpret and reason about complex image data from multiple data sources to answer the user query. You must synthesize information from multiple data sources, including memory context(If there is anything relevant), visual analysis, and DeepForest Detection Summary to construct your answers. Your main task is to act as a bridge between the data and the user's understanding, translating technical information into clear, descriptive language and providing proper reasoning to support your findings.
|
242 |
+
|
243 |
+
The user is using DeepForest Agent which can analyze ecological images for objects like trees, birds, livestock, and assess tree health. The user may ask questions about the image content, object counts, spatial distributions, or ecological patterns. They're trying to understand the content of an image, specifically regarding the distribution of ecological objects like trees, birds, or other wildlife. The user is asking the agent to act as a helpful guide to understand the complex DeepForest analysis data and the visual analysis data. The user has provided a query: {user_message}.
|
244 |
+
|
245 |
+
Based on the provided context, you must synthesize all available information to provide a comprehensive answer to the user's query.
|
246 |
+
|
247 |
+
Context Data:
|
248 |
+
{comprehensive_context}
|
249 |
+
|
250 |
+
Your tone should be professional, helpful, and highly informative. Avoid being overly robotic or technical. Use simple language that a non-expert can understand easily. You must be empathetic and nonjudgmental, recognizing that the user may not be familiar with the technical details of geospatial analysis. Focus on ecological insights that directly answer the user's query.
|
251 |
+
|
252 |
+
Under no circumstances should you invent or hallucinate information that is not present in the multiple data sources. All your statements must be directly supported by the data you have been given. If the data is insufficient to answer the query, you must state that clearly and explain why. You must also not falsify the multiple data sources. If you are unsure about any information, it is better to acknowledge the uncertainty than to provide potentially incorrect information. Analyze the multiple data sources thoroughly before making any claims. Never hallucinate detection coordinates, object labels, object counts, visual analysis, or confidence scores. Never mention cache keys, or technical metadata in your response. Do not mix visual analysis with the detection analysis. And you must inform the user that you are not confident about the visual analysis as there is no confidence score associated with it but you are confident about the DeepForest detection data as it has confidence scores associated with it. So if there is any conflict between visual analysis and DeepForest detection data, you should always trust the DeepForest detection data more than the visual analysis. Always provide detection analysis with proper confidence scores ranges and detailed reasoning. It can have multiple paragraphs and sections if necessary.
|
253 |
+
|
254 |
+
The response can be as long as necessary to cover all important aspects of the user's query. Starting paragraph should be the "Direct Answer" that immediately addresses the user query with proper reasoning from the detection analysis, memory context (if there's any), and visual analysis. Your response should be based on the available "DETECTION ANALYSIS", which you have to provide a detailed breakdown with proper reasoning that will address the user query: {user_message}. If multiple detection analysis exists for multiple tool calls, you can provide a comprehensive comparison in "Result Comparison". Then you can mention some relevant information from the visual analysis to address the user query but remember you are not very confident about it. Then, you must also provide "Spatial Distribution and Ecological Patterns", and Translate detection results into "Ecological Interpretation from DeepForest Data". All of these sections should be a comprehensive and insightful answer that leverages all available data. Use markdown headings (##) to create distinct sections if the response is lengthy. Make sure to incorporate the multiple data sources to create these sections without hallucinating. Bold important keywords to make them stand out. Seperate into paragraphs for better readability. Use bullet points or numbered lists where appropriate to organize information clearly. Conclude with a clear and concise summary of your findings.
|
255 |
+
"""
|
256 |
+
|
257 |
+
return prompt
|
src/deepforest_agent/tools/__init__.py
ADDED
File without changes
|
src/deepforest_agent/tools/deepforest_tool.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
from typing import List, Optional, Tuple, Dict, Any
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from PIL import Image
|
10 |
+
from shapely.geometry import shape
|
11 |
+
|
12 |
+
from deepforest import main
|
13 |
+
from deepforest.model import CropModel
|
14 |
+
from deepforest_agent.conf.config import Config
|
15 |
+
from deepforest_agent.utils.image_utils import convert_rgb_to_bgr, convert_bgr_to_rgb, load_image_as_np_array, create_temp_image_file, cleanup_temp_file
|
16 |
+
|
17 |
+
|
18 |
+
class DeepForestPredictor:
|
19 |
+
"""Predictor class for DeepForest object detection models."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
"""Initialize the DeepForest predictor."""
|
23 |
+
pass
|
24 |
+
|
25 |
+
def _generate_detection_summary(self, predictions_df: pd.DataFrame,
|
26 |
+
alive_dead_trees: bool = False) -> str:
|
27 |
+
"""
|
28 |
+
Generate summary of detection results.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
predictions_df: DataFrame containing detection results
|
32 |
+
alive_dead_trees: Whether alive/dead tree classification was used
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
DeepForest Detection Summary String
|
36 |
+
"""
|
37 |
+
if predictions_df.empty:
|
38 |
+
return "No objects detected by DeepForest with the requested models."
|
39 |
+
|
40 |
+
detection_summary_parts = []
|
41 |
+
counts = predictions_df['label'].value_counts()
|
42 |
+
|
43 |
+
if 'classification_label' in predictions_df.columns:
|
44 |
+
non_tree_df = predictions_df[predictions_df['label'] != 'tree']
|
45 |
+
if not non_tree_df.empty:
|
46 |
+
non_tree_counts = non_tree_df['label'].value_counts()
|
47 |
+
for label, count in non_tree_counts.items():
|
48 |
+
label_str = str(label).replace('_', ' ')
|
49 |
+
if count == 1:
|
50 |
+
detection_summary_parts.append(f"{count} {label_str}")
|
51 |
+
else:
|
52 |
+
detection_summary_parts.append(f"{count} {label_str}s")
|
53 |
+
|
54 |
+
tree_df = predictions_df[predictions_df['label'] == 'tree']
|
55 |
+
if not tree_df.empty:
|
56 |
+
total_trees = len(tree_df)
|
57 |
+
classification_counts = tree_df['classification_label'].value_counts()
|
58 |
+
|
59 |
+
classification_parts = []
|
60 |
+
for class_label, count in classification_counts.items():
|
61 |
+
class_str = str(class_label).replace('_', ' ')
|
62 |
+
classification_parts.append(f"{count} are classified as {class_str}")
|
63 |
+
|
64 |
+
if total_trees == 1:
|
65 |
+
detection_summary_parts.append(f"from {total_trees} tree, {' and '.join(classification_parts)}")
|
66 |
+
else:
|
67 |
+
detection_summary_parts.append(f"from {total_trees} trees, {' and '.join(classification_parts)}")
|
68 |
+
else:
|
69 |
+
for label, count in counts.items():
|
70 |
+
label_str = str(label).replace('_', ' ')
|
71 |
+
if count == 1:
|
72 |
+
detection_summary_parts.append(f"{count} {label_str}")
|
73 |
+
else:
|
74 |
+
detection_summary_parts.append(f"{count} {label_str}s")
|
75 |
+
|
76 |
+
detection_summary = f"DeepForest detected: {', '.join(detection_summary_parts)}."
|
77 |
+
|
78 |
+
return detection_summary
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def _plot_boxes(image_array: np.ndarray, predictions: pd.DataFrame,
|
82 |
+
colors: dict, thickness: int = 2) -> np.ndarray:
|
83 |
+
"""
|
84 |
+
Plot bounding boxes on image.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
image_array: Input image as numpy array
|
88 |
+
predictions: DataFrame with detection results
|
89 |
+
colors: Color mapping for different labels
|
90 |
+
thickness: Line thickness for bounding boxes
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Image array with drawn bounding boxes
|
94 |
+
"""
|
95 |
+
image = image_array.copy()
|
96 |
+
image = convert_rgb_to_bgr(image)
|
97 |
+
|
98 |
+
for _, row in predictions.iterrows():
|
99 |
+
xmin, ymin = int(row['xmin']), int(row['ymin'])
|
100 |
+
xmax, ymax = int(row['xmax']), int(row['ymax'])
|
101 |
+
|
102 |
+
if 'classification_label' in row and pd.notna(row['classification_label']):
|
103 |
+
label = str(row['classification_label'])
|
104 |
+
else:
|
105 |
+
label = str(row['label'])
|
106 |
+
color = colors.get(label.lower(), (200, 200, 200))
|
107 |
+
|
108 |
+
cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, thickness)
|
109 |
+
|
110 |
+
text_x = xmin
|
111 |
+
text_y = ymin - 10 if ymin - 10 > 10 else ymin + 15
|
112 |
+
cv2.putText(image, label, (text_x, text_y),
|
113 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, thickness)
|
114 |
+
|
115 |
+
image = convert_bgr_to_rgb(image)
|
116 |
+
|
117 |
+
return image
|
118 |
+
|
119 |
+
def predict_objects(
|
120 |
+
self,
|
121 |
+
image_data_array: Optional[np.ndarray] = None,
|
122 |
+
image_file_path: Optional[str] = None,
|
123 |
+
model_names: Optional[List[str]] = None,
|
124 |
+
patch_size: int = Config.DEEPFOREST_DEFAULTS["patch_size"],
|
125 |
+
patch_overlap: float = Config.DEEPFOREST_DEFAULTS["patch_overlap"],
|
126 |
+
iou_threshold: float = Config.DEEPFOREST_DEFAULTS["iou_threshold"],
|
127 |
+
thresh: float = Config.DEEPFOREST_DEFAULTS["thresh"],
|
128 |
+
alive_dead_trees: bool = Config.DEEPFOREST_DEFAULTS["alive_dead_trees"]
|
129 |
+
) -> Tuple[str, Optional[np.ndarray], List[Dict[str, Any]]]:
|
130 |
+
"""
|
131 |
+
Predict objects using DeepForest models with predict_tile method of DeepForest models
|
132 |
+
|
133 |
+
Args:
|
134 |
+
image_data_array: Input image as numpy array (optional if image_file_path not provided)
|
135 |
+
image_file_path: Path to image file
|
136 |
+
model_names: List of model names to use for prediction
|
137 |
+
patch_size: Size of patches for tiled prediction
|
138 |
+
patch_overlap: Patch overlap among windows
|
139 |
+
iou_threshold: Minimum IoU overlap among predictions between windows to be suppressed
|
140 |
+
thresh: Score threshold used to filter bboxes after soft-NMS is performed
|
141 |
+
alive_dead_trees: Whether to classify trees as alive/dead
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
Tuple containing:
|
145 |
+
- detection_summary: Human-readable summary of detections
|
146 |
+
- annotated_image_array: Image with bounding boxes drawn
|
147 |
+
- detections_list: List of detection data
|
148 |
+
"""
|
149 |
+
|
150 |
+
if model_names is None:
|
151 |
+
model_names = ["tree", "bird", "livestock"]
|
152 |
+
|
153 |
+
if image_file_path is None and image_data_array is None:
|
154 |
+
raise ValueError("Either image_data_array or image_file_path must be provided")
|
155 |
+
|
156 |
+
temp_file_path = None
|
157 |
+
use_provided_path = image_file_path is not None
|
158 |
+
|
159 |
+
if not use_provided_path:
|
160 |
+
if image_data_array is not None:
|
161 |
+
temp_file_path = create_temp_image_file(image_data_array, suffix=".png")
|
162 |
+
working_file_path = temp_file_path
|
163 |
+
working_array = image_data_array
|
164 |
+
else:
|
165 |
+
raise ValueError("image_data_array cannot be None when use_provided_path is False")
|
166 |
+
else:
|
167 |
+
working_file_path = image_file_path
|
168 |
+
working_array = load_image_as_np_array(image_file_path)
|
169 |
+
|
170 |
+
all_predictions_df = pd.DataFrame({
|
171 |
+
"xmin": pd.Series(dtype=int),
|
172 |
+
"ymin": pd.Series(dtype=int),
|
173 |
+
"xmax": pd.Series(dtype=int),
|
174 |
+
"ymax": pd.Series(dtype=int),
|
175 |
+
"score": pd.Series(dtype=float),
|
176 |
+
"label": pd.Series(dtype=str),
|
177 |
+
"model_type": pd.Series(dtype=str)
|
178 |
+
})
|
179 |
+
|
180 |
+
model_instances = {}
|
181 |
+
for model_name_key in model_names:
|
182 |
+
model_path = Config.DEEPFOREST_MODELS.get(model_name_key)
|
183 |
+
if model_path is None:
|
184 |
+
print(f"Warning: Model '{model_name_key}' not found in "
|
185 |
+
f"Config.DEEPFOREST_MODELS. Skipping.")
|
186 |
+
continue
|
187 |
+
|
188 |
+
try:
|
189 |
+
model = main.deepforest()
|
190 |
+
model.load_model(model_name=model_path)
|
191 |
+
model_instances[model_name_key] = model
|
192 |
+
except Exception as e:
|
193 |
+
print(f"Error loading DeepForest model '{model_name_key}' "
|
194 |
+
f"from path '{model_path}': {e}. Skipping this model.")
|
195 |
+
continue
|
196 |
+
|
197 |
+
temp_file_path = None
|
198 |
+
|
199 |
+
# Process each model
|
200 |
+
for model_type, model in model_instances.items():
|
201 |
+
current_predictions = pd.DataFrame()
|
202 |
+
try:
|
203 |
+
if model_type == "tree" and alive_dead_trees:
|
204 |
+
crop_model_instance = CropModel(num_classes=2)
|
205 |
+
current_predictions = model.predict_tile(
|
206 |
+
raster_path=working_file_path,
|
207 |
+
patch_size=patch_size,
|
208 |
+
patch_overlap=patch_overlap,
|
209 |
+
crop_model=crop_model_instance,
|
210 |
+
iou_threshold=iou_threshold,
|
211 |
+
thresh=thresh
|
212 |
+
)
|
213 |
+
else:
|
214 |
+
current_predictions = model.predict_tile(
|
215 |
+
raster_path=working_file_path,
|
216 |
+
patch_size=patch_size,
|
217 |
+
patch_overlap=patch_overlap,
|
218 |
+
iou_threshold=iou_threshold,
|
219 |
+
thresh=thresh
|
220 |
+
)
|
221 |
+
|
222 |
+
if not current_predictions.empty:
|
223 |
+
current_predictions['model_type'] = model_type
|
224 |
+
|
225 |
+
if 'label' in current_predictions.columns:
|
226 |
+
current_predictions['label'] = (
|
227 |
+
current_predictions['label'].apply(
|
228 |
+
lambda x: str(x).lower()
|
229 |
+
)
|
230 |
+
)
|
231 |
+
|
232 |
+
# Handle alive/dead tree classification results
|
233 |
+
if (alive_dead_trees and 'cropmodel_label' in
|
234 |
+
current_predictions.columns and model_type == "tree"):
|
235 |
+
current_predictions['classification_label'] = (
|
236 |
+
current_predictions.apply(
|
237 |
+
lambda row: (
|
238 |
+
'alive_tree' if row['cropmodel_label'] == 0
|
239 |
+
else 'dead_tree' if row['cropmodel_label'] == 1
|
240 |
+
else row['label']
|
241 |
+
),
|
242 |
+
axis=1
|
243 |
+
)
|
244 |
+
)
|
245 |
+
if 'cropmodel_score' in current_predictions.columns:
|
246 |
+
current_predictions['classification_score'] = current_predictions['cropmodel_score']
|
247 |
+
current_predictions = current_predictions.drop(columns=['cropmodel_score'], errors='ignore')
|
248 |
+
|
249 |
+
current_predictions = current_predictions.drop(
|
250 |
+
columns=['cropmodel_label'],
|
251 |
+
errors='ignore'
|
252 |
+
)
|
253 |
+
|
254 |
+
all_predictions_df = pd.concat(
|
255 |
+
[all_predictions_df, current_predictions],
|
256 |
+
ignore_index=True
|
257 |
+
)
|
258 |
+
|
259 |
+
except Exception as e:
|
260 |
+
print(f"Error during DeepForest prediction for model "
|
261 |
+
f"'{model_type}': {e}")
|
262 |
+
if temp_file_path:
|
263 |
+
cleanup_temp_file(temp_file_path)
|
264 |
+
|
265 |
+
# Generate detection summary
|
266 |
+
detection_summary = self._generate_detection_summary(
|
267 |
+
all_predictions_df, alive_dead_trees
|
268 |
+
)
|
269 |
+
|
270 |
+
# Create annotated image with bounding boxes
|
271 |
+
annotated_image_array = None
|
272 |
+
if working_array.ndim == 2:
|
273 |
+
annotated_image_array = cv2.cvtColor(
|
274 |
+
working_array, cv2.COLOR_GRAY2RGB
|
275 |
+
)
|
276 |
+
elif (working_array.ndim == 3 and
|
277 |
+
working_array.shape[2] == 4):
|
278 |
+
annotated_image_array = cv2.cvtColor(
|
279 |
+
working_array, cv2.COLOR_RGBA2RGB
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
annotated_image_array = working_array.copy()
|
283 |
+
|
284 |
+
if annotated_image_array.dtype != np.uint8:
|
285 |
+
annotated_image_array = annotated_image_array.astype(np.uint8)
|
286 |
+
|
287 |
+
annotated_image_array = self._plot_boxes(
|
288 |
+
annotated_image_array, all_predictions_df, Config.COLORS
|
289 |
+
)
|
290 |
+
|
291 |
+
output_df = all_predictions_df.copy()
|
292 |
+
|
293 |
+
essential_columns = ['xmin', 'ymin', 'xmax', 'ymax', 'score', 'label']
|
294 |
+
if 'classification_label' in output_df.columns:
|
295 |
+
essential_columns.append('classification_label')
|
296 |
+
if 'classification_score' in output_df.columns:
|
297 |
+
essential_columns.append('classification_score')
|
298 |
+
|
299 |
+
output_df = output_df[
|
300 |
+
[col for col in essential_columns if col in output_df.columns]
|
301 |
+
]
|
302 |
+
detections_list = []
|
303 |
+
if not output_df.empty:
|
304 |
+
for _, row in output_df.iterrows():
|
305 |
+
record = {
|
306 |
+
"xmin": int(row['xmin']),
|
307 |
+
"ymin": int(row['ymin']),
|
308 |
+
"xmax": int(row['xmax']),
|
309 |
+
"ymax": int(row['ymax']),
|
310 |
+
"score": float(row['score']),
|
311 |
+
"label": str(row['label'])
|
312 |
+
}
|
313 |
+
if 'classification_label' in row:
|
314 |
+
record["classification_label"] = str(row['classification_label'])
|
315 |
+
if 'classification_score' in row:
|
316 |
+
try:
|
317 |
+
record["classification_score"] = float(row['classification_score'])
|
318 |
+
except (ValueError, TypeError):
|
319 |
+
pass
|
320 |
+
|
321 |
+
detections_list.append(record)
|
322 |
+
|
323 |
+
return detection_summary, annotated_image_array, detections_list
|
src/deepforest_agent/tools/tool_handler.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from json import JSONDecoder
|
3 |
+
import re
|
4 |
+
from typing import Dict, Any, Optional, Union, List
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from deepforest_agent.tools.deepforest_tool import DeepForestPredictor
|
9 |
+
from deepforest_agent.utils.state_manager import session_state_manager
|
10 |
+
from deepforest_agent.utils.image_utils import validate_image_path
|
11 |
+
from deepforest_agent.conf.config import Config
|
12 |
+
|
13 |
+
deepforest_predictor = DeepForestPredictor()
|
14 |
+
|
15 |
+
def run_deepforest_object_detection(
|
16 |
+
session_id: str,
|
17 |
+
model_names: List[str] = ["tree", "bird", "livestock"],
|
18 |
+
patch_size: int = Config.DEEPFOREST_DEFAULTS["patch_size"],
|
19 |
+
patch_overlap: float = Config.DEEPFOREST_DEFAULTS["patch_overlap"],
|
20 |
+
iou_threshold: float = Config.DEEPFOREST_DEFAULTS["iou_threshold"],
|
21 |
+
thresh: float = Config.DEEPFOREST_DEFAULTS["thresh"],
|
22 |
+
alive_dead_trees: bool = Config.DEEPFOREST_DEFAULTS["alive_dead_trees"]
|
23 |
+
) -> Dict[str, Any]:
|
24 |
+
"""
|
25 |
+
Run DeepForest object detection on the globally stored image.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
session_id (str): Unique session identifier for this user
|
29 |
+
model_names: List of model names to use ("tree", "bird", "livestock")
|
30 |
+
patch_size: Patch size for each window in pixels (not geographic units). The size for the crops used to cut the input image/raster into smaller pieces.
|
31 |
+
patch_overlap: Patch overlap among windows. The horizontal and vertical overlap among patches (must be between 0-1).
|
32 |
+
iou_threshold: Minimum IoU overlap among predictions between windows to be suppressed.
|
33 |
+
thresh: Score threshold used to filter bboxes after soft-NMS is performed.
|
34 |
+
alive_dead_trees: Whether to classify trees as alive/dead
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Dictionary with detection_summary and detections_list
|
38 |
+
"""
|
39 |
+
# Validate session exists
|
40 |
+
if not session_state_manager.session_exists(session_id):
|
41 |
+
return {
|
42 |
+
"detection_summary": f"Session {session_id} not found.",
|
43 |
+
"detections_list": [],
|
44 |
+
"status": "error"
|
45 |
+
}
|
46 |
+
|
47 |
+
image_file_path = session_state_manager.get(session_id, "image_file_path")
|
48 |
+
current_image = session_state_manager.get(session_id, "current_image")
|
49 |
+
|
50 |
+
if image_file_path is None and current_image is None:
|
51 |
+
return {
|
52 |
+
"detection_summary": f"No image available for detection in session {session_id}.",
|
53 |
+
"detections_list": [],
|
54 |
+
"status": "error"
|
55 |
+
}
|
56 |
+
|
57 |
+
if image_file_path and not validate_image_path(image_file_path):
|
58 |
+
print(f"Warning: Invalid image file path {image_file_path}, falling back to PIL image")
|
59 |
+
image_file_path = None
|
60 |
+
|
61 |
+
try:
|
62 |
+
if image_file_path:
|
63 |
+
print(f"DeepForest: Processing image from file path: {image_file_path}")
|
64 |
+
detection_summary, annotated_image, detections_list = deepforest_predictor.predict_objects(
|
65 |
+
image_file_path=image_file_path,
|
66 |
+
model_names=model_names,
|
67 |
+
patch_size=patch_size,
|
68 |
+
patch_overlap=patch_overlap,
|
69 |
+
iou_threshold=iou_threshold,
|
70 |
+
thresh=thresh,
|
71 |
+
alive_dead_trees=alive_dead_trees
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
print(f"DeepForest: Processing PIL image (size: {current_image.size})")
|
75 |
+
image_array = np.array(current_image)
|
76 |
+
detection_summary, annotated_image, detections_list = deepforest_predictor.predict_objects(
|
77 |
+
image_data_array=image_array,
|
78 |
+
model_names=model_names,
|
79 |
+
patch_size=patch_size,
|
80 |
+
patch_overlap=patch_overlap,
|
81 |
+
iou_threshold=iou_threshold,
|
82 |
+
thresh=thresh,
|
83 |
+
alive_dead_trees=alive_dead_trees
|
84 |
+
)
|
85 |
+
|
86 |
+
if annotated_image is not None:
|
87 |
+
session_state_manager.set(session_id, "annotated_image", Image.fromarray(annotated_image))
|
88 |
+
|
89 |
+
result = {
|
90 |
+
"detection_summary": detection_summary,
|
91 |
+
"detections_list": detections_list,
|
92 |
+
"total_detections": len(detections_list),
|
93 |
+
"status": "success"
|
94 |
+
}
|
95 |
+
|
96 |
+
return result
|
97 |
+
|
98 |
+
except Exception as e:
|
99 |
+
error_msg = f"Error during image detection in session {session_id}: {str(e)}"
|
100 |
+
print(f"DeepForest Detection Error: {error_msg}")
|
101 |
+
return {
|
102 |
+
"detection_summary": error_msg,
|
103 |
+
"detections_list": [],
|
104 |
+
"total_detections": 0,
|
105 |
+
"status": "error"
|
106 |
+
}
|
107 |
+
|
108 |
+
def extract_all_tool_calls(text: str) -> List[Dict[str, Any]]:
|
109 |
+
"""
|
110 |
+
Extract all tool call information from model output text.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
text: The model's output text that may contain multiple tool calls
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
List of dictionaries with tool call info (empty list if none found)
|
117 |
+
"""
|
118 |
+
tool_calls = []
|
119 |
+
|
120 |
+
# Method 1: Wrapped in XML
|
121 |
+
xml_pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
|
122 |
+
xml_matches = re.findall(xml_pattern, text, re.DOTALL)
|
123 |
+
|
124 |
+
for match in xml_matches:
|
125 |
+
try:
|
126 |
+
result = json.loads(match.strip())
|
127 |
+
if isinstance(result, dict) and "name" in result and "arguments" in result:
|
128 |
+
print(f"Found valid XML tool call: {result}")
|
129 |
+
tool_calls.append(result)
|
130 |
+
except json.JSONDecodeError as e:
|
131 |
+
print(f"Failed to parse XML tool call JSON: {e}")
|
132 |
+
continue
|
133 |
+
|
134 |
+
# Method 2: If no XML format found, try raw JSON format
|
135 |
+
if not tool_calls:
|
136 |
+
decoder = JSONDecoder()
|
137 |
+
brace_start = 0
|
138 |
+
|
139 |
+
while True:
|
140 |
+
match = text.find('{', brace_start)
|
141 |
+
if match == -1:
|
142 |
+
break
|
143 |
+
try:
|
144 |
+
result, index = decoder.raw_decode(text[match:])
|
145 |
+
if isinstance(result, dict) and "name" in result and "arguments" in result:
|
146 |
+
print(f"Found valid raw JSON tool call: {result}")
|
147 |
+
tool_calls.append(result)
|
148 |
+
brace_start = match + index
|
149 |
+
else:
|
150 |
+
brace_start = match + 1
|
151 |
+
except ValueError:
|
152 |
+
brace_start = match + 1
|
153 |
+
|
154 |
+
print(f"Total tool calls extracted: {len(tool_calls)}")
|
155 |
+
return tool_calls
|
156 |
+
|
157 |
+
def handle_tool_call(tool_name: str, tool_arguments: Dict[str, Any], session_id: str) -> Union[str, Dict[str, Any]]:
|
158 |
+
"""
|
159 |
+
Handle tool call execution from tool name and arguments.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
tool_name (str): The name of the tool to be executed.
|
163 |
+
tool_arguments (Dict[str, Any]): A dictionary of arguments for the tool.
|
164 |
+
session_id: Unique session identifier for this user
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
Either error message (str) or tool execution result (dict)
|
168 |
+
"""
|
169 |
+
print(f"Tool Call Detected:")
|
170 |
+
print(f"Tool Name: {tool_name}")
|
171 |
+
print(f"Arguments: {tool_arguments}")
|
172 |
+
|
173 |
+
if tool_name == "run_deepforest_object_detection":
|
174 |
+
try:
|
175 |
+
result = run_deepforest_object_detection(session_id=session_id, **tool_arguments)
|
176 |
+
|
177 |
+
return result
|
178 |
+
|
179 |
+
except Exception as e:
|
180 |
+
error_msg = f"Error executing {tool_name} in session {session_id}: {str(e)}"
|
181 |
+
print(f"Tool Execution Failed: {error_msg}")
|
182 |
+
|
183 |
+
return error_msg
|
184 |
+
else:
|
185 |
+
error_msg = f"Unknown tool: {tool_name}"
|
186 |
+
print(f"Unknown Tool: {error_msg}")
|
187 |
+
|
188 |
+
return error_msg
|
src/deepforest_agent/utils/__init__.py
ADDED
File without changes
|
src/deepforest_agent/utils/cache_utils.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
import tempfile
|
5 |
+
import os
|
6 |
+
from typing import Dict, Any, Optional, List
|
7 |
+
from PIL import Image
|
8 |
+
import pickle
|
9 |
+
import gzip
|
10 |
+
import base64
|
11 |
+
|
12 |
+
from deepforest_agent.conf.config import Config
|
13 |
+
from deepforest_agent.utils.image_utils import convert_pil_image_to_bytes
|
14 |
+
|
15 |
+
class ToolCallCache:
|
16 |
+
"""
|
17 |
+
Cache utility with data handling and efficient image storage.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, cache_dir: Optional[str] = None):
|
21 |
+
"""
|
22 |
+
Initialize the tool call cache with data handling.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
cache_dir: Directory to store cached images. If None, uses system temp directory.
|
26 |
+
"""
|
27 |
+
self.cache_data = {}
|
28 |
+
|
29 |
+
if cache_dir is None:
|
30 |
+
self.cache_dir = os.path.join(tempfile.gettempdir(), "deepforest_cache")
|
31 |
+
else:
|
32 |
+
self.cache_dir = cache_dir
|
33 |
+
|
34 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
35 |
+
print(f"Cache directory: {self.cache_dir}")
|
36 |
+
|
37 |
+
def _normalize_arguments(self, arguments: Dict[str, Any]) -> str:
|
38 |
+
"""
|
39 |
+
Normalize tool arguments to create a consistent cache key.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
arguments: Tool arguments to normalize
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Normalized JSON string of arguments sorted by key
|
46 |
+
"""
|
47 |
+
normalized_args = Config.DEEPFOREST_DEFAULTS.copy()
|
48 |
+
normalized_args.update(arguments)
|
49 |
+
if "model_names" in arguments:
|
50 |
+
normalized_args["model_names"] = arguments["model_names"]
|
51 |
+
|
52 |
+
print(f"Cache normalization: {arguments} -> {normalized_args}")
|
53 |
+
return json.dumps(normalized_args, sort_keys=True, separators=(',', ':'))
|
54 |
+
|
55 |
+
def _create_cache_key(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
56 |
+
"""
|
57 |
+
Create a unique cache key from tool name and arguments.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
tool_name: Name of the tool being called
|
61 |
+
arguments: Arguments passed to the tool
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
MD5 hash that uniquely identifies this tool call
|
65 |
+
"""
|
66 |
+
cache_input = f"{tool_name}:{self._normalize_arguments(arguments)}"
|
67 |
+
return hashlib.md5(cache_input.encode('utf-8')).hexdigest()
|
68 |
+
|
69 |
+
def _store_image(self, image: Image.Image, cache_key: str) -> str:
|
70 |
+
"""
|
71 |
+
Store PIL Image while preserving original characteristics.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
image: PIL Image to store
|
75 |
+
cache_key: Unique identifier for this cache entry
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
File path where the image was stored
|
79 |
+
"""
|
80 |
+
if image is None:
|
81 |
+
return None
|
82 |
+
|
83 |
+
image_filename = f"cached_image_{cache_key}.pkl.gz"
|
84 |
+
image_path = os.path.join(self.cache_dir, image_filename)
|
85 |
+
|
86 |
+
try:
|
87 |
+
# Pickle for exact PIL Image preservation, compressed with gzip
|
88 |
+
with gzip.open(image_path, 'wb') as f:
|
89 |
+
pickle.dump(image, f, protocol=pickle.HIGHEST_PROTOCOL)
|
90 |
+
|
91 |
+
file_size_mb = os.path.getsize(image_path) / (1024 * 1024)
|
92 |
+
print(f"Image cached to {image_path} ({file_size_mb:.2f} MB)")
|
93 |
+
|
94 |
+
return image_path
|
95 |
+
|
96 |
+
except Exception as e:
|
97 |
+
print(f"Error storing image efficiently: {e}")
|
98 |
+
return self._fallback_image_storage(image)
|
99 |
+
|
100 |
+
def _load_image(self, image_path: str) -> Optional[Image.Image]:
|
101 |
+
"""
|
102 |
+
Load PIL Image from storage.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
image_path: File path where image was stored
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
Reconstructed PIL Image, or None if loading fails
|
109 |
+
"""
|
110 |
+
if not image_path or not os.path.exists(image_path):
|
111 |
+
return None
|
112 |
+
|
113 |
+
try:
|
114 |
+
with gzip.open(image_path, 'rb') as f:
|
115 |
+
image = pickle.load(f)
|
116 |
+
|
117 |
+
print(f"Image loaded from cache: {image_path}")
|
118 |
+
return image
|
119 |
+
|
120 |
+
except Exception as e:
|
121 |
+
print(f"Error loading cached image: {e}")
|
122 |
+
return None
|
123 |
+
|
124 |
+
def _fallback_image_storage(self, image: Image.Image) -> str:
|
125 |
+
"""
|
126 |
+
Fallback method for image storage when storage fails.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
image: PIL Image to store
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
Base64 encoded string of the image
|
133 |
+
"""
|
134 |
+
img_bytes = convert_pil_image_to_bytes(image)
|
135 |
+
|
136 |
+
return base64.b64encode(img_bytes).decode('utf-8')
|
137 |
+
|
138 |
+
def get_cached_result(self, tool_name: str, arguments: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
139 |
+
"""
|
140 |
+
Retrieve cached result with data handling.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
tool_name: Name of the tool being called
|
144 |
+
arguments: Arguments for the tool call
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
Dictionary containing all cached data or None if not found
|
148 |
+
"""
|
149 |
+
cache_key = self._create_cache_key(tool_name, arguments)
|
150 |
+
|
151 |
+
if cache_key not in self.cache_data:
|
152 |
+
print(f"Cache MISS: No cached result for {tool_name} with key {cache_key}")
|
153 |
+
return None
|
154 |
+
|
155 |
+
cached_entry = self.cache_data[cache_key]
|
156 |
+
cached_result = {}
|
157 |
+
|
158 |
+
if "detection_summary" in cached_entry["result"]:
|
159 |
+
cached_result["detection_summary"] = cached_entry["result"]["detection_summary"]
|
160 |
+
print(f"Cache: Retrieved detection_summary: {cached_result['detection_summary']}")
|
161 |
+
|
162 |
+
if "detections_list" in cached_entry["result"]:
|
163 |
+
cached_result["detections_list"] = cached_entry["result"]["detections_list"]
|
164 |
+
print(f"Cache: Retrieved {len(cached_result['detections_list'])} detections")
|
165 |
+
|
166 |
+
if "total_detections" in cached_entry["result"]:
|
167 |
+
cached_result["total_detections"] = cached_entry["result"]["total_detections"]
|
168 |
+
|
169 |
+
if "status" in cached_entry["result"]:
|
170 |
+
cached_result["status"] = cached_entry["result"]["status"]
|
171 |
+
|
172 |
+
if "annotated_image_path" in cached_entry["result"]:
|
173 |
+
cached_result["annotated_image"] = self._load_image(
|
174 |
+
cached_entry["result"]["annotated_image_path"]
|
175 |
+
)
|
176 |
+
if cached_result["annotated_image"]:
|
177 |
+
print(f"Cache: Retrieved annotated image ({cached_result['annotated_image'].size})")
|
178 |
+
|
179 |
+
cached_result["cache_info"] = {
|
180 |
+
"cached_at": cached_entry["timestamp"],
|
181 |
+
"cache_hit": True,
|
182 |
+
"cache_key": cache_key,
|
183 |
+
"tool_name": tool_name,
|
184 |
+
"arguments": arguments
|
185 |
+
}
|
186 |
+
|
187 |
+
print(f"Successfully retrieved all data for {tool_name}")
|
188 |
+
return cached_result
|
189 |
+
|
190 |
+
def store_result(self, tool_name: str, arguments: Dict[str, Any], result: Dict[str, Any]) -> str:
|
191 |
+
"""
|
192 |
+
Store tool call result with data handling.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
tool_name: Name of the tool that was executed
|
196 |
+
arguments: Arguments that were passed to the tool
|
197 |
+
result: Result dictionary containing:
|
198 |
+
- detection_summary (str): Text summary of what was detected
|
199 |
+
- detections_list (List): List of detection objects
|
200 |
+
- total_detections (int): Count of detections
|
201 |
+
- status (str): Success/error status
|
202 |
+
- annotated_image (PIL.Image, optional): Image with annotations
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
Cache key that was used to store this result
|
206 |
+
"""
|
207 |
+
cache_key = self._create_cache_key(tool_name, arguments)
|
208 |
+
|
209 |
+
storable_result = {}
|
210 |
+
|
211 |
+
if "detection_summary" in result:
|
212 |
+
storable_result["detection_summary"] = result["detection_summary"]
|
213 |
+
print(f"Detection_summary = {result['detection_summary']}")
|
214 |
+
else:
|
215 |
+
print("No detection_summary found in result to cache")
|
216 |
+
|
217 |
+
if "detections_list" in result:
|
218 |
+
storable_result["detections_list"] = result["detections_list"]
|
219 |
+
print(f"Detections_list with {len(result['detections_list'])} items")
|
220 |
+
else:
|
221 |
+
print("No detections_list found in result to cache")
|
222 |
+
storable_result["detections_list"] = []
|
223 |
+
|
224 |
+
if "total_detections" in result:
|
225 |
+
storable_result["total_detections"] = result["total_detections"]
|
226 |
+
else:
|
227 |
+
storable_result["total_detections"] = len(storable_result["detections_list"])
|
228 |
+
|
229 |
+
if "status" in result:
|
230 |
+
storable_result["status"] = result["status"]
|
231 |
+
else:
|
232 |
+
storable_result["status"] = "unknown"
|
233 |
+
|
234 |
+
if "annotated_image" in result and result["annotated_image"] is not None:
|
235 |
+
image_path = self._store_image(result["annotated_image"], cache_key)
|
236 |
+
if image_path:
|
237 |
+
storable_result["annotated_image_path"] = image_path
|
238 |
+
print(f"Annotated_image stored efficiently")
|
239 |
+
else:
|
240 |
+
print("No annotated_image to store")
|
241 |
+
|
242 |
+
self.cache_data[cache_key] = {
|
243 |
+
"tool_name": tool_name,
|
244 |
+
"arguments": arguments.copy(),
|
245 |
+
"result": storable_result,
|
246 |
+
"timestamp": time.time(),
|
247 |
+
"cache_key": cache_key
|
248 |
+
}
|
249 |
+
|
250 |
+
print(f"Successfully cached all data for {tool_name} with key {cache_key}")
|
251 |
+
return cache_key
|
252 |
+
|
253 |
+
def get_cache_stats(self) -> Dict[str, Any]:
|
254 |
+
"""
|
255 |
+
Get detailed statistics about cached data.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
Dictionary with comprehensive cache statistics
|
259 |
+
"""
|
260 |
+
total_images = 0
|
261 |
+
total_detections = 0
|
262 |
+
cache_size_mb = 0
|
263 |
+
|
264 |
+
for entry in self.cache_data.values():
|
265 |
+
result = entry["result"]
|
266 |
+
|
267 |
+
if "annotated_image_path" in result:
|
268 |
+
total_images += 1
|
269 |
+
# Calculate file size if image exists
|
270 |
+
if os.path.exists(result["annotated_image_path"]):
|
271 |
+
cache_size_mb += os.path.getsize(result["annotated_image_path"]) / (1024 * 1024)
|
272 |
+
|
273 |
+
# Count total detections across all cached results
|
274 |
+
total_detections += result.get("total_detections", 0)
|
275 |
+
|
276 |
+
return {
|
277 |
+
"total_entries": len(self.cache_data),
|
278 |
+
"total_images_cached": total_images,
|
279 |
+
"total_detections_cached": total_detections,
|
280 |
+
"cache_size_mb": round(cache_size_mb, 2),
|
281 |
+
"cache_directory": self.cache_dir,
|
282 |
+
"tools_cached": set(entry["tool_name"] for entry in self.cache_data.values())
|
283 |
+
}
|
284 |
+
|
285 |
+
def cleanup_cache_files(self):
|
286 |
+
"""
|
287 |
+
Clean up cached image files from disk.
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
The total number of files that were successfully removed.
|
291 |
+
"""
|
292 |
+
files_removed = 0
|
293 |
+
for entry in self.cache_data.values():
|
294 |
+
if "annotated_image_path" in entry["result"]:
|
295 |
+
image_path = entry["result"]["annotated_image_path"]
|
296 |
+
if os.path.exists(image_path):
|
297 |
+
try:
|
298 |
+
os.remove(image_path)
|
299 |
+
files_removed += 1
|
300 |
+
except Exception as e:
|
301 |
+
print(f"Error removing cached image {image_path}: {e}")
|
302 |
+
|
303 |
+
print(f"Cleaned up {files_removed} cached image files")
|
304 |
+
return files_removed
|
305 |
+
|
306 |
+
tool_call_cache = ToolCallCache()
|
src/deepforest_agent/utils/detection_narrative_generator.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import List, Dict, Any
|
3 |
+
from collections import Counter, defaultdict
|
4 |
+
|
5 |
+
from deepforest_agent.utils.rtree_spatial_utils import DetectionSpatialAnalyzer
|
6 |
+
|
7 |
+
|
8 |
+
class DetectionNarrativeGenerator:
|
9 |
+
"""
|
10 |
+
Generates natural language narratives from DeepForest detection results with proper classification handling.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, image_width: int, image_height: int):
|
14 |
+
"""
|
15 |
+
Initialize narrative generator with image dimensions.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
image_width: Width of the image in pixels
|
19 |
+
image_height: Height of the image in pixels
|
20 |
+
"""
|
21 |
+
self.image_width = image_width
|
22 |
+
self.image_height = image_height
|
23 |
+
self.spatial_analyzer = DetectionSpatialAnalyzer(image_width, image_height)
|
24 |
+
|
25 |
+
def generate_comprehensive_narrative(self, detections_list: List[Dict[str, Any]]) -> str:
|
26 |
+
"""
|
27 |
+
Generate comprehensive detection narrative using spatial analysis with proper classification handling.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
detections_list: List of detection dictionaries from DeepForest
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
Natural language narrative describing all aspects of detections
|
34 |
+
"""
|
35 |
+
if not detections_list:
|
36 |
+
return "No objects were detected by DeepForest in this image."
|
37 |
+
|
38 |
+
# Add detections to spatial analyzer
|
39 |
+
self.spatial_analyzer.add_detections(detections_list)
|
40 |
+
|
41 |
+
# Get comprehensive statistics
|
42 |
+
stats = self.spatial_analyzer.get_detection_statistics()
|
43 |
+
grid_analysis = self.spatial_analyzer.get_grid_analysis()
|
44 |
+
|
45 |
+
narrative_parts = []
|
46 |
+
|
47 |
+
# 1. Overall Summary with proper classification handling
|
48 |
+
narrative_parts.append(self._generate_overall_summary(detections_list))
|
49 |
+
|
50 |
+
# 2. Confidence Analysis
|
51 |
+
narrative_parts.append(self._generate_confidence_analysis(detections_list))
|
52 |
+
|
53 |
+
# 3. Spatial Distribution Analysis
|
54 |
+
narrative_parts.append(self._generate_spatial_distribution_narrative(grid_analysis, detections_list))
|
55 |
+
|
56 |
+
# 4. Spatial Relationships Analysis using R-tree indexing
|
57 |
+
narrative_parts.append(self._generate_spatial_relationships_narrative(detections_list))
|
58 |
+
|
59 |
+
# 5. Object Coverage Analysis
|
60 |
+
narrative_parts.append(self._generate_coverage_analysis(detections_list))
|
61 |
+
|
62 |
+
return "\n\n".join(narrative_parts)
|
63 |
+
|
64 |
+
def _generate_overall_summary(self, detections_list: List[Dict[str, Any]]) -> str:
|
65 |
+
"""
|
66 |
+
Generate overall detection summary with proper classification handling.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
detections_list (List[Dict[str, Any]]): List of all detection results
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
str: Formatted summary section including:
|
73 |
+
- Total detection count and average confidence
|
74 |
+
- Base object counts (birds, trees, livestock)
|
75 |
+
- Tree classification breakdown (alive trees, dead trees)
|
76 |
+
"""
|
77 |
+
total_count = len(detections_list)
|
78 |
+
|
79 |
+
# Calculate overall confidence
|
80 |
+
scores = [s for s in (d.get("score") for d in detections_list) if s is not None and np.isfinite(s)]
|
81 |
+
overall_confidence = float(np.mean(scores)) if scores else 0.0
|
82 |
+
|
83 |
+
# Proper object counting with classification handling
|
84 |
+
base_label_counts = {} # bird, tree, livestock
|
85 |
+
classification_counts = {} # alive_tree, dead_tree
|
86 |
+
|
87 |
+
for detection in detections_list:
|
88 |
+
base_label = detection.get('label', 'unknown')
|
89 |
+
base_label_counts[base_label] = base_label_counts.get(base_label, 0) + 1
|
90 |
+
|
91 |
+
# Handle tree classifications
|
92 |
+
if base_label == 'tree':
|
93 |
+
classification_label = detection.get('classification_label')
|
94 |
+
classification_score = detection.get('classification_score')
|
95 |
+
|
96 |
+
# Only count valid classifications (not NaN or None)
|
97 |
+
if (classification_label and
|
98 |
+
classification_score is not None and
|
99 |
+
str(classification_label).lower() != 'nan' and
|
100 |
+
str(classification_score).lower() != 'nan'):
|
101 |
+
|
102 |
+
classification_counts[classification_label] = classification_counts.get(classification_label, 0) + 1
|
103 |
+
|
104 |
+
summary = f"**Overall Detection Summary**\n"
|
105 |
+
summary += f"In the whole image, {total_count} objects were detected with an average confidence of {overall_confidence:.3f}.\n\n"
|
106 |
+
|
107 |
+
# Object breakdown with proper classification display
|
108 |
+
object_parts = []
|
109 |
+
for label, count in base_label_counts.items():
|
110 |
+
label_name = label.replace('_', ' ')
|
111 |
+
|
112 |
+
if label == 'tree' and classification_counts:
|
113 |
+
# Special handling for trees with classifications
|
114 |
+
total_trees = count
|
115 |
+
classified_trees = sum(classification_counts.values())
|
116 |
+
|
117 |
+
if classified_trees > 0:
|
118 |
+
tree_part = f"{total_trees} trees are detected"
|
119 |
+
classification_parts = []
|
120 |
+
for class_label, class_count in classification_counts.items():
|
121 |
+
class_name = class_label.replace('_', ' ')
|
122 |
+
classification_parts.append(f"{class_count} {class_name}s")
|
123 |
+
|
124 |
+
tree_part += f". These {total_trees} trees are classified as {' and '.join(classification_parts)}"
|
125 |
+
object_parts.append(tree_part)
|
126 |
+
else:
|
127 |
+
object_parts.append(f"{count} {label_name}{'s' if count != 1 else ''}")
|
128 |
+
else:
|
129 |
+
object_parts.append(f"{count} {label_name}{'s' if count != 1 else ''}")
|
130 |
+
|
131 |
+
summary += "Whole image Object breakdown: " + ", ".join(object_parts) + "."
|
132 |
+
|
133 |
+
return summary
|
134 |
+
|
135 |
+
def _generate_confidence_analysis(self, detections_list: List[Dict[str, Any]]) -> str:
|
136 |
+
"""
|
137 |
+
Generate confidence-based analysis with proper classification handling.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
detections_list (List[Dict[str, Any]]): List of all detection results
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
str: Formatted confidence analysis section including:
|
144 |
+
- Object counts per confidence range
|
145 |
+
- Base object type breakdown within each range
|
146 |
+
- Tree classification details (alive/dead) within each range
|
147 |
+
"""
|
148 |
+
# Group by confidence ranges
|
149 |
+
confidence_groups = {
|
150 |
+
"Detections with High Confidence Score (0.7-1.0)": [],
|
151 |
+
"Detections with Medium Confidence Score (0.3-0.7)": [],
|
152 |
+
"Detections with Low Confidence Score (0.0-0.3)": []
|
153 |
+
}
|
154 |
+
|
155 |
+
for detection in detections_list:
|
156 |
+
score = detection.get('score', 0.0)
|
157 |
+
if score >= 0.7:
|
158 |
+
confidence_groups["Detections with High Confidence Score (0.7-1.0)"].append(detection)
|
159 |
+
elif score >= 0.3:
|
160 |
+
confidence_groups["Detections with Medium Confidence Score (0.3-0.7)"].append(detection)
|
161 |
+
else:
|
162 |
+
confidence_groups["Detections with Low Confidence Score (0.0-0.3)"].append(detection)
|
163 |
+
|
164 |
+
narrative = f"**Whole image Confidence Score Analysis**\n"
|
165 |
+
|
166 |
+
for conf_range, detections in confidence_groups.items():
|
167 |
+
if not detections:
|
168 |
+
narrative += f"{conf_range}: No objects detected\n"
|
169 |
+
continue
|
170 |
+
|
171 |
+
count = len(detections)
|
172 |
+
narrative += f"{conf_range}: {count} objects detected in the whole image\n"
|
173 |
+
|
174 |
+
# Count by base labels and classifications
|
175 |
+
base_counts = {}
|
176 |
+
class_counts = {}
|
177 |
+
|
178 |
+
for detection in detections:
|
179 |
+
base_label = detection.get('label', 'unknown')
|
180 |
+
base_counts[base_label] = base_counts.get(base_label, 0) + 1
|
181 |
+
|
182 |
+
if base_label == 'tree':
|
183 |
+
classification_label = detection.get('classification_label')
|
184 |
+
if (classification_label and
|
185 |
+
str(classification_label).lower() != 'nan'):
|
186 |
+
class_counts[classification_label] = class_counts.get(classification_label, 0) + 1
|
187 |
+
|
188 |
+
# Display breakdown
|
189 |
+
breakdown_parts = []
|
190 |
+
for label, label_count in base_counts.items():
|
191 |
+
if label == 'tree' and class_counts:
|
192 |
+
tree_part = f"{label_count} trees"
|
193 |
+
class_parts = []
|
194 |
+
for class_label, class_count in class_counts.items():
|
195 |
+
class_name = class_label.replace('_', ' ')
|
196 |
+
class_parts.append(f"{class_count} {class_name}s")
|
197 |
+
if class_parts:
|
198 |
+
tree_part += f" ({', '.join(class_parts)})"
|
199 |
+
breakdown_parts.append(tree_part)
|
200 |
+
else:
|
201 |
+
label_name = label.replace('_', ' ')
|
202 |
+
breakdown_parts.append(f"{label_count} {label_name}{'s' if label_count != 1 else ''}")
|
203 |
+
|
204 |
+
narrative += f" - {', '.join(breakdown_parts)}\n"
|
205 |
+
|
206 |
+
return narrative
|
207 |
+
|
208 |
+
def _generate_spatial_distribution_narrative(self, grid_analysis: Dict[str, Dict[str, Any]], detections_list: List[Dict[str, Any]]) -> str:
|
209 |
+
"""
|
210 |
+
Generate spatial distribution narrative using 9-grid analysis
|
211 |
+
|
212 |
+
Args:
|
213 |
+
grid_analysis (Dict[str, Dict[str, Any]]): Pre-computed grid analysis from spatial_analyzer
|
214 |
+
containing detection counts and confidence analysis for each grid section
|
215 |
+
detections_list (List[Dict[str, Any]]): Original detection list for additional processing
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
str: Formatted spatial distribution section including:
|
219 |
+
- Grid-by-grid object analysis with confidence breakdowns
|
220 |
+
- Tree classification details within each grid section
|
221 |
+
- Density pattern identification (dense vs sparse regions)
|
222 |
+
"""
|
223 |
+
narrative = f"**Spatial Distribution Analysis**\n"
|
224 |
+
narrative += f"The image is divided into nine grid sections for spatial analysis:\n\n"
|
225 |
+
|
226 |
+
# Grid-by-grid analysis
|
227 |
+
for grid_name, grid_data in grid_analysis.items():
|
228 |
+
total_dets = grid_data['total_detections']
|
229 |
+
conf_analysis = grid_data['confidence_analysis']
|
230 |
+
|
231 |
+
if total_dets == 0:
|
232 |
+
narrative += f"{grid_name}: No objects detected\n"
|
233 |
+
continue
|
234 |
+
|
235 |
+
narrative += f"{grid_name}: {total_dets} objects detected\n"
|
236 |
+
|
237 |
+
# Per confidence category analysis
|
238 |
+
for conf_category, conf_data in conf_analysis.items():
|
239 |
+
if conf_data['count'] > 0:
|
240 |
+
# Count base labels and classifications for this grid/confidence
|
241 |
+
grid_detections = [d for d in detections_list
|
242 |
+
if self._detection_in_grid(d, grid_data['bounds'])]
|
243 |
+
|
244 |
+
conf_range = self._get_confidence_range(conf_category)
|
245 |
+
conf_detections = [d for d in grid_detections
|
246 |
+
if conf_range[0] <= d.get('score', 0) < conf_range[1] or
|
247 |
+
(conf_range[1] == 1.0 and d.get('score', 0) == 1.0)]
|
248 |
+
|
249 |
+
base_counts, class_counts = self._count_labels_with_classification(conf_detections)
|
250 |
+
|
251 |
+
# Display object breakdown
|
252 |
+
object_desc = []
|
253 |
+
for label, count in base_counts.items():
|
254 |
+
if label == 'tree' and label in class_counts:
|
255 |
+
tree_desc = f"{count} trees"
|
256 |
+
if class_counts[label]:
|
257 |
+
class_parts = []
|
258 |
+
for class_label, class_count in class_counts[label].items():
|
259 |
+
class_name = class_label.replace('_', ' ')
|
260 |
+
class_parts.append(f"{class_count} {class_name}s")
|
261 |
+
tree_desc += f" ({', '.join(class_parts)})"
|
262 |
+
object_desc.append(tree_desc)
|
263 |
+
else:
|
264 |
+
label_name = label.replace('_', ' ')
|
265 |
+
object_desc.append(f"{count} {label_name}{'s' if count != 1 else ''}")
|
266 |
+
|
267 |
+
# Simple description
|
268 |
+
narrative += f" - {conf_category}: {', '.join(object_desc)}\n"
|
269 |
+
|
270 |
+
narrative += "\n"
|
271 |
+
|
272 |
+
# Overall density patterns
|
273 |
+
grid_counts = {name: data['total_detections'] for name, data in grid_analysis.items()}
|
274 |
+
avg_count = sum(grid_counts.values()) / len(grid_counts) if grid_counts else 0
|
275 |
+
|
276 |
+
dense_regions = [name for name, count in grid_counts.items() if count > avg_count]
|
277 |
+
sparse_regions = [name for name, count in grid_counts.items() if count < avg_count]
|
278 |
+
|
279 |
+
if dense_regions or sparse_regions:
|
280 |
+
narrative += "**Density Patterns:**\n"
|
281 |
+
if dense_regions:
|
282 |
+
narrative += f"Dense regions: {', '.join(dense_regions)}\n"
|
283 |
+
if sparse_regions:
|
284 |
+
narrative += f"Sparse regions: {', '.join(sparse_regions)}\n"
|
285 |
+
|
286 |
+
return narrative
|
287 |
+
|
288 |
+
def _generate_coverage_analysis(self, detections_list: List[Dict[str, Any]]) -> str:
|
289 |
+
"""
|
290 |
+
Generate object coverage analysis broken down by object type.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
detections_list (List[Dict[str, Any]]): List of all detection results
|
294 |
+
|
295 |
+
Returns:
|
296 |
+
str: Formatted coverage analysis including:
|
297 |
+
- Percentage coverage for each object type (birds, trees, livestock)
|
298 |
+
- Tree classification coverage breakdown (alive trees vs dead trees)
|
299 |
+
- Total area calculations relative to full image
|
300 |
+
"""
|
301 |
+
narrative = f"**Object Coverage Analysis**\n"
|
302 |
+
|
303 |
+
total_image_area = self.image_width * self.image_height
|
304 |
+
|
305 |
+
# Calculate coverage by object type
|
306 |
+
base_coverage = {}
|
307 |
+
classification_coverage = {}
|
308 |
+
|
309 |
+
for detection in detections_list:
|
310 |
+
width = detection.get('xmax', 0) - detection.get('xmin', 0)
|
311 |
+
height = detection.get('ymax', 0) - detection.get('ymin', 0)
|
312 |
+
area = width * height
|
313 |
+
|
314 |
+
base_label = detection.get('label', 'unknown')
|
315 |
+
base_coverage[base_label] = base_coverage.get(base_label, 0) + area
|
316 |
+
|
317 |
+
# Handle tree classifications
|
318 |
+
if base_label == 'tree':
|
319 |
+
classification_label = detection.get('classification_label')
|
320 |
+
if (classification_label and
|
321 |
+
str(classification_label).lower() != 'nan'):
|
322 |
+
classification_coverage[classification_label] = classification_coverage.get(classification_label, 0) + area
|
323 |
+
|
324 |
+
# Display coverage percentages
|
325 |
+
coverage_parts = []
|
326 |
+
for label, area in base_coverage.items():
|
327 |
+
coverage_percent = (area / total_image_area) * 100
|
328 |
+
|
329 |
+
if label == 'tree' and classification_coverage:
|
330 |
+
# Show tree breakdown
|
331 |
+
tree_coverage = f"{label}s: {coverage_percent:.2f}%"
|
332 |
+
|
333 |
+
class_parts = []
|
334 |
+
for class_label, class_area in classification_coverage.items():
|
335 |
+
class_percent = (class_area / total_image_area) * 100
|
336 |
+
class_name = class_label.replace('_', ' ')
|
337 |
+
class_parts.append(f"{class_name}s: {class_percent:.2f}%")
|
338 |
+
|
339 |
+
if class_parts:
|
340 |
+
tree_coverage += f" ({', '.join(class_parts)})"
|
341 |
+
coverage_parts.append(tree_coverage)
|
342 |
+
else:
|
343 |
+
label_name = label.replace('_', ' ')
|
344 |
+
coverage_parts.append(f"{label_name}s: {coverage_percent:.2f}%")
|
345 |
+
|
346 |
+
narrative += ", ".join(coverage_parts) + " of the total image area."
|
347 |
+
|
348 |
+
return narrative
|
349 |
+
|
350 |
+
def _generate_spatial_relationships_narrative(self, detections_list: List[Dict[str, Any]]) -> str:
|
351 |
+
"""
|
352 |
+
Generate spatial relationships narrative using R-tree indexing.
|
353 |
+
|
354 |
+
Args:
|
355 |
+
detections_list (List[Dict[str, Any]]): List of all detection results
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
str: Formatted spatial relationships section including:
|
359 |
+
- Count of high-confidence objects analyzed
|
360 |
+
- R-tree based intersection and proximity analysis
|
361 |
+
- Natural language descriptions of object relationships
|
362 |
+
- Confidence threshold information (>= 0.3)
|
363 |
+
"""
|
364 |
+
spatial_relationships = self.spatial_analyzer.analyze_spatial_relationships_with_indexing(confidence_threshold=0.3)
|
365 |
+
|
366 |
+
if not spatial_relationships:
|
367 |
+
return "**Spatial Relationships Analysis (Confidence ≥ 0.3)**\nNo objects with sufficient confidence found for spatial relationship analysis."
|
368 |
+
|
369 |
+
narrative = f"**Spatial Relationships Analysis in the whole image (Confidence ≥ 0.3)**\n"
|
370 |
+
|
371 |
+
# Generate narrative using the spatial analyzer
|
372 |
+
spatial_narrative = self.spatial_analyzer.generate_spatial_narrative(confidence_threshold=0.3)
|
373 |
+
narrative += spatial_narrative
|
374 |
+
|
375 |
+
return narrative
|
376 |
+
|
377 |
+
def _detection_in_grid(self, detection: Dict[str, Any], grid_bounds: Dict[str, float]) -> bool:
|
378 |
+
"""
|
379 |
+
Check if detection overlaps with grid bounds.
|
380 |
+
|
381 |
+
Args:
|
382 |
+
detection (Dict[str, Any]): Detection dictionary with 'xmin', 'ymin', 'xmax', 'ymax' keys
|
383 |
+
grid_bounds (Dict[str, float]): Grid section bounds with 'x_min', 'y_min', 'x_max', 'y_max' keys
|
384 |
+
|
385 |
+
Returns:
|
386 |
+
bool: True if detection bounding box overlaps with grid bounds, False otherwise
|
387 |
+
"""
|
388 |
+
det_xmin = detection.get('xmin', 0)
|
389 |
+
det_ymin = detection.get('ymin', 0)
|
390 |
+
det_xmax = detection.get('xmax', 0)
|
391 |
+
det_ymax = detection.get('ymax', 0)
|
392 |
+
|
393 |
+
return not (det_xmax <= grid_bounds['x_min'] or det_xmin >= grid_bounds['x_max'] or
|
394 |
+
det_ymax <= grid_bounds['y_min'] or det_ymin >= grid_bounds['y_max'])
|
395 |
+
|
396 |
+
def _get_confidence_range(self, conf_category: str) -> tuple:
|
397 |
+
"""
|
398 |
+
Get confidence range tuple from category string.
|
399 |
+
|
400 |
+
Args:
|
401 |
+
conf_category (str): Category name containing "High", "Medium", or other confidence indicator
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
tuple: (min_confidence, max_confidence) as floats
|
405 |
+
- High: (0.7, 1.0)
|
406 |
+
- Medium: (0.3, 0.7)
|
407 |
+
- Low/Other: (0.0, 0.3)
|
408 |
+
"""
|
409 |
+
if "High" in conf_category:
|
410 |
+
return (0.7, 1.0)
|
411 |
+
elif "Medium" in conf_category:
|
412 |
+
return (0.3, 0.7)
|
413 |
+
else:
|
414 |
+
return (0.0, 0.3)
|
415 |
+
|
416 |
+
def _count_labels_with_classification(self, detections: List[Dict[str, Any]]) -> tuple:
|
417 |
+
"""
|
418 |
+
Count base labels and classifications separately.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
detections (List[Dict[str, Any]]): List of detection dictionaries
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
tuple: (base_counts, class_counts) where:
|
425 |
+
- base_counts (Dict[str, int]): Count of each base object type
|
426 |
+
- class_counts (Dict[str, Dict[str, int]]): Nested count structure for
|
427 |
+
tree classifications under 'tree' key
|
428 |
+
"""
|
429 |
+
base_counts = {}
|
430 |
+
class_counts = {}
|
431 |
+
|
432 |
+
for detection in detections:
|
433 |
+
base_label = detection.get('label', 'unknown')
|
434 |
+
base_counts[base_label] = base_counts.get(base_label, 0) + 1
|
435 |
+
|
436 |
+
if base_label == 'tree':
|
437 |
+
classification_label = detection.get('classification_label')
|
438 |
+
if (classification_label and
|
439 |
+
str(classification_label).lower() != 'nan'):
|
440 |
+
|
441 |
+
if base_label not in class_counts:
|
442 |
+
class_counts[base_label] = {}
|
443 |
+
class_counts[base_label][classification_label] = class_counts[base_label].get(classification_label, 0) + 1
|
444 |
+
|
445 |
+
return base_counts, class_counts
|
src/deepforest_agent/utils/image_utils.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
from typing import Dict, Any, List, Literal, Optional, Tuple
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import tempfile
|
9 |
+
import rasterio
|
10 |
+
|
11 |
+
from deepforest_agent.conf.config import Config
|
12 |
+
|
13 |
+
def load_image_as_np_array(image_path: str) -> np.ndarray:
|
14 |
+
"""
|
15 |
+
Load an image from a file path as a NumPy array.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
image_path: Path to the image file
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
RGB image as numpy array, or None if not found
|
22 |
+
|
23 |
+
Raises:
|
24 |
+
FileNotFoundError: If image file is not found at any expected path
|
25 |
+
"""
|
26 |
+
if not os.path.exists(image_path):
|
27 |
+
raise FileNotFoundError(
|
28 |
+
f"Image not found at any expected path: {image_path}"
|
29 |
+
)
|
30 |
+
|
31 |
+
img = Image.open(image_path)
|
32 |
+
if img.mode != 'RGB':
|
33 |
+
img = img.convert('RGB')
|
34 |
+
return np.array(img)
|
35 |
+
|
36 |
+
|
37 |
+
def load_pil_image_from_path(image_path: str) -> Optional[Image.Image]:
|
38 |
+
"""
|
39 |
+
Load PIL Image from file path.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
image_path: Path to the image file
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
PIL Image object, or None if loading fails
|
46 |
+
|
47 |
+
Raises:
|
48 |
+
FileNotFoundError: If image file is not found
|
49 |
+
Exception: If image cannot be loaded or converted
|
50 |
+
"""
|
51 |
+
if not os.path.exists(image_path):
|
52 |
+
raise FileNotFoundError(f"Image not found at path: {image_path}")
|
53 |
+
|
54 |
+
try:
|
55 |
+
img = Image.open(image_path)
|
56 |
+
if img.mode != 'RGB':
|
57 |
+
img = img.convert('RGB')
|
58 |
+
return img
|
59 |
+
except Exception as e:
|
60 |
+
print(f"Error loading PIL image from {image_path}: {e}")
|
61 |
+
return None
|
62 |
+
|
63 |
+
|
64 |
+
def create_temp_image_file(image_array: np.ndarray, suffix: str = ".png") -> str:
|
65 |
+
"""
|
66 |
+
Create a temporary image file from numpy array.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
image_array: Image as numpy array
|
70 |
+
suffix: File extension (default: ".png")
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Path to temporary file
|
74 |
+
|
75 |
+
Raises:
|
76 |
+
Exception: If temporary file creation fails
|
77 |
+
"""
|
78 |
+
try:
|
79 |
+
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_file:
|
80 |
+
temp_file_path = tmp_file.name
|
81 |
+
|
82 |
+
pil_image = Image.fromarray(image_array)
|
83 |
+
pil_image.save(temp_file_path, format='PNG')
|
84 |
+
|
85 |
+
print(f"Created temporary image file: {temp_file_path}")
|
86 |
+
return temp_file_path
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error creating temporary image file: {e}")
|
90 |
+
raise e
|
91 |
+
|
92 |
+
|
93 |
+
def cleanup_temp_file(file_path: str) -> bool:
|
94 |
+
"""
|
95 |
+
Clean up temporary file.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
file_path: Path to file to remove
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
True if successful, False otherwise
|
102 |
+
"""
|
103 |
+
if file_path and os.path.exists(file_path):
|
104 |
+
try:
|
105 |
+
os.remove(file_path)
|
106 |
+
print(f"Cleaned up temporary file: {file_path}")
|
107 |
+
return True
|
108 |
+
except OSError as e:
|
109 |
+
print(f"Error cleaning up temporary file {file_path}: {e}")
|
110 |
+
return False
|
111 |
+
return False
|
112 |
+
|
113 |
+
|
114 |
+
def validate_image_path(image_path: str) -> bool:
|
115 |
+
"""
|
116 |
+
Validate if image path exists and is a valid image file.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
image_path: Path to validate
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
True if valid image path, False otherwise
|
123 |
+
"""
|
124 |
+
if not image_path or not os.path.exists(image_path):
|
125 |
+
return False
|
126 |
+
|
127 |
+
try:
|
128 |
+
with Image.open(image_path) as img:
|
129 |
+
img.verify()
|
130 |
+
return True
|
131 |
+
except Exception:
|
132 |
+
return False
|
133 |
+
|
134 |
+
|
135 |
+
def get_image_info(image_path: str) -> Optional[Dict[str, Any]]:
|
136 |
+
"""
|
137 |
+
Get basic information about an image file.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
image_path: Path to image file
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
Dictionary with image info or None if error
|
144 |
+
"""
|
145 |
+
try:
|
146 |
+
with Image.open(image_path) as img:
|
147 |
+
return {
|
148 |
+
"size": img.size,
|
149 |
+
"mode": img.mode,
|
150 |
+
"format": img.format,
|
151 |
+
"file_size_bytes": os.path.getsize(image_path)
|
152 |
+
}
|
153 |
+
except Exception as e:
|
154 |
+
print(f"Error getting image info for {image_path}: {e}")
|
155 |
+
return None
|
156 |
+
|
157 |
+
|
158 |
+
def encode_image_to_base64_url(image_array: np.ndarray, format: str = 'PNG',
|
159 |
+
quality: int = 80) -> Optional[str]:
|
160 |
+
"""
|
161 |
+
Encode a NumPy image array to a base64 data URL.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
image_array: Image as numpy array
|
165 |
+
format: Output format ('PNG' or 'JPEG')
|
166 |
+
quality: JPEG quality (only used for JPEG format)
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
Base64 encoded data URL string, or None if encoding fails
|
170 |
+
"""
|
171 |
+
if image_array is None:
|
172 |
+
return None
|
173 |
+
|
174 |
+
try:
|
175 |
+
pil_image = Image.fromarray(image_array)
|
176 |
+
if pil_image.mode == 'RGBA':
|
177 |
+
background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
178 |
+
background.paste(pil_image, mask=pil_image.split()[3])
|
179 |
+
pil_image = background
|
180 |
+
elif pil_image.mode != 'RGB':
|
181 |
+
pil_image = pil_image.convert('RGB')
|
182 |
+
|
183 |
+
byte_arr = io.BytesIO()
|
184 |
+
if format.lower() == 'jpeg':
|
185 |
+
pil_image.save(byte_arr, format='JPEG', quality=quality)
|
186 |
+
elif format.lower() == 'png':
|
187 |
+
pil_image.save(byte_arr, format='PNG')
|
188 |
+
else:
|
189 |
+
raise ValueError(f"Unsupported format: {format}. Choose 'jpeg' or 'png'.")
|
190 |
+
|
191 |
+
encoded_string = base64.b64encode(byte_arr.getvalue()).decode('utf-8')
|
192 |
+
return f"data:image/{format.lower()};base64,{encoded_string}"
|
193 |
+
except Exception as e:
|
194 |
+
print(f"Error encoding image to base64: {e}")
|
195 |
+
return None
|
196 |
+
|
197 |
+
|
198 |
+
def convert_pil_image_to_bytes(image: Image.Image) -> bytes:
|
199 |
+
"""
|
200 |
+
Convert a PIL Image to bytes in PNG format.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
image: PIL Image object
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
Image bytes in PNG format
|
207 |
+
"""
|
208 |
+
img_byte_arr = io.BytesIO()
|
209 |
+
|
210 |
+
if image.mode != 'RGB':
|
211 |
+
image = image.convert('RGB')
|
212 |
+
image.save(img_byte_arr, format='PNG')
|
213 |
+
img_bytes = img_byte_arr.getvalue()
|
214 |
+
|
215 |
+
return img_bytes
|
216 |
+
|
217 |
+
|
218 |
+
def encode_pil_image_to_base64_url(image: Image.Image) -> str:
|
219 |
+
"""
|
220 |
+
Encode a PIL Image directly to a base64 data URL.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
image: PIL Image object
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
Base64 encoded PNG data URL string
|
227 |
+
"""
|
228 |
+
img_bytes = convert_pil_image_to_bytes(image)
|
229 |
+
img_str = base64.b64encode(img_bytes).decode()
|
230 |
+
data_url = f"data:image/png;base64,{img_str}"
|
231 |
+
return data_url
|
232 |
+
|
233 |
+
|
234 |
+
def decode_base64_to_pil_image(base64_data: str) -> Image.Image:
|
235 |
+
"""
|
236 |
+
Decode base64 data to a PIL Image.
|
237 |
+
|
238 |
+
Handles both data URL format and raw base64 strings.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
base64_data: Base64 encoded image data, either as data URL
|
242 |
+
(...) or raw base64 string
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
PIL Image object
|
246 |
+
|
247 |
+
Raises:
|
248 |
+
ValueError: If base64 data is invalid or cannot be decoded
|
249 |
+
"""
|
250 |
+
try:
|
251 |
+
if base64_data.startswith('data:image'):
|
252 |
+
# Extract base64 part after the comma
|
253 |
+
base64_string = base64_data.split(',')[1]
|
254 |
+
else:
|
255 |
+
# Raw base64 data
|
256 |
+
base64_string = base64_data
|
257 |
+
|
258 |
+
image_bytes = base64.b64decode(base64_string)
|
259 |
+
pil_image = Image.open(io.BytesIO(image_bytes))
|
260 |
+
|
261 |
+
return pil_image
|
262 |
+
|
263 |
+
except Exception as e:
|
264 |
+
raise ValueError(f"Failed to decode base64 data to PIL Image: {e}")
|
265 |
+
|
266 |
+
|
267 |
+
def decode_base64_url_to_np_array(image_url: str) -> Optional[np.ndarray]:
|
268 |
+
"""
|
269 |
+
Decode a base64 data URL to a NumPy array.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
image_url: Base64 data URL (...)
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
RGB image as numpy array, or None if decoding fails
|
276 |
+
"""
|
277 |
+
if not image_url.startswith('data:image'):
|
278 |
+
print(f"Invalid data URL format: {image_url[:50]}...")
|
279 |
+
return None
|
280 |
+
|
281 |
+
try:
|
282 |
+
pil_image = decode_base64_to_pil_image(image_url)
|
283 |
+
|
284 |
+
if pil_image.mode != 'RGB':
|
285 |
+
pil_image = pil_image.convert('RGB')
|
286 |
+
|
287 |
+
return np.array(pil_image)
|
288 |
+
|
289 |
+
except ValueError as e:
|
290 |
+
print(f"Error extracting image from data URL: {e}")
|
291 |
+
return None
|
292 |
+
except Exception as e:
|
293 |
+
print(f"Unexpected error processing image URL: {e}")
|
294 |
+
return None
|
295 |
+
|
296 |
+
|
297 |
+
def convert_rgb_to_bgr(image_array: np.ndarray) -> np.ndarray:
|
298 |
+
"""
|
299 |
+
Convert an RGB NumPy image array to BGR format.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
image_array: RGB image as numpy array
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
BGR image as numpy array
|
306 |
+
"""
|
307 |
+
if (image_array.ndim == 3 and image_array.shape[2] == 3 and
|
308 |
+
image_array.dtype == np.uint8):
|
309 |
+
return cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
|
310 |
+
return image_array
|
311 |
+
|
312 |
+
|
313 |
+
def convert_bgr_to_rgb(image_array: np.ndarray) -> np.ndarray:
|
314 |
+
"""
|
315 |
+
Convert a BGR NumPy image array to RGB format.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
image_array: BGR image as numpy array
|
319 |
+
|
320 |
+
Returns:
|
321 |
+
RGB image as numpy array
|
322 |
+
"""
|
323 |
+
if (image_array.ndim == 3 and image_array.shape[2] == 3 and
|
324 |
+
image_array.dtype == np.uint8):
|
325 |
+
return cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)
|
326 |
+
return image_array
|
327 |
+
|
328 |
+
def check_image_resolution_for_deepforest(image_path: str, max_resolution_cm: float = 10.0) -> Dict[str, Any]:
|
329 |
+
"""
|
330 |
+
Resolution check for DeepForest suitability.
|
331 |
+
|
332 |
+
For GeoTIFF files: Check if pixel resolution is <= 10cm
|
333 |
+
For other formats: Allow processing with warning
|
334 |
+
|
335 |
+
Args:
|
336 |
+
image_path: Path to the image file
|
337 |
+
max_resolution_cm: Maximum required resolution in cm/pixel (default: 10.0)
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
Dict containing:
|
341 |
+
- is_suitable: bool - Whether resolution is suitable for DeepForest
|
342 |
+
- resolution_cm: float or None - Actual resolution in cm/pixel
|
343 |
+
- resolution_info: str - Resolution info
|
344 |
+
- is_georeferenced: bool - Whether image is a GeoTIFF
|
345 |
+
- warning: str or None - Warning message if any
|
346 |
+
"""
|
347 |
+
try:
|
348 |
+
with rasterio.open(image_path) as src:
|
349 |
+
if src.crs is None:
|
350 |
+
return _non_geotiff_result(image_path, "No coordinate system found")
|
351 |
+
if src.crs.is_geographic:
|
352 |
+
return _non_geotiff_result(image_path, "Geographic coordinates detected")
|
353 |
+
transform = src.transform
|
354 |
+
if transform.is_identity:
|
355 |
+
return _non_geotiff_result(image_path, "No spatial transformation found")
|
356 |
+
|
357 |
+
# Calculate pixel size
|
358 |
+
pixel_width = abs(transform.a)
|
359 |
+
pixel_height = abs(transform.e)
|
360 |
+
pixel_size = max(pixel_width, pixel_height)
|
361 |
+
|
362 |
+
# Convert to centimeters based on CRS units
|
363 |
+
crs_units = src.crs.to_dict().get('units', '').lower()
|
364 |
+
|
365 |
+
if crs_units in ['m', 'metre', 'meter']:
|
366 |
+
resolution_cm = pixel_size * 100
|
367 |
+
elif 'foot' in crs_units or crs_units == 'ft':
|
368 |
+
resolution_cm = pixel_size * 30.48
|
369 |
+
else:
|
370 |
+
return {
|
371 |
+
"is_suitable": True,
|
372 |
+
"resolution_cm": None,
|
373 |
+
"resolution_info": f"Unknown units '{crs_units}' - proceeding optimistically",
|
374 |
+
"is_georeferenced": True,
|
375 |
+
"warning": f"Cannot determine pixel size units: {crs_units}"
|
376 |
+
}
|
377 |
+
|
378 |
+
is_suitable = resolution_cm <= max_resolution_cm
|
379 |
+
|
380 |
+
return {
|
381 |
+
"is_suitable": is_suitable,
|
382 |
+
"resolution_cm": resolution_cm,
|
383 |
+
"resolution_info": f"{resolution_cm:.1f} cm/pixel ({'suitable' if is_suitable else 'insufficient'} for DeepForest)",
|
384 |
+
"is_georeferenced": True,
|
385 |
+
"warning": None if is_suitable else f"Resolution {resolution_cm:.1f} cm/pixel exceeds {max_resolution_cm} cm/pixel threshold"
|
386 |
+
}
|
387 |
+
|
388 |
+
except rasterio.RasterioIOError:
|
389 |
+
return _non_geotiff_result(image_path, "Not a GeoTIFF file")
|
390 |
+
except Exception as e:
|
391 |
+
return _non_geotiff_result(image_path, f"Error reading file: {str(e)}")
|
392 |
+
|
393 |
+
|
394 |
+
def _non_geotiff_result(image_path: str, reason: str) -> Dict[str, Any]:
|
395 |
+
"""
|
396 |
+
Helper function for non-GeoTIFF images to allow processing with warning.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
image_path: Path to the image file
|
400 |
+
reason: Reason why it's not treated as GeoTIFF
|
401 |
+
|
402 |
+
Returns:
|
403 |
+
Dict with suitable=True but warning about using GeoTIFF
|
404 |
+
"""
|
405 |
+
file_ext = os.path.splitext(image_path)[1].lower()
|
406 |
+
|
407 |
+
return {
|
408 |
+
"is_suitable": True,
|
409 |
+
"resolution_cm": None,
|
410 |
+
"resolution_info": f"Non-geospatial image ({file_ext}) - proceeding without resolution check",
|
411 |
+
"is_georeferenced": False,
|
412 |
+
"warning": f"For optimal DeepForest results, use GeoTIFF images with ≤10 cm/pixel resolution. Current: {reason.lower()}"
|
413 |
+
}
|
414 |
+
|
415 |
+
def determine_patch_size(image_file_path: str, image_dimensions: Optional[Tuple[int, int]] = None) -> int:
|
416 |
+
"""
|
417 |
+
Determine patch size based on image file type and dimensions for OOM fallback strategy.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
image_file_path: Path to the image file
|
421 |
+
image_dimensions: Optional tuple of (width, height) if known
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
int: Patch size optimized for image type and size
|
425 |
+
"""
|
426 |
+
# Get image dimensions if not provided
|
427 |
+
if image_dimensions is None:
|
428 |
+
try:
|
429 |
+
with Image.open(image_file_path) as img:
|
430 |
+
width, height = img.size
|
431 |
+
except Exception:
|
432 |
+
return Config.DEEPFOREST_DEFAULTS["patch_size"]
|
433 |
+
else:
|
434 |
+
width, height = image_dimensions
|
435 |
+
|
436 |
+
# Determine maximum dimension
|
437 |
+
max_dimension = max(width, height)
|
438 |
+
|
439 |
+
# For large dimensions, use larger patch sizes to handle OOM
|
440 |
+
if max_dimension > 7500:
|
441 |
+
return 2000
|
442 |
+
else:
|
443 |
+
return 1500
|
444 |
+
|
445 |
+
def get_image_dimensions_fast(image_path: str) -> Optional[Tuple[int, int]]:
|
446 |
+
"""
|
447 |
+
Get image dimensions quickly without loading full image into memory.
|
448 |
+
|
449 |
+
Args:
|
450 |
+
image_path: Path to image file
|
451 |
+
|
452 |
+
Returns:
|
453 |
+
Tuple of (width, height) or None if cannot determine
|
454 |
+
"""
|
455 |
+
try:
|
456 |
+
# Try with PIL first
|
457 |
+
with Image.open(image_path) as img:
|
458 |
+
return img.size
|
459 |
+
except Exception:
|
460 |
+
try:
|
461 |
+
# Fallback to rasterio for GeoTIFF files
|
462 |
+
with rasterio.open(image_path) as src:
|
463 |
+
return (src.width, src.height)
|
464 |
+
except Exception:
|
465 |
+
return None
|
src/deepforest_agent/utils/logging_utils.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from datetime import datetime, timezone
|
4 |
+
from typing import Dict, Any, Optional, List
|
5 |
+
from pathlib import Path
|
6 |
+
import threading
|
7 |
+
import json as json_module
|
8 |
+
|
9 |
+
|
10 |
+
class MultiAgentLogger:
|
11 |
+
"""
|
12 |
+
Logging system for conversation-style logs.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, logs_dir: str = "logs"):
|
16 |
+
"""
|
17 |
+
Initialize the multi-agent logger.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
logs_dir: Directory to store log files
|
21 |
+
"""
|
22 |
+
self.logs_dir = Path(logs_dir)
|
23 |
+
self.logs_dir.mkdir(exist_ok=True)
|
24 |
+
self._lock = threading.Lock()
|
25 |
+
|
26 |
+
print(f"Logging initialized. Logs directory: {self.logs_dir.absolute()}")
|
27 |
+
|
28 |
+
def _get_log_file_path(self, session_id: str) -> Path:
|
29 |
+
"""
|
30 |
+
Get the log file path for a specific session.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
session_id: Unique session identifier
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Path object for the session's log file
|
37 |
+
"""
|
38 |
+
date_str = datetime.now().strftime("%Y%m%d")
|
39 |
+
filename = f"session_{session_id}_{date_str}.log"
|
40 |
+
return self.logs_dir / filename
|
41 |
+
|
42 |
+
def _write_log_entry(self, session_id: str, agent_name: str, content: str) -> None:
|
43 |
+
"""
|
44 |
+
Write a log entry to the session's log file.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
session_id: Session identifier
|
48 |
+
agent_name: Current agent in the process
|
49 |
+
content: Current agent response
|
50 |
+
"""
|
51 |
+
with self._lock:
|
52 |
+
log_file_path = self._get_log_file_path(session_id)
|
53 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
54 |
+
|
55 |
+
try:
|
56 |
+
with open(log_file_path, 'a', encoding='utf-8') as f:
|
57 |
+
if agent_name == "SESSION_START":
|
58 |
+
f.write(f"=== SESSION {session_id} STARTED ===\n\n")
|
59 |
+
elif agent_name == "SESSION_EVENT":
|
60 |
+
f.write(f"{timestamp} - {content}\n\n")
|
61 |
+
else:
|
62 |
+
f.write(f"{timestamp} - {agent_name}: {content}\n\n")
|
63 |
+
f.flush()
|
64 |
+
except Exception as e:
|
65 |
+
print(f"Error writing to log file {log_file_path}: {e}")
|
66 |
+
|
67 |
+
def log_session_event(self, session_id: str, event_type: str, details: Optional[Dict[str, Any]] = None) -> None:
|
68 |
+
"""
|
69 |
+
Log session lifecycle events (creation, image upload, clearing, etc.).
|
70 |
+
|
71 |
+
Args:
|
72 |
+
session_id: Session identifier
|
73 |
+
event_type: Type of session event
|
74 |
+
details: Additional event details
|
75 |
+
"""
|
76 |
+
if event_type == "session_created":
|
77 |
+
self._write_log_entry(session_id, "SESSION_START", "")
|
78 |
+
if details:
|
79 |
+
image_size = details.get("image_size", "unknown")
|
80 |
+
image_mode = details.get("image_mode", "unknown")
|
81 |
+
self._write_log_entry(session_id, "SESSION_EVENT", f"Image uploaded: {image_size}, mode: {image_mode}")
|
82 |
+
else:
|
83 |
+
self._write_log_entry(session_id, "SESSION_EVENT", "Image uploaded: unknown")
|
84 |
+
elif event_type == "conversation_cleared":
|
85 |
+
self._write_log_entry(session_id, "SESSION_EVENT", "Conversation cleared")
|
86 |
+
elif event_type == "multi_agent_workflow_started":
|
87 |
+
self._write_log_entry(session_id, "SESSION_EVENT", "Multi-agent workflow started")
|
88 |
+
|
89 |
+
def log_user_query(self, session_id: str, user_message: str, message_context: Optional[Dict[str, Any]] = None) -> None:
|
90 |
+
"""
|
91 |
+
Log user queries and context.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
session_id: Session identifier
|
95 |
+
user_message: User's input message
|
96 |
+
message_context: Additional context (conversation length, etc.)
|
97 |
+
"""
|
98 |
+
self._write_log_entry(session_id, "USER", user_message)
|
99 |
+
|
100 |
+
def log_agent_execution(
|
101 |
+
self,
|
102 |
+
session_id: str,
|
103 |
+
agent_name: str,
|
104 |
+
agent_input: str,
|
105 |
+
agent_output: str,
|
106 |
+
execution_time: float,
|
107 |
+
additional_data: Optional[Dict[str, Any]] = None
|
108 |
+
) -> None:
|
109 |
+
"""
|
110 |
+
Log individual agent execution details.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
session_id: Session identifier
|
114 |
+
agent_name: Name of the agent (memory, detector, visual, ecology)
|
115 |
+
agent_input: Input provided to the agent
|
116 |
+
agent_output: Output generated by the agent
|
117 |
+
execution_time: Time taken for agent execution in seconds
|
118 |
+
additional_data: Agent-specific additional data
|
119 |
+
"""
|
120 |
+
|
121 |
+
if agent_name == "memory":
|
122 |
+
formatted_name = "Memory Agent"
|
123 |
+
elif agent_name == "detector":
|
124 |
+
formatted_name = "DeepForest Detector Agent"
|
125 |
+
elif agent_name == "visual":
|
126 |
+
formatted_name = "Visual Agent"
|
127 |
+
elif agent_name == "ecology":
|
128 |
+
formatted_name = "Ecology Agent"
|
129 |
+
else:
|
130 |
+
formatted_name = agent_name.title()
|
131 |
+
|
132 |
+
formatted_name_with_time = f"{formatted_name} ({execution_time:.2f}s)"
|
133 |
+
|
134 |
+
content = agent_output
|
135 |
+
self._write_log_entry(session_id, formatted_name_with_time, content)
|
136 |
+
|
137 |
+
def log_tool_call(
|
138 |
+
self,
|
139 |
+
session_id: str,
|
140 |
+
tool_name: str,
|
141 |
+
tool_arguments: Dict[str, Any],
|
142 |
+
tool_result: Dict[str, Any],
|
143 |
+
execution_time: float,
|
144 |
+
cache_hit: bool,
|
145 |
+
reasoning: Optional[str] = None
|
146 |
+
) -> None:
|
147 |
+
"""
|
148 |
+
Log tool calls, their results, and cache information.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
session_id: Session identifier
|
152 |
+
tool_name: Name of the tool that was called
|
153 |
+
tool_arguments: Arguments passed to the tool
|
154 |
+
tool_result: Result returned by the tool
|
155 |
+
execution_time: Time taken for tool execution
|
156 |
+
cache_hit: Whether this was served from cache
|
157 |
+
reasoning: AI's reasoning for this tool call
|
158 |
+
"""
|
159 |
+
if cache_hit:
|
160 |
+
status = "Cache Hit (0.00s)"
|
161 |
+
else:
|
162 |
+
status = f"Cache Miss - Executed DeepForest detection ({execution_time:.2f}s)"
|
163 |
+
|
164 |
+
content = f"{status}\n"
|
165 |
+
content += f"Detection Summary: {tool_result.get('detection_summary', 'No summary')}\n"
|
166 |
+
|
167 |
+
detections = tool_result.get('detections_list', [])
|
168 |
+
if detections:
|
169 |
+
content += f"Detection Data: {detections}"
|
170 |
+
|
171 |
+
self._write_log_entry(session_id, "DeepForest Function execution", content)
|
172 |
+
|
173 |
+
def log_error(self, session_id: str, error_type: str, error_message: str, context: Optional[Dict[str, Any]] = None) -> None:
|
174 |
+
"""
|
175 |
+
Log errors in simple format.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
session_id: Session identifier
|
179 |
+
error_type: Type/category of error
|
180 |
+
error_message: Error message
|
181 |
+
context: Additional context about where the error occurred
|
182 |
+
"""
|
183 |
+
self._write_log_entry(session_id, "ERROR", f"{error_type}: {error_message}")
|
184 |
+
|
185 |
+
def log_resolution_check(
|
186 |
+
self,
|
187 |
+
session_id: str,
|
188 |
+
image_file_path: str,
|
189 |
+
resolution_result: Dict[str, Any],
|
190 |
+
execution_time: float
|
191 |
+
) -> None:
|
192 |
+
"""
|
193 |
+
Log image resolution check results.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
session_id: Session identifier
|
197 |
+
image_file_path: Path to the image that was checked
|
198 |
+
resolution_result: Results from simplified resolution check
|
199 |
+
execution_time: Time taken for resolution check
|
200 |
+
"""
|
201 |
+
is_suitable = resolution_result.get("is_suitable", True)
|
202 |
+
resolution_info = resolution_result.get("resolution_info", "No resolution info")
|
203 |
+
is_georeferenced = resolution_result.get("is_georeferenced", False)
|
204 |
+
resolution_cm = resolution_result.get("resolution_cm")
|
205 |
+
warning = resolution_result.get("warning")
|
206 |
+
|
207 |
+
content = f"Image Resolution Check ({execution_time:.3f}s)\n"
|
208 |
+
content += f"File: {image_file_path}\n"
|
209 |
+
content += f"Result: {'Suitable' if is_suitable else 'Insufficient'} for DeepForest\n"
|
210 |
+
content += f"Details: {resolution_info}\n"
|
211 |
+
content += f"Type: {'GeoTIFF' if is_georeferenced else 'Regular image'}\n"
|
212 |
+
|
213 |
+
if resolution_cm is not None:
|
214 |
+
content += f"Resolution: {resolution_cm:.2f} cm/pixel\n"
|
215 |
+
|
216 |
+
if warning:
|
217 |
+
content += f"Warning: {warning}\n"
|
218 |
+
|
219 |
+
if not is_suitable:
|
220 |
+
content += "Impact: DeepForest detection will be skipped due to insufficient resolution"
|
221 |
+
elif warning:
|
222 |
+
content += "Impact: DeepForest detection will proceed with noted warning"
|
223 |
+
else:
|
224 |
+
content += "Impact: Resolution suitable for DeepForest detection"
|
225 |
+
|
226 |
+
self._write_log_entry(session_id, "Resolution Check", content)
|
227 |
+
|
228 |
+
def log_deepforest_skip(
|
229 |
+
self,
|
230 |
+
session_id: str,
|
231 |
+
skip_reasons: List[str],
|
232 |
+
resolution_result: Optional[Dict[str, Any]] = None,
|
233 |
+
visual_result: Optional[Dict[str, Any]] = None
|
234 |
+
) -> None:
|
235 |
+
"""
|
236 |
+
Log when DeepForest detection is skipped and why.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
session_id: Session identifier
|
240 |
+
skip_reasons: List of reasons why DeepForest was skipped
|
241 |
+
resolution_result: Resolution check results (optional)
|
242 |
+
visual_result: Visual analysis results (optional)
|
243 |
+
"""
|
244 |
+
content = "DeepForest Detection Skipped\n"
|
245 |
+
content += f"Reasons: {', '.join(skip_reasons)}\n"
|
246 |
+
|
247 |
+
# Add detailed reason breakdown
|
248 |
+
if "insufficient resolution" in ' '.join(skip_reasons).lower():
|
249 |
+
if resolution_result:
|
250 |
+
resolution_info = resolution_result.get("resolution_info", "No details")
|
251 |
+
content += f"Resolution Details: {resolution_info}\n"
|
252 |
+
|
253 |
+
if "poor image quality" in ' '.join(skip_reasons).lower():
|
254 |
+
if visual_result:
|
255 |
+
quality_assessment = visual_result.get("image_quality_for_deepforest", "Unknown")
|
256 |
+
content += f"Visual Quality Assessment: {quality_assessment}\n"
|
257 |
+
|
258 |
+
content += "Impact: Analysis will rely on visual analysis only"
|
259 |
+
|
260 |
+
self._write_log_entry(session_id, "DeepForest Skip Decision", content)
|
261 |
+
|
262 |
+
def log_tile_analysis(self, session_id: str, tile_id: int, result: Dict[str, Any], execution_time: float) -> None:
|
263 |
+
"""
|
264 |
+
Log individual tile analysis results.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
session_id: Session identifier
|
268 |
+
tile_id: Tile identifier
|
269 |
+
result: Tile analysis result
|
270 |
+
execution_time: Time taken for tile analysis
|
271 |
+
"""
|
272 |
+
content = f"Tile {tile_id} Analysis ({execution_time:.2f}s)\n"
|
273 |
+
|
274 |
+
coordinates = result.get('coordinates', {})
|
275 |
+
content += f"Coordinates: x={coordinates.get('x', 0)}, y={coordinates.get('y', 0)}, "
|
276 |
+
content += f"width={coordinates.get('width', 0)}, height={coordinates.get('height', 0)}\n"
|
277 |
+
|
278 |
+
additional_objects = result.get('additional_objects', [])
|
279 |
+
if additional_objects:
|
280 |
+
content += f"Additional Objects: {len(additional_objects)} objects detected\n"
|
281 |
+
for obj in additional_objects:
|
282 |
+
label = obj.get('label', 'unknown')
|
283 |
+
bbox = obj.get('bbox', 'no coordinates')
|
284 |
+
content += f" - {label} at {bbox}\n"
|
285 |
+
else:
|
286 |
+
content += f"Additional Objects: None detected\n"
|
287 |
+
|
288 |
+
visual_analysis = result.get('visual_analysis', '')
|
289 |
+
if visual_analysis:
|
290 |
+
content += f"Visual Analysis: {visual_analysis}\n"
|
291 |
+
|
292 |
+
assigned_detections = result.get('assigned_detections', [])
|
293 |
+
content += f"Assigned DeepForest Detections: {len(assigned_detections)}\n"
|
294 |
+
|
295 |
+
if 'error' in result:
|
296 |
+
content += f"Error: {result['error']}\n"
|
297 |
+
|
298 |
+
self._write_log_entry(session_id, f"Tile {tile_id} Analysis", content)
|
299 |
+
|
300 |
+
def log_spatial_relationships(
|
301 |
+
self,
|
302 |
+
session_id: str,
|
303 |
+
spatial_relationships: List[Dict[str, Any]],
|
304 |
+
execution_time: float
|
305 |
+
) -> None:
|
306 |
+
"""Log spatial relationships analysis results.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
session_id: The unique identifier for the current session.
|
310 |
+
spatial_relationships: A list of dictionaries, where each
|
311 |
+
dictionary contains details about an object's spatial
|
312 |
+
relationships, including its grid region and intersecting
|
313 |
+
objects.
|
314 |
+
execution_time: The time taken to perform the spatial
|
315 |
+
relationships analysis, in seconds.
|
316 |
+
"""
|
317 |
+
relationships_count = len(spatial_relationships)
|
318 |
+
content = f"Spatial Relationships Analysis ({execution_time:.3f}s)\n"
|
319 |
+
content += f"Analyzed {relationships_count} objects with confidence ≥ 0.3\n"
|
320 |
+
|
321 |
+
# Group by regions
|
322 |
+
by_region = {}
|
323 |
+
for rel in spatial_relationships:
|
324 |
+
region = rel['grid_region']
|
325 |
+
by_region[region] = by_region.get(region, 0) + 1
|
326 |
+
|
327 |
+
content += f"Distribution by region: {dict(by_region)}\n"
|
328 |
+
content += f"Objects with neighbors: {sum(1 for r in spatial_relationships if r['intersecting_objects'])}\n"
|
329 |
+
|
330 |
+
self._write_log_entry(session_id, "Spatial Relationships Analysis", content)
|
331 |
+
|
332 |
+
def log_detection_narrative(
|
333 |
+
self,
|
334 |
+
session_id: str,
|
335 |
+
detection_narrative: str,
|
336 |
+
detections_count: int,
|
337 |
+
execution_time: float
|
338 |
+
) -> None:
|
339 |
+
"""Log detection narrative generation.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
session_id: The unique identifier for the current session.
|
343 |
+
detection_narrative: The string containing the generated narrative.
|
344 |
+
detections_count: The total number of detections used to
|
345 |
+
generate the narrative.
|
346 |
+
execution_time: The time taken for narrative generation, in seconds.
|
347 |
+
"""
|
348 |
+
narrative_length = len(detection_narrative)
|
349 |
+
content = f"Detection Narrative Generation ({execution_time:.3f}s)\n"
|
350 |
+
content += f"Generated narrative for {detections_count} detections\n"
|
351 |
+
content += f"Narrative length: {narrative_length} characters\n"
|
352 |
+
content += f"Narrative content:\n{detection_narrative}"
|
353 |
+
|
354 |
+
self._write_log_entry(session_id, "Detection Narrative", content)
|
355 |
+
|
356 |
+
def log_visual_analysis_unified(
|
357 |
+
self,
|
358 |
+
session_id: str,
|
359 |
+
analysis_type: str,
|
360 |
+
visual_analysis: str,
|
361 |
+
additional_objects_count: int,
|
362 |
+
execution_time: float
|
363 |
+
) -> None:
|
364 |
+
"""Log unified visual analysis results.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
session_id: The unique identifier for the current session.
|
368 |
+
analysis_type: A string specifying the type of visual analysis
|
369 |
+
performed (e.g., 'segmentation', 'classification').
|
370 |
+
visual_analysis: The string containing the final analysis result.
|
371 |
+
additional_objects_count: The number of objects detected beyond
|
372 |
+
the initial set.
|
373 |
+
execution_time: The time taken for the visual analysis, in seconds.
|
374 |
+
"""
|
375 |
+
content = f"Visual Analysis - {analysis_type} ({execution_time:.3f}s)\n"
|
376 |
+
content += f"Additional objects detected: {additional_objects_count}\n"
|
377 |
+
content += f"Analysis: {visual_analysis}"
|
378 |
+
|
379 |
+
self._write_log_entry(session_id, f"Visual Analysis ({analysis_type})", content)
|
380 |
+
|
381 |
+
def get_session_log_summary(self, session_id: str) -> Dict[str, Any]:
|
382 |
+
"""
|
383 |
+
Get a summary of all logged events for a session.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
session_id: Session identifier
|
387 |
+
|
388 |
+
Returns:
|
389 |
+
Dictionary containing session log summary
|
390 |
+
"""
|
391 |
+
log_file_path = self._get_log_file_path(session_id)
|
392 |
+
|
393 |
+
if not log_file_path.exists():
|
394 |
+
return {"error": f"No log file found for session {session_id}"}
|
395 |
+
|
396 |
+
try:
|
397 |
+
with open(log_file_path, 'r', encoding='utf-8') as f:
|
398 |
+
content = f.read()
|
399 |
+
|
400 |
+
return {
|
401 |
+
"session_id": session_id,
|
402 |
+
"log_file": str(log_file_path),
|
403 |
+
"content_preview": content
|
404 |
+
}
|
405 |
+
except Exception as e:
|
406 |
+
return {"error": f"Error reading log file: {str(e)}"}
|
407 |
+
|
408 |
+
def get_all_session_logs(self) -> List[str]:
|
409 |
+
"""
|
410 |
+
Get a list of all session IDs that have log files.
|
411 |
+
|
412 |
+
Returns:
|
413 |
+
List of session IDs with existing log files
|
414 |
+
"""
|
415 |
+
session_ids = []
|
416 |
+
|
417 |
+
for log_file in self.logs_dir.glob("session_*.log"):
|
418 |
+
filename = log_file.stem
|
419 |
+
parts = filename.split("_")
|
420 |
+
if len(parts) >= 2:
|
421 |
+
session_id = parts[1]
|
422 |
+
session_ids.append(session_id)
|
423 |
+
|
424 |
+
return sorted(set(session_ids))
|
425 |
+
|
426 |
+
def cleanup_old_logs(self, days_to_keep: int = 7) -> int:
|
427 |
+
"""
|
428 |
+
Clean up log files older than specified days.
|
429 |
+
|
430 |
+
Args:
|
431 |
+
days_to_keep: Number of days of logs to retain
|
432 |
+
|
433 |
+
Returns:
|
434 |
+
Number of log files deleted
|
435 |
+
"""
|
436 |
+
cutoff_time = time.time() - (days_to_keep * 24 * 60 * 60)
|
437 |
+
deleted_count = 0
|
438 |
+
|
439 |
+
for log_file in self.logs_dir.glob("session_*.log"):
|
440 |
+
if log_file.stat().st_mtime < cutoff_time:
|
441 |
+
try:
|
442 |
+
log_file.unlink()
|
443 |
+
deleted_count += 1
|
444 |
+
except Exception as e:
|
445 |
+
print(f"Error deleting old log file {log_file}: {e}")
|
446 |
+
|
447 |
+
return deleted_count
|
448 |
+
|
449 |
+
multi_agent_logger = MultiAgentLogger()
|
src/deepforest_agent/utils/parsing_utils.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from typing import Dict, List, Any, Optional
|
4 |
+
|
5 |
+
|
6 |
+
def parse_image_quality_for_deepforest(response: str) -> str:
|
7 |
+
"""
|
8 |
+
Parse IMAGE_QUALITY_FOR_DEEPFOREST from response.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
response: Model response text
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
"Yes" or "No"
|
15 |
+
"""
|
16 |
+
quality_match = re.search(r'(?:\*\*)?IMAGE_QUALITY_FOR_DEEPFOREST[:\*\s]+\[?(YES|NO|Yes|No|yes|no)\]?', response, re.IGNORECASE)
|
17 |
+
if quality_match:
|
18 |
+
quality_value = quality_match.group(1).upper()
|
19 |
+
return "Yes" if quality_value == "YES" else "No"
|
20 |
+
return "No"
|
21 |
+
|
22 |
+
def parse_deepforest_objects_present(response: str) -> List[str]:
|
23 |
+
"""
|
24 |
+
Parse DEEPFOREST_OBJECTS_PRESENT from response.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
response: Model response text
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
List of objects present
|
31 |
+
"""
|
32 |
+
objects_match = re.search(r'(?:\*\*)?DEEPFOREST_OBJECTS_PRESENT[:\*\s]+(\[.*?\])', response, re.DOTALL)
|
33 |
+
if objects_match:
|
34 |
+
try:
|
35 |
+
objects_str = objects_match.group(1)
|
36 |
+
objects_str = re.sub(r'[`\'"]', '"', objects_str)
|
37 |
+
objects_list = json.loads(objects_str)
|
38 |
+
|
39 |
+
allowed_objects = ["bird", "tree", "livestock"]
|
40 |
+
validated_objects = [obj for obj in objects_list if obj in allowed_objects]
|
41 |
+
return validated_objects
|
42 |
+
except json.JSONDecodeError:
|
43 |
+
objects_str = objects_match.group(1)
|
44 |
+
manual_objects = re.findall(r'"(bird|tree|livestock)"', objects_str)
|
45 |
+
return list(set(manual_objects))
|
46 |
+
return []
|
47 |
+
|
48 |
+
|
49 |
+
def parse_additional_objects_json(response: str) -> List[Dict[str, Any]]:
|
50 |
+
"""
|
51 |
+
Parse ADDITIONAL_OBJECTS_JSON from response.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
response: Model response text
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
List of additional objects with coordinates
|
58 |
+
"""
|
59 |
+
additional_match = re.search(r'(?:\*\*)?ADDITIONAL_OBJECTS_JSON[:\*\s]+(.*?)(?=\n(?:\*\*)?(?:VISUAL_ANALYSIS|IMAGE_QUALITY|DEEPFOREST_OBJECTS)|$)', response, re.DOTALL)
|
60 |
+
if additional_match:
|
61 |
+
try:
|
62 |
+
additional_str = additional_match.group(1).strip()
|
63 |
+
if additional_str.startswith('```json'):
|
64 |
+
additional_str = additional_str[7:]
|
65 |
+
if additional_str.startswith('```'):
|
66 |
+
additional_str = additional_str[3:]
|
67 |
+
if additional_str.endswith('```'):
|
68 |
+
additional_str = additional_str[:-3]
|
69 |
+
|
70 |
+
additional_str = additional_str.strip()
|
71 |
+
|
72 |
+
if additional_str.startswith('[') and additional_str.endswith(']'):
|
73 |
+
additional_objects = json.loads(additional_str)
|
74 |
+
if isinstance(additional_objects, list):
|
75 |
+
return additional_objects
|
76 |
+
else:
|
77 |
+
additional_objects = []
|
78 |
+
for line in additional_str.split('\n'):
|
79 |
+
line = line.strip().rstrip(',')
|
80 |
+
if line and line.startswith('{') and line.endswith('}'):
|
81 |
+
try:
|
82 |
+
obj = json.loads(line)
|
83 |
+
additional_objects.append(obj)
|
84 |
+
except json.JSONDecodeError:
|
85 |
+
continue
|
86 |
+
return additional_objects
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error parsing additional objects JSON: {e}")
|
90 |
+
return []
|
91 |
+
|
92 |
+
|
93 |
+
def parse_visual_analysis(response: str) -> str:
|
94 |
+
"""
|
95 |
+
Parse VISUAL_ANALYSIS from response.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
response: Model response text
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
Visual analysis text
|
102 |
+
"""
|
103 |
+
analysis_match = re.search(r'(?:\*\*)?VISUAL_ANALYSIS[:\*\s]+(.*?)(?=\n(?:\*\*)?(?:IMAGE_QUALITY|DEEPFOREST_OBJECTS|ADDITIONAL_OBJECTS)|$)', response, re.IGNORECASE | re.DOTALL)
|
104 |
+
if analysis_match:
|
105 |
+
return analysis_match.group(1).strip()
|
106 |
+
else:
|
107 |
+
fallback_match = re.search(r'(?:\*\*)?VISUAL_ANALYSIS[:\*\s]+(.*)', response, re.IGNORECASE | re.DOTALL)
|
108 |
+
if fallback_match:
|
109 |
+
return fallback_match.group(1).strip()
|
110 |
+
return response
|
111 |
+
|
112 |
+
|
113 |
+
def parse_deepforest_agent_response_with_reasoning(response: str) -> Dict[str, Any]:
|
114 |
+
"""
|
115 |
+
Parse DeepForest detector agent response with reasoning.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
response: Model response text
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Dictionary with reasoning and tool calls
|
122 |
+
"""
|
123 |
+
from deepforest_agent.tools.tool_handler import extract_all_tool_calls
|
124 |
+
|
125 |
+
try:
|
126 |
+
tool_calls = extract_all_tool_calls(response)
|
127 |
+
|
128 |
+
if not tool_calls:
|
129 |
+
return {"error": "No valid tool calls found in response"}
|
130 |
+
|
131 |
+
reasoning_text = ""
|
132 |
+
first_json_match = re.search(r'\{[^}]*"name"[^}]*"arguments"[^}]*\}', response)
|
133 |
+
|
134 |
+
if first_json_match:
|
135 |
+
reasoning_text = response[:first_json_match.start()].strip()
|
136 |
+
reasoning_text = re.sub(r'^(REASONING:|Reasoning:|Analysis:|\*\*REASONING:\*\*)', '', reasoning_text).strip()
|
137 |
+
|
138 |
+
if not reasoning_text:
|
139 |
+
reasoning_text = "Tool calls generated based on analysis"
|
140 |
+
|
141 |
+
return {
|
142 |
+
"reasoning": reasoning_text,
|
143 |
+
"tool_calls": tool_calls
|
144 |
+
}
|
145 |
+
|
146 |
+
except Exception as e:
|
147 |
+
return {"error": f"Unexpected error parsing response: {str(e)}"}
|
148 |
+
|
149 |
+
def parse_memory_agent_response(response: str) -> Dict[str, Any]:
|
150 |
+
"""
|
151 |
+
Parse memory agent structured response format with new TOOL_CACHE_ID field.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
response: Model response text
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
Dictionary with answer_present, direct_answer, tool_cache_id, and relevant_context
|
158 |
+
"""
|
159 |
+
try:
|
160 |
+
# Parse ANSWER_PRESENT
|
161 |
+
answer_present_match = re.search(r'(?:\*\*)?ANSWER_PRESENT:(?:\*\*)?\s*\[?(YES|NO)\]?', response, re.IGNORECASE)
|
162 |
+
answer_present = False
|
163 |
+
if answer_present_match:
|
164 |
+
answer_present = answer_present_match.group(1).upper() == "YES"
|
165 |
+
|
166 |
+
# Parse TOOL_CACHE_ID
|
167 |
+
tool_cache_id_match = re.search(r'(?:\*\*)?TOOL_CACHE_ID:(?:\*\*)?\s*(.*?)(?=\n(?:\*\*)?(?:RELEVANT_CONTEXT|$))', response, re.IGNORECASE | re.DOTALL)
|
168 |
+
tool_cache_id = None
|
169 |
+
|
170 |
+
if tool_cache_id_match:
|
171 |
+
tool_cache_id_text = tool_cache_id_match.group(1).strip()
|
172 |
+
|
173 |
+
# Extract all cache IDs using multiple patterns
|
174 |
+
cache_ids = []
|
175 |
+
|
176 |
+
# Pattern 1: IDs within brackets [id1, id2, ...]
|
177 |
+
bracket_pattern = r'\[([^\[\]]*)\]'
|
178 |
+
bracket_matches = re.findall(bracket_pattern, tool_cache_id_text)
|
179 |
+
for bracket_content in bracket_matches:
|
180 |
+
if bracket_content.strip(): # Skip empty brackets
|
181 |
+
# Extract hex IDs from bracket content
|
182 |
+
hex_ids = re.findall(r'([a-fA-F0-9]{8,})', bracket_content)
|
183 |
+
cache_ids.extend(hex_ids)
|
184 |
+
|
185 |
+
# Pattern 2: Direct hex IDs (not in brackets)
|
186 |
+
# Remove bracketed content first, then find remaining hex IDs
|
187 |
+
text_without_brackets = re.sub(r'\[[^\[\]]*\]', '', tool_cache_id_text)
|
188 |
+
direct_hex_ids = re.findall(r'([a-fA-F0-9]{8,})', text_without_brackets)
|
189 |
+
cache_ids.extend(direct_hex_ids)
|
190 |
+
|
191 |
+
# Pattern 3: Standalone hex IDs on separate lines (check the whole response)
|
192 |
+
standalone_pattern = r'^([a-fA-F0-9]{8,})$'
|
193 |
+
standalone_matches = re.findall(standalone_pattern, response, re.MULTILINE)
|
194 |
+
cache_ids.extend(standalone_matches)
|
195 |
+
|
196 |
+
# Remove duplicates while preserving order
|
197 |
+
seen = set()
|
198 |
+
unique_cache_ids = []
|
199 |
+
for cache_id in cache_ids:
|
200 |
+
if cache_id not in seen:
|
201 |
+
seen.add(cache_id)
|
202 |
+
unique_cache_ids.append(cache_id)
|
203 |
+
|
204 |
+
if unique_cache_ids:
|
205 |
+
tool_cache_id = ", ".join(unique_cache_ids) if len(unique_cache_ids) > 1 else unique_cache_ids[0]
|
206 |
+
elif tool_cache_id_text and tool_cache_id_text.lower() not in ["", "empty", "none", "no tool cache id"]:
|
207 |
+
tool_cache_id = tool_cache_id_text
|
208 |
+
|
209 |
+
# Parse RELEVANT_CONTEXT
|
210 |
+
context_match = re.search(
|
211 |
+
r'(?:\*\*)?RELEVANT_CONTEXT:(?:\*\*)?\s*(.*?)(?=\n\*\*[A-Z_]+:|\Z)',
|
212 |
+
response,
|
213 |
+
re.IGNORECASE | re.DOTALL
|
214 |
+
)
|
215 |
+
|
216 |
+
relevant_context = ""
|
217 |
+
if context_match:
|
218 |
+
relevant_context = context_match.group(1).strip()
|
219 |
+
elif not answer_present:
|
220 |
+
relevant_context = response
|
221 |
+
|
222 |
+
return {
|
223 |
+
"answer_present": answer_present,
|
224 |
+
"direct_answer": "YES" if answer_present else "NO",
|
225 |
+
"tool_cache_id": tool_cache_id,
|
226 |
+
"relevant_context": relevant_context,
|
227 |
+
"raw_response": response
|
228 |
+
}
|
229 |
+
|
230 |
+
except Exception as e:
|
231 |
+
print(f"Error parsing memory response: {e}")
|
232 |
+
return {
|
233 |
+
"answer_present": False,
|
234 |
+
"direct_answer": "NO",
|
235 |
+
"tool_cache_id": None,
|
236 |
+
"relevant_context": response,
|
237 |
+
"raw_response": response
|
238 |
+
}
|
src/deepforest_agent/utils/rtree_spatial_utils.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import List, Dict, Any, Tuple, Optional
|
3 |
+
from rtree import index
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
|
7 |
+
class DetectionSpatialAnalyzer:
|
8 |
+
"""
|
9 |
+
Spatial analyzer using R-tree for DeepForest detection results.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, image_width: int, image_height: int):
|
13 |
+
"""
|
14 |
+
Initialize spatial analyzer with image dimensions.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
image_width: Width of the image in pixels
|
18 |
+
image_height: Height of the image in pixels
|
19 |
+
"""
|
20 |
+
self.image_width = image_width
|
21 |
+
self.image_height = image_height
|
22 |
+
self.spatial_index = index.Index()
|
23 |
+
self.detections = []
|
24 |
+
|
25 |
+
def add_detections(self, detections_list: List[Dict[str, Any]]) -> None:
|
26 |
+
"""
|
27 |
+
Add detections to R-tree spatial index.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
detections_list: List of detection dictionaries with coordinates
|
31 |
+
"""
|
32 |
+
for i, detection in enumerate(detections_list):
|
33 |
+
xmin = detection.get('xmin', 0)
|
34 |
+
ymin = detection.get('ymin', 0)
|
35 |
+
xmax = detection.get('xmax', 0)
|
36 |
+
ymax = detection.get('ymax', 0)
|
37 |
+
|
38 |
+
# Validate box ordering - swap if necessary
|
39 |
+
if xmin > xmax:
|
40 |
+
xmin, xmax = xmax, xmin
|
41 |
+
if ymin > ymax:
|
42 |
+
ymin, ymax = ymax, ymin
|
43 |
+
|
44 |
+
# Clamp to image bounds
|
45 |
+
xmin = max(0, min(xmin, self.image_width))
|
46 |
+
ymin = max(0, min(ymin, self.image_height))
|
47 |
+
xmax = max(0, min(xmax, self.image_width))
|
48 |
+
ymax = max(0, min(ymax, self.image_height))
|
49 |
+
|
50 |
+
# Skip invalid boxes (zero area after validation)
|
51 |
+
if xmin >= xmax or ymin >= ymax:
|
52 |
+
continue
|
53 |
+
|
54 |
+
# Add to R-tree index
|
55 |
+
self.spatial_index.insert(i, (xmin, ymin, xmax, ymax))
|
56 |
+
|
57 |
+
# Store detection with spatial info
|
58 |
+
detection_copy = detection.copy()
|
59 |
+
detection_copy['detection_id'] = i
|
60 |
+
detection_copy['centroid_x'] = (xmin + xmax) / 2
|
61 |
+
detection_copy['centroid_y'] = (ymin + ymax) / 2
|
62 |
+
detection_copy['area'] = (xmax - xmin) * (ymax - ymin)
|
63 |
+
self.detections.append(detection_copy)
|
64 |
+
|
65 |
+
def get_grid_analysis(self) -> Dict[str, Dict[str, Any]]:
|
66 |
+
"""
|
67 |
+
Analyze detections using 3x3 grid system.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
Dictionary with analysis for each grid cell
|
71 |
+
"""
|
72 |
+
grid_width = self.image_width / 3
|
73 |
+
grid_height = self.image_height / 3
|
74 |
+
|
75 |
+
grid_names = {
|
76 |
+
(0, 0): "Top-Left (Northwest)", (1, 0): "Top-Center (North)", (2, 0): "Top-Right (Northeast)",
|
77 |
+
(0, 1): "Middle-Left (West)", (1, 1): "Center", (2, 1): "Middle-Right (East)",
|
78 |
+
(0, 2): "Bottom-Left (Southwest)", (1, 2): "Bottom-Center (South)", (2, 2): "Bottom-Right (Southeast)"
|
79 |
+
}
|
80 |
+
|
81 |
+
grid_analysis = {}
|
82 |
+
|
83 |
+
for (grid_x, grid_y), grid_name in grid_names.items():
|
84 |
+
# Define grid bounds
|
85 |
+
x_min = grid_x * grid_width
|
86 |
+
y_min = grid_y * grid_height
|
87 |
+
x_max = (grid_x + 1) * grid_width
|
88 |
+
y_max = (grid_y + 1) * grid_height
|
89 |
+
|
90 |
+
# Query R-tree for intersecting detections
|
91 |
+
intersecting_ids = list(self.spatial_index.intersection((x_min, y_min, x_max, y_max)))
|
92 |
+
grid_detections = [self.detections[i] for i in intersecting_ids]
|
93 |
+
|
94 |
+
# Analyze by confidence categories
|
95 |
+
confidence_analysis = self._analyze_confidence_categories(grid_detections)
|
96 |
+
|
97 |
+
grid_analysis[grid_name] = {
|
98 |
+
"total_detections": len(grid_detections),
|
99 |
+
"confidence_analysis": confidence_analysis,
|
100 |
+
"bounds": {"x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max}
|
101 |
+
}
|
102 |
+
|
103 |
+
return grid_analysis
|
104 |
+
|
105 |
+
def _analyze_confidence_categories(self, detections: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
106 |
+
"""
|
107 |
+
Analyze detections by confidence categories.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
detections: List of detection dictionaries
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
Analysis by confidence categories (Low, Medium, High)
|
114 |
+
"""
|
115 |
+
categories = {
|
116 |
+
"Detections with Low Confidence Score (0.0-0.3)": {"detections": [], "range": (0.0, 0.3)},
|
117 |
+
"Detections with Medium Confidence Score (0.3-0.7)": {"detections": [], "range": (0.3, 0.7)},
|
118 |
+
"Detections with High Confidence Score (0.7-1.0)": {"detections": [], "range": (0.7, 1.0)}
|
119 |
+
}
|
120 |
+
|
121 |
+
for detection in detections:
|
122 |
+
score = detection.get('score', 0.0)
|
123 |
+
if score < 0.3:
|
124 |
+
categories["Detections with Low Confidence Score (0.0-0.3)"]["detections"].append(detection)
|
125 |
+
elif score < 0.7:
|
126 |
+
categories["Detections with Medium Confidence Score (0.3-0.7)"]["detections"].append(detection)
|
127 |
+
else:
|
128 |
+
categories["Detections with High Confidence Score (0.7-1.0)"]["detections"].append(detection)
|
129 |
+
|
130 |
+
# Calculate statistics for each category
|
131 |
+
analysis = {}
|
132 |
+
for category_name, category_data in categories.items():
|
133 |
+
cat_detections = category_data["detections"]
|
134 |
+
if cat_detections:
|
135 |
+
areas = [d['area'] for d in cat_detections]
|
136 |
+
analysis[category_name] = {
|
137 |
+
"count": len(cat_detections),
|
138 |
+
"avg_area": np.mean(areas),
|
139 |
+
"min_area": np.min(areas),
|
140 |
+
"max_area": np.max(areas),
|
141 |
+
"total_area_covered": np.sum(areas),
|
142 |
+
"labels": [d.get('label', 'unknown') for d in cat_detections]
|
143 |
+
}
|
144 |
+
else:
|
145 |
+
analysis[category_name] = {
|
146 |
+
"count": 0,
|
147 |
+
"avg_area": 0,
|
148 |
+
"min_area": 0,
|
149 |
+
"max_area": 0,
|
150 |
+
"total_area_covered": 0,
|
151 |
+
"labels": []
|
152 |
+
}
|
153 |
+
|
154 |
+
return analysis
|
155 |
+
|
156 |
+
def analyze_spatial_relationships_with_indexing(self, confidence_threshold: float = 0.3) -> List[Dict[str, Any]]:
|
157 |
+
"""
|
158 |
+
Analyze spatial relationships using R-tree indexing for confidence >= 0.3 detections.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
confidence_threshold: Minimum confidence score (default: 0.3)
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
List of spatial relationship dictionaries with intersection and nearest neighbor data
|
165 |
+
"""
|
166 |
+
# Filter detections by confidence threshold
|
167 |
+
high_confidence_detections = [
|
168 |
+
d for d in self.detections
|
169 |
+
if d.get('score', 0.0) >= confidence_threshold
|
170 |
+
]
|
171 |
+
|
172 |
+
if not high_confidence_detections:
|
173 |
+
return []
|
174 |
+
|
175 |
+
relationships = []
|
176 |
+
|
177 |
+
for detection in high_confidence_detections:
|
178 |
+
# Get bounding box coordinates directly
|
179 |
+
xmin = detection.get('xmin', 0)
|
180 |
+
ymin = detection.get('ymin', 0)
|
181 |
+
xmax = detection.get('xmax', 0)
|
182 |
+
ymax = detection.get('ymax', 0)
|
183 |
+
detection_id = detection.get('detection_id', 0)
|
184 |
+
|
185 |
+
# Get object label (handle classification labels for trees)
|
186 |
+
if 'classification_label' in detection and detection['classification_label'] and str(detection['classification_label']).lower() != 'nan':
|
187 |
+
object_label = detection['classification_label']
|
188 |
+
else:
|
189 |
+
object_label = detection.get('label', 'unknown')
|
190 |
+
|
191 |
+
# Find intersecting objects using spatial index
|
192 |
+
intersecting_ids = list(self.spatial_index.intersection((xmin, ymin, xmax, ymax)))
|
193 |
+
|
194 |
+
# Remove self from intersections
|
195 |
+
intersecting_ids = [idx for idx in intersecting_ids if idx != detection_id]
|
196 |
+
|
197 |
+
# Get details of intersecting objects
|
198 |
+
intersecting_objects = []
|
199 |
+
for idx in intersecting_ids:
|
200 |
+
if idx < len(self.detections):
|
201 |
+
intersecting_detection = self.detections[idx]
|
202 |
+
if intersecting_detection.get('score', 0.0) >= confidence_threshold:
|
203 |
+
if 'classification_label' in intersecting_detection and intersecting_detection['classification_label'] and str(intersecting_detection['classification_label']).lower() != 'nan':
|
204 |
+
intersecting_label = intersecting_detection['classification_label']
|
205 |
+
else:
|
206 |
+
intersecting_label = intersecting_detection.get('label', 'unknown')
|
207 |
+
intersecting_objects.append(intersecting_label)
|
208 |
+
|
209 |
+
# Find nearest neighbor using spatial index
|
210 |
+
nearest_ids = list(self.spatial_index.nearest((xmin, ymin, xmax, ymax), 2)) # 2 to get self + nearest
|
211 |
+
nearest_neighbor = None
|
212 |
+
|
213 |
+
for idx in nearest_ids:
|
214 |
+
if idx != detection_id and idx < len(self.detections):
|
215 |
+
nearest_detection = self.detections[idx]
|
216 |
+
if nearest_detection.get('score', 0.0) >= confidence_threshold:
|
217 |
+
if 'classification_label' in nearest_detection and nearest_detection['classification_label'] and str(nearest_detection['classification_label']).lower() != 'nan':
|
218 |
+
nearest_label = nearest_detection['classification_label']
|
219 |
+
else:
|
220 |
+
nearest_label = nearest_detection.get('label', 'unknown')
|
221 |
+
nearest_neighbor = nearest_label
|
222 |
+
break
|
223 |
+
|
224 |
+
# Determine grid region
|
225 |
+
grid_region = self._determine_grid_region(detection)
|
226 |
+
|
227 |
+
# Count intersecting objects by type
|
228 |
+
object_counts = {}
|
229 |
+
for obj_label in intersecting_objects:
|
230 |
+
object_counts[obj_label] = object_counts.get(obj_label, 0) + 1
|
231 |
+
|
232 |
+
relationships.append({
|
233 |
+
'object_type': object_label,
|
234 |
+
'object_location': f"({ymin}, {xmin})",
|
235 |
+
'grid_region': grid_region,
|
236 |
+
'intersecting_objects': object_counts,
|
237 |
+
'nearest_neighbor': nearest_neighbor,
|
238 |
+
'confidence_score': detection.get('score', 0.0),
|
239 |
+
'total_intersections': len(intersecting_objects)
|
240 |
+
})
|
241 |
+
|
242 |
+
return relationships
|
243 |
+
|
244 |
+
def _determine_grid_region(self, detection: Dict[str, Any]) -> str:
|
245 |
+
"""
|
246 |
+
Determine which grid region a detection belongs to based on its centroid.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
detection: Detection dictionary with coordinates
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
Grid region name (e.g., "northern", "northwest", etc.)
|
253 |
+
"""
|
254 |
+
centroid_x = detection.get('centroid_x', 0)
|
255 |
+
centroid_y = detection.get('centroid_y', 0)
|
256 |
+
|
257 |
+
grid_width = self.image_width / 3
|
258 |
+
grid_height = self.image_height / 3
|
259 |
+
|
260 |
+
# Determine grid position
|
261 |
+
grid_x = int(centroid_x // grid_width)
|
262 |
+
grid_y = int(centroid_y // grid_height)
|
263 |
+
|
264 |
+
# Ensure within bounds
|
265 |
+
grid_x = min(2, max(0, grid_x))
|
266 |
+
grid_y = min(2, max(0, grid_y))
|
267 |
+
|
268 |
+
grid_names = {
|
269 |
+
(0, 0): "northwestern", (1, 0): "northern", (2, 0): "northeastern",
|
270 |
+
(0, 1): "western", (1, 1): "central", (2, 1): "eastern",
|
271 |
+
(0, 2): "southwestern", (1, 2): "southern", (2, 2): "southeastern"
|
272 |
+
}
|
273 |
+
|
274 |
+
return grid_names.get((grid_x, grid_y), "central")
|
275 |
+
|
276 |
+
def generate_spatial_narrative(self, confidence_threshold: float = 0.3) -> str:
|
277 |
+
"""
|
278 |
+
Generate narrative description of spatial relationships using R-tree analysis.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
confidence_threshold: Minimum confidence score for analysis (default: 0.3)
|
282 |
+
|
283 |
+
Returns:
|
284 |
+
Natural language narrative of spatial relationships
|
285 |
+
"""
|
286 |
+
relationships = self.analyze_spatial_relationships_with_indexing(confidence_threshold)
|
287 |
+
|
288 |
+
if not relationships:
|
289 |
+
return f"No objects with confidence score >= {confidence_threshold} found for spatial relationship analysis."
|
290 |
+
|
291 |
+
narrative_parts = []
|
292 |
+
|
293 |
+
# Process each relationship and only include different object types
|
294 |
+
for rel in relationships:
|
295 |
+
object_type = rel['object_type']
|
296 |
+
confidence_score = rel['confidence_score']
|
297 |
+
grid_region = rel['grid_region']
|
298 |
+
object_location = rel['object_location']
|
299 |
+
|
300 |
+
# Only process intersecting objects that are DIFFERENT from the main object
|
301 |
+
different_intersecting = {}
|
302 |
+
for intersecting_type, count in rel['intersecting_objects'].items():
|
303 |
+
if intersecting_type != object_type: # Only different object types
|
304 |
+
different_intersecting[intersecting_type] = count
|
305 |
+
|
306 |
+
# Generate narrative for intersecting different objects
|
307 |
+
if different_intersecting:
|
308 |
+
intersecting_parts = []
|
309 |
+
for obj_label, count in different_intersecting.items():
|
310 |
+
if count == 1:
|
311 |
+
intersecting_parts.append(f"{count} {obj_label.replace('_', ' ')}")
|
312 |
+
else:
|
313 |
+
intersecting_parts.append(f"{count} {obj_label.replace('_', ' ')}s")
|
314 |
+
|
315 |
+
intersecting_desc = ", ".join(intersecting_parts)
|
316 |
+
|
317 |
+
narrative_parts.append(
|
318 |
+
f"I am about {confidence_score*100:.1f}% confident that, in {grid_region} region, "
|
319 |
+
f"{intersecting_desc} found overlapping around the {object_type.replace('_', ' ')} "
|
320 |
+
f"object at location (top, left) = {object_location}.\n"
|
321 |
+
)
|
322 |
+
|
323 |
+
# Only add nearest neighbor information if it's a DIFFERENT object type
|
324 |
+
if rel['nearest_neighbor'] and rel['nearest_neighbor'] != object_type:
|
325 |
+
narrative_parts.append(
|
326 |
+
f"I am about {confidence_score*100:.1f}% confident that, in {grid_region} region, "
|
327 |
+
f"around the {object_type.replace('_', ' ')} at location (top, left) = {object_location} "
|
328 |
+
f"the nearest neighbor is a {rel['nearest_neighbor'].replace('_', ' ')}.\n"
|
329 |
+
)
|
330 |
+
|
331 |
+
if narrative_parts:
|
332 |
+
# Remove duplicates while preserving order
|
333 |
+
unique_narratives = []
|
334 |
+
seen = set()
|
335 |
+
for part in narrative_parts:
|
336 |
+
if part not in seen:
|
337 |
+
unique_narratives.append(part)
|
338 |
+
seen.add(part)
|
339 |
+
|
340 |
+
return " ".join(unique_narratives)
|
341 |
+
else:
|
342 |
+
return f"Spatial analysis completed for {len(relationships)} objects with confidence >= {confidence_threshold}, but no significant spatial relationships between different object types detected."
|
343 |
+
|
344 |
+
def get_detection_statistics(self) -> Dict[str, Any]:
|
345 |
+
"""
|
346 |
+
Get comprehensive detection statistics.
|
347 |
+
|
348 |
+
Returns:
|
349 |
+
Dictionary with overall statistics
|
350 |
+
"""
|
351 |
+
if not self.detections:
|
352 |
+
return {"total_count": 0}
|
353 |
+
|
354 |
+
# Basic counts and confidence
|
355 |
+
total_count = len(self.detections)
|
356 |
+
scores = [d.get('score', 0.0) for d in self.detections]
|
357 |
+
overall_confidence = np.mean(scores)
|
358 |
+
|
359 |
+
# Size statistics
|
360 |
+
areas = [d['area'] for d in self.detections]
|
361 |
+
avg_area = np.mean(areas)
|
362 |
+
min_area = np.min(areas)
|
363 |
+
max_area = np.max(areas)
|
364 |
+
total_area = np.sum(areas)
|
365 |
+
|
366 |
+
# Label distribution
|
367 |
+
labels = [d.get('label', 'unknown') for d in self.detections]
|
368 |
+
# Handle classification labels for trees
|
369 |
+
classified_labels = []
|
370 |
+
for d in self.detections:
|
371 |
+
if 'classification_label' in d and d['classification_label'] and str(d['classification_label']).lower() != 'nan':
|
372 |
+
classified_labels.append(d['classification_label'])
|
373 |
+
else:
|
374 |
+
classified_labels.append(d.get('label', 'unknown'))
|
375 |
+
|
376 |
+
from collections import Counter
|
377 |
+
label_counts = Counter(classified_labels)
|
378 |
+
|
379 |
+
return {
|
380 |
+
"total_count": total_count,
|
381 |
+
"overall_confidence": overall_confidence,
|
382 |
+
"size_stats": {
|
383 |
+
"avg_area": avg_area,
|
384 |
+
"min_area": min_area,
|
385 |
+
"max_area": max_area,
|
386 |
+
"total_area_covered": total_area
|
387 |
+
},
|
388 |
+
"label_distribution": dict(label_counts),
|
389 |
+
"confidence_distribution": {
|
390 |
+
"low_count": len([s for s in scores if s < 0.3]),
|
391 |
+
"medium_count": len([s for s in scores if 0.3 <= s < 0.7]),
|
392 |
+
"high_count": len([s for s in scores if s >= 0.7])
|
393 |
+
}
|
394 |
+
}
|
src/deepforest_agent/utils/state_manager.py
ADDED
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import uuid
|
3 |
+
import time
|
4 |
+
from typing import Optional, Any, Dict, List
|
5 |
+
|
6 |
+
from deepforest_agent.utils.cache_utils import tool_call_cache
|
7 |
+
|
8 |
+
|
9 |
+
class SessionStateManager:
|
10 |
+
"""
|
11 |
+
Session-based state manager with thread ID for the DeepForest Agent.
|
12 |
+
|
13 |
+
This class manages state for multiple concurrent users with each user
|
14 |
+
having their own session containing current image, conversation
|
15 |
+
history, and session information.
|
16 |
+
|
17 |
+
Attributes:
|
18 |
+
_lock (threading.Lock): Thread synchronization lock
|
19 |
+
_sessions (Dict[str, Dict[str, Any]]): Dictionary mapping session_ids to session state
|
20 |
+
_cleanup_interval (int): Time in seconds after which inactive sessions are cleaned up
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, cleanup_interval: int = 3600) -> None:
|
24 |
+
"""
|
25 |
+
Initialize the session state manager.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
cleanup_interval (int): Time in seconds after which inactive sessions
|
29 |
+
are eligible for cleanup (default: 1 hour)
|
30 |
+
"""
|
31 |
+
self._lock = threading.Lock()
|
32 |
+
self._sessions = {}
|
33 |
+
self._cleanup_interval = cleanup_interval
|
34 |
+
|
35 |
+
def create_session(self, image: Any = None) -> str:
|
36 |
+
"""
|
37 |
+
Create a new session with initial image.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
image (Any, optional): Initial image for the session
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
str: Unique session ID
|
44 |
+
"""
|
45 |
+
session_id = str(uuid.uuid4())[:12]
|
46 |
+
|
47 |
+
with self._lock:
|
48 |
+
self._sessions[session_id] = {
|
49 |
+
"current_image": image,
|
50 |
+
"conversation_history": [],
|
51 |
+
"annotated_image": None,
|
52 |
+
"thread_id": session_id,
|
53 |
+
"first_message": True,
|
54 |
+
"created_at": time.time(),
|
55 |
+
"last_accessed": time.time(),
|
56 |
+
"is_cancelled": False,
|
57 |
+
"is_processing": False,
|
58 |
+
"tool_call_history": [],
|
59 |
+
"visual_analysis_history": []
|
60 |
+
}
|
61 |
+
|
62 |
+
return session_id
|
63 |
+
|
64 |
+
def get_session_state(self, session_id: str) -> Dict[str, Any]:
|
65 |
+
"""
|
66 |
+
Get complete state for a specific session.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
session_id (str): The session ID to retrieve
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
Dict[str, Any]: Copy of session state dictionary
|
73 |
+
|
74 |
+
Raises:
|
75 |
+
KeyError: If session_id doesn't exist
|
76 |
+
"""
|
77 |
+
with self._lock:
|
78 |
+
if session_id not in self._sessions:
|
79 |
+
raise KeyError(f"Session {session_id} not found")
|
80 |
+
|
81 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
82 |
+
|
83 |
+
# Return a copy to prevent external modification
|
84 |
+
return self._sessions[session_id].copy()
|
85 |
+
|
86 |
+
def get(self, session_id: str, key: str, default: Any = None) -> Any:
|
87 |
+
"""
|
88 |
+
Get a value from session state.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
session_id (str): The session ID
|
92 |
+
key (str): The state key to retrieve
|
93 |
+
default (Any, optional): Default value if key not found.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
Any: The value associated with the key, or default if not found
|
97 |
+
|
98 |
+
Raises:
|
99 |
+
KeyError: If session_id doesn't exist
|
100 |
+
"""
|
101 |
+
with self._lock:
|
102 |
+
if session_id not in self._sessions:
|
103 |
+
raise KeyError(f"Session {session_id} not found")
|
104 |
+
|
105 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
106 |
+
|
107 |
+
return self._sessions[session_id].get(key, default)
|
108 |
+
|
109 |
+
def set(self, session_id: str, key: str, value: Any) -> None:
|
110 |
+
"""
|
111 |
+
Set a value in session state.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
session_id (str): The session ID
|
115 |
+
key (str): The state key to set
|
116 |
+
value (Any): The value to store
|
117 |
+
|
118 |
+
Raises:
|
119 |
+
KeyError: If session_id doesn't exist
|
120 |
+
"""
|
121 |
+
with self._lock:
|
122 |
+
if session_id not in self._sessions:
|
123 |
+
raise KeyError(f"Session {session_id} not found")
|
124 |
+
|
125 |
+
self._sessions[session_id][key] = value
|
126 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
127 |
+
|
128 |
+
def update(self, session_id: str, updates: Dict[str, Any]) -> None:
|
129 |
+
"""
|
130 |
+
Update multiple values in session state.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
session_id (str): The session ID
|
134 |
+
updates (Dict[str, Any]): Dictionary of key-value pairs to update
|
135 |
+
|
136 |
+
Raises:
|
137 |
+
KeyError: If session_id doesn't exist
|
138 |
+
"""
|
139 |
+
with self._lock:
|
140 |
+
if session_id not in self._sessions:
|
141 |
+
raise KeyError(f"Session {session_id} not found")
|
142 |
+
|
143 |
+
self._sessions[session_id].update(updates)
|
144 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
145 |
+
|
146 |
+
def set_processing_state(self, session_id: str, is_processing: bool) -> None:
|
147 |
+
"""
|
148 |
+
Set processing state for a session.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
session_id (str): The session ID
|
152 |
+
is_processing (bool): Whether processing is active
|
153 |
+
"""
|
154 |
+
with self._lock:
|
155 |
+
if session_id in self._sessions:
|
156 |
+
self._sessions[session_id]["is_processing"] = is_processing
|
157 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
158 |
+
|
159 |
+
def cancel_session(self, session_id: str) -> None:
|
160 |
+
"""
|
161 |
+
Cancel processing for a session.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
session_id (str): The session ID to cancel
|
165 |
+
"""
|
166 |
+
with self._lock:
|
167 |
+
if session_id in self._sessions:
|
168 |
+
self._sessions[session_id]["is_cancelled"] = True
|
169 |
+
self._sessions[session_id]["is_processing"] = False
|
170 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
171 |
+
|
172 |
+
def is_cancelled(self, session_id: str) -> bool:
|
173 |
+
"""
|
174 |
+
Check if session is cancelled.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
session_id (str): The session ID to check
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
bool: True if cancelled
|
181 |
+
"""
|
182 |
+
with self._lock:
|
183 |
+
if session_id not in self._sessions:
|
184 |
+
return True
|
185 |
+
return self._sessions[session_id].get("is_cancelled", False)
|
186 |
+
|
187 |
+
def reset_cancellation(self, session_id: str) -> None:
|
188 |
+
"""
|
189 |
+
Reset cancellation flag for a session.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
session_id (str): The session ID to reset
|
193 |
+
"""
|
194 |
+
with self._lock:
|
195 |
+
if session_id in self._sessions:
|
196 |
+
self._sessions[session_id]["is_cancelled"] = False
|
197 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
198 |
+
|
199 |
+
def add_tool_call_to_history(self, session_id: str, tool_name: str, arguments: Dict[str, Any], cache_key: str) -> None:
|
200 |
+
"""
|
201 |
+
Add a tool call to the session's tool call history.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
session_id (str): The session ID
|
205 |
+
tool_name (str): Name of the tool that was called
|
206 |
+
arguments (Dict[str, Any]): Arguments passed to the tool
|
207 |
+
cache_key (str): Cache key used for this tool call
|
208 |
+
|
209 |
+
Raises:
|
210 |
+
KeyError: If session_id doesn't exist
|
211 |
+
"""
|
212 |
+
with self._lock:
|
213 |
+
if session_id not in self._sessions:
|
214 |
+
raise KeyError(f"Session {session_id} not found")
|
215 |
+
|
216 |
+
tool_call_entry = {
|
217 |
+
"tool_name": tool_name,
|
218 |
+
"arguments": arguments.copy(),
|
219 |
+
"cache_key": cache_key,
|
220 |
+
"timestamp": time.time(),
|
221 |
+
"call_number": len(self._sessions[session_id]["tool_call_history"]) + 1
|
222 |
+
}
|
223 |
+
|
224 |
+
self._sessions[session_id]["tool_call_history"].append(tool_call_entry)
|
225 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
226 |
+
|
227 |
+
def get_tool_call_history(self, session_id: str) -> List[Dict[str, Any]]:
|
228 |
+
"""
|
229 |
+
Get the tool call history for a specific session.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
session_id (str): The session ID
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
List[Dict[str, Any]]: List of tool calls made in this session
|
236 |
+
|
237 |
+
Raises:
|
238 |
+
KeyError: If session_id doesn't exist
|
239 |
+
"""
|
240 |
+
with self._lock:
|
241 |
+
if session_id not in self._sessions:
|
242 |
+
raise KeyError(f"Session {session_id} not found")
|
243 |
+
|
244 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
245 |
+
return self._sessions[session_id]["tool_call_history"].copy()
|
246 |
+
|
247 |
+
def add_visual_analysis_to_history(self, session_id: str, visual_analysis: str, additional_objects: Optional[List[Dict[str, Any]]] = None) -> None:
|
248 |
+
"""
|
249 |
+
Add a visual analysis response to the session's history.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
session_id (str): The session ID
|
253 |
+
visual_analysis (str): Visual analysis text from visual agent
|
254 |
+
additional_objects (Optional[List[Dict[str, Any]]]): Additional objects detected by visual agent
|
255 |
+
|
256 |
+
Raises:
|
257 |
+
KeyError: If session_id doesn't exist
|
258 |
+
"""
|
259 |
+
with self._lock:
|
260 |
+
if session_id not in self._sessions:
|
261 |
+
raise KeyError(f"Session {session_id} not found")
|
262 |
+
|
263 |
+
visual_entry = {
|
264 |
+
"visual_analysis": visual_analysis,
|
265 |
+
"additional_objects": additional_objects or [],
|
266 |
+
"timestamp": time.time(),
|
267 |
+
"turn_number": len(self._sessions[session_id]["visual_analysis_history"]) + 1
|
268 |
+
}
|
269 |
+
|
270 |
+
self._sessions[session_id]["visual_analysis_history"].append(visual_entry)
|
271 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
272 |
+
|
273 |
+
def get_visual_analysis_history(self, session_id: str) -> List[Dict[str, Any]]:
|
274 |
+
"""
|
275 |
+
Get all visual analysis responses from previous turns.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
session_id (str): The session ID
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
List[Dict[str, Any]]: List of visual analysis entries with text and additional objects
|
282 |
+
|
283 |
+
Raises:
|
284 |
+
KeyError: If session_id doesn't exist
|
285 |
+
"""
|
286 |
+
with self._lock:
|
287 |
+
if session_id not in self._sessions:
|
288 |
+
raise KeyError(f"Session {session_id} not found")
|
289 |
+
|
290 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
291 |
+
|
292 |
+
return self._sessions[session_id]["visual_analysis_history"].copy()
|
293 |
+
|
294 |
+
def get_formatted_tool_call_history(self, session_id: str) -> str:
|
295 |
+
"""
|
296 |
+
Get formatted tool call history for memory agent context.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
session_id (str): The session ID
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
str: Formatted tool call history string
|
303 |
+
|
304 |
+
Raises:
|
305 |
+
KeyError: If session_id doesn't exist
|
306 |
+
"""
|
307 |
+
try:
|
308 |
+
tool_calls = self.get_tool_call_history(session_id)
|
309 |
+
if not tool_calls:
|
310 |
+
return "No previous tool calls in this session."
|
311 |
+
|
312 |
+
formatted_history = []
|
313 |
+
for tool_call in tool_calls:
|
314 |
+
call_info = f"Tool Call #{tool_call.get('call_number', 'N/A')}: "
|
315 |
+
call_info += f"{tool_call.get('tool_name', 'unknown')} "
|
316 |
+
call_info += f"with args {tool_call.get('arguments', {})}"
|
317 |
+
formatted_history.append(call_info)
|
318 |
+
|
319 |
+
return "\n".join(formatted_history)
|
320 |
+
except KeyError:
|
321 |
+
return f"Session {session_id} not found - no tool call history available."
|
322 |
+
|
323 |
+
def store_conversation_turn_context(
|
324 |
+
self,
|
325 |
+
session_id: str,
|
326 |
+
turn_number: int,
|
327 |
+
user_query: str,
|
328 |
+
visual_context: str,
|
329 |
+
detection_narrative: str,
|
330 |
+
tool_cache_id: Optional[str],
|
331 |
+
ecology_response: str
|
332 |
+
) -> None:
|
333 |
+
"""
|
334 |
+
Store complete turn context for memory agent.
|
335 |
+
|
336 |
+
Args:
|
337 |
+
session_id (str): The session ID
|
338 |
+
turn_number (int): Sequential number of this conversation turn (1-indexed)
|
339 |
+
user_query (str): The original user question of the current turn
|
340 |
+
visual_context (str): Complete visual analysis output from the visual agent
|
341 |
+
detection_narrative (str): Comprehensive spatial analysis narrative generated
|
342 |
+
from DeepForest detection results
|
343 |
+
tool_cache_id (Optional[str]): Cache identifier for DeepForest tool execution
|
344 |
+
results
|
345 |
+
ecology_response (str): Final synthesized ecological analysis response
|
346 |
+
"""
|
347 |
+
turn_data = {
|
348 |
+
"user_query": user_query,
|
349 |
+
"visual_context": visual_context,
|
350 |
+
"detection_narrative": detection_narrative,
|
351 |
+
"tool_cache_id": tool_cache_id or "No tool cache ID",
|
352 |
+
"ecology_response": ecology_response,
|
353 |
+
"timestamp": time.time()
|
354 |
+
}
|
355 |
+
|
356 |
+
self.set(session_id, f"conversation_turn_{turn_number}", turn_data)
|
357 |
+
|
358 |
+
# Update turn counter
|
359 |
+
current_turns = self.get(session_id, "total_turns", 0)
|
360 |
+
self.set(session_id, "total_turns", max(current_turns, turn_number))
|
361 |
+
|
362 |
+
def get_cache_stats_for_session(self, session_id: str) -> Dict[str, Any]:
|
363 |
+
"""
|
364 |
+
Get cache statistics specific to this session.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
session_id (str): The session ID
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
Dict[str, Any]: Cache statistics for this session
|
371 |
+
|
372 |
+
Raises:
|
373 |
+
KeyError: If session_id doesn't exist
|
374 |
+
"""
|
375 |
+
with self._lock:
|
376 |
+
if session_id not in self._sessions:
|
377 |
+
raise KeyError(f"Session {session_id} not found")
|
378 |
+
|
379 |
+
session_tool_calls = self._sessions[session_id]["tool_call_history"]
|
380 |
+
|
381 |
+
return {
|
382 |
+
"session_id": session_id,
|
383 |
+
"total_tool_calls": len(session_tool_calls),
|
384 |
+
"tool_calls": session_tool_calls,
|
385 |
+
"global_cache_stats": tool_call_cache.get_cache_stats()
|
386 |
+
}
|
387 |
+
|
388 |
+
def clear_session_cache_data(self, session_id: str) -> None:
|
389 |
+
"""
|
390 |
+
Clear tool call history for a specific session.
|
391 |
+
|
392 |
+
Note: This only clears the session's record of tool calls,
|
393 |
+
not the global cache itself.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
session_id (str): The session ID
|
397 |
+
|
398 |
+
Raises:
|
399 |
+
KeyError: If session_id doesn't exist
|
400 |
+
"""
|
401 |
+
with self._lock:
|
402 |
+
if session_id not in self._sessions:
|
403 |
+
raise KeyError(f"Session {session_id} not found")
|
404 |
+
|
405 |
+
self._sessions[session_id]["tool_call_history"] = []
|
406 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
407 |
+
|
408 |
+
def clear_conversation(self, session_id: str) -> None:
|
409 |
+
"""
|
410 |
+
Clear conversation-specific state for a session.
|
411 |
+
|
412 |
+
current_image and thread_id are preserved so that users can
|
413 |
+
start a new conversation without re-uploading the image.
|
414 |
+
|
415 |
+
Args:
|
416 |
+
session_id (str): The session ID to clear
|
417 |
+
|
418 |
+
Raises:
|
419 |
+
KeyError: If session_id doesn't exist
|
420 |
+
"""
|
421 |
+
with self._lock:
|
422 |
+
if session_id not in self._sessions:
|
423 |
+
raise KeyError(f"Session {session_id} not found")
|
424 |
+
|
425 |
+
self._sessions[session_id].update({
|
426 |
+
"conversation_history": [],
|
427 |
+
"annotated_image": None,
|
428 |
+
"first_message": True,
|
429 |
+
"last_accessed": time.time(),
|
430 |
+
"is_cancelled": True,
|
431 |
+
"is_processing": False,
|
432 |
+
"tool_call_history": [],
|
433 |
+
"visual_analysis_history": []
|
434 |
+
})
|
435 |
+
|
436 |
+
def reset_for_new_image(self, session_id: str, image: Any) -> None:
|
437 |
+
"""
|
438 |
+
Reset session state for new image upload.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
session_id (str): The session ID
|
442 |
+
image (Any): The new image object (typically PIL Image)
|
443 |
+
|
444 |
+
Raises:
|
445 |
+
KeyError: If session_id doesn't exist
|
446 |
+
"""
|
447 |
+
with self._lock:
|
448 |
+
if session_id not in self._sessions:
|
449 |
+
raise KeyError(f"Session {session_id} not found")
|
450 |
+
|
451 |
+
self._sessions[session_id].update({
|
452 |
+
"current_image": image,
|
453 |
+
"conversation_history": [],
|
454 |
+
"annotated_image": None,
|
455 |
+
"first_message": True,
|
456 |
+
"last_accessed": time.time(),
|
457 |
+
"tool_call_history": [],
|
458 |
+
"visual_analysis_history": []
|
459 |
+
})
|
460 |
+
|
461 |
+
def add_to_conversation(self, session_id: str, message: Dict[str, Any]) -> None:
|
462 |
+
"""
|
463 |
+
Add a message to conversation history for a specific session.
|
464 |
+
|
465 |
+
Args:
|
466 |
+
session_id (str): The session ID
|
467 |
+
message (Dict[str, Any]): Message dictionary with role and content
|
468 |
+
|
469 |
+
Raises:
|
470 |
+
KeyError: If session_id doesn't exist
|
471 |
+
"""
|
472 |
+
with self._lock:
|
473 |
+
if session_id not in self._sessions:
|
474 |
+
raise KeyError(f"Session {session_id} not found")
|
475 |
+
|
476 |
+
self._sessions[session_id]["conversation_history"].append(message)
|
477 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
478 |
+
|
479 |
+
def get_conversation_length(self, session_id: str) -> int:
|
480 |
+
"""
|
481 |
+
Get the length of conversation history for a session.
|
482 |
+
|
483 |
+
Args:
|
484 |
+
session_id (str): The session ID
|
485 |
+
|
486 |
+
Returns:
|
487 |
+
int: Number of messages in conversation history
|
488 |
+
|
489 |
+
Raises:
|
490 |
+
KeyError: If session_id doesn't exist
|
491 |
+
"""
|
492 |
+
with self._lock:
|
493 |
+
if session_id not in self._sessions:
|
494 |
+
raise KeyError(f"Session {session_id} not found")
|
495 |
+
|
496 |
+
self._sessions[session_id]["last_accessed"] = time.time()
|
497 |
+
return len(self._sessions[session_id]["conversation_history"])
|
498 |
+
|
499 |
+
def session_exists(self, session_id: str) -> bool:
|
500 |
+
"""
|
501 |
+
Check if a session exists.
|
502 |
+
|
503 |
+
Args:
|
504 |
+
session_id (str): The session ID to check
|
505 |
+
|
506 |
+
Returns:
|
507 |
+
bool: True if session exists, False otherwise
|
508 |
+
"""
|
509 |
+
with self._lock:
|
510 |
+
return session_id in self._sessions
|
511 |
+
|
512 |
+
def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
|
513 |
+
"""
|
514 |
+
Get information about all active sessions.
|
515 |
+
|
516 |
+
Returns:
|
517 |
+
Dict[str, Dict[str, Any]]: Dictionary mapping session_ids to session info
|
518 |
+
"""
|
519 |
+
with self._lock:
|
520 |
+
session_info = {}
|
521 |
+
for session_id, session_data in self._sessions.items():
|
522 |
+
session_info[session_id] = {
|
523 |
+
"thread_id": session_data.get("thread_id"),
|
524 |
+
"created_at": session_data.get("created_at"),
|
525 |
+
"last_accessed": session_data.get("last_accessed"),
|
526 |
+
"conversation_length": len(session_data.get("conversation_history", [])),
|
527 |
+
"has_image": session_data.get("current_image") is not None,
|
528 |
+
"has_annotated_image": session_data.get("annotated_image") is not None,
|
529 |
+
"tool_calls_count": len(session_data.get("tool_call_history", []))
|
530 |
+
}
|
531 |
+
return session_info
|
532 |
+
|
533 |
+
def cleanup_inactive_sessions(self) -> int:
|
534 |
+
"""
|
535 |
+
Remove sessions that haven't been accessed recently.
|
536 |
+
|
537 |
+
Returns:
|
538 |
+
int: Number of sessions cleaned up
|
539 |
+
"""
|
540 |
+
current_time = time.time()
|
541 |
+
cleaned_count = 0
|
542 |
+
|
543 |
+
with self._lock:
|
544 |
+
inactive_sessions = []
|
545 |
+
for session_id, session_data in self._sessions.items():
|
546 |
+
last_accessed = session_data.get("last_accessed", 0)
|
547 |
+
if current_time - last_accessed > self._cleanup_interval:
|
548 |
+
inactive_sessions.append(session_id)
|
549 |
+
|
550 |
+
for session_id in inactive_sessions:
|
551 |
+
del self._sessions[session_id]
|
552 |
+
cleaned_count += 1
|
553 |
+
|
554 |
+
return cleaned_count
|
555 |
+
|
556 |
+
def delete_session(self, session_id: str) -> bool:
|
557 |
+
"""
|
558 |
+
Manually delete a specific session.
|
559 |
+
|
560 |
+
Args:
|
561 |
+
session_id (str): The session ID to delete
|
562 |
+
|
563 |
+
Returns:
|
564 |
+
bool: True if session was deleted, False if it didn't exist
|
565 |
+
"""
|
566 |
+
with self._lock:
|
567 |
+
if session_id in self._sessions:
|
568 |
+
del self._sessions[session_id]
|
569 |
+
return True
|
570 |
+
return False
|
571 |
+
|
572 |
+
|
573 |
+
# Global session manager instance
|
574 |
+
session_state_manager = SessionStateManager()
|
src/deepforest_agent/utils/tile_manager.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import Tuple, Dict, List, Any, Optional
|
3 |
+
from PIL import Image
|
4 |
+
import rasterio as rio
|
5 |
+
from rasterio.windows import Window
|
6 |
+
from deepforest import preprocess
|
7 |
+
try:
|
8 |
+
import slidingwindow
|
9 |
+
SLIDINGWINDOW_AVAILABLE = True
|
10 |
+
except ImportError:
|
11 |
+
SLIDINGWINDOW_AVAILABLE = False
|
12 |
+
print("Warning: slidingwindow not available, falling back to deepforest preprocess")
|
13 |
+
|
14 |
+
from deepforest_agent.conf.config import Config
|
15 |
+
|
16 |
+
|
17 |
+
def tile_image_for_analysis(
|
18 |
+
image: Image.Image,
|
19 |
+
patch_size: int = Config.DEEPFOREST_DEFAULTS["patch_size"],
|
20 |
+
patch_overlap: float = Config.DEEPFOREST_DEFAULTS["patch_overlap"],
|
21 |
+
image_file_path: Optional[str] = None,
|
22 |
+
) -> Tuple[List[Image.Image], List[Dict[str, Any]]]:
|
23 |
+
"""
|
24 |
+
Tile am Image for visual analysis.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
image (Image.Image): PIL Image to tile
|
28 |
+
patch_size (int): Size of each tile in pixels (default: 400)
|
29 |
+
patch_overlap (float): Overlap between tiles as fraction 0-1 (default: 0.05)
|
30 |
+
image_file_path (Optional[str]): Path to raster file for memory-efficient dimension reading
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
Tuple containing:
|
34 |
+
- List[Image.Image]: List of PIL Image tiles
|
35 |
+
- List[Dict[str, Any]]: List of tile metadata with coordinates
|
36 |
+
|
37 |
+
Raises:
|
38 |
+
ValueError: If patch_overlap > 1 or image is too small for patch_size
|
39 |
+
Exception: If tiling process fails
|
40 |
+
"""
|
41 |
+
try:
|
42 |
+
# Use slidingwindow for all image types if available
|
43 |
+
if SLIDINGWINDOW_AVAILABLE:
|
44 |
+
height = width = None
|
45 |
+
method = "unknown"
|
46 |
+
|
47 |
+
if image_file_path:
|
48 |
+
try:
|
49 |
+
# Get raster shape without keeping file open
|
50 |
+
with rio.open(image_file_path) as src:
|
51 |
+
height = src.shape[0]
|
52 |
+
width = src.shape[1]
|
53 |
+
method = "slidingwindow_raster"
|
54 |
+
print(f"Using raster dimensions: {width}x{height} from file path")
|
55 |
+
except Exception as raster_error:
|
56 |
+
print(f"Raster reading failed: {raster_error}, using PIL image dimensions")
|
57 |
+
height = width = None
|
58 |
+
|
59 |
+
# If raster reading failed or no file path, get dimensions from PIL image
|
60 |
+
if height is None or width is None:
|
61 |
+
width, height = image.size
|
62 |
+
method = "slidingwindow_pil"
|
63 |
+
print(f"Using PIL dimensions: {width}x{height} from image object")
|
64 |
+
|
65 |
+
try:
|
66 |
+
# Generate windows using slidingwindow for any image type
|
67 |
+
windows = slidingwindow.generateForSize(
|
68 |
+
height=height,
|
69 |
+
width=width,
|
70 |
+
dimOrder=slidingwindow.DimOrder.ChannelHeightWidth,
|
71 |
+
maxWindowSize=patch_size,
|
72 |
+
overlapPercent=patch_overlap
|
73 |
+
)
|
74 |
+
|
75 |
+
print(f"Generated {len(windows)} tiles using slidingwindow with method: {method}")
|
76 |
+
|
77 |
+
tiles = []
|
78 |
+
tile_metadata = []
|
79 |
+
|
80 |
+
for i, window in enumerate(windows):
|
81 |
+
x = window.x
|
82 |
+
y = window.y
|
83 |
+
w = window.w
|
84 |
+
h = window.h
|
85 |
+
|
86 |
+
# Extract actual image data for this tile
|
87 |
+
if method == "slidingwindow_raster" and image_file_path:
|
88 |
+
try:
|
89 |
+
with rio.open(image_file_path) as src:
|
90 |
+
window_data = src.read(window=Window(x, y, w, h))
|
91 |
+
if window_data.ndim == 3:
|
92 |
+
window_data = window_data.transpose(1, 2, 0)
|
93 |
+
|
94 |
+
if window_data.dtype != np.uint8:
|
95 |
+
if window_data.max() <= 1.0:
|
96 |
+
window_data = (window_data * 255).astype(np.uint8)
|
97 |
+
else:
|
98 |
+
window_data = window_data.astype(np.uint8)
|
99 |
+
|
100 |
+
tile_pil = Image.fromarray(window_data)
|
101 |
+
print(f"Tile {i}: Read raster data {window_data.shape} -> PIL {tile_pil.size}")
|
102 |
+
|
103 |
+
except Exception as raster_read_error:
|
104 |
+
print(f"Failed to read raster tile {i}: {raster_read_error}")
|
105 |
+
tile_pil = image.crop((x, y, x + w, y + h))
|
106 |
+
print(f"Tile {i}: Fallback PIL crop -> {tile_pil.size}")
|
107 |
+
else:
|
108 |
+
tile_pil = image.crop((x, y, x + w, y + h))
|
109 |
+
print(f"Tile {i}: PIL crop ({x},{y},{x+w},{y+h}) -> {tile_pil.size}")
|
110 |
+
|
111 |
+
tiles.append(tile_pil)
|
112 |
+
|
113 |
+
# Create tile metadata with tile info
|
114 |
+
metadata = {
|
115 |
+
"tile_index": i,
|
116 |
+
"window_coords": {
|
117 |
+
"x": x,
|
118 |
+
"y": y,
|
119 |
+
"width": w,
|
120 |
+
"height": h
|
121 |
+
},
|
122 |
+
"tile_size": tile_pil.size,
|
123 |
+
"original_image_size": (width, height),
|
124 |
+
"method": method,
|
125 |
+
"actual_crop_bounds": (x, y, x + w, y + h)
|
126 |
+
}
|
127 |
+
tile_metadata.append(metadata)
|
128 |
+
|
129 |
+
print(f"Successfully created {len(tiles)} tiles using slidingwindow method")
|
130 |
+
return tiles, tile_metadata
|
131 |
+
|
132 |
+
except Exception as slidingwindow_error:
|
133 |
+
print(f"Slidingwindow method failed: {slidingwindow_error}, falling back to deepforest preprocess")
|
134 |
+
|
135 |
+
# Fallback to deepforest preprocess method only if slidingwindow failed
|
136 |
+
print(f"Using PIL-based tiling for image with size {image.size}")
|
137 |
+
|
138 |
+
numpy_image = np.array(image)
|
139 |
+
|
140 |
+
if numpy_image.shape[2] == 4:
|
141 |
+
numpy_image = numpy_image[:, :, :3]
|
142 |
+
elif numpy_image.shape[2] != 3:
|
143 |
+
raise ValueError(f"Image must have 3 channels (RGB), got {numpy_image.shape[2]}")
|
144 |
+
|
145 |
+
numpy_image = numpy_image.transpose(2, 0, 1)
|
146 |
+
numpy_image = numpy_image / 255.0
|
147 |
+
numpy_image = numpy_image.astype(np.float32)
|
148 |
+
|
149 |
+
print(f"Tiling image with shape {numpy_image.shape} using patch_size={patch_size}, patch_overlap={patch_overlap}")
|
150 |
+
|
151 |
+
windows = preprocess.compute_windows(numpy_image, patch_size, patch_overlap)
|
152 |
+
|
153 |
+
print(f"Generated {len(windows)} tiles for analysis using deepforest preprocess")
|
154 |
+
|
155 |
+
tiles = []
|
156 |
+
tile_metadata = []
|
157 |
+
|
158 |
+
for i, window in enumerate(windows):
|
159 |
+
tile_array = numpy_image[window.indices()]
|
160 |
+
tile_array = tile_array.transpose(1, 2, 0)
|
161 |
+
if tile_array.dtype != np.uint8:
|
162 |
+
tile_array = (tile_array * 255).astype(np.uint8) if tile_array.max() <= 1.0 else tile_array.astype(np.uint8)
|
163 |
+
|
164 |
+
tile_pil = Image.fromarray(tile_array)
|
165 |
+
tiles.append(tile_pil)
|
166 |
+
|
167 |
+
x, y, w, h = window.getRect()
|
168 |
+
print(f"DeepForest tile {i}: array shape {tile_array.shape} -> PIL {tile_pil.size}")
|
169 |
+
|
170 |
+
# Create tile metadata
|
171 |
+
metadata = {
|
172 |
+
"tile_index": i,
|
173 |
+
"window_coords": {
|
174 |
+
"x": x,
|
175 |
+
"y": y,
|
176 |
+
"width": w,
|
177 |
+
"height": h
|
178 |
+
},
|
179 |
+
"tile_size": tile_pil.size,
|
180 |
+
"original_image_size": image.size,
|
181 |
+
"method": "deepforest_preprocess"
|
182 |
+
}
|
183 |
+
tile_metadata.append(metadata)
|
184 |
+
|
185 |
+
if not tiles:
|
186 |
+
raise Exception("No tiles were created - check image dimensions and parameters")
|
187 |
+
|
188 |
+
# Check for empty or invalid tiles
|
189 |
+
valid_tiles = []
|
190 |
+
valid_metadata = []
|
191 |
+
for i, tile in enumerate(tiles):
|
192 |
+
if tile.size[0] > 0 and tile.size[1] > 0:
|
193 |
+
valid_tiles.append(tile)
|
194 |
+
valid_metadata.append(tile_metadata[i])
|
195 |
+
else:
|
196 |
+
print(f"Warning: Tile {i} has invalid size {tile.size}, skipping")
|
197 |
+
|
198 |
+
if not valid_tiles:
|
199 |
+
raise Exception("No valid tiles were created")
|
200 |
+
|
201 |
+
if len(valid_tiles) != len(tiles):
|
202 |
+
print(f"Filtered {len(tiles)} -> {len(valid_tiles)} valid tiles")
|
203 |
+
tiles = valid_tiles
|
204 |
+
tile_metadata = valid_metadata
|
205 |
+
|
206 |
+
print(f"Successfully created {len(tiles)} tiles for multi-image analysis using fallback method")
|
207 |
+
return tiles, tile_metadata
|
208 |
+
|
209 |
+
except Exception as e:
|
210 |
+
print(f"Error during image tiling: {e}")
|
211 |
+
raise e
|
tests/test_deepforest_tool.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from matplotlib import pyplot as plt
|
4 |
+
|
5 |
+
from deepforest_agent.conf.config import Config
|
6 |
+
from deepforest_agent.tools.deepforest_tool import DeepForestPredictor
|
7 |
+
from deepforest_agent.utils.image_utils import load_image_as_np_array
|
8 |
+
|
9 |
+
TEST_IMAGE_PATH_SMALL = "data/AWPE Pigeon Lake 2020 DJI_0005.JPG"
|
10 |
+
TEST_IMAGE_PATH_LARGE = "data/OSBS_029.tif"
|
11 |
+
|
12 |
+
deepforest_predictor = DeepForestPredictor()
|
13 |
+
|
14 |
+
|
15 |
+
def display_image_for_test(image_array: np.ndarray, title: str = "Test Image"):
|
16 |
+
"""
|
17 |
+
Display an image using matplotlib for visual inspection during testing.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
image_array: Image as numpy array
|
21 |
+
title: Title for the plot
|
22 |
+
"""
|
23 |
+
plt.imshow(image_array)
|
24 |
+
plt.axis('off')
|
25 |
+
plt.title(title)
|
26 |
+
plt.show()
|
27 |
+
|
28 |
+
|
29 |
+
def test_deepforest_predict_objects_basic_detection_bird():
|
30 |
+
"""Test basic bird detection with default parameters on a small image."""
|
31 |
+
image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
|
32 |
+
if image_array is None:
|
33 |
+
return
|
34 |
+
|
35 |
+
summary, annotated_image, detections_list = (
|
36 |
+
deepforest_predictor.predict_objects(
|
37 |
+
image_data_array=image_array,
|
38 |
+
model_names=["bird"]
|
39 |
+
)
|
40 |
+
)
|
41 |
+
|
42 |
+
assert "DeepForest detected" in summary or "No objects detected" in summary
|
43 |
+
assert ("bird" in summary or "No objects detected" in summary)
|
44 |
+
assert annotated_image is not None
|
45 |
+
assert isinstance(annotated_image, np.ndarray)
|
46 |
+
assert annotated_image.shape[:2] == image_array.shape[:2]
|
47 |
+
|
48 |
+
assert isinstance(detections_list, list)
|
49 |
+
if detections_list:
|
50 |
+
bird_labels_found = any(
|
51 |
+
detection["label"] == "bird" for detection in detections_list if 'label' in detection
|
52 |
+
)
|
53 |
+
assert bird_labels_found
|
54 |
+
|
55 |
+
display_image_for_test(annotated_image, "Bird Detection Test")
|
56 |
+
|
57 |
+
|
58 |
+
def test_deepforest_predict_objects_basic_detection_tree():
|
59 |
+
"""Test basic tree detection with default parameters on a small image."""
|
60 |
+
image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
|
61 |
+
if image_array is None:
|
62 |
+
return
|
63 |
+
|
64 |
+
summary, annotated_image, detections_list = (
|
65 |
+
deepforest_predictor.predict_objects(
|
66 |
+
image_data_array=image_array,
|
67 |
+
model_names=["tree"]
|
68 |
+
)
|
69 |
+
)
|
70 |
+
|
71 |
+
assert "DeepForest detected" in summary or "No objects detected" in summary
|
72 |
+
assert "tree" in summary or "No objects detected" in summary
|
73 |
+
assert annotated_image is not None
|
74 |
+
assert isinstance(annotated_image, np.ndarray)
|
75 |
+
assert annotated_image.shape[:2] == image_array.shape[:2]
|
76 |
+
|
77 |
+
assert isinstance(detections_list, list)
|
78 |
+
if detections_list:
|
79 |
+
tree_labels_found = any(
|
80 |
+
detection["label"] == "tree" for detection in detections_list if 'label' in detection
|
81 |
+
)
|
82 |
+
assert tree_labels_found
|
83 |
+
|
84 |
+
display_image_for_test(annotated_image, "Tree Detection Test")
|
85 |
+
|
86 |
+
|
87 |
+
def test_deepforest_predict_objects_multiple_models():
|
88 |
+
"""Test detection using multiple models simultaneously."""
|
89 |
+
image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
|
90 |
+
if image_array is None:
|
91 |
+
return
|
92 |
+
|
93 |
+
summary, annotated_image, detections_list = (
|
94 |
+
deepforest_predictor.predict_objects(
|
95 |
+
image_data_array=image_array,
|
96 |
+
model_names=["bird", "tree", "livestock"]
|
97 |
+
)
|
98 |
+
)
|
99 |
+
|
100 |
+
assert "DeepForest detected" in summary or "No objects detected" in summary
|
101 |
+
assert annotated_image is not None
|
102 |
+
assert isinstance(annotated_image, np.ndarray)
|
103 |
+
assert annotated_image.shape[:2] == image_array.shape[:2]
|
104 |
+
|
105 |
+
assert isinstance(detections_list, list)
|
106 |
+
if detections_list:
|
107 |
+
labels = {detection['label'] for detection in detections_list if 'label' in detection}
|
108 |
+
assert "bird" in labels or "tree" in labels or "livestock" in labels
|
109 |
+
|
110 |
+
display_image_for_test(annotated_image, "Multiple Models Test")
|
111 |
+
|
112 |
+
|
113 |
+
def test_deepforest_predict_objects_large_image_processing():
|
114 |
+
"""Test processing of large images using tiled prediction."""
|
115 |
+
summary, annotated_image, detections_list = (
|
116 |
+
deepforest_predictor.predict_objects(
|
117 |
+
image_file_path=TEST_IMAGE_PATH_LARGE,
|
118 |
+
model_names=["tree"],
|
119 |
+
patch_size=Config.DEEPFOREST_DEFAULTS["patch_size"],
|
120 |
+
patch_overlap=Config.DEEPFOREST_DEFAULTS["patch_overlap"],
|
121 |
+
iou_threshold=Config.DEEPFOREST_DEFAULTS["iou_threshold"],
|
122 |
+
thresh=Config.DEEPFOREST_DEFAULTS["thresh"]
|
123 |
+
)
|
124 |
+
)
|
125 |
+
|
126 |
+
assert "DeepForest detected" in summary or "No objects detected" in summary
|
127 |
+
assert annotated_image is not None
|
128 |
+
assert isinstance(annotated_image, np.ndarray)
|
129 |
+
|
130 |
+
assert isinstance(detections_list, list)
|
131 |
+
if detections_list:
|
132 |
+
assert any(detection['label'] == 'tree' for detection in detections_list if 'label' in detection)
|
133 |
+
|
134 |
+
display_image_for_test(annotated_image, "Large Image Processing Test")
|
135 |
+
|
136 |
+
|
137 |
+
def test_deepforest_predict_objects_custom_patch_size():
|
138 |
+
"""Test detection with custom patch size parameter."""
|
139 |
+
image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
|
140 |
+
if image_array is None:
|
141 |
+
return
|
142 |
+
|
143 |
+
summary, annotated_image, detections_list = (
|
144 |
+
deepforest_predictor.predict_objects(
|
145 |
+
image_data_array=image_array,
|
146 |
+
model_names=["tree"],
|
147 |
+
patch_size=800,
|
148 |
+
patch_overlap=Config.DEEPFOREST_DEFAULTS["patch_overlap"],
|
149 |
+
iou_threshold=Config.DEEPFOREST_DEFAULTS["iou_threshold"],
|
150 |
+
thresh=Config.DEEPFOREST_DEFAULTS["thresh"]
|
151 |
+
)
|
152 |
+
)
|
153 |
+
|
154 |
+
assert "DeepForest detected" in summary or "No objects detected" in summary
|
155 |
+
assert annotated_image is not None
|
156 |
+
assert isinstance(annotated_image, np.ndarray)
|
157 |
+
assert annotated_image.shape[:2] == image_array.shape[:2]
|
158 |
+
|
159 |
+
assert isinstance(detections_list, list)
|
160 |
+
if detections_list:
|
161 |
+
assert any(detection['label'] == 'tree' for detection in detections_list if 'label' in detection)
|
162 |
+
|
163 |
+
display_image_for_test(annotated_image, "Custom Patch Size Test")
|
164 |
+
|
165 |
+
|
166 |
+
def test_deepforest_predict_objects_multiple_custom_parameters():
|
167 |
+
"""Test detection with multiple custom parameters."""
|
168 |
+
image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
|
169 |
+
if image_array is None:
|
170 |
+
return
|
171 |
+
|
172 |
+
summary, annotated_image, detections_list = (
|
173 |
+
deepforest_predictor.predict_objects(
|
174 |
+
image_data_array=image_array,
|
175 |
+
model_names=["tree"],
|
176 |
+
patch_size=600,
|
177 |
+
patch_overlap=0.1,
|
178 |
+
iou_threshold=0.3,
|
179 |
+
thresh=0.3
|
180 |
+
)
|
181 |
+
)
|
182 |
+
|
183 |
+
assert "DeepForest detected" in summary or "No objects detected" in summary
|
184 |
+
assert annotated_image is not None
|
185 |
+
assert isinstance(annotated_image, np.ndarray)
|
186 |
+
assert annotated_image.shape[:2] == image_array.shape[:2]
|
187 |
+
|
188 |
+
assert isinstance(detections_list, list)
|
189 |
+
if detections_list:
|
190 |
+
assert any(detection['label'] == 'tree' for detection in detections_list if 'label' in detection)
|
191 |
+
|
192 |
+
display_image_for_test(annotated_image, "Multiple Custom Parameters Test")
|
193 |
+
|
194 |
+
|
195 |
+
def test_deepforest_predict_objects_alive_dead_trees():
|
196 |
+
"""Test alive/dead tree classification detection."""
|
197 |
+
summary, annotated_image, detections_list = (
|
198 |
+
deepforest_predictor.predict_objects(
|
199 |
+
image_file_path=TEST_IMAGE_PATH_LARGE,
|
200 |
+
model_names=["tree"],
|
201 |
+
alive_dead_trees=True
|
202 |
+
)
|
203 |
+
)
|
204 |
+
|
205 |
+
assert "DeepForest detected" in summary or "No objects detected" in summary
|
206 |
+
assert annotated_image is not None
|
207 |
+
assert isinstance(annotated_image, np.ndarray)
|
208 |
+
|
209 |
+
print(summary)
|
210 |
+
|
211 |
+
assert isinstance(detections_list, list)
|
212 |
+
if detections_list:
|
213 |
+
tree_detections = [d for d in detections_list if d.get('label') == 'tree']
|
214 |
+
assert len(tree_detections) > 0, "Expected at least one tree detection"
|
215 |
+
|
216 |
+
# Check for classification_label field in tree detections
|
217 |
+
classification_labels = {d.get('classification_label') for d in tree_detections
|
218 |
+
if 'classification_label' in d}
|
219 |
+
assert ('alive_tree' in classification_labels or 'dead_tree' in classification_labels), \
|
220 |
+
f"Expected alive_tree or dead_tree in classification labels, got: {classification_labels}"
|
221 |
+
|
222 |
+
# Check that summary mentions classification results
|
223 |
+
assert (("alive" in summary and "tree" in summary) or
|
224 |
+
("dead" in summary and "tree" in summary) or
|
225 |
+
("No objects detected" in summary)), \
|
226 |
+
f"Summary should mention alive/dead classification: {summary}"
|
227 |
+
|
228 |
+
display_image_for_test(annotated_image, "Alive/Dead Tree Detection Test")
|
229 |
+
|
230 |
+
|
231 |
+
def test_deepforest_predict_objects_no_detections():
|
232 |
+
"""Test the function gracefully handles cases with no detections."""
|
233 |
+
blank_image = np.zeros((100, 100, 3), dtype=np.uint8)
|
234 |
+
|
235 |
+
summary, annotated_image, detections_list = (
|
236 |
+
deepforest_predictor.predict_objects(
|
237 |
+
image_data_array=blank_image,
|
238 |
+
model_names=["tree"],
|
239 |
+
thresh=1.0
|
240 |
+
)
|
241 |
+
)
|
242 |
+
|
243 |
+
assert "No objects detected by DeepForest" in summary
|
244 |
+
assert annotated_image is not None
|
245 |
+
assert isinstance(annotated_image, np.ndarray)
|
246 |
+
assert annotated_image.shape[:2] == blank_image.shape[:2]
|
247 |
+
|
248 |
+
assert isinstance(detections_list, list)
|
249 |
+
assert len(detections_list) == 0
|
250 |
+
|
251 |
+
display_image_for_test(annotated_image, "No Detections Test")
|
252 |
+
|
253 |
+
|
254 |
+
def test_deepforest_predict_objects_custom_thresholds():
|
255 |
+
"""Test detection with custom threshold parameters."""
|
256 |
+
image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
|
257 |
+
if image_array is None:
|
258 |
+
return
|
259 |
+
|
260 |
+
summary, annotated_image, detections_list = (
|
261 |
+
deepforest_predictor.predict_objects(
|
262 |
+
image_data_array=image_array,
|
263 |
+
model_names=["tree"],
|
264 |
+
thresh=0.9,
|
265 |
+
iou_threshold=0.5
|
266 |
+
)
|
267 |
+
)
|
268 |
+
|
269 |
+
assert ("DeepForest detected" in summary or
|
270 |
+
"No objects detected" in summary)
|
271 |
+
assert annotated_image is not None
|
272 |
+
assert isinstance(annotated_image, np.ndarray)
|
273 |
+
assert annotated_image.shape[:2] == image_array.shape[:2]
|
274 |
+
|
275 |
+
assert isinstance(detections_list, list)
|
276 |
+
if detections_list:
|
277 |
+
assert any(detection['label'] == 'tree' for detection in detections_list if 'label' in detection)
|
278 |
+
|
279 |
+
display_image_for_test(annotated_image, "Custom Thresholds Test")
|
280 |
+
|
281 |
+
|
282 |
+
def test_deepforest_predict_objects_unsupported_model_name():
|
283 |
+
"""Test behavior with an unsupported model name."""
|
284 |
+
image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
|
285 |
+
if image_array is None:
|
286 |
+
return
|
287 |
+
|
288 |
+
summary, annotated_image, detections_list = (
|
289 |
+
deepforest_predictor.predict_objects(
|
290 |
+
image_data_array=image_array,
|
291 |
+
model_names=["tree", "nonexistent_model"]
|
292 |
+
)
|
293 |
+
)
|
294 |
+
|
295 |
+
assert ("DeepForest detected" in summary or
|
296 |
+
"No objects detected" in summary)
|
297 |
+
assert annotated_image is not None
|
298 |
+
assert isinstance(annotated_image, np.ndarray)
|
299 |
+
assert annotated_image.shape[:2] == image_array.shape[:2]
|
300 |
+
|
301 |
+
assert isinstance(detections_list, list)
|
302 |
+
if detections_list:
|
303 |
+
labels = {detection['label'] for detection in detections_list if 'label' in detection}
|
304 |
+
assert 'tree' in labels
|
305 |
+
assert 'nonexistent_model' not in labels
|
306 |
+
|
307 |
+
display_image_for_test(annotated_image, "Unsupported Model Test")
|
308 |
+
|
309 |
+
|
310 |
+
def test_plot_boxes_basic():
|
311 |
+
"""Test _plot_boxes with some sample bounding box data."""
|
312 |
+
img = np.zeros((100, 100, 3), dtype=np.uint8) + 255
|
313 |
+
predictions = pd.DataFrame([
|
314 |
+
{'xmin': 10, 'ymin': 10, 'xmax': 30, 'ymax': 30,
|
315 |
+
'label': 'bird', 'score': 0.9},
|
316 |
+
{'xmin': 50, 'ymin': 50, 'xmax': 70, 'ymax': 70,
|
317 |
+
'label': 'tree', 'score': 0.8}
|
318 |
+
])
|
319 |
+
|
320 |
+
annotated_img = DeepForestPredictor._plot_boxes(
|
321 |
+
img, predictions, Config.COLORS
|
322 |
+
)
|
323 |
+
assert annotated_img.shape == img.shape
|
324 |
+
assert not np.array_equal(annotated_img, img)
|
325 |
+
|
326 |
+
display_image_for_test(annotated_img, "Plot Boxes Basic Test")
|
327 |
+
|
328 |
+
|
329 |
+
def test_plot_boxes_empty_predictions():
|
330 |
+
"""Test _plot_boxes with empty predictions DataFrame."""
|
331 |
+
img = np.zeros((100, 100, 3), dtype=np.uint8) + 255
|
332 |
+
|
333 |
+
predictions = pd.DataFrame({
|
334 |
+
"xmin": pd.Series(dtype=float),
|
335 |
+
"ymin": pd.Series(dtype=float),
|
336 |
+
"xmax": pd.Series(dtype=float),
|
337 |
+
"ymax": pd.Series(dtype=float),
|
338 |
+
"label": pd.Series(dtype=str),
|
339 |
+
"score": pd.Series(dtype=float)
|
340 |
+
})
|
341 |
+
|
342 |
+
annotated_img = DeepForestPredictor._plot_boxes(
|
343 |
+
img, predictions, Config.COLORS
|
344 |
+
)
|
345 |
+
assert np.array_equal(annotated_img, img)
|
346 |
+
|
347 |
+
display_image_for_test(annotated_img, "Empty Predictions Test")
|
348 |
+
|
349 |
+
|
350 |
+
def test_deepforest_predict_objects_default_parameters():
|
351 |
+
"""Test that default parameters work correctly with tiled prediction."""
|
352 |
+
image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
|
353 |
+
if image_array is None:
|
354 |
+
return
|
355 |
+
|
356 |
+
summary, annotated_image, detections_list = (
|
357 |
+
deepforest_predictor.predict_objects(
|
358 |
+
image_data_array=image_array,
|
359 |
+
model_names=["tree"]
|
360 |
+
)
|
361 |
+
)
|
362 |
+
|
363 |
+
assert ("DeepForest detected" in summary or "No objects detected" in summary)
|
364 |
+
assert annotated_image is not None
|
365 |
+
assert isinstance(annotated_image, np.ndarray)
|
366 |
+
assert annotated_image.shape[:2] == image_array.shape[:2]
|
367 |
+
|
368 |
+
assert isinstance(detections_list, list)
|
369 |
+
|
370 |
+
print("Default parameters test completed successfully")
|
371 |
+
display_image_for_test(annotated_image, "Default Parameters Test")
|
372 |
+
|
373 |
+
|
374 |
+
def test_generate_detection_summary():
|
375 |
+
"""Test the _generate_detection_summary method directly."""
|
376 |
+
# Test with empty DataFrame
|
377 |
+
empty_df = pd.DataFrame()
|
378 |
+
summary = deepforest_predictor._generate_detection_summary(empty_df)
|
379 |
+
assert "No objects detected" in summary
|
380 |
+
|
381 |
+
# Test with basic detections
|
382 |
+
predictions_df = pd.DataFrame([
|
383 |
+
{'label': 'tree', 'score': 0.9},
|
384 |
+
{'label': 'tree', 'score': 0.8},
|
385 |
+
{'label': 'bird', 'score': 0.7}
|
386 |
+
])
|
387 |
+
summary = deepforest_predictor._generate_detection_summary(predictions_df)
|
388 |
+
assert "DeepForest detected" in summary
|
389 |
+
assert "2 trees" in summary
|
390 |
+
assert "1 bird" in summary
|
391 |
+
|
392 |
+
print("Detection summary tests completed successfully")
|
393 |
+
|
394 |
+
|
395 |
+
def test_detections_list_structure():
|
396 |
+
"""Test that detections_list has the correct structure."""
|
397 |
+
image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
|
398 |
+
if image_array is None:
|
399 |
+
return
|
400 |
+
|
401 |
+
summary, annotated_image, detections_list = (
|
402 |
+
deepforest_predictor.predict_objects(
|
403 |
+
image_data_array=image_array,
|
404 |
+
model_names=["tree"]
|
405 |
+
)
|
406 |
+
)
|
407 |
+
|
408 |
+
assert isinstance(detections_list, list)
|
409 |
+
|
410 |
+
if detections_list:
|
411 |
+
for detection in detections_list:
|
412 |
+
assert isinstance(detection, dict)
|
413 |
+
assert 'xmin' in detection
|
414 |
+
assert 'ymin' in detection
|
415 |
+
assert 'xmax' in detection
|
416 |
+
assert 'ymax' in detection
|
417 |
+
assert 'score' in detection
|
418 |
+
assert 'label' in detection
|
419 |
+
|
420 |
+
assert isinstance(detection['xmin'], int)
|
421 |
+
assert isinstance(detection['ymin'], int)
|
422 |
+
assert isinstance(detection['xmax'], int)
|
423 |
+
assert isinstance(detection['ymax'], int)
|
424 |
+
assert isinstance(detection['score'], float)
|
425 |
+
assert isinstance(detection['label'], str)
|
426 |
+
|
427 |
+
print("Detections list structure test completed successfully")
|
428 |
+
|
429 |
+
|
430 |
+
def test_error_handling_invalid_model():
|
431 |
+
"""Test error handling when all models are invalid."""
|
432 |
+
image_array = load_image_as_np_array(TEST_IMAGE_PATH_SMALL)
|
433 |
+
if image_array is None:
|
434 |
+
return
|
435 |
+
|
436 |
+
summary, annotated_image, detections_list = (
|
437 |
+
deepforest_predictor.predict_objects(
|
438 |
+
image_data_array=image_array,
|
439 |
+
model_names=["invalid_model_1", "invalid_model_2"]
|
440 |
+
)
|
441 |
+
)
|
442 |
+
|
443 |
+
assert "No objects detected" in summary
|
444 |
+
assert annotated_image is not None
|
445 |
+
assert isinstance(annotated_image, np.ndarray)
|
446 |
+
assert isinstance(detections_list, list)
|
447 |
+
assert len(detections_list) == 0
|
448 |
+
|
449 |
+
print("Error handling test completed successfully")
|
450 |
+
|
451 |
+
|
452 |
+
def test_input_validation():
|
453 |
+
"""Test input validation for the predict_objects method."""
|
454 |
+
# Test with neither image_data_array nor image_file_path provided
|
455 |
+
try:
|
456 |
+
deepforest_predictor.predict_objects(
|
457 |
+
image_data_array=None,
|
458 |
+
image_file_path=None,
|
459 |
+
model_names=["tree"]
|
460 |
+
)
|
461 |
+
assert False, "Should have raised ValueError"
|
462 |
+
except ValueError as e:
|
463 |
+
assert "Either image_data_array or image_file_path must be provided" in str(e)
|
464 |
+
|
465 |
+
print("Input validation test completed successfully")
|