AshenH commited on
Commit
6860773
·
verified ·
1 Parent(s): e4818d5

Update tools/report_tool.py

Browse files
Files changed (1) hide show
  1. tools/report_tool.py +339 -22
tools/report_tool.py CHANGED
@@ -1,21 +1,282 @@
1
  # space/tools/report_tool.py
2
  import os
 
3
  from typing import Optional, Dict, Any
 
4
 
5
  import pandas as pd
6
- from jinja2 import Environment, FileSystemLoader
7
 
8
  from utils.tracing import Tracer
9
  from utils.config import AppConfig
10
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class ReportTool:
 
 
 
 
 
13
  def __init__(self, cfg: AppConfig, tracer: Tracer):
14
  self.cfg = cfg
15
  self.tracer = tracer
16
- templates_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "templates"))
17
- self.env = Environment(loader=FileSystemLoader(templates_dir), autoescape=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def render_and_save(
20
  self,
21
  user_query: str,
@@ -24,22 +285,78 @@ class ReportTool:
24
  explain_images: Dict[str, str],
25
  plan: Dict[str, Any],
26
  ) -> str:
27
- tmpl = self.env.get_template("report_template.md")
28
- html_body = tmpl.render(
29
- user_query=user_query,
30
- plan=plan,
31
- sql_preview=sql_preview.to_markdown(index=False) if isinstance(sql_preview, pd.DataFrame) else "",
32
- predict_preview=predict_preview.to_markdown(index=False) if isinstance(predict_preview, pd.DataFrame) else "",
33
- explain_images=explain_images or {},
34
- )
35
- out_name = f"report_{pd.Timestamp.utcnow().strftime('%Y%m%d_%H%M%S')}.html"
36
- out_path = os.path.abspath(os.path.join(os.getcwd(), out_name))
37
- css_link = "templates/report_styles.css"
38
- html = f'<link rel="stylesheet" href="{css_link}">\n' + html_body
39
- with open(out_path, "w", encoding="utf-8") as f:
40
- f.write(html)
41
- try:
42
- self.tracer.trace_event("report", {"path": out_name})
43
- except Exception:
44
- pass
45
- return out_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # space/tools/report_tool.py
2
  import os
3
+ import logging
4
  from typing import Optional, Dict, Any
5
+ from datetime import datetime
6
 
7
  import pandas as pd
8
+ from jinja2 import Environment, FileSystemLoader, TemplateNotFound
9
 
10
  from utils.tracing import Tracer
11
  from utils.config import AppConfig
12
 
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Constants
16
+ MAX_PREVIEW_ROWS = 100
17
+ MAX_REPORT_SIZE_MB = 50
18
+
19
+
20
+ class ReportToolError(Exception):
21
+ """Custom exception for report tool errors."""
22
+ pass
23
+
24
 
25
  class ReportTool:
26
+ """
27
+ Generates HTML reports from analysis results.
28
+ Includes error handling, size limits, and proper template management.
29
+ """
30
+
31
  def __init__(self, cfg: AppConfig, tracer: Tracer):
32
  self.cfg = cfg
33
  self.tracer = tracer
34
+
35
+ # Setup Jinja2 environment
36
+ try:
37
+ templates_dir = os.path.abspath(
38
+ os.path.join(os.path.dirname(__file__), "..", "templates")
39
+ )
40
+
41
+ if not os.path.exists(templates_dir):
42
+ logger.warning(f"Templates directory not found: {templates_dir}. Creating it.")
43
+ os.makedirs(templates_dir, exist_ok=True)
44
+
45
+ self.env = Environment(
46
+ loader=FileSystemLoader(templates_dir),
47
+ autoescape=False, # We control the content
48
+ trim_blocks=True,
49
+ lstrip_blocks=True
50
+ )
51
+
52
+ logger.info(f"Report tool initialized with templates from: {templates_dir}")
53
+
54
+ except Exception as e:
55
+ raise ReportToolError(f"Failed to initialize report tool: {e}") from e
56
+
57
+ def _validate_inputs(
58
+ self,
59
+ user_query: str,
60
+ sql_preview: Optional[pd.DataFrame],
61
+ predict_preview: Optional[pd.DataFrame],
62
+ explain_images: Dict[str, str],
63
+ plan: Dict[str, Any]
64
+ ) -> tuple[bool, str]:
65
+ """
66
+ Validate report generation inputs.
67
+ Returns (is_valid, error_message).
68
+ """
69
+ if not user_query or not user_query.strip():
70
+ return False, "User query is empty"
71
+
72
+ if not plan or not isinstance(plan, dict):
73
+ return False, "Plan is invalid"
74
+
75
+ # Check explain_images size
76
+ if explain_images:
77
+ total_size = sum(len(img) for img in explain_images.values())
78
+ size_mb = total_size / (1024 * 1024)
79
+ if size_mb > MAX_REPORT_SIZE_MB:
80
+ return False, f"Embedded images too large: {size_mb:.2f} MB (max {MAX_REPORT_SIZE_MB} MB)"
81
+
82
+ return True, ""
83
+
84
+ def _prepare_dataframe_preview(self, df: Optional[pd.DataFrame], max_rows: int = MAX_PREVIEW_ROWS) -> str:
85
+ """
86
+ Convert dataframe to markdown table with row limit.
87
+ Returns empty string if no data.
88
+ """
89
+ if df is None or df.empty:
90
+ return ""
91
+
92
+ try:
93
+ # Limit rows
94
+ if len(df) > max_rows:
95
+ preview_df = df.head(max_rows)
96
+ suffix = f"\n\n*... and {len(df) - max_rows} more rows*"
97
+ else:
98
+ preview_df = df
99
+ suffix = ""
100
+
101
+ # Convert to markdown
102
+ markdown = preview_df.to_markdown(index=False, tablefmt="github")
103
+ return markdown + suffix
104
+
105
+ except Exception as e:
106
+ logger.warning(f"Failed to convert dataframe to markdown: {e}")
107
+ return f"*Error displaying data: {str(e)}*"
108
+
109
+ def _get_template_name(self) -> str:
110
+ """
111
+ Determine which template to use.
112
+ Falls back to creating a default if none exists.
113
+ """
114
+ template_name = "report_template.md"
115
+
116
+ try:
117
+ # Check if template exists
118
+ self.env.get_template(template_name)
119
+ return template_name
120
+ except TemplateNotFound:
121
+ logger.warning(f"Template '{template_name}' not found. Creating default template.")
122
+ self._create_default_template()
123
+ return template_name
124
+
125
+ def _create_default_template(self):
126
+ """Create a default report template if none exists."""
127
+ default_template = """# Analysis Report
128
+
129
+ **Generated:** {{ timestamp }}
130
+
131
+ ## User Query
132
+ {{ user_query }}
133
+
134
+ ## Execution Plan
135
+ **Steps:** {{ plan.steps | join(', ') }}
136
 
137
+ **Rationale:** {{ plan.rationale }}
138
+
139
+ {% if sql_preview %}
140
+ ## Data Query Results
141
+ {{ sql_preview }}
142
+ {% endif %}
143
+
144
+ {% if predict_preview %}
145
+ ## Predictions
146
+ {{ predict_preview }}
147
+ {% endif %}
148
+
149
+ {% if explain_images %}
150
+ ## Model Explanations
151
+
152
+ {% if explain_images.global_bar %}
153
+ ### Feature Importance
154
+ ![Feature Importance]({{ explain_images.global_bar }})
155
+ {% endif %}
156
+
157
+ {% if explain_images.beeswarm %}
158
+ ### Feature Effects
159
+ ![Feature Effects]({{ explain_images.beeswarm }})
160
+ {% endif %}
161
+ {% endif %}
162
+
163
+ ---
164
+ *Report generated by Tabular Agentic XAI*
165
+ """
166
+
167
+ templates_dir = self.env.loader.searchpath[0]
168
+ template_path = os.path.join(templates_dir, "report_template.md")
169
+
170
+ try:
171
+ with open(template_path, 'w', encoding='utf-8') as f:
172
+ f.write(default_template)
173
+ logger.info(f"Created default template at: {template_path}")
174
+ except Exception as e:
175
+ logger.error(f"Failed to create default template: {e}")
176
+
177
+ def _render_template(
178
+ self,
179
+ user_query: str,
180
+ sql_preview_md: str,
181
+ predict_preview_md: str,
182
+ explain_images: Dict[str, str],
183
+ plan: Dict[str, Any]
184
+ ) -> str:
185
+ """
186
+ Render the report template with provided data.
187
+ """
188
+ try:
189
+ template_name = self._get_template_name()
190
+ template = self.env.get_template(template_name)
191
+
192
+ context = {
193
+ "timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"),
194
+ "user_query": user_query,
195
+ "plan": plan,
196
+ "sql_preview": sql_preview_md,
197
+ "predict_preview": predict_preview_md,
198
+ "explain_images": explain_images or {}
199
+ }
200
+
201
+ html_body = template.render(**context)
202
+ logger.info(f"Template rendered successfully: {len(html_body)} characters")
203
+
204
+ return html_body
205
+
206
+ except Exception as e:
207
+ raise ReportToolError(f"Template rendering failed: {e}") from e
208
+
209
+ def _save_report(self, html_content: str) -> str:
210
+ """
211
+ Save HTML report to file.
212
+ Returns the filename.
213
+ """
214
+ try:
215
+ # Generate unique filename
216
+ timestamp = datetime.utcnow().strftime('%Y%m%d_%H%M%S')
217
+ filename = f"report_{timestamp}.html"
218
+
219
+ # Determine output path
220
+ output_dir = os.getenv("REPORT_OUTPUT_DIR", os.getcwd())
221
+ os.makedirs(output_dir, exist_ok=True)
222
+
223
+ filepath = os.path.abspath(os.path.join(output_dir, filename))
224
+
225
+ # Add CSS styling
226
+ css_path = os.path.join(
227
+ os.path.dirname(__file__), "..", "templates", "report_styles.css"
228
+ )
229
+
230
+ if os.path.exists(css_path):
231
+ css_link = f'<link rel="stylesheet" href="{css_path}">'
232
+ else:
233
+ # Inline basic CSS if external file not found
234
+ css_link = """
235
+ <style>
236
+ body { font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }
237
+ h1 { color: #2c3e50; border-bottom: 3px solid #3498db; padding-bottom: 10px; }
238
+ h2 { color: #34495e; margin-top: 30px; }
239
+ table { border-collapse: collapse; width: 100%; margin: 20px 0; }
240
+ th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
241
+ th { background-color: #3498db; color: white; }
242
+ tr:nth-child(even) { background-color: #f2f2f2; }
243
+ img { max-width: 100%; height: auto; margin: 20px 0; }
244
+ code { background-color: #f4f4f4; padding: 2px 6px; border-radius: 3px; }
245
+ pre { background-color: #f4f4f4; padding: 15px; border-radius: 5px; overflow-x: auto; }
246
+ </style>
247
+ """
248
+
249
+ # Construct full HTML
250
+ full_html = f"""<!DOCTYPE html>
251
+ <html lang="en">
252
+ <head>
253
+ <meta charset="UTF-8">
254
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
255
+ <title>Analysis Report - {timestamp}</title>
256
+ {css_link}
257
+ </head>
258
+ <body>
259
+ {html_content}
260
+ </body>
261
+ </html>
262
+ """
263
+
264
+ # Write to file
265
+ with open(filepath, 'w', encoding='utf-8') as f:
266
+ f.write(full_html)
267
+
268
+ # Check file size
269
+ file_size_mb = os.path.getsize(filepath) / (1024 * 1024)
270
+ logger.info(f"Report saved: {filepath} ({file_size_mb:.2f} MB)")
271
+
272
+ if file_size_mb > MAX_REPORT_SIZE_MB:
273
+ logger.warning(f"Report file is large: {file_size_mb:.2f} MB")
274
+
275
+ return filename
276
+
277
+ except Exception as e:
278
+ raise ReportToolError(f"Failed to save report: {e}") from e
279
+
280
  def render_and_save(
281
  self,
282
  user_query: str,
 
285
  explain_images: Dict[str, str],
286
  plan: Dict[str, Any],
287
  ) -> str:
