Spaces:
Running
Running
| """ | |
| Datawrapper Chart Generation Client | |
| Integrates RAG pipeline with Datawrapper API for intelligent chart creation. | |
| """ | |
| import json | |
| import os | |
| from typing import Optional, Tuple | |
| import pandas as pd | |
| from .prompts import ( | |
| CHART_SELECTION_SYSTEM_PROMPT, | |
| get_chart_selection_prompt, | |
| get_chart_styling_prompt | |
| ) | |
| from .llm_client import create_llm_client | |
| from .rag_pipeline import GraphicsDesignPipeline | |
| # Import Datawrapper MCP handlers directly | |
| from datawrapper_mcp.handlers.create import create_chart as mcp_create_chart | |
| from datawrapper_mcp.handlers.publish import publish_chart as mcp_publish_chart | |
| from datawrapper_mcp.handlers.retrieve import get_chart_info as mcp_get_chart_info | |
| def get_data_summary(df: pd.DataFrame) -> str: | |
| """ | |
| Generate a summary of the DataFrame structure and content. | |
| Args: | |
| df: Input DataFrame | |
| Returns: | |
| String summary of data characteristics | |
| """ | |
| summary_parts = [] | |
| # Basic info | |
| summary_parts.append(f"Rows: {len(df)}, Columns: {len(df.columns)}") | |
| summary_parts.append(f"Column names: {', '.join(df.columns.tolist())}") | |
| # Column types | |
| numeric_cols = df.select_dtypes(include=['number']).columns.tolist() | |
| text_cols = df.select_dtypes(include=['object']).columns.tolist() | |
| date_cols = df.select_dtypes(include=['datetime']).columns.tolist() | |
| if numeric_cols: | |
| summary_parts.append(f"Numeric columns: {', '.join(numeric_cols)}") | |
| if text_cols: | |
| summary_parts.append(f"Text columns: {', '.join(text_cols)}") | |
| if date_cols: | |
| summary_parts.append(f"Date columns: {', '.join(date_cols)}") | |
| # Data preview (first 3 rows) | |
| summary_parts.append(f"\nData preview:\n{df.head(3).to_string()}") | |
| return "\n".join(summary_parts) | |
| def analyze_csv_for_chart_type( | |
| df: pd.DataFrame, | |
| user_prompt: str, | |
| rag_pipeline: GraphicsDesignPipeline | |
| ) -> Tuple[str, str]: | |
| """ | |
| Use RAG and LLM to determine the best chart type for the data. | |
| Args: | |
| df: Input DataFrame | |
| user_prompt: User's description of what they want to visualize | |
| rag_pipeline: RAG pipeline for retrieving best practices | |
| Returns: | |
| Tuple of (chart_type, reasoning) | |
| """ | |
| # Get data summary | |
| data_summary = get_data_summary(df) | |
| # Query RAG for chart selection best practices | |
| rag_query = f"chart type selection for {user_prompt}" | |
| relevant_docs = rag_pipeline.retrieve_documents(rag_query, k=3) | |
| rag_context = rag_pipeline.vectorstore.format_documents_for_context(relevant_docs) | |
| # Generate chart type recommendation using LLM | |
| chart_prompt = get_chart_selection_prompt() | |
| full_prompt = chart_prompt.format( | |
| user_prompt=user_prompt, | |
| data_summary=data_summary, | |
| rag_context=rag_context | |
| ) | |
| llm_client = create_llm_client( | |
| model=os.getenv("LLM_MODEL", "meta-llama/Llama-3.1-8B-Instruct"), | |
| temperature=0.3, # Lower temperature for more deterministic chart selection | |
| max_tokens=500 | |
| ) | |
| response = llm_client.generate( | |
| prompt=full_prompt, | |
| system_prompt=CHART_SELECTION_SYSTEM_PROMPT | |
| ) | |
| # Parse JSON response | |
| try: | |
| # Extract JSON from response (handle markdown code blocks) | |
| response_clean = response.strip() | |
| if "```json" in response_clean: | |
| response_clean = response_clean.split("```json")[1].split("```")[0].strip() | |
| elif "```" in response_clean: | |
| response_clean = response_clean.split("```")[1].split("```")[0].strip() | |
| result = json.loads(response_clean) | |
| chart_type = result.get("chart_type", "line") | |
| reasoning = result.get("reasoning", "") | |
| # Validate chart type | |
| valid_types = ["bar", "line", "area", "scatter", "column", "stacked_bar", "arrow", "multiple_column"] | |
| if chart_type not in valid_types: | |
| chart_type = "line" # Default fallback | |
| return chart_type, reasoning | |
| except Exception as e: | |
| print(f"Error parsing chart type response: {e}") | |
| print(f"Response was: {response}") | |
| # Default to line chart | |
| return "line", "Using default line chart due to parsing error" | |
| def generate_chart_config( | |
| chart_type: str, | |
| df: pd.DataFrame, | |
| user_prompt: str, | |
| rag_pipeline: GraphicsDesignPipeline | |
| ) -> dict: | |
| """ | |
| Generate Datawrapper chart configuration using RAG and LLM. | |
| Args: | |
| chart_type: Type of chart to create | |
| df: Input DataFrame | |
| user_prompt: User's visualization request | |
| rag_pipeline: RAG pipeline for retrieving design best practices | |
| Returns: | |
| Dictionary with chart configuration | |
| """ | |
| # Get data summary | |
| data_summary = get_data_summary(df) | |
| # Query RAG for styling and design best practices | |
| rag_query = f"chart design best practices colors accessibility {chart_type}" | |
| relevant_docs = rag_pipeline.retrieve_documents(rag_query, k=3) | |
| rag_context = rag_pipeline.vectorstore.format_documents_for_context(relevant_docs) | |
| # Generate chart configuration using LLM | |
| styling_prompt = get_chart_styling_prompt() | |
| full_prompt = styling_prompt.format( | |
| chart_type=chart_type, | |
| user_prompt=user_prompt, | |
| data_summary=data_summary, | |
| rag_context=rag_context | |
| ) | |
| llm_client = create_llm_client( | |
| model=os.getenv("LLM_MODEL", "meta-llama/Llama-3.1-8B-Instruct"), | |
| temperature=0.5, | |
| max_tokens=800 | |
| ) | |
| response = llm_client.generate( | |
| prompt=full_prompt, | |
| system_prompt="You are a data visualization expert. Generate valid JSON configuration for Datawrapper charts." | |
| ) | |
| # Parse JSON response | |
| try: | |
| # Extract JSON from response | |
| response_clean = response.strip() | |
| if "```json" in response_clean: | |
| response_clean = response_clean.split("```json")[1].split("```")[0].strip() | |
| elif "```" in response_clean: | |
| response_clean = response_clean.split("```")[1].split("```")[0].strip() | |
| config = json.loads(response_clean) | |
| # Ensure basic required fields | |
| if "title" not in config: | |
| config["title"] = user_prompt[:100] # Use prompt as fallback title | |
| return config | |
| except Exception as e: | |
| print(f"Error parsing chart config: {e}") | |
| print(f"Response was: {response}") | |
| # Return minimal config | |
| return { | |
| "title": user_prompt[:100] if user_prompt else "Data Visualization", | |
| "source_name": "User Data" | |
| } | |
| async def create_and_publish_chart( | |
| df: pd.DataFrame, | |
| user_prompt: str, | |
| rag_pipeline: GraphicsDesignPipeline, | |
| api_token: Optional[str] = None | |
| ) -> dict: | |
| """ | |
| Complete workflow: analyze data, select chart type, create and publish chart. | |
| Args: | |
| df: Input DataFrame | |
| user_prompt: User's visualization request | |
| rag_pipeline: RAG pipeline instance | |
| api_token: Datawrapper API token (defaults to env var) | |
| Returns: | |
| Dictionary with chart info including iframe URL | |
| """ | |
| if api_token is None: | |
| api_token = os.getenv("DATAWRAPPER_ACCESS_TOKEN") | |
| if not api_token: | |
| raise ValueError("DATAWRAPPER_ACCESS_TOKEN not found in environment") | |
| try: | |
| # Step 1: Analyze data and select chart type | |
| chart_type, reasoning = analyze_csv_for_chart_type(df, user_prompt, rag_pipeline) | |
| # Step 2: Generate chart configuration | |
| chart_config = generate_chart_config(chart_type, df, user_prompt, rag_pipeline) | |
| # Step 3: Convert DataFrame to list of dicts for Datawrapper | |
| data_list = df.to_dict('records') | |
| # Step 4: Create chart using MCP handler | |
| create_args = { | |
| "data": data_list, | |
| "chart_type": chart_type, | |
| "chart_config": chart_config | |
| } | |
| create_result = await mcp_create_chart(create_args) | |
| if not create_result or len(create_result) == 0: | |
| raise ValueError("Empty response from chart creation") | |
| result_text = create_result[0].text | |
| if not result_text or result_text.strip() == "": | |
| raise ValueError("Empty text in chart creation response") | |
| result_data = json.loads(result_text) | |
| chart_id = result_data.get("chart_id") | |
| if not chart_id: | |
| raise ValueError(f"Failed to get chart_id from creation response. Response was: {result_data}") | |
| # Step 5: Try to publish chart using MCP handler | |
| publish_success = False | |
| publish_message = "" | |
| try: | |
| publish_args = {"chart_id": chart_id} | |
| publish_result = await mcp_publish_chart(publish_args) | |
| publish_text = publish_result[0].text | |
| publish_data = json.loads(publish_text) | |
| publish_success = True | |
| publish_message = publish_data.get("message", "Published successfully") | |
| except Exception as publish_error: | |
| publish_message = f"Publish failed: {str(publish_error)}" | |
| # Step 6: Get full chart info using MCP handler | |
| chart_info_args = {"chart_id": chart_id} | |
| chart_info_result = await mcp_get_chart_info(chart_info_args) | |
| chart_info_text = chart_info_result[0].text | |
| chart_info = json.loads(chart_info_text) | |
| # Return complete info | |
| return { | |
| "success": True, | |
| "chart_id": chart_id, | |
| "chart_type": chart_type, | |
| "reasoning": reasoning, | |
| "public_url": chart_info.get("public_url"), | |
| "edit_url": chart_info.get("edit_url"), | |
| "published": publish_success, | |
| "publish_message": publish_message, | |
| "title": chart_config.get("title", "Chart") | |
| } | |
| except json.JSONDecodeError as e: | |
| error_msg = f"JSON parsing error: {str(e)}" | |
| print(f"Error in chart creation: {error_msg}") | |
| print(f"Failed to parse: {result_text if 'result_text' in locals() else 'N/A'}") | |
| return { | |
| "success": False, | |
| "error": error_msg, | |
| "chart_type": chart_type if 'chart_type' in locals() else None, | |
| "public_url": None | |
| } | |
| except Exception as e: | |
| error_msg = f"{type(e).__name__}: {str(e)}" | |
| print(f"Error in chart creation: {error_msg}") | |
| import traceback | |
| traceback.print_exc() | |
| return { | |
| "success": False, | |
| "error": error_msg, | |
| "chart_type": chart_type if 'chart_type' in locals() else None, | |
| "public_url": None | |
| } | |
| def get_iframe_html(chart_url: str, height: int = 600) -> str: | |
| """ | |
| Generate iframe HTML for embedding a Datawrapper chart. | |
| Args: | |
| chart_url: Public URL of the chart | |
| height: Height of iframe in pixels | |
| Returns: | |
| HTML string with iframe | |
| """ | |
| if not chart_url: | |
| return "<div style='padding: 50px; text-align: center;'>No chart available</div>" | |
| return f""" | |
| <div style="width: 100%; height: {height}px;"> | |
| <iframe | |
| src="{chart_url}" | |
| style="width: 100%; height: 100%; border: none;" | |
| frameborder="0" | |
| scrolling="no" | |
| aria-label="Chart"> | |
| </iframe> | |
| </div> | |
| """ | |