| import gradio as gr |
| import torch |
| from transformers import ( |
| BlipProcessor, BlipForConditionalGeneration, |
| TrOCRProcessor, VisionEncoderDecoderModel, |
| AutoProcessor, AutoModelForCausalLM |
| ) |
| from PIL import Image |
| import easyocr |
| import matplotlib.pyplot as plt |
| import pandas as pd |
| import numpy as np |
| import cv2 |
| import io |
| import base64 |
| import requests |
| import warnings |
| import json |
| from datetime import datetime |
| from typing import Dict, List, Any, Optional |
| import re |
|
|
| |
| warnings.filterwarnings("ignore") |
|
|
| class StructuredChartAnalyzer: |
| def __init__(self): |
| """Initialize the enhanced chart analyzer with structured output capabilities""" |
| self.load_models() |
| self.prompt_templates = self._init_prompt_templates() |
| |
| def _init_prompt_templates(self) -> Dict[str, str]: |
| """Initialize predefined prompt templates for different analysis types""" |
| return { |
| "comprehensive": "Analyze this chart comprehensively. Identify the chart type, extract all visible text including titles, labels, legends, and data values. Describe the data trends, patterns, and key insights.", |
| |
| "data_extraction": "Focus on extracting numerical data from this chart. Identify all data points, values, categories, and measurements. Pay special attention to axis labels, data series, and quantitative information.", |
| |
| "visual_elements": "Describe the visual elements of this chart including colors, chart type, layout, axes, legends, and overall design. Focus on the structural components.", |
| |
| "trend_analysis": "Analyze the trends and patterns shown in this chart. Identify increasing/decreasing trends, correlations, outliers, and significant data patterns. Provide insights about what the data reveals.", |
| |
| "accessibility": "Describe this chart in a way that would be helpful for visually impaired users. Include all textual content, data relationships, and key findings in a clear, structured manner.", |
| |
| "business_insights": "Analyze this chart from a business perspective. What are the key performance indicators, trends, and actionable insights that can be derived from this data?" |
| } |
| |
| def load_models(self): |
| """Load all required models with better error handling""" |
| self.models_loaded = { |
| 'blip': False, |
| 'trocr': False, |
| 'easyocr': False, |
| 'florence': False |
| } |
| |
| try: |
| print("Loading BLIP model...") |
| self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
| self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
| self.models_loaded['blip'] = True |
| |
| print("Loading TrOCR model...") |
| self.trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") |
| self.trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed") |
| self.models_loaded['trocr'] = True |
| |
| print("Loading EasyOCR...") |
| self.ocr_reader = easyocr.Reader(['en'], gpu=False) |
| self.models_loaded['easyocr'] = True |
| |
| |
| try: |
| print("Attempting to load Florence-2...") |
| self.florence_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) |
| self.florence_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) |
| self.models_loaded['florence'] = True |
| print("Florence-2 loaded successfully!") |
| except Exception as e: |
| print(f"Florence-2 not available: {e}") |
| self.models_loaded['florence'] = False |
| |
| print("Model loading completed!") |
| |
| except Exception as e: |
| print(f"Error loading models: {e}") |
| raise e |
| |
| def analyze_chart_with_prompt(self, image, custom_prompt: str = None, analysis_type: str = "comprehensive") -> Dict[str, Any]: |
| """ |
| Main function to analyze charts with structured JSON output |
| |
| Args: |
| image: PIL Image or numpy array |
| custom_prompt: Custom analysis prompt |
| analysis_type: Type of analysis to perform |
| |
| Returns: |
| Structured dictionary with analysis results |
| """ |
| |
| structured_output = { |
| "metadata": { |
| "timestamp": datetime.now().isoformat(), |
| "analysis_type": analysis_type, |
| "models_used": [model for model, loaded in self.models_loaded.items() if loaded], |
| "prompt_used": custom_prompt or self.prompt_templates.get(analysis_type, self.prompt_templates["comprehensive"]) |
| }, |
| "image_info": {}, |
| "text_extraction": {}, |
| "chart_analysis": {}, |
| "data_insights": {}, |
| "quality_metrics": {}, |
| "errors": [] |
| } |
| |
| if image is None: |
| structured_output["errors"].append("No image provided") |
| return structured_output |
| |
| try: |
| |
| if not isinstance(image, Image.Image): |
| image = Image.fromarray(image).convert('RGB') |
| |
| |
| structured_output["image_info"] = self._extract_image_info(image) |
| |
| |
| structured_output["text_extraction"] = self._extract_text_comprehensive(image) |
| |
| |
| structured_output["chart_analysis"] = self._analyze_chart_structure(image, structured_output["text_extraction"]) |
| |
| |
| structured_output["data_insights"] = self._extract_data_insights(image, structured_output) |
| |
| |
| structured_output["quality_metrics"] = self._assess_quality(image, structured_output) |
| |
| |
| if self.models_loaded['florence'] and analysis_type in ["comprehensive", "advanced"]: |
| structured_output["advanced_analysis"] = self._florence_advanced_analysis(image, custom_prompt) |
| |
| return structured_output |
| |
| except Exception as e: |
| structured_output["errors"].append(f"Analysis error: {str(e)}") |
| return structured_output |
| |
| def _extract_image_info(self, image: Image.Image) -> Dict[str, Any]: |
| """Extract basic image information""" |
| try: |
| return { |
| "dimensions": { |
| "width": image.size[0], |
| "height": image.size[1] |
| }, |
| "format": image.format or "Unknown", |
| "mode": image.mode, |
| "has_transparency": image.mode in ("RGBA", "LA"), |
| "aspect_ratio": round(image.size[0] / image.size[1], 2) |
| } |
| except Exception as e: |
| return {"error": str(e)} |
| |
| def _extract_text_comprehensive(self, image: Image.Image) -> Dict[str, Any]: |
| """Comprehensive text extraction with multiple methods""" |
| text_results = { |
| "methods_used": [], |
| "extracted_texts": {}, |
| "confidence_scores": {}, |
| "combined_text": "", |
| "detected_numbers": [], |
| "detected_labels": [] |
| } |
| |
| |
| if self.models_loaded['trocr']: |
| try: |
| pixel_values = self.trocr_processor(image, return_tensors="pt").pixel_values |
| generated_ids = self.trocr_model.generate(pixel_values, max_length=200) |
| trocr_text = self.trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| text_results["extracted_texts"]["trocr"] = trocr_text |
| text_results["methods_used"].append("TrOCR") |
| except Exception as e: |
| text_results["extracted_texts"]["trocr"] = f"Error: {str(e)}" |
| |
| |
| if self.models_loaded['easyocr']: |
| try: |
| image_np = np.array(image) |
| ocr_results = self.ocr_reader.readtext(image_np) |
| |
| easyocr_data = [] |
| for bbox, text, confidence in ocr_results: |
| easyocr_data.append({ |
| "text": text, |
| "confidence": float(confidence), |
| "bbox": bbox |
| }) |
| |
| easyocr_text = ' '.join([result["text"] for result in easyocr_data]) |
| text_results["extracted_texts"]["easyocr"] = easyocr_text |
| text_results["confidence_scores"]["easyocr"] = easyocr_data |
| text_results["methods_used"].append("EasyOCR") |
| except Exception as e: |
| text_results["extracted_texts"]["easyocr"] = f"Error: {str(e)}" |
| |
| |
| all_texts = [text for text in text_results["extracted_texts"].values() if not text.startswith("Error:")] |
| text_results["combined_text"] = " ".join(all_texts) |
| |
| |
| text_results["detected_numbers"] = self._extract_numbers(text_results["combined_text"]) |
| text_results["detected_labels"] = self._extract_potential_labels(text_results["combined_text"]) |
| |
| return text_results |
| |
| def _extract_numbers(self, text: str) -> List[Dict[str, Any]]: |
| """Extract numbers from text with context""" |
| number_patterns = [ |
| r'\d+\.?\d*%', |
| r'\$\d+\.?\d*', |
| r'\d{1,3}(?:,\d{3})*\.?\d*', |
| r'\d+\.?\d*' |
| ] |
| |
| numbers = [] |
| for pattern in number_patterns: |
| matches = re.finditer(pattern, text) |
| for match in matches: |
| numbers.append({ |
| "value": match.group(), |
| "position": match.span(), |
| "type": "percentage" if "%" in match.group() else |
| "currency" if "$" in match.group() else "number" |
| }) |
| |
| return numbers |
| |
| def _extract_potential_labels(self, text: str) -> List[str]: |
| """Extract potential chart labels and categories""" |
| |
| words = text.split() |
| potential_labels = [] |
| |
| for word in words: |
| |
| if re.match(r'^\d+\.?\d*$', word): |
| continue |
| |
| if len(word) < 2: |
| continue |
| |
| if word.istitle() or word.isupper(): |
| potential_labels.append(word) |
| |
| return list(set(potential_labels)) |
| |
| def _analyze_chart_structure(self, image: Image.Image, text_data: Dict) -> Dict[str, Any]: |
| """Analyze chart structure and type""" |
| analysis = { |
| "chart_type": "unknown", |
| "confidence": 0.0, |
| "visual_elements": {}, |
| "layout_analysis": {} |
| } |
| |
| |
| if self.models_loaded['blip']: |
| try: |
| inputs = self.blip_processor(image, return_tensors="pt") |
| out = self.blip_model.generate(**inputs, max_length=150) |
| description = self.blip_processor.decode(out[0], skip_special_tokens=True) |
| analysis["description"] = description |
| |
| |
| analysis["chart_type"] = self._detect_chart_type_advanced(description, text_data["combined_text"]) |
| |
| except Exception as e: |
| analysis["description"] = f"Error: {str(e)}" |
| |
| |
| try: |
| analysis["visual_elements"] = self._analyze_visual_elements(image) |
| analysis["layout_analysis"] = self._analyze_layout(image) |
| except Exception as e: |
| analysis["visual_elements"] = {"error": str(e)} |
| |
| return analysis |
| |
| def _detect_chart_type_advanced(self, description: str, text: str) -> str: |
| """Advanced chart type detection with confidence scoring""" |
| combined_text = (description + " " + text).lower() |
| |
| chart_indicators = { |
| 'bar_chart': ['bar', 'column', 'histogram', 'vertical bars', 'horizontal bars'], |
| 'line_chart': ['line', 'trend', 'time series', 'curve', 'linear'], |
| 'pie_chart': ['pie', 'circular', 'slice', 'wedge', 'donut'], |
| 'scatter_plot': ['scatter', 'correlation', 'points', 'dots', 'plot'], |
| 'area_chart': ['area', 'filled', 'stacked area'], |
| 'box_plot': ['box', 'whisker', 'quartile', 'median'], |
| 'heatmap': ['heat', 'color coded', 'matrix', 'intensity'], |
| 'gauge': ['gauge', 'dial', 'speedometer', 'meter'], |
| 'funnel': ['funnel', 'conversion', 'stages'], |
| 'radar': ['radar', 'spider', 'web chart'] |
| } |
| |
| scores = {} |
| for chart_type, keywords in chart_indicators.items(): |
| score = sum(1 for keyword in keywords if keyword in combined_text) |
| if score > 0: |
| scores[chart_type] = score |
| |
| if scores: |
| best_match = max(scores.items(), key=lambda x: x[1]) |
| return best_match[0].replace('_', ' ').title() |
| |
| return "Unknown Chart Type" |
| |
| def _analyze_visual_elements(self, image: Image.Image) -> Dict[str, Any]: |
| """Analyze visual elements of the chart""" |
| try: |
| image_np = np.array(image) |
| |
| |
| colors = image_np.reshape(-1, 3) |
| unique_colors = np.unique(colors, axis=0) |
| dominant_colors = self._get_dominant_colors(colors) |
| |
| |
| gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) |
| edges = cv2.Canny(gray, 50, 150) |
| |
| return { |
| "color_count": len(unique_colors), |
| "dominant_colors": dominant_colors, |
| "edge_density": np.sum(edges > 0) / edges.size, |
| "brightness": float(np.mean(gray) / 255), |
| "contrast": float(np.std(gray) / 255) |
| } |
| except Exception as e: |
| return {"error": str(e)} |
| |
| def _get_dominant_colors(self, colors: np.ndarray, n_colors: int = 5) -> List[List[int]]: |
| """Get dominant colors from image""" |
| try: |
| from sklearn.cluster import KMeans |
| kmeans = KMeans(n_clusters=min(n_colors, len(np.unique(colors, axis=0))), random_state=42) |
| kmeans.fit(colors) |
| return [color.astype(int).tolist() for color in kmeans.cluster_centers_] |
| except: |
| |
| unique_colors = np.unique(colors, axis=0) |
| return unique_colors[:n_colors].tolist() |
| |
| def _analyze_layout(self, image: Image.Image) -> Dict[str, Any]: |
| """Analyze chart layout and structure""" |
| try: |
| image_np = np.array(image.convert('L')) |
| |
| |
| h_lines = self._detect_horizontal_lines(image_np) |
| v_lines = self._detect_vertical_lines(image_np) |
| |
| return { |
| "horizontal_lines": len(h_lines), |
| "vertical_lines": len(v_lines), |
| "has_grid": len(h_lines) > 2 and len(v_lines) > 2, |
| "image_regions": self._identify_regions(image_np) |
| } |
| except Exception as e: |
| return {"error": str(e)} |
| |
| def _detect_horizontal_lines(self, gray_image: np.ndarray) -> List: |
| """Detect horizontal lines in image""" |
| horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 1)) |
| detected_lines = cv2.morphologyEx(gray_image, cv2.MORPH_OPEN, horizontal_kernel, iterations=2) |
| cnts = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| return cnts[0] if len(cnts) == 2 else cnts[1] |
| |
| def _detect_vertical_lines(self, gray_image: np.ndarray) -> List: |
| """Detect vertical lines in image""" |
| vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 25)) |
| detected_lines = cv2.morphologyEx(gray_image, cv2.MORPH_OPEN, vertical_kernel, iterations=2) |
| cnts = cv2.findContours(detected_lines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| return cnts[0] if len(cnts) == 2 else cnts[1] |
| |
| def _identify_regions(self, image: np.ndarray) -> Dict[str, Any]: |
| """Identify different regions of the chart""" |
| h, w = image.shape |
| return { |
| "title_region": {"y": 0, "height": h // 10}, |
| "chart_area": {"y": h // 10, "height": int(h * 0.7)}, |
| "legend_area": {"y": int(h * 0.8), "height": h // 5}, |
| "total_dimensions": {"width": w, "height": h} |
| } |
| |
| def _extract_data_insights(self, image: Image.Image, analysis_data: Dict) -> Dict[str, Any]: |
| """Extract data insights and patterns""" |
| insights = { |
| "numerical_data": [], |
| "categories": [], |
| "trends": [], |
| "outliers": [], |
| "summary_statistics": {} |
| } |
| |
| try: |
| |
| numbers = analysis_data["text_extraction"]["detected_numbers"] |
| numerical_values = [] |
| |
| for num_data in numbers: |
| if num_data["type"] == "number": |
| try: |
| |
| clean_num = re.sub(r'[,\s]', '', num_data["value"]) |
| value = float(clean_num) |
| numerical_values.append(value) |
| except: |
| continue |
| |
| if numerical_values: |
| insights["numerical_data"] = numerical_values |
| insights["summary_statistics"] = { |
| "count": len(numerical_values), |
| "min": min(numerical_values), |
| "max": max(numerical_values), |
| "mean": np.mean(numerical_values), |
| "median": np.median(numerical_values), |
| "std": np.std(numerical_values) if len(numerical_values) > 1 else 0 |
| } |
| |
| |
| insights["categories"] = analysis_data["text_extraction"]["detected_labels"] |
| |
| return insights |
| |
| except Exception as e: |
| insights["error"] = str(e) |
| return insights |
| |
| def _assess_quality(self, image: Image.Image, analysis_data: Dict) -> Dict[str, Any]: |
| """Assess the quality and readability of the chart""" |
| quality = { |
| "overall_score": 0.0, |
| "readability": {}, |
| "completeness": {}, |
| "technical_quality": {} |
| } |
| |
| try: |
| |
| text_methods = len(analysis_data["text_extraction"]["methods_used"]) |
| extracted_text_length = len(analysis_data["text_extraction"]["combined_text"]) |
| |
| quality["readability"] = { |
| "text_extraction_methods": text_methods, |
| "text_length": extracted_text_length, |
| "numbers_detected": len(analysis_data["text_extraction"]["detected_numbers"]), |
| "labels_detected": len(analysis_data["text_extraction"]["detected_labels"]) |
| } |
| |
| |
| has_title = "title" in analysis_data["text_extraction"]["combined_text"].lower() |
| has_numbers = len(analysis_data["text_extraction"]["detected_numbers"]) > 0 |
| has_labels = len(analysis_data["text_extraction"]["detected_labels"]) > 0 |
| |
| quality["completeness"] = { |
| "has_title": has_title, |
| "has_numerical_data": has_numbers, |
| "has_labels": has_labels, |
| "chart_type_identified": analysis_data["chart_analysis"]["chart_type"] != "Unknown Chart Type" |
| } |
| |
| |
| visual_elements = analysis_data["chart_analysis"].get("visual_elements", {}) |
| if not visual_elements.get("error"): |
| quality["technical_quality"] = { |
| "image_brightness": visual_elements.get("brightness", 0), |
| "image_contrast": visual_elements.get("contrast", 0), |
| "color_diversity": visual_elements.get("color_count", 0), |
| "edge_clarity": visual_elements.get("edge_density", 0) |
| } |
| |
| |
| completeness_score = sum(quality["completeness"].values()) / len(quality["completeness"]) |
| readability_score = min(1.0, (extracted_text_length / 100) * 0.5 + (text_methods / 2) * 0.5) |
| |
| quality["overall_score"] = (completeness_score * 0.6 + readability_score * 0.4) |
| |
| except Exception as e: |
| quality["error"] = str(e) |
| |
| return quality |
| |
| def _florence_advanced_analysis(self, image: Image.Image, custom_prompt: str = None) -> Dict[str, Any]: |
| """Advanced analysis using Florence-2 with custom prompts""" |
| if not self.models_loaded['florence']: |
| return {"error": "Florence-2 model not available"} |
| |
| florence_results = {} |
| |
| |
| florence_tasks = { |
| "object_detection": "<OD>", |
| "dense_caption": "<DENSE_REGION_CAPTION>", |
| "ocr_with_regions": "<OCR_WITH_REGION>", |
| "detailed_caption": "<MORE_DETAILED_CAPTION>" |
| } |
| |
| |
| if custom_prompt: |
| florence_tasks["custom_analysis"] = f"<CAPTION>{custom_prompt}" |
| |
| try: |
| for task_name, prompt in florence_tasks.items(): |
| try: |
| inputs = self.florence_processor(text=prompt, images=image, return_tensors="pt") |
| generated_ids = self.florence_model.generate( |
| input_ids=inputs["input_ids"], |
| pixel_values=inputs["pixel_values"], |
| max_new_tokens=1024, |
| num_beams=3, |
| do_sample=False |
| ) |
| generated_text = self.florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
| florence_results[task_name] = self._parse_florence_output(generated_text, prompt) |
| except Exception as e: |
| florence_results[task_name] = {"error": str(e)} |
| |
| return florence_results |
| |
| except Exception as e: |
| return {"error": f"Florence-2 analysis failed: {str(e)}"} |
| |
| def _parse_florence_output(self, output: str, prompt: str) -> Dict[str, Any]: |
| """Parse Florence-2 output into structured format""" |
| try: |
| |
| if prompt in output: |
| parsed_output = output.replace(prompt, "").strip() |
| else: |
| parsed_output = output.strip() |
| |
| |
| if parsed_output.startswith('{') and parsed_output.endswith('}'): |
| try: |
| return json.loads(parsed_output) |
| except: |
| pass |
| |
| return {"raw_output": parsed_output} |
| |
| except Exception as e: |
| return {"error": str(e), "raw_output": output} |
| |
| def format_results_for_display(self, structured_output: Dict[str, Any]) -> str: |
| """Format structured results for human-readable display""" |
| formatted = "# π Enhanced Chart Analysis Results\n\n" |
| |
| |
| metadata = structured_output.get("metadata", {}) |
| formatted += f"**Analysis Type:** {metadata.get('analysis_type', 'Unknown')}\n" |
| formatted += f"**Timestamp:** {metadata.get('timestamp', 'Unknown')}\n" |
| formatted += f"**Models Used:** {', '.join(metadata.get('models_used', []))}\n\n" |
| |
| |
| image_info = structured_output.get("image_info", {}) |
| if not image_info.get("error"): |
| dims = image_info.get("dimensions", {}) |
| formatted += f"## πΌοΈ Image Information\n" |
| formatted += f"**Dimensions:** {dims.get('width', 'Unknown')} x {dims.get('height', 'Unknown')}\n" |
| formatted += f"**Format:** {image_info.get('format', 'Unknown')}\n" |
| formatted += f"**Aspect Ratio:** {image_info.get('aspect_ratio', 'Unknown')}\n\n" |
| |
| |
| chart_analysis = structured_output.get("chart_analysis", {}) |
| formatted += f"## π Chart Analysis\n" |
| formatted += f"**Chart Type:** {chart_analysis.get('chart_type', 'Unknown')}\n" |
| if chart_analysis.get("description"): |
| formatted += f"**Description:** {chart_analysis['description']}\n\n" |
| |
| |
| text_extraction = structured_output.get("text_extraction", {}) |
| if text_extraction.get("combined_text"): |
| formatted += f"## π Extracted Text\n" |
| formatted += f"**Methods Used:** {', '.join(text_extraction.get('methods_used', []))}\n" |
| formatted += f"**Combined Text:** {text_extraction['combined_text']}\n" |
| |
| if text_extraction.get("detected_numbers"): |
| formatted += f"**Numbers Found:** {len(text_extraction['detected_numbers'])}\n" |
| |
| if text_extraction.get("detected_labels"): |
| formatted += f"**Labels Found:** {', '.join(text_extraction['detected_labels'])}\n\n" |
| |
| |
| data_insights = structured_output.get("data_insights", {}) |
| if data_insights.get("summary_statistics"): |
| stats = data_insights["summary_statistics"] |
| formatted += f"## π Data Insights\n" |
| formatted += f"**Data Points:** {stats.get('count', 0)}\n" |
| formatted += f"**Range:** {stats.get('min', 'N/A')} - {stats.get('max', 'N/A')}\n" |
| formatted += f"**Average:** {stats.get('mean', 'N/A'):.2f}\n" |
| formatted += f"**Median:** {stats.get('median', 'N/A'):.2f}\n\n" |
| |
| |
| quality = structured_output.get("quality_metrics", {}) |
| if quality.get("overall_score") is not None: |
| formatted += f"## β Quality Assessment\n" |
| formatted += f"**Overall Score:** {quality['overall_score']:.2f}/1.0\n" |
| |
| completeness = quality.get("completeness", {}) |
| if completeness: |
| formatted += f"**Has Title:** {'Yes' if completeness.get('has_title') else 'No'}\n" |
| formatted += f"**Has Data:** {'Yes' if completeness.get('has_numerical_data') else 'No'}\n" |
| formatted += f"**Chart Type Identified:** {'Yes' if completeness.get('chart_type_identified') else 'No'}\n\n" |
| |
| |
| errors = structured_output.get("errors", []) |
| if errors: |
| formatted += f"## β οΈ Errors\n" |
| for error in errors: |
| formatted += f"- {error}\n" |
| formatted += "\n" |
| |
| return formatted |
|
|
| |
| analyzer = StructuredChartAnalyzer() |
|
|
| def analyze_with_structured_output(image, analysis_type, custom_prompt, include_florence): |
| """Wrapper function for Gradio interface""" |
| if custom_prompt.strip(): |
| prompt_to_use = custom_prompt |
| else: |
| prompt_to_use = None |
| |
| |
| structured_result = analyzer.analyze_chart_with_prompt( |
| image, |
| custom_prompt=prompt_to_use, |
| analysis_type=analysis_type |
| ) |
| |
| |
| formatted_display = analyzer.format_results_for_display(structured_result) |
| |
| |
| csv_data = None |
| data_insights = structured_result.get("data_insights", {}) |
| if data_insights.get("numerical_data"): |
| df = pd.DataFrame({ |
| 'Values': data_insights["numerical_data"], |
| 'Categories': data_insights.get("categories", [""] * len(data_insights["numerical_data"]))[:len(data_insights["numerical_data"])] |
| }) |
| csv_buffer = io.StringIO() |
| df.to_csv(csv_buffer, index=False) |
| csv_data = csv_buffer.getvalue() |
| |
| return formatted_display, structured_result, csv_data |
|
|
| |
| with gr.Blocks(title="Enhanced Chart Analyzer with Structured Output", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# π Enhanced Chart Analyzer with Structured JSON Output") |
| gr.Markdown("Upload a chart image and get comprehensive analysis with structured data output. Supports custom prompts and multiple AI models.") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("## π Analysis Configuration") |
| |
| image_input = gr.Image( |
| type="pil", |
| label="Upload Chart Image", |
| height=300 |
| ) |
| |
| analysis_type = gr.Dropdown( |
| choices=list(analyzer.prompt_templates.keys()), |
| value="comprehensive", |
| label="Analysis Type", |
| info="Choose predefined analysis type or use custom prompt" |
| ) |
| |
| custom_prompt = gr.Textbox( |
| label="Custom Analysis Prompt", |
| placeholder="Enter your custom analysis instructions here...", |
| lines=3, |
| info="Optional: Override the selected analysis type with a custom prompt" |
| ) |
| |
| with gr.Accordion("Prompt Templates", open=False): |
| template_display = gr.Markdown() |
| |
| def update_template_display(analysis_type): |
| return f"**{analysis_type.title()} Template:**\n\n{analyzer.prompt_templates.get(analysis_type, 'No template available')}" |
| |
| analysis_type.change(update_template_display, inputs=[analysis_type], outputs=[template_display]) |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| include_florence = gr.Checkbox( |
| label="Use Florence-2 Advanced Analysis", |
| value=True, |
| info="Include advanced computer vision analysis (if model available)" |
| ) |
| |
| confidence_threshold = gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.5, |
| label="OCR Confidence Threshold" |
| ) |
| |
| analyze_btn = gr.Button("π Analyze Chart", variant="primary", size="lg") |
| clear_btn = gr.Button("ποΈ Clear All", variant="secondary") |
| |
| with gr.Column(scale=2): |
| gr.Markdown("## π Analysis Results") |
| |
| with gr.Tabs(): |
| with gr.Tab("π Formatted Results"): |
| formatted_output = gr.Markdown( |
| value="Upload an image and click 'Analyze Chart' to see results here.", |
| label="Analysis Results" |
| ) |
| |
| with gr.Tab("π§ Structured JSON"): |
| json_output = gr.JSON( |
| label="Complete Structured Output", |
| show_label=True |
| ) |
| |
| with gr.Tab("π Data Export"): |
| gr.Markdown("### Export Options") |
| |
| with gr.Row(): |
| json_download = gr.File( |
| label="Download JSON Results", |
| visible=False |
| ) |
| csv_download = gr.File( |
| label="Download CSV Data", |
| visible=False |
| ) |
| |
| export_btn = gr.Button("π₯ Generate Export Files") |
| export_status = gr.Textbox(label="Export Status", interactive=False) |
| |
| |
| gr.Markdown("## π― Example Prompts") |
| |
| example_prompts = [ |
| ["What are the main trends shown in this chart?", "trend_analysis"], |
| ["Extract all numerical data points and their labels", "data_extraction"], |
| ["Describe this chart for accessibility purposes", "accessibility"], |
| ["What business insights can be derived from this data?", "business_insights"], |
| ["Analyze the performance metrics shown in this dashboard", "comprehensive"] |
| ] |
| |
| gr.Examples( |
| examples=example_prompts, |
| inputs=[custom_prompt, analysis_type], |
| label="Try these example prompts:" |
| ) |
| |
| |
| def analyze_chart_comprehensive(image, analysis_type, custom_prompt, include_florence, confidence_threshold): |
| """Main analysis function with all parameters""" |
| if image is None: |
| return "Please upload an image first.", {}, "No data to export", "No data to export" |
| |
| try: |
| |
| structured_result = analyzer.analyze_chart_with_prompt( |
| image, |
| custom_prompt=custom_prompt.strip() if custom_prompt.strip() else None, |
| analysis_type=analysis_type |
| ) |
| |
| |
| formatted_display = analyzer.format_results_for_display(structured_result) |
| |
| return formatted_display, structured_result, "β
Analysis completed successfully", "Ready for export" |
| |
| except Exception as e: |
| error_msg = f"β Analysis failed: {str(e)}" |
| return error_msg, {"error": str(e)}, error_msg, error_msg |
| |
| def generate_export_files(json_data): |
| """Generate downloadable export files""" |
| if not json_data or json_data.get("error"): |
| return None, None, "β No valid data to export" |
| |
| try: |
| |
| json_str = json.dumps(json_data, indent=2, default=str) |
| json_file = io.StringIO(json_str) |
| |
| |
| csv_file = None |
| data_insights = json_data.get("data_insights", {}) |
| |
| if data_insights.get("numerical_data"): |
| df_data = { |
| 'Numerical_Values': data_insights["numerical_data"] |
| } |
| |
| |
| categories = data_insights.get("categories", []) |
| if categories: |
| |
| num_values = len(data_insights["numerical_data"]) |
| if len(categories) < num_values: |
| categories.extend([""] * (num_values - len(categories))) |
| else: |
| categories = categories[:num_values] |
| df_data['Categories'] = categories |
| |
| |
| detected_numbers = json_data.get("text_extraction", {}).get("detected_numbers", []) |
| if detected_numbers: |
| |
| number_summary = [] |
| for num_data in detected_numbers: |
| number_summary.append({ |
| 'Value': num_data.get('value', ''), |
| 'Type': num_data.get('type', ''), |
| 'Position': str(num_data.get('position', '')) |
| }) |
| |
| |
| numbers_df = pd.DataFrame(number_summary) |
| csv_buffer = io.StringIO() |
| numbers_df.to_csv(csv_buffer, index=False) |
| csv_file = csv_buffer.getvalue() |
| else: |
| |
| df = pd.DataFrame(df_data) |
| csv_buffer = io.StringIO() |
| df.to_csv(csv_buffer, index=False) |
| csv_file = csv_buffer.getvalue() |
| |
| return json_str, csv_file, "β
Export files generated successfully" |
| |
| except Exception as e: |
| return None, None, f"β Export failed: {str(e)}" |
| |
| def clear_all_inputs(): |
| """Clear all inputs and outputs""" |
| return ( |
| None, |
| "Upload an image and click 'Analyze Chart' to see results here.", |
| {}, |
| "No data to export", |
| "", |
| None, |
| None |
| ) |
| |
| |
| analyze_btn.click( |
| fn=analyze_chart_comprehensive, |
| inputs=[image_input, analysis_type, custom_prompt, include_florence, confidence_threshold], |
| outputs=[formatted_output, json_output, export_status, export_status] |
| ) |
| |
| export_btn.click( |
| fn=generate_export_files, |
| inputs=[json_output], |
| outputs=[json_download, csv_download, export_status] |
| ) |
| |
| clear_btn.click( |
| fn=clear_all_inputs, |
| outputs=[image_input, formatted_output, json_output, export_status, custom_prompt, json_download, csv_download] |
| ) |
| |
| |
| template_display.value = update_template_display("comprehensive") |
|
|
| |
| def load_image_from_url(url): |
| """Load image from URL""" |
| try: |
| response = requests.get(url, timeout=10) |
| response.raise_for_status() |
| image = Image.open(io.BytesIO(response.content)) |
| return image, "β
Image loaded successfully from URL" |
| except Exception as e: |
| return None, f"β Failed to load image: {str(e)}" |
|
|
| |
| with demo: |
| with gr.Accordion("π Load from URL", open=False): |
| url_input = gr.Textbox( |
| label="Image URL", |
| placeholder="https://example.com/chart.png" |
| ) |
| load_url_btn = gr.Button("π₯ Load from URL") |
| |
| load_url_btn.click( |
| fn=load_image_from_url, |
| inputs=[url_input], |
| outputs=[image_input, export_status] |
| ) |
|
|
| if __name__ == "__main__": |
| print("π Starting Enhanced Chart Analyzer...") |
| print("π Features:") |
| print(" - Structured JSON output") |
| print(" - Custom analysis prompts") |
| print(" - Multiple AI models (BLIP, TrOCR, EasyOCR, Florence-2)") |
| print(" - Data export capabilities") |
| print(" - Quality assessment") |
| print(" - Advanced visual analysis") |
| |
| try: |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| show_error=True, |
| debug=True |
| ) |
| except Exception as e: |
| print(f"β Error launching app: {e}") |
| print("π Trying fallback launch...") |
| demo.launch() |