288
+ """
289
+ Render and save analysis report.
290
+
291
+ Args:
292
+ user_query: Original user query
293
+ sql_preview: SQL query results (optional)
294
+ predict_preview: Prediction results (optional)
295
+ explain_images: Dictionary of explanation plots (name -> data URI)
296
+ plan: Execution plan dictionary
297
+
298
+ Returns:
299
+ Filename of saved report
300
+
301
+ Raises:
302
+ ReportToolError: If report generation fails
303
+ """
304
+ try:
305
+ logger.info("Generating analysis report...")
306
+
307
+ # Validate inputs
308
+ is_valid, error_msg = self._validate_inputs(
309
+ user_query, sql_preview, predict_preview, explain_images, plan
310
+ )
311
+ if not is_valid:
312
+ raise ReportToolError(f"Invalid inputs: {error_msg}")
313
+
314
+ # Prepare dataframe previews
315
+ sql_preview_md = self._prepare_dataframe_preview(sql_preview)
316
+ predict_preview_md = self._prepare_dataframe_preview(predict_preview)
317
+
318
+ # Render template
319
+ html_content = self._render_template(
320
+ user_query=user_query,
321
+ sql_preview_md=sql_preview_md,
322
+ predict_preview_md=predict_preview_md,
323
+ explain_images=explain_images,
324
+ plan=plan
325
+ )
326
+
327
+ # Save report
328
+ filename = self._save_report(html_content)
329
+
330
+ # Trace event
331
+ if self.tracer:
332
+ self.tracer.trace_event("report", {
333
+ "filename": filename,
334
+ "has_sql": bool(sql_preview_md),
335
+ "has_predictions": bool(predict_preview_md),
336
+ "num_images": len(explain_images) if explain_images else 0
337
+ })
338
+
339
+ logger.info(f"Report generation successful: {filename}")
340
+ return filename
341
+
342
+ except ReportToolError:
343
+ raise
344
+ except Exception as e:
345
+ error_msg = f"Report generation failed: {str(e)}"
346
+ logger.error(error_msg)
347
+ if self.tracer:
348
+ self.tracer.trace_event("report_error", {"error": error_msg})
349
+ raise ReportToolError(error_msg) from e
350
+
351
+ def list_available_templates(self) -> list:
352
+ """List all available report templates."""
353
+ try:
354
+ templates_dir = self.env.loader.searchpath[0]
355
+ templates = [
356
+ f for f in os.listdir(templates_dir)
357
+ if f.endswith(('.md', '.html', '.jinja2'))
358
+ ]
359
+ return templates
360
+ except Exception as e:
361
+ logger.warning(f"Failed to list templates: {e}")
362
+ return []