Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| # Writable cache directory for HF downloads | |
| HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/data/hf-cache") | |
| try: | |
| os.makedirs(HF_CACHE_DIR, exist_ok=True) | |
| except Exception: | |
| pass | |
| # Add custom modules to path - try multiple possible locations | |
| possible_paths = [ | |
| "./custom_models", | |
| "../custom_models", | |
| "./Dense-Captioning-Platform/custom_models" | |
| ] | |
| for path in possible_paths: | |
| if os.path.exists(path): | |
| sys.path.insert(0, os.path.abspath(path)) | |
| break | |
| # Add mmcv to path if it exists | |
| if os.path.exists('./mmcv'): | |
| sys.path.insert(0, os.path.abspath('./mmcv')) | |
| print("β Added local mmcv to path") | |
| # Import and register custom modules | |
| try: | |
| from custom_models import register | |
| print("β Custom modules registered successfully") | |
| except Exception as e: | |
| print(f"β οΈ Warning: Could not register custom modules: {e}") | |
| # ---------------------- | |
| # Optional MedSAM integration | |
| # ---------------------- | |
| class MedSAMIntegrator: | |
| def __init__(self): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.medsam_model = None | |
| self.current_image = None | |
| self.current_image_path = None | |
| self.embedding = None | |
| self._load_medsam_model() | |
| def _ensure_segment_anything(self): | |
| try: | |
| import segment_anything # noqa: F401 | |
| return True | |
| except Exception as e: | |
| print(f"β segment_anything not available: {e}. Install it in Dockerfile to enable MedSAM.") | |
| return False | |
| def _load_medsam_model(self): | |
| try: | |
| # Ensure library is present | |
| if not self._ensure_segment_anything(): | |
| print("MedSAM features disabled (segment_anything not available)") | |
| return | |
| from segment_anything import sam_model_registry as _reg | |
| import torch as _torch | |
| # Preferred local path in HF cache | |
| medsam_ckpt_path = os.path.join(HF_CACHE_DIR, "medsam_vit_b.pth") | |
| # If not present, fetch from HF Hub using provided repo or default | |
| if not os.path.exists(medsam_ckpt_path): | |
| try: | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| repo_id = os.environ.get("HF_MEDSAM_REPO", "Aniketg6/Fine-Tuned-MedSAM") | |
| print(f"π Trying to download MedSAM checkpoint from {repo_id} ...") | |
| files = list_repo_files(repo_id) | |
| candidate = None | |
| for f in files: | |
| lf = f.lower() | |
| if lf.endswith(".pth") or lf.endswith(".pt"): | |
| candidate = f | |
| break | |
| if candidate is None: | |
| candidate = "medsam_vit_b.pth" | |
| ckpt_path = hf_hub_download(repo_id=repo_id, filename=candidate, cache_dir=HF_CACHE_DIR) | |
| medsam_ckpt_path = ckpt_path | |
| print(f"β Downloaded MedSAM checkpoint: {medsam_ckpt_path}") | |
| except Exception as dl_err: | |
| print(f"β Could not fetch MedSAM checkpoint from HF Hub: {dl_err}") | |
| print("MedSAM features disabled (no checkpoint)") | |
| return | |
| # Load checkpoint | |
| checkpoint = _torch.load(medsam_ckpt_path, map_location='cpu') | |
| self.medsam_model = _reg["vit_b"](checkpoint=None) | |
| self.medsam_model.load_state_dict(checkpoint) | |
| self.medsam_model.to(self.device) | |
| self.medsam_model.eval() | |
| print("β MedSAM model loaded successfully") | |
| except Exception as e: | |
| print(f"β MedSAM model not available: {e}. MedSAM features disabled.") | |
| def is_available(self): | |
| return self.medsam_model is not None | |
| def load_image(self, image_path, precomputed_embedding=None): | |
| try: | |
| from skimage import transform, io # local import to avoid hard dep if unused | |
| img_np = io.imread(image_path) | |
| if len(img_np.shape) == 2: | |
| img_3c = np.repeat(img_np[:, :, None], 3, axis=-1) | |
| else: | |
| img_3c = img_np | |
| self.current_image = img_3c | |
| self.current_image_path = image_path | |
| if precomputed_embedding is not None: | |
| if not self.set_precomputed_embedding(precomputed_embedding): | |
| self.get_embeddings() | |
| else: | |
| self.get_embeddings() | |
| return True | |
| except Exception as e: | |
| print(f"Error loading image for MedSAM: {e}") | |
| return False | |
| def get_embeddings(self): | |
| if self.current_image is None or self.medsam_model is None: | |
| return None | |
| from skimage import transform | |
| img_1024 = transform.resize( | |
| self.current_image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True | |
| ).astype(np.uint8) | |
| img_1024 = (img_1024 - img_1024.min()) / np.clip(img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None) | |
| img_1024_tensor = ( | |
| torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(self.device) | |
| ) | |
| self.embedding = self.medsam_model.image_encoder(img_1024_tensor) | |
| return self.embedding | |
| def set_precomputed_embedding(self, embedding_array): | |
| try: | |
| if isinstance(embedding_array, np.ndarray): | |
| embedding_tensor = torch.tensor(embedding_array).to(self.device) | |
| self.embedding = embedding_tensor | |
| return True | |
| return False | |
| except Exception as e: | |
| print(f"Error setting precomputed embedding: {e}") | |
| return False | |
| def medsam_inference(self, box_1024, height, width): | |
| if self.embedding is None or self.medsam_model is None: | |
| return None | |
| box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=self.embedding.device) | |
| if len(box_torch.shape) == 2: | |
| box_torch = box_torch[:, None, :] | |
| sparse_embeddings, dense_embeddings = self.medsam_model.prompt_encoder( | |
| points=None, boxes=box_torch, masks=None, | |
| ) | |
| low_res_logits, _ = self.medsam_model.mask_decoder( | |
| image_embeddings=self.embedding, | |
| image_pe=self.medsam_model.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=False, | |
| ) | |
| low_res_pred = torch.sigmoid(low_res_logits) | |
| low_res_pred = torch.nn.functional.interpolate( | |
| low_res_pred, size=(height, width), mode="bilinear", align_corners=False, | |
| ) | |
| low_res_pred = low_res_pred.squeeze().cpu().numpy() | |
| medsam_seg = (low_res_pred > 0.5).astype(np.uint8) | |
| return medsam_seg | |
| def segment_with_box(self, bbox): | |
| if self.embedding is None or self.current_image is None: | |
| return None | |
| try: | |
| H, W, _ = self.current_image.shape | |
| x1, y1, x2, y2 = bbox | |
| x1 = max(0, min(int(x1), W - 1)) | |
| y1 = max(0, min(int(y1), H - 1)) | |
| x2 = max(0, min(int(x2), W - 1)) | |
| y2 = max(0, min(int(y2), H - 1)) | |
| if x2 <= x1: | |
| x2 = min(x1 + 10, W - 1) | |
| if y2 <= y1: | |
| y2 = min(y1 + 10, H - 1) | |
| box_np = np.array([[x1, y1, x2, y2]], dtype=float) | |
| box_1024 = box_np / np.array([W, H, W, H]) * 1024.0 | |
| medsam_mask = self.medsam_inference(box_1024, H, W) | |
| if medsam_mask is not None: | |
| return {"mask": medsam_mask, "confidence": 1.0, "method": "medsam_box"} | |
| return None | |
| except Exception as e: | |
| print(f"Error in MedSAM box-based segmentation: {e}") | |
| return None | |
| # Single global instance | |
| _medsam = MedSAMIntegrator() | |
| def _extract_bboxes_from_mmdet_result(det_result): | |
| """Extract Nx4 xyxy bboxes from various MMDet result formats.""" | |
| boxes = [] | |
| try: | |
| # MMDet 3.x: list of DetDataSample | |
| if isinstance(det_result, list) and len(det_result) > 0: | |
| sample = det_result[0] | |
| if hasattr(sample, 'pred_instances'): | |
| inst = sample.pred_instances | |
| if hasattr(inst, 'bboxes'): | |
| b = inst.bboxes | |
| # mmengine structures may use .tensor for boxes | |
| if hasattr(b, 'tensor'): | |
| b = b.tensor | |
| boxes = b.detach().cpu().numpy().tolist() | |
| # Single DetDataSample | |
| elif hasattr(det_result, 'pred_instances'): | |
| inst = det_result.pred_instances | |
| if hasattr(inst, 'bboxes'): | |
| b = inst.bboxes | |
| if hasattr(b, 'tensor'): | |
| b = b.tensor | |
| boxes = b.detach().cpu().numpy().tolist() | |
| # MMDet 2.x: tuple of (bbox_result, segm_result) | |
| elif isinstance(det_result, tuple) and len(det_result) >= 1: | |
| bbox_result = det_result[0] | |
| # bbox_result is list per class, each Nx5 [x1,y1,x2,y2,score] | |
| if isinstance(bbox_result, (list, tuple)): | |
| for arr in bbox_result: | |
| try: | |
| arr_np = np.array(arr) | |
| if arr_np.ndim == 2 and arr_np.shape[1] >= 4: | |
| boxes.extend(arr_np[:, :4].tolist()) | |
| except Exception: | |
| continue | |
| except Exception as e: | |
| print(f"Failed to parse MMDet result for boxes: {e}") | |
| return boxes | |
| def _overlay_masks_on_image(image_pil, mask_list, alpha=0.4): | |
| """Overlay binary masks on an image with random colors.""" | |
| if image_pil is None or not mask_list: | |
| return image_pil | |
| img = np.array(image_pil.convert('RGB')) | |
| overlay = img.copy() | |
| for idx, m in enumerate(mask_list): | |
| if m is None or 'mask' not in m or m['mask'] is None: | |
| continue | |
| mask = m['mask'].astype(bool) | |
| color = np.random.RandomState(seed=idx + 1234).randint(0, 255, size=3) | |
| overlay[mask] = (0.5 * overlay[mask] + 0.5 * color).astype(np.uint8) | |
| blended = (alpha * overlay + (1 - alpha) * img).astype(np.uint8) | |
| return Image.fromarray(blended) | |
| def _mask_to_polygons(mask: np.ndarray): | |
| """Convert a binary mask (H,W) to a list of polygons ([[x,y], ...]) using OpenCV contours.""" | |
| try: | |
| mask_u8 = (mask.astype(np.uint8) * 255) | |
| contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| polygons = [] | |
| for cnt in contours: | |
| if cnt is None or len(cnt) < 3: | |
| continue | |
| # Simplify contour slightly | |
| epsilon = 0.002 * cv2.arcLength(cnt, True) | |
| approx = cv2.approxPolyDP(cnt, epsilon, True) | |
| poly = approx.reshape(-1, 2).tolist() | |
| polygons.append(poly) | |
| return polygons | |
| except Exception as e: | |
| print(f"_mask_to_polygons failed: {e}") | |
| return [] | |
| def _find_largest_foreground_bbox(pil_img: Image.Image): | |
| """Heuristic: find largest foreground region bbox via Otsu threshold on grayscale. | |
| Returns [x1, y1, x2, y2] or full-image bbox if none found.""" | |
| try: | |
| img = np.array(pil_img.convert('RGB')) | |
| gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
| # Otsu threshold (invert if needed by checking mean) | |
| _, th = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| # Assume foreground is darker; invert if threshold yields background as white majority | |
| if th.mean() > 127: | |
| th = 255 - th | |
| # Morph close to connect regions | |
| kernel = np.ones((5, 5), np.uint8) | |
| th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| W, H = pil_img.size | |
| return [0, 0, W - 1, H - 1] | |
| # Largest contour by area | |
| cnt = max(contours, key=cv2.contourArea) | |
| x, y, w, h = cv2.boundingRect(cnt) | |
| # Pad a little | |
| pad = int(0.02 * max(w, h)) | |
| x1 = max(0, x - pad) | |
| y1 = max(0, y - pad) | |
| x2 = min(img.shape[1] - 1, x + w + pad) | |
| y2 = min(img.shape[0] - 1, y + h + pad) | |
| return [x1, y1, x2, y2] | |
| except Exception as e: | |
| print(f"_find_largest_foreground_bbox failed: {e}") | |
| W, H = pil_img.size | |
| return [0, 0, W - 1, H - 1] | |
| def _find_topk_foreground_bboxes(pil_img: Image.Image, max_regions: int = 20, min_area: int = 100): | |
| """Find top-K foreground bboxes via Otsu threshold + morphology. Returns list of [x1,y1,x2,y2].""" | |
| try: | |
| img = np.array(pil_img.convert('RGB')) | |
| gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
| _, th = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| if th.mean() > 127: | |
| th = 255 - th | |
| kernel = np.ones((3, 3), np.uint8) | |
| th = cv2.morphologyEx(th, cv2.MORPH_OPEN, kernel, iterations=1) | |
| th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return [] | |
| contours = sorted(contours, key=cv2.contourArea, reverse=True) | |
| bboxes = [] | |
| H, W = img.shape[:2] | |
| for cnt in contours: | |
| area = cv2.contourArea(cnt) | |
| if area < min_area: | |
| continue | |
| x, y, w, h = cv2.boundingRect(cnt) | |
| # Filter very thin shapes | |
| if w < 5 or h < 5: | |
| continue | |
| pad = int(0.01 * max(w, h)) | |
| x1 = max(0, x - pad) | |
| y1 = max(0, y - pad) | |
| x2 = min(W - 1, x + w + pad) | |
| y2 = min(H - 1, y + h + pad) | |
| bboxes.append([x1, y1, x2, y2]) | |
| if len(bboxes) >= max_regions: | |
| break | |
| return bboxes | |
| except Exception as e: | |
| print(f"_find_topk_foreground_bboxes failed: {e}") | |
| return [] | |
| # Try to import mmdet for inference | |
| try: | |
| from mmdet.apis import init_detector, inference_detector | |
| MM_DET_AVAILABLE = True | |
| print("β MMDetection available for inference") | |
| except ImportError as e: | |
| print(f"β οΈ MMDetection import failed: {e}") | |
| print("β MMDetection not available - install in Dockerfile") | |
| MM_DET_AVAILABLE = False | |
| # === Chart Type Classification (DocFigure) === | |
| print("π Loading Chart Classification Model...") | |
| # Chart type labels from DocFigure dataset (28 classes) | |
| CHART_TYPE_LABELS = [ | |
| 'Line graph', 'Natural image', 'Table', '3D object', 'Bar plot', 'Scatter plot', | |
| 'Medical image', 'Sketch', 'Geographic map', 'Flow chart', 'Heat map', 'Mask', | |
| 'Block diagram', 'Venn diagram', 'Confusion matrix', 'Histogram', 'Box plot', | |
| 'Vector plot', 'Pie chart', 'Surface plot', 'Algorithm', 'Contour plot', | |
| 'Tree diagram', 'Bubble chart', 'Polar plot', 'Area chart', 'Pareto chart', 'Radar chart' | |
| ] | |
| try: | |
| # Load the chart_type.pth model file from Hugging Face Hub | |
| from huggingface_hub import hf_hub_download | |
| from torchvision import transforms | |
| print("π Downloading chart_type.pth from Hugging Face Hub...") | |
| chart_type_path = hf_hub_download( | |
| repo_id="hanszhu/ChartTypeNet-DocFigure", | |
| filename="chart_type.pth", | |
| cache_dir=HF_CACHE_DIR | |
| ) | |
| print(f"β Downloaded to: {chart_type_path}") | |
| # Load the PyTorch model | |
| loaded_data = torch.load(chart_type_path, map_location='cpu') | |
| # Check if it's a state dict or a complete model | |
| if isinstance(loaded_data, dict): | |
| # Check if it's a checkpoint with model_state_dict | |
| if "model_state_dict" in loaded_data: | |
| print("π Loading checkpoint, extracting model_state_dict...") | |
| state_dict = loaded_data["model_state_dict"] | |
| else: | |
| # It's a direct state dict | |
| print("π Loading state dict, creating model architecture...") | |
| state_dict = loaded_data | |
| # Strip "backbone." prefix from state dict keys if present | |
| cleaned_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith("backbone."): | |
| # Remove "backbone." prefix | |
| new_key = key[9:] | |
| cleaned_state_dict[new_key] = value | |
| else: | |
| cleaned_state_dict[key] = value | |
| print(f"π Cleaned state dict: {len(cleaned_state_dict)} keys") | |
| # Create the model architecture | |
| from torchvision.models import resnet50 | |
| chart_type_model = resnet50(pretrained=False) | |
| # Create the correct classifier structure to match the state dict | |
| import torch.nn as nn | |
| in_features = chart_type_model.fc.in_features | |
| dropout = nn.Dropout(0.5) | |
| chart_type_model.fc = nn.Sequential( | |
| nn.Linear(in_features, 512), | |
| nn.ReLU(inplace=True), | |
| dropout, | |
| nn.Linear(512, 28) | |
| ) | |
| # Load the cleaned state dict | |
| chart_type_model.load_state_dict(cleaned_state_dict) | |
| else: | |
| # It's a complete model | |
| chart_type_model = loaded_data | |
| chart_type_model.eval() | |
| # Create a simple processor for the model | |
| chart_type_processor = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| CHART_TYPE_AVAILABLE = True | |
| print("β Chart classification model loaded") | |
| except Exception as e: | |
| print(f"β οΈ Failed to load chart classification model: {e}") | |
| import traceback | |
| print("π Full traceback:") | |
| traceback.print_exc() | |
| CHART_TYPE_AVAILABLE = False | |
| # === Chart Element Detection (Cascade R-CNN) === | |
| element_model = None | |
| datapoint_model = None | |
| print(f"π MM_DET_AVAILABLE: {MM_DET_AVAILABLE}") | |
| if MM_DET_AVAILABLE: | |
| # Check if config files exist | |
| element_config = "models/chart_elementnet_swin.py" | |
| point_config = "models/chart_pointnet_swin.py" | |
| print(f"π Checking config files...") | |
| print(f"π Element config exists: {os.path.exists(element_config)}") | |
| print(f"π Point config exists: {os.path.exists(point_config)}") | |
| print(f"π Current working directory: {os.getcwd()}") | |
| print(f"π Files in models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}") | |
| try: | |
| print("π Loading ChartElementNet-MultiClass (Cascade R-CNN)...") | |
| print(f"π Config path: {element_config}") | |
| print(f"π Weights path: hanszhu/ChartElementNet-MultiClass") | |
| print(f"π About to call init_detector...") | |
| # Download model from Hugging Face Hub | |
| from huggingface_hub import hf_hub_download | |
| print("π Downloading ChartElementNet weights from Hugging Face Hub...") | |
| element_checkpoint = hf_hub_download( | |
| repo_id="hanszhu/ChartElementNet-MultiClass", | |
| filename="chart_label+.pth", | |
| cache_dir=HF_CACHE_DIR | |
| ) | |
| print(f"β Downloaded to: {element_checkpoint}") | |
| # Use local config with downloaded weights | |
| element_model = init_detector(element_config, element_checkpoint, device="cpu") | |
| print("β ChartElementNet loaded successfully") | |
| except Exception as e: | |
| print(f"β Failed to load ChartElementNet: {e}") | |
| print(f"π Error type: {type(e).__name__}") | |
| print(f"π Error details: {str(e)}") | |
| import traceback | |
| print("π Full traceback:") | |
| traceback.print_exc() | |
| try: | |
| print("π Loading ChartPointNet-InstanceSeg (Mask R-CNN)...") | |
| print(f"π Config path: {point_config}") | |
| print(f"π Weights path: hanszhu/ChartPointNet-InstanceSeg") | |
| print(f"π About to call init_detector...") | |
| # Download model from Hugging Face Hub | |
| print("π Downloading ChartPointNet weights from Hugging Face Hub...") | |
| datapoint_checkpoint = hf_hub_download( | |
| repo_id="hanszhu/ChartPointNet-InstanceSeg", | |
| filename="chart_datapoint.pth", | |
| cache_dir=HF_CACHE_DIR | |
| ) | |
| print(f"β Downloaded to: {datapoint_checkpoint}") | |
| # Use local config with downloaded weights | |
| datapoint_model = init_detector(point_config, datapoint_checkpoint, device="cpu") | |
| print("β ChartPointNet loaded successfully") | |
| except Exception as e: | |
| print(f"β Failed to load ChartPointNet: {e}") | |
| print(f"π Error type: {type(e).__name__}") | |
| print(f"π Error details: {str(e)}") | |
| import traceback | |
| print("π Full traceback:") | |
| traceback.print_exc() | |
| else: | |
| print("β MMDetection not available - cannot load custom models") | |
| print(f"π MM_DET_AVAILABLE was False") | |
| print(f"π Final model status:") | |
| print(f"π element_model: {element_model is not None}") | |
| print(f"π datapoint_model: {datapoint_model is not None}") | |
| # === Main prediction function === | |
| def analyze(image): | |
| """ | |
| Analyze a chart image and return comprehensive results. | |
| Args: | |
| image: Input chart image (filepath string or PIL.Image) | |
| Returns: | |
| dict: Analysis results containing: | |
| - chart_type_id (int): Numeric chart type identifier (0-27) | |
| - chart_type_label (str): Human-readable chart type name | |
| - element_result (str): Detected chart elements (titles, axes, legends, etc.) | |
| - datapoint_result (str): Segmented data points and regions | |
| - status (str): Processing status message | |
| - processing_time (float): Time taken for analysis in seconds | |
| """ | |
| import time | |
| from PIL import Image | |
| start_time = time.time() | |
| # Handle filepath input (convert to PIL Image) | |
| if isinstance(image, str): | |
| # It's a filepath, load the image | |
| image = Image.open(image).convert("RGB") | |
| elif image is None: | |
| return {"error": "No image provided"} | |
| # Ensure we have a PIL Image | |
| if not isinstance(image, Image.Image): | |
| return {"error": "Invalid image format"} | |
| result = { | |
| "chart_type_id": "Model not available", | |
| "chart_type_label": "Model not available", | |
| "element_result": "MMDetection models not available", | |
| "datapoint_result": "MMDetection models not available", | |
| "status": "Basic chart classification only", | |
| "processing_time": 0.0, | |
| "medsam": {"available": False} | |
| } | |
| # Chart Type Classification | |
| if CHART_TYPE_AVAILABLE: | |
| try: | |
| # Preprocess image for PyTorch model | |
| processed_image = chart_type_processor(image).unsqueeze(0) # Add batch dimension | |
| # Get prediction | |
| with torch.no_grad(): | |
| outputs = chart_type_model(processed_image) | |
| # Handle different output formats | |
| if isinstance(outputs, torch.Tensor): | |
| logits = outputs | |
| elif hasattr(outputs, 'logits'): | |
| logits = outputs.logits | |
| else: | |
| logits = outputs | |
| predicted_class = logits.argmax(dim=-1).item() | |
| result["chart_type_id"] = predicted_class | |
| result["chart_type_label"] = CHART_TYPE_LABELS[predicted_class] if 0 <= predicted_class < len(CHART_TYPE_LABELS) else f"Unknown ({predicted_class})" | |
| result["status"] = "Chart classification completed" | |
| except Exception as e: | |
| result["chart_type_id"] = f"Error: {str(e)}" | |
| result["chart_type_label"] = f"Error: {str(e)}" | |
| result["status"] = "Error in chart classification" | |
| # Chart Element Detection (Cascade R-CNN) | |
| if element_model is not None: | |
| try: | |
| # Convert PIL image to numpy array for MMDetection | |
| np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL β BGR | |
| element_result = inference_detector(element_model, np_img) | |
| # Convert result to more API-friendly format | |
| if isinstance(element_result, tuple): | |
| bbox_result, segm_result = element_result | |
| element_data = { | |
| "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result), | |
| "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result) | |
| } | |
| else: | |
| element_data = str(element_result) | |
| result["element_result"] = element_data | |
| result["status"] = "Chart classification + element detection completed" | |
| except Exception as e: | |
| result["element_result"] = f"Error: {str(e)}" | |
| # Chart Data Point Segmentation (Mask R-CNN) | |
| if datapoint_model is not None: | |
| try: | |
| # Convert PIL image to numpy array for MMDetection | |
| np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL β BGR | |
| datapoint_result = inference_detector(datapoint_model, np_img) | |
| # Convert result to more API-friendly format | |
| if isinstance(datapoint_result, tuple): | |
| bbox_result, segm_result = datapoint_result | |
| datapoint_data = { | |
| "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result), | |
| "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result) | |
| } | |
| else: | |
| datapoint_data = str(datapoint_result) | |
| result["datapoint_result"] = datapoint_data | |
| result["status"] = "Full analysis completed" | |
| except Exception as e: | |
| result["datapoint_result"] = f"Error: {str(e)}" | |
| # If predicted as medical image and MedSAM is available, include mask data (polygons) | |
| try: | |
| label_lower = str(result.get("chart_type_label", "")).strip().lower() | |
| if label_lower == "medical image": | |
| if _medsam.is_available(): | |
| # Indicate availability; masks are generated in then-chain | |
| result["medsam"] = {"available": True} | |
| else: | |
| # Not available; include reason | |
| result["medsam"] = {"available": False, "reason": "segment_anything or checkpoint missing"} | |
| except Exception as e: | |
| print(f"MedSAM JSON augmentation failed: {e}") | |
| result["processing_time"] = round(time.time() - start_time, 3) | |
| return result | |
| def analyze_with_medsam(base_result, image): | |
| """Auto-generate segmentations for medical images using SAM ViT-H if available, | |
| otherwise fallback to MedSAM over top-K foreground boxes. Returns updated JSON and overlay image.""" | |
| try: | |
| if not isinstance(base_result, dict): | |
| return base_result, None | |
| label = str(base_result.get("chart_type_label", "")).strip().lower() | |
| if label != "medical image" or not _medsam.is_available(): | |
| return base_result, None | |
| pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image | |
| if pil_img is None: | |
| return base_result, None | |
| # Prepare embedding | |
| img_path = image if isinstance(image, str) else None | |
| if img_path is None: | |
| tmp_path = "./_tmp_input_image.png" | |
| pil_img.save(tmp_path) | |
| img_path = tmp_path | |
| _medsam.load_image(img_path) | |
| segmentations = [] | |
| masks_for_overlay = [] | |
| # AUTO segmentation path | |
| try: | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
| import cv2 as _cv2 | |
| # If ViT-H checkpoint present, use SAM automatic mask generator (download if missing) | |
| vit_h_ckpt = os.path.join(HF_CACHE_DIR, "sam_vit_h_4b8939.pth") | |
| if not os.path.exists(vit_h_ckpt): | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| vit_h_ckpt = hf_hub_download( | |
| repo_id="Aniketg6/SAM", | |
| filename="sam_vit_h_4b8939.pth", | |
| cache_dir=HF_CACHE_DIR | |
| ) | |
| print(f"β Downloaded SAM ViT-H checkpoint to: {vit_h_ckpt}") | |
| except Exception as dlh: | |
| print(f"β Failed to download SAM ViT-H checkpoint: {dlh}") | |
| if os.path.exists(vit_h_ckpt): | |
| img_bgr = _cv2.imread(img_path) | |
| sam = sam_model_registry["vit_h"](checkpoint=vit_h_ckpt) | |
| mask_generator = SamAutomaticMaskGenerator(sam) | |
| masks = mask_generator.generate(img_bgr) | |
| # Keep top-12 masks by stability_score | |
| masks = sorted(masks, key=lambda m: m.get('stability_score', 0), reverse=True)[:12] | |
| for m in masks: | |
| seg = m.get('segmentation', None) | |
| if seg is None: | |
| continue | |
| seg_u8 = seg.astype(np.uint8) | |
| segmentations.append({ | |
| "mask": seg_u8.tolist(), | |
| "confidence": float(m.get('stability_score', 1.0)), | |
| "method": "sam_auto" | |
| }) | |
| masks_for_overlay.append({"mask": seg_u8}) | |
| else: | |
| # Fallback: derive candidate boxes and run MedSAM per box | |
| cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=8, min_area=200) | |
| for bbox in cand_bboxes: | |
| m = _medsam.segment_with_box(bbox) | |
| if m is None or not isinstance(m.get('mask'), np.ndarray): | |
| continue | |
| segmentations.append({ | |
| "mask": m['mask'].astype(np.uint8).tolist(), | |
| "confidence": float(m.get('confidence', 1.0)), | |
| "method": m.get("method", "medsam_box_auto") | |
| }) | |
| masks_for_overlay.append(m) | |
| except Exception as auto_e: | |
| print(f"Automatic MedSAM segmentation failed: {auto_e}") | |
| W, H = pil_img.size | |
| base_result["medsam"] = { | |
| "available": True, | |
| "height": H, | |
| "width": W, | |
| "segmentations": segmentations, | |
| "num_segments": len(segmentations) | |
| } | |
| overlay_img = _overlay_masks_on_image(pil_img, masks_for_overlay) if masks_for_overlay else None | |
| return base_result, overlay_img | |
| except Exception as e: | |
| print(f"analyze_with_medsam failed: {e}") | |
| return base_result, None | |
| # === Gradio UI with API enhancements === | |
| # Create Blocks interface with explicit API name for stable API surface | |
| with gr.Blocks( | |
| title="π Dense Captioning Platform" | |
| ) as demo: | |
| gr.Markdown("# π Dense Captioning Platform") | |
| gr.Markdown(""" | |
| **Comprehensive Chart Analysis API** | |
| Upload a chart image to get: | |
| - **Chart Type Classification**: Identifies the type of chart (line, bar, scatter, etc.) | |
| - **Element Detection**: Detects chart elements like titles, axes, legends, data points | |
| - **Data Point Segmentation**: Segments individual data points and regions | |
| Masks will be automatically generated for medical images when supported. | |
| **API Usage:** | |
| ```python | |
| from gradio_client import Client, handle_file | |
| client = Client("hanszhu/Dense-Captioning-Platform") | |
| result = client.predict( | |
| image=handle_file('path/to/your/chart.png'), | |
| api_name="/predict" | |
| ) | |
| print(result) | |
| ``` | |
| **Supported Chart Types:** Line graphs, Bar plots, Scatter plots, Pie charts, Heat maps, and 23+ more | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Input | |
| image_input = gr.Image( | |
| type="filepath", # β REQUIRED for gradio_client | |
| label="Upload Chart Image", | |
| height=400 | |
| ) | |
| # Analyze button (single) | |
| analyze_btn = gr.Button( | |
| "π Analyze", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(): | |
| # Output JSON | |
| result_output = gr.JSON( | |
| label="Analysis Results", | |
| height=400 | |
| ) | |
| # Overlay image output (populated only for medical images) | |
| overlay_output = gr.Image( | |
| label="MedSAM Overlay (Medical images)", | |
| height=400 | |
| ) | |
| # Single API endpoint for JSON | |
| analyze_event = analyze_btn.click( | |
| fn=analyze, | |
| inputs=image_input, | |
| outputs=result_output, | |
| api_name="/predict" # β Standard API name that gradio_client expects | |
| ) | |
| # Automatic overlay generation step for medical images | |
| analyze_event.then( | |
| fn=analyze_with_medsam, | |
| inputs=[result_output, image_input], | |
| outputs=[result_output, overlay_output], | |
| ) | |
| # Add some examples | |
| gr.Examples( | |
| examples=[ | |
| ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"] | |
| ], | |
| inputs=image_input, | |
| label="Try with this example" | |
| ) | |
| # Launch with API-friendly settings | |
| if __name__ == "__main__": | |
| launch_kwargs = { | |
| "server_name": "0.0.0.0", # Allow external connections | |
| "server_port": 7860, | |
| "share": False, # Set to True if you want a public link | |
| "show_error": True, # Show detailed errors for debugging | |
| "quiet": False, # Show startup messages | |
| "show_api": True, # Enable API documentation | |
| "ssr_mode": False # Disable experimental SSR in Docker env | |
| } | |
| # Lightweight keepalive (self-ping) to avoid idle shutdowns | |
| try: | |
| import threading, time, requests | |
| def _keepalive(): | |
| url = "http://127.0.0.1:7860/" | |
| while True: | |
| try: | |
| requests.get(url, timeout=3) | |
| except Exception: | |
| pass | |
| time.sleep(60) | |
| threading.Thread(target=_keepalive, daemon=True).start() | |
| except Exception: | |
| pass | |
| # Enable queue for gradio_client compatibility | |
| demo.queue().launch(**launch_kwargs) # β required for gradio_client to work |