Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import gradio as gr | |
from PIL import Image | |
import torch | |
import numpy as np | |
import cv2 | |
import time | |
import json | |
import traceback | |
# Simple timestamped logger | |
def log(msg: str) -> None: | |
print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) | |
# Writable cache directory for HF downloads | |
HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/data/hf-cache") | |
try: | |
os.makedirs(HF_CACHE_DIR, exist_ok=True) | |
except Exception: | |
pass | |
# Add custom modules to path - try multiple possible locations | |
possible_paths = [ | |
"./custom_models", | |
"../custom_models", | |
"./Dense-Captioning-Platform/custom_models" | |
] | |
for path in possible_paths: | |
if os.path.exists(path): | |
sys.path.insert(0, os.path.abspath(path)) | |
break | |
# Add mmcv to path if it exists | |
if os.path.exists('./mmcv'): | |
sys.path.insert(0, os.path.abspath('./mmcv')) | |
print("β Added local mmcv to path") | |
# Import and register custom modules | |
try: | |
from custom_models import register | |
print("β Custom modules registered successfully") | |
except Exception as e: | |
print(f"β οΈ Warning: Could not register custom modules: {e}") | |
# ---------------------- | |
# Optional MedSAM integration | |
# ---------------------- | |
class MedSAMIntegrator: | |
def __init__(self): | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.medsam_model = None | |
self.current_image = None | |
self.current_image_path = None | |
self.embedding = None | |
self._load_medsam_model() | |
def _ensure_segment_anything(self): | |
try: | |
import segment_anything # noqa: F401 | |
return True | |
except Exception as e: | |
print(f"β segment_anything not available: {e}. Install it in Dockerfile to enable MedSAM.") | |
return False | |
def _load_medsam_model(self): | |
try: | |
# Ensure library is present | |
if not self._ensure_segment_anything(): | |
print("MedSAM features disabled (segment_anything not available)") | |
return | |
from segment_anything import sam_model_registry as _reg | |
import torch as _torch | |
# Preferred local path in HF cache | |
medsam_ckpt_path = os.path.join(HF_CACHE_DIR, "medsam_vit_b.pth") | |
# If not present, fetch from HF Hub using provided repo or default | |
if not os.path.exists(medsam_ckpt_path): | |
try: | |
from huggingface_hub import hf_hub_download, list_repo_files | |
repo_id = os.environ.get("HF_MEDSAM_REPO", "Aniketg6/Fine-Tuned-MedSAM") | |
print(f"π Trying to download MedSAM checkpoint from {repo_id} ...") | |
files = list_repo_files(repo_id) | |
candidate = None | |
for f in files: | |
lf = f.lower() | |
if lf.endswith(".pth") or lf.endswith(".pt"): | |
candidate = f | |
break | |
if candidate is None: | |
candidate = "medsam_vit_b.pth" | |
ckpt_path = hf_hub_download(repo_id=repo_id, filename=candidate, cache_dir=HF_CACHE_DIR) | |
medsam_ckpt_path = ckpt_path | |
print(f"β Downloaded MedSAM checkpoint: {medsam_ckpt_path}") | |
except Exception as dl_err: | |
print(f"β Could not fetch MedSAM checkpoint from HF Hub: {dl_err}") | |
print("MedSAM features disabled (no checkpoint)") | |
return | |
# Load checkpoint | |
checkpoint = _torch.load(medsam_ckpt_path, map_location='cpu') | |
self.medsam_model = _reg["vit_b"](checkpoint=None) | |
self.medsam_model.load_state_dict(checkpoint) | |
self.medsam_model.to(self.device) | |
self.medsam_model.eval() | |
print("β MedSAM model loaded successfully") | |
except Exception as e: | |
print(f"β MedSAM model not available: {e}. MedSAM features disabled.") | |
def is_available(self): | |
return self.medsam_model is not None | |
def load_image(self, image_path, precomputed_embedding=None): | |
try: | |
from skimage import transform, io # local import to avoid hard dep if unused | |
img_np = io.imread(image_path) | |
if len(img_np.shape) == 2: | |
img_3c = np.repeat(img_np[:, :, None], 3, axis=-1) | |
else: | |
img_3c = img_np | |
self.current_image = img_3c | |
self.current_image_path = image_path | |
if precomputed_embedding is not None: | |
if not self.set_precomputed_embedding(precomputed_embedding): | |
self.get_embeddings() | |
else: | |
self.get_embeddings() | |
return True | |
except Exception as e: | |
print(f"Error loading image for MedSAM: {e}") | |
return False | |
def get_embeddings(self): | |
if self.current_image is None or self.medsam_model is None: | |
return None | |
from skimage import transform | |
img_1024 = transform.resize( | |
self.current_image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True | |
).astype(np.uint8) | |
img_1024 = (img_1024 - img_1024.min()) / np.clip(img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None) | |
img_1024_tensor = ( | |
torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(self.device) | |
) | |
self.embedding = self.medsam_model.image_encoder(img_1024_tensor) | |
return self.embedding | |
def set_precomputed_embedding(self, embedding_array): | |
try: | |
if isinstance(embedding_array, np.ndarray): | |
embedding_tensor = torch.tensor(embedding_array).to(self.device) | |
self.embedding = embedding_tensor | |
return True | |
return False | |
except Exception as e: | |
print(f"Error setting precomputed embedding: {e}") | |
return False | |
def medsam_inference(self, box_1024, height, width): | |
if self.embedding is None or self.medsam_model is None: | |
return None | |
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=self.embedding.device) | |
if len(box_torch.shape) == 2: | |
box_torch = box_torch[:, None, :] | |
sparse_embeddings, dense_embeddings = self.medsam_model.prompt_encoder( | |
points=None, boxes=box_torch, masks=None, | |
) | |
low_res_logits, _ = self.medsam_model.mask_decoder( | |
image_embeddings=self.embedding, | |
image_pe=self.medsam_model.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=False, | |
) | |
low_res_pred = torch.sigmoid(low_res_logits) | |
low_res_pred = torch.nn.functional.interpolate( | |
low_res_pred, size=(height, width), mode="bilinear", align_corners=False, | |
) | |
low_res_pred = low_res_pred.squeeze().cpu().numpy() | |
medsam_seg = (low_res_pred > 0.5).astype(np.uint8) | |
return medsam_seg | |
def segment_with_box(self, bbox): | |
if self.embedding is None or self.current_image is None: | |
return None | |
try: | |
H, W, _ = self.current_image.shape | |
x1, y1, x2, y2 = bbox | |
x1 = max(0, min(int(x1), W - 1)) | |
y1 = max(0, min(int(y1), H - 1)) | |
x2 = max(0, min(int(x2), W - 1)) | |
y2 = max(0, min(int(y2), H - 1)) | |
if x2 <= x1: | |
x2 = min(x1 + 10, W - 1) | |
if y2 <= y1: | |
y2 = min(y1 + 10, H - 1) | |
box_np = np.array([[x1, y1, x2, y2]], dtype=float) | |
box_1024 = box_np / np.array([W, H, W, H]) * 1024.0 | |
medsam_mask = self.medsam_inference(box_1024, H, W) | |
if medsam_mask is not None: | |
return {"mask": medsam_mask, "confidence": 1.0, "method": "medsam_box"} | |
return None | |
except Exception as e: | |
print(f"Error in MedSAM box-based segmentation: {e}") | |
return None | |
# Single global instance | |
_medsam = MedSAMIntegrator() | |
def _extract_bboxes_from_mmdet_result(det_result): | |
"""Extract Nx4 xyxy bboxes from various MMDet result formats.""" | |
boxes = [] | |
try: | |
# MMDet 3.x: list of DetDataSample | |
if isinstance(det_result, list) and len(det_result) > 0: | |
sample = det_result[0] | |
if hasattr(sample, 'pred_instances'): | |
inst = sample.pred_instances | |
if hasattr(inst, 'bboxes'): | |
b = inst.bboxes | |
# mmengine structures may use .tensor for boxes | |
if hasattr(b, 'tensor'): | |
b = b.tensor | |
boxes = b.detach().cpu().numpy().tolist() | |
# Single DetDataSample | |
elif hasattr(det_result, 'pred_instances'): | |
inst = det_result.pred_instances | |
if hasattr(inst, 'bboxes'): | |
b = inst.bboxes | |
if hasattr(b, 'tensor'): | |
b = b.tensor | |
boxes = b.detach().cpu().numpy().tolist() | |
# MMDet 2.x: tuple of (bbox_result, segm_result) | |
elif isinstance(det_result, tuple) and len(det_result) >= 1: | |
bbox_result = det_result[0] | |
# bbox_result is list per class, each Nx5 [x1,y1,x2,y2,score] | |
if isinstance(bbox_result, (list, tuple)): | |
for arr in bbox_result: | |
try: | |
arr_np = np.array(arr) | |
if arr_np.ndim == 2 and arr_np.shape[1] >= 4: | |
boxes.extend(arr_np[:, :4].tolist()) | |
except Exception: | |
continue | |
except Exception as e: | |
print(f"Failed to parse MMDet result for boxes: {e}") | |
return boxes | |
def _overlay_masks_on_image(image_pil, mask_list, alpha=0.4): | |
"""Overlay binary masks on an image with random colors.""" | |
if image_pil is None or not mask_list: | |
return image_pil | |
img = np.array(image_pil.convert('RGB')) | |
overlay = img.copy() | |
for idx, m in enumerate(mask_list): | |
if m is None or 'mask' not in m or m['mask'] is None: | |
continue | |
mask = m['mask'].astype(bool) | |
color = np.random.RandomState(seed=idx + 1234).randint(0, 255, size=3) | |
overlay[mask] = (0.5 * overlay[mask] + 0.5 * color).astype(np.uint8) | |
blended = (alpha * overlay + (1 - alpha) * img).astype(np.uint8) | |
return Image.fromarray(blended) | |
def _mask_to_polygons(mask: np.ndarray): | |
"""Convert a binary mask (H,W) to a list of polygons ([[x,y], ...]) using OpenCV contours.""" | |
try: | |
mask_u8 = (mask.astype(np.uint8) * 255) | |
contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
polygons = [] | |
for cnt in contours: | |
if cnt is None or len(cnt) < 3: | |
continue | |
# Simplify contour slightly | |
epsilon = 0.002 * cv2.arcLength(cnt, True) | |
approx = cv2.approxPolyDP(cnt, epsilon, True) | |
poly = approx.reshape(-1, 2).tolist() | |
polygons.append(poly) | |
return polygons | |
except Exception as e: | |
print(f"_mask_to_polygons failed: {e}") | |
return [] | |
def _find_largest_foreground_bbox(pil_img: Image.Image): | |
"""Heuristic: find largest foreground region bbox via Otsu threshold on grayscale. | |
Returns [x1, y1, x2, y2] or full-image bbox if none found.""" | |
try: | |
img = np.array(pil_img.convert('RGB')) | |
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
# Otsu threshold (invert if needed by checking mean) | |
_, th = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
# Assume foreground is darker; invert if threshold yields background as white majority | |
if th.mean() > 127: | |
th = 255 - th | |
# Morph close to connect regions | |
kernel = np.ones((5, 5), np.uint8) | |
th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2) | |
contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if not contours: | |
W, H = pil_img.size | |
return [0, 0, W - 1, H - 1] | |
# Largest contour by area | |
cnt = max(contours, key=cv2.contourArea) | |
x, y, w, h = cv2.boundingRect(cnt) | |
# Pad a little | |
pad = int(0.02 * max(w, h)) | |
x1 = max(0, x - pad) | |
y1 = max(0, y - pad) | |
x2 = min(img.shape[1] - 1, x + w + pad) | |
y2 = min(img.shape[0] - 1, y + h + pad) | |
return [x1, y1, x2, y2] | |
except Exception as e: | |
print(f"_find_largest_foreground_bbox failed: {e}") | |
W, H = pil_img.size | |
return [0, 0, W - 1, H - 1] | |
def _find_topk_foreground_bboxes(pil_img: Image.Image, max_regions: int = 20, min_area: int = 100): | |
"""Find top-K foreground bboxes via Otsu threshold + morphology. Returns list of [x1,y1,x2,y2].""" | |
try: | |
img = np.array(pil_img.convert('RGB')) | |
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
_, th = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
if th.mean() > 127: | |
th = 255 - th | |
kernel = np.ones((3, 3), np.uint8) | |
th = cv2.morphologyEx(th, cv2.MORPH_OPEN, kernel, iterations=1) | |
th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2) | |
contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if not contours: | |
return [] | |
contours = sorted(contours, key=cv2.contourArea, reverse=True) | |
bboxes = [] | |
H, W = img.shape[:2] | |
for cnt in contours: | |
area = cv2.contourArea(cnt) | |
if area < min_area: | |
continue | |
x, y, w, h = cv2.boundingRect(cnt) | |
# Filter very thin shapes | |
if w < 5 or h < 5: | |
continue | |
pad = int(0.01 * max(w, h)) | |
x1 = max(0, x - pad) | |
y1 = max(0, y - pad) | |
x2 = min(W - 1, x + w + pad) | |
y2 = min(H - 1, y + h + pad) | |
bboxes.append([x1, y1, x2, y2]) | |
if len(bboxes) >= max_regions: | |
break | |
return bboxes | |
except Exception as e: | |
print(f"_find_topk_foreground_bboxes failed: {e}") | |
return [] | |
# Try to import mmdet for inference | |
try: | |
from mmdet.apis import init_detector, inference_detector | |
MM_DET_AVAILABLE = True | |
print("β MMDetection available for inference") | |
except ImportError as e: | |
print(f"β οΈ MMDetection import failed: {e}") | |
print("β MMDetection not available - install in Dockerfile") | |
MM_DET_AVAILABLE = False | |
# === Chart Type Classification (DocFigure) === | |
print("π Loading Chart Classification Model...") | |
# Chart type labels from DocFigure dataset (28 classes) | |
CHART_TYPE_LABELS = [ | |
'Line graph', 'Natural image', 'Table', '3D object', 'Bar plot', 'Scatter plot', | |
'Medical image', 'Sketch', 'Geographic map', 'Flow chart', 'Heat map', 'Mask', | |
'Block diagram', 'Venn diagram', 'Confusion matrix', 'Histogram', 'Box plot', | |
'Vector plot', 'Pie chart', 'Surface plot', 'Algorithm', 'Contour plot', | |
'Tree diagram', 'Bubble chart', 'Polar plot', 'Area chart', 'Pareto chart', 'Radar chart' | |
] | |
try: | |
# Load the chart_type.pth model file from Hugging Face Hub | |
from huggingface_hub import hf_hub_download | |
from torchvision import transforms | |
print("π Downloading chart_type.pth from Hugging Face Hub...") | |
chart_type_path = hf_hub_download( | |
repo_id="hanszhu/ChartTypeNet-DocFigure", | |
filename="chart_type.pth", | |
cache_dir=HF_CACHE_DIR | |
) | |
print(f"β Downloaded to: {chart_type_path}") | |
# Load the PyTorch model | |
loaded_data = torch.load(chart_type_path, map_location='cpu') | |
# Check if it's a state dict or a complete model | |
if isinstance(loaded_data, dict): | |
# Check if it's a checkpoint with model_state_dict | |
if "model_state_dict" in loaded_data: | |
print("π Loading checkpoint, extracting model_state_dict...") | |
state_dict = loaded_data["model_state_dict"] | |
else: | |
# It's a direct state dict | |
print("π Loading state dict, creating model architecture...") | |
state_dict = loaded_data | |
# Strip "backbone." prefix from state dict keys if present | |
cleaned_state_dict = {} | |
for key, value in state_dict.items(): | |
if key.startswith("backbone."): | |
# Remove "backbone." prefix | |
new_key = key[9:] | |
cleaned_state_dict[new_key] = value | |
else: | |
cleaned_state_dict[key] = value | |
print(f"π Cleaned state dict: {len(cleaned_state_dict)} keys") | |
# Create the model architecture | |
from torchvision.models import resnet50 | |
chart_type_model = resnet50(pretrained=False) | |
# Create the correct classifier structure to match the state dict | |
import torch.nn as nn | |
in_features = chart_type_model.fc.in_features | |
dropout = nn.Dropout(0.5) | |
chart_type_model.fc = nn.Sequential( | |
nn.Linear(in_features, 512), | |
nn.ReLU(inplace=True), | |
dropout, | |
nn.Linear(512, 28) | |
) | |
# Load the cleaned state dict | |
chart_type_model.load_state_dict(cleaned_state_dict) | |
else: | |
# It's a complete model | |
chart_type_model = loaded_data | |
chart_type_model.eval() | |
# Create a simple processor for the model | |
chart_type_processor = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
CHART_TYPE_AVAILABLE = True | |
print("β Chart classification model loaded") | |
except Exception as e: | |
print(f"β οΈ Failed to load chart classification model: {e}") | |
import traceback | |
print("π Full traceback:") | |
traceback.print_exc() | |
CHART_TYPE_AVAILABLE = False | |
# === Chart Element Detection (Cascade R-CNN) === | |
element_model = None | |
datapoint_model = None | |
print(f"π MM_DET_AVAILABLE: {MM_DET_AVAILABLE}") | |
if MM_DET_AVAILABLE: | |
# Check if config files exist | |
element_config = "models/chart_elementnet_swin.py" | |
point_config = "models/chart_pointnet_swin.py" | |
print(f"π Checking config files...") | |
print(f"π Element config exists: {os.path.exists(element_config)}") | |
print(f"π Point config exists: {os.path.exists(point_config)}") | |
print(f"π Current working directory: {os.getcwd()}") | |
print(f"π Files in models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}") | |
try: | |
print("π Loading ChartElementNet-MultiClass (Cascade R-CNN)...") | |
print(f"π Config path: {element_config}") | |
print(f"π Weights path: hanszhu/ChartElementNet-MultiClass") | |
print(f"π About to call init_detector...") | |
# Download model from Hugging Face Hub | |
from huggingface_hub import hf_hub_download | |
print("π Downloading ChartElementNet weights from Hugging Face Hub...") | |
element_checkpoint = hf_hub_download( | |
repo_id="hanszhu/ChartElementNet-MultiClass", | |
filename="chart_label+.pth", | |
cache_dir=HF_CACHE_DIR | |
) | |
print(f"β Downloaded to: {element_checkpoint}") | |
# Use local config with downloaded weights | |
element_model = init_detector(element_config, element_checkpoint, device="cpu") | |
print("β ChartElementNet loaded successfully") | |
except Exception as e: | |
print(f"β Failed to load ChartElementNet: {e}") | |
print(f"π Error type: {type(e).__name__}") | |
print(f"π Error details: {str(e)}") | |
import traceback | |
print("π Full traceback:") | |
traceback.print_exc() | |
try: | |
print("π Loading ChartPointNet-InstanceSeg (Mask R-CNN)...") | |
print(f"π Config path: {point_config}") | |
print(f"π Weights path: hanszhu/ChartPointNet-InstanceSeg") | |
print(f"π About to call init_detector...") | |
# Download model from Hugging Face Hub | |
print("π Downloading ChartPointNet weights from Hugging Face Hub...") | |
datapoint_checkpoint = hf_hub_download( | |
repo_id="hanszhu/ChartPointNet-InstanceSeg", | |
filename="chart_datapoint.pth", | |
cache_dir=HF_CACHE_DIR | |
) | |
print(f"β Downloaded to: {datapoint_checkpoint}") | |
# Use local config with downloaded weights | |
datapoint_model = init_detector(point_config, datapoint_checkpoint, device="cpu") | |
print("β ChartPointNet loaded successfully") | |
except Exception as e: | |
print(f"β Failed to load ChartPointNet: {e}") | |
print(f"π Error type: {type(e).__name__}") | |
print(f"π Error details: {str(e)}") | |
import traceback | |
print("π Full traceback:") | |
traceback.print_exc() | |
else: | |
print("β MMDetection not available - cannot load custom models") | |
print(f"π MM_DET_AVAILABLE was False") | |
print(f"π Final model status:") | |
print(f"π element_model: {element_model is not None}") | |
print(f"π datapoint_model: {datapoint_model is not None}") | |
# === Main prediction function === | |
def analyze(image): | |
try: | |
log("analyze: start") | |
start_time = time.time() | |
# Handle filepath input | |
if isinstance(image, str): | |
image = Image.open(image).convert("RGB") | |
elif image is None: | |
return {"error": "No image provided"} | |
if not isinstance(image, Image.Image): | |
return {"error": "Invalid image format"} | |
result = { | |
"chart_type_id": "Model not available", | |
"chart_type_label": "Model not available", | |
"element_result": "MMDetection models not available", | |
"datapoint_result": "MMDetection models not available", | |
"status": "Basic chart classification only", | |
"processing_time": 0.0, | |
"medsam": {"available": False} | |
} | |
# Chart Type Classification | |
if CHART_TYPE_AVAILABLE: | |
try: | |
processed_image = chart_type_processor(image).unsqueeze(0) | |
with torch.no_grad(): | |
outputs = chart_type_model(processed_image) | |
logits = outputs if isinstance(outputs, torch.Tensor) else getattr(outputs, 'logits', outputs) | |
predicted_class = logits.argmax(dim=-1).item() | |
result["chart_type_id"] = predicted_class | |
result["chart_type_label"] = CHART_TYPE_LABELS[predicted_class] if 0 <= predicted_class < len(CHART_TYPE_LABELS) else f"Unknown ({predicted_class})" | |
result["status"] = "Chart classification completed" | |
log(f"analyze: chart_type={result['chart_type_label']} ({result['chart_type_id']})") | |
except Exception: | |
log("analyze: chart classification error") | |
traceback.print_exc() | |
is_medical = str(result.get("chart_type_label", "")).strip().lower() == "medical image" | |
# Element Detection (skip for medical images) | |
if element_model is not None and not is_medical: | |
try: | |
np_img = np.array(image.convert("RGB"))[:, :, ::-1] | |
element_result = inference_detector(element_model, np_img) | |
if isinstance(element_result, tuple): | |
bbox_result, segm_result = element_result | |
element_data = { | |
"bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result), | |
"segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result) | |
} | |
else: | |
element_data = str(element_result) | |
result["element_result"] = element_data | |
result["status"] = "Chart classification + element detection completed" | |
log("analyze: element detection done") | |
except Exception: | |
log("analyze: element detection error") | |
traceback.print_exc() | |
elif is_medical: | |
result["element_result"] = "skipped for medical image" | |
# Datapoint Segmentation (skip for medical images) | |
if datapoint_model is not None and not is_medical: | |
try: | |
np_img = np.array(image.convert("RGB"))[:, :, ::-1] | |
datapoint_result = inference_detector(datapoint_model, np_img) | |
if isinstance(datapoint_result, tuple): | |
bbox_result, segm_result = datapoint_result | |
datapoint_data = { | |
"bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result), | |
"segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result) | |
} | |
else: | |
datapoint_data = str(datapoint_result) | |
result["datapoint_result"] = datapoint_data | |
result["status"] = "Full analysis completed" | |
log("analyze: datapoint segmentation done") | |
except Exception: | |
log("analyze: datapoint segmentation error") | |
traceback.print_exc() | |
elif is_medical: | |
result["datapoint_result"] = "skipped for medical image" | |
# MedSAM availability info | |
try: | |
label_lower = str(result.get("chart_type_label", "")).strip().lower() | |
if label_lower == "medical image": | |
if _medsam.is_available(): | |
result["medsam"] = {"available": True} | |
else: | |
result["medsam"] = {"available": False, "reason": "segment_anything or checkpoint missing"} | |
except Exception: | |
log("analyze: medsam availability annotation error") | |
traceback.print_exc() | |
result["processing_time"] = round(time.time() - start_time, 3) | |
log(f"analyze: end in {result['processing_time']}s") | |
return result | |
except Exception: | |
log("analyze: fatal error") | |
traceback.print_exc() | |
return {"error": "Internal error in analyze"} | |
def analyze_with_medsam(base_result, image, include_raw_masks=False, bboxes_json="", points_json=""): | |
try: | |
log("analyze_with_medsam: start") | |
if not isinstance(base_result, dict): | |
return base_result, None | |
label = str(base_result.get("chart_type_label", "")).strip().lower() | |
if label != "medical image" or not _medsam.is_available(): | |
log("analyze_with_medsam: skip (non-medical or MedSAM unavailable)") | |
return base_result, None | |
pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image | |
if pil_img is None: | |
return base_result, None | |
img_path = image if isinstance(image, str) else None | |
if img_path is None: | |
tmp_path = "./_tmp_input_image.png" | |
pil_img.save(tmp_path) | |
img_path = tmp_path | |
_medsam.load_image(img_path) | |
# Parse prompts | |
parsed_bboxes = [] | |
parsed_points = [] | |
try: | |
if bboxes_json: | |
parsed_bboxes = json.loads(bboxes_json) | |
if points_json: | |
parsed_points = json.loads(points_json) | |
except Exception: | |
log("analyze_with_medsam: failed to parse prompts JSON") | |
# If no prompts provided, skip (follow original behavior) | |
if not parsed_bboxes and not parsed_points: | |
log("analyze_with_medsam: no prompts provided; skipping segmentation") | |
return base_result, None | |
segmentations = [] | |
masks_for_overlay = [] | |
# Run MedSAM for provided boxes | |
for bbox in parsed_bboxes: | |
if not isinstance(bbox, (list, tuple)) or len(bbox) != 4: | |
continue | |
m = _medsam.segment_with_box(bbox) | |
if m is None or not isinstance(m.get('mask'), np.ndarray): | |
continue | |
mask_np = m['mask'].astype(np.uint8) | |
seg_entry = { | |
"confidence": float(m.get('confidence', 1.0)), | |
"method": m.get("method", "medsam_box"), | |
"polygons": _mask_to_polygons(mask_np) | |
} | |
if include_raw_masks: | |
seg_entry["mask"] = mask_np.tolist() | |
segmentations.append(seg_entry) | |
masks_for_overlay.append(m) | |
# Run MedSAM for provided points by converting to bbox | |
for item in parsed_points: | |
try: | |
# Expect item like {"points": [[x,y],...]} or [ [x,y], ... ] | |
pts = item.get("points") if isinstance(item, dict) else item | |
pts_np = np.array(pts) | |
x_min, y_min = pts_np.min(axis=0) | |
x_max, y_max = pts_np.max(axis=0) | |
pad = 20 | |
H, W = _medsam.current_image.shape[:2] | |
bbox = [max(0, x_min - pad), max(0, y_min - pad), min(W - 1, x_max + pad), min(H - 1, y_max + pad)] | |
m = _medsam.segment_with_box(bbox) | |
if m is None or not isinstance(m.get('mask'), np.ndarray): | |
continue | |
mask_np = m['mask'].astype(np.uint8) | |
seg_entry = { | |
"confidence": float(m.get('confidence', 1.0)), | |
"method": m.get("method", "medsam_points_box"), | |
"polygons": _mask_to_polygons(mask_np) | |
} | |
if include_raw_masks: | |
seg_entry["mask"] = mask_np.tolist() | |
segmentations.append(seg_entry) | |
masks_for_overlay.append(m) | |
except Exception: | |
continue | |
W, H = pil_img.size | |
base_result["medsam"] = { | |
"available": True, | |
"height": H, | |
"width": W, | |
"segmentations": segmentations, | |
"num_segments": len(segmentations) | |
} | |
log(f"analyze_with_medsam: segments={len(segmentations)}") | |
overlay_img = _overlay_masks_on_image(pil_img, masks_for_overlay) if masks_for_overlay else None | |
log("analyze_with_medsam: end") | |
return base_result, overlay_img | |
except Exception: | |
log("analyze_with_medsam: fatal error") | |
traceback.print_exc() | |
return base_result, None | |
# === Gradio UI with API enhancements === | |
# Create Blocks interface with explicit API name for stable API surface | |
with gr.Blocks( | |
title="π Dense Captioning Platform" | |
) as demo: | |
gr.Markdown("# π Dense Captioning Platform") | |
gr.Markdown(""" | |
**Comprehensive Chart Analysis API** | |
Upload a chart image to get: | |
- **Chart Type Classification**: Identifies the type of chart (line, bar, scatter, etc.) | |
- **Element Detection**: Detects chart elements like titles, axes, legends, data points | |
- **Data Point Segmentation**: Segments individual data points and regions | |
Masks will be automatically generated for medical images when supported. | |
**API Usage:** | |
```python | |
from gradio_client import Client, handle_file | |
client = Client("hanszhu/Dense-Captioning-Platform") | |
result = client.predict( | |
image=handle_file('path/to/your/chart.png'), | |
api_name="/predict" | |
) | |
print(result) | |
``` | |
**Supported Chart Types:** Line graphs, Bar plots, Scatter plots, Pie charts, Heat maps, and 23+ more | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
# Input | |
image_input = gr.Image( | |
type="filepath", # β REQUIRED for gradio_client | |
label="Upload Chart Image", | |
height=400, | |
elem_id="image-input" | |
) | |
include_raw_masks_cb = gr.Checkbox(value=False, visible=False, elem_id="include-raw-masks") | |
bboxes_tb = gr.Textbox(value="", visible=False, elem_id="bboxes-json") | |
points_tb = gr.Textbox(value="", visible=False, elem_id="points-json") | |
# Analyze button (single) | |
analyze_btn = gr.Button( | |
"π Analyze", | |
variant="primary", | |
size="lg", | |
elem_id="analyze-btn" | |
) | |
with gr.Column(): | |
# Output JSON | |
result_output = gr.JSON( | |
label="Analysis Results", | |
height=400, | |
elem_id="result-output" | |
) | |
# Overlay image output (populated only for medical images) | |
overlay_output = gr.Image( | |
label="MedSAM Overlay (Medical images)", | |
height=400, | |
elem_id="overlay-output" | |
) | |
# Single API endpoint for JSON | |
analyze_event = analyze_btn.click( | |
fn=analyze, | |
inputs=image_input, | |
outputs=result_output, | |
api_name="/predict" # β Standard API name that gradio_client expects | |
) | |
# MedSAM step (prompt-only). If no prompts, it will skip | |
analyze_event.then( | |
fn=analyze_with_medsam, | |
inputs=[result_output, image_input, include_raw_masks_cb, bboxes_tb, points_tb], | |
outputs=[result_output, overlay_output], | |
api_name="/medsam" | |
) | |
# Add some examples | |
gr.Examples( | |
examples=[ | |
["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"] | |
], | |
inputs=image_input, | |
label="Try with this example" | |
) | |
# Launch with API-friendly settings | |
if __name__ == "__main__": | |
launch_kwargs = { | |
"server_name": "0.0.0.0", # Allow external connections | |
"server_port": 7860, | |
"share": False, # Set to True if you want a public link | |
"show_error": True, # Show detailed errors for debugging | |
"quiet": False, # Show startup messages | |
"show_api": True, # Enable API documentation | |
"ssr_mode": False # Disable experimental SSR in Docker env | |
} | |
# Enable queue for gradio_client compatibility | |
demo.queue().launch(**launch_kwargs) # β required for gradio_client to work |