import json
import os
from datetime import datetime
from typing import Dict, List, Optional
from smolagents import Tool
import plotly.graph_objects as go
import plotly.express as px
from jinja2 import Template
class ReportGeneratorTool(Tool):
"""Tool for generating interactive HTML vulnerability reports with charts."""
name = "generate_vulnerability_report"
description = "Generates an interactive HTML report with charts and vulnerability analysis. The report is generated from CVEDB search results."
inputs = {
"vulnerability_data": {
"type": "string",
"description": "Vulnerability data in JSON format",
},
"report_type": {
"type": "string",
"description": "Report type: 'cve' for a specific CVE or 'product' for a product",
}
}
output_type = "string"
def __init__(self):
super().__init__()
# Base HTML template
self.html_template = """
Vulnerability Report
Severity Distribution (CVSS)
Vulnerability Details
{{ vulnerabilities_table }}
"""
def forward(self, vulnerability_data: str, report_type: str) -> str:
"""Generates an HTML report with interactive charts from vulnerability data."""
try:
data = json.loads(vulnerability_data)
# Generate charts with Plotly
cvss_chart = self._generate_cvss_chart(data)
timeline_chart = self._generate_timeline_chart(data)
# Generate vulnerability table
vulnerabilities_table = self._generate_vulnerabilities_table(data)
# Generate summary
summary = self._generate_summary(data, report_type)
# Render template
template = Template(self.html_template)
html = template.render(
generation_date=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
summary=summary,
vulnerabilities_table=vulnerabilities_table,
plotly_js=f"""
var cvssData = {cvss_chart};
var timelineData = {timeline_chart};
Plotly.newPlot('cvss_chart', cvssData.data, cvssData.layout);
Plotly.newPlot('timeline_chart', timelineData.data, timelineData.layout);
"""
)
# Save the report to the reports folder
# NOTE: Only saves to folder when running locally
# If deployed on a Hugging Face Space, it doesn't save files
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"vulnerability_report_{report_type}_{timestamp}.html"
# Create the reports folder if it doesn't exist
reports_dir = "reports"
if not os.path.exists(reports_dir):
os.makedirs(reports_dir)
# Save the file
filepath = os.path.join(reports_dir, filename)
with open(filepath, 'w', encoding='utf-8') as f:
f.write(html)
return f"Report generated and saved as: {filepath}\n\n{html}"
except Exception as e:
return f"Error generating report: {str(e)}"
def _generate_cvss_chart(self, data: Dict) -> Dict:
"""Generates a CVSS score distribution chart."""
if isinstance(data, list):
cvss_scores = [v.get('cvss', 0) for v in data if 'cvss' in v]
else:
cvss_scores = [data.get('cvss', 0)] if 'cvss' in data else []
fig = go.Figure()
fig.add_trace(go.Histogram(
x=cvss_scores,
nbinsx=10,
name='CVSS Scores'
))
fig.update_layout(
title='CVSS Score Distribution',
xaxis_title='CVSS Score',
yaxis_title='Number of Vulnerabilities',
showlegend=False
)
return fig.to_json()
def _generate_timeline_chart(self, data: Dict) -> Dict:
"""Generates a vulnerability timeline chart."""
if isinstance(data, list):
dates = [v.get('published_time', '') for v in data if 'published_time' in v]
else:
dates = [data.get('published_time', '')] if 'published_time' in data else []
# Convert dates to datetime format and count by month
from collections import Counter
from datetime import datetime
date_counts = Counter()
for date_str in dates:
try:
date = datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S")
month_key = date.strftime("%Y-%m")
date_counts[month_key] += 1
except:
continue
months = sorted(date_counts.keys())
counts = [date_counts[m] for m in months]
fig = go.Figure()
fig.add_trace(go.Scatter(
x=months,
y=counts,
mode='lines+markers',
name='Vulnerabilities'
))
fig.update_layout(
title='Vulnerability Timeline Trend',
xaxis_title='Month',
yaxis_title='Number of Vulnerabilities',
showlegend=False
)
return fig.to_json()
def _generate_vulnerabilities_table(self, data: Dict) -> str:
"""Generates an HTML table with vulnerability details."""
if isinstance(data, list):
vulnerabilities = data
else:
vulnerabilities = [data]
if not vulnerabilities:
return "No vulnerability data available.
"
table_html = """
CVE ID |
CVSS Score |
EPSS Score |
Known Exploitable |
Publication Date |
Summary |
"""
for vuln in vulnerabilities:
cvss = vuln.get('cvss', 'Not available')
epss = vuln.get('epss', 'Not available')
kev = vuln.get('kev', False)
# Determine risk class based on CVSS score
risk_class = ""
if cvss != 'Not available' and isinstance(cvss, (int, float)):
if cvss >= 7.0:
risk_class = "critical"
elif cvss >= 4.0:
risk_class = "high"
else:
risk_class = "low"
table_html += f"""
{vuln.get('id', 'Not available')} |
{cvss} |
{epss} |
{'Yes' if kev else 'No'} |
{vuln.get('published_time', 'Not available')} |
{vuln.get('summary', 'Not available')[:100]}... |
"""
table_html += """
"""
return table_html
def _generate_summary(self, data: Dict, report_type: str) -> str:
"""Generates a summary of the vulnerability data."""
if isinstance(data, list):
total_vulns = len(data)
exploited = sum(1 for v in data if v.get('kev', False))
avg_cvss = sum(v.get('cvss', 0) for v in data if 'cvss' in v) / max(1, sum(1 for v in data if 'cvss' in v))
else:
total_vulns = 1
exploited = 1 if data.get('kev', False) else 0
avg_cvss = data.get('cvss', 0)
summary_html = f"""
Found {total_vulns} vulnerabilities.
Exploited vulnerabilities: {exploited}
Average CVSS Score: {avg_cvss:.2f}
Publication date: {data.get('published_time', 'N/A')}
"""
return summary_html