Spaces:
Running
Running
import gradio as gr | |
import os | |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU if needed | |
import torch | |
import numpy as np | |
from PIL import Image | |
from PIL import Image as PILImage | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
import io | |
from skimage.io import imread | |
from skimage.color import rgb2gray | |
from csbdeep.utils import normalize | |
from stardist.models import StarDist2D | |
from stardist.plot import render_label | |
from MEDIARFormer import MEDIARFormer | |
from Predictor import Predictor | |
from cellpose import models as cellpose_models, io as cellpose_io, plot as cellpose_plot | |
# Load SegFormer | |
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation | |
processor_segformer = SegformerImageProcessor(do_reduce_labels=False) | |
model_segformer = SegformerForSemanticSegmentation.from_pretrained( | |
"nvidia/segformer-b0-finetuned-ade-512-512", | |
num_labels=8, | |
ignore_mismatched_sizes=True | |
) | |
model_segformer.load_state_dict(torch.load("trained_model_200.pt", map_location="cpu")) | |
model_segformer.eval() | |
# StarDist model | |
model_stardist = StarDist2D.from_pretrained('2D_versatile_fluo') | |
# Cellpose model | |
model_cellpose = cellpose_models.CellposeModel(gpu=False) | |
# Handle SegFormer prediction | |
def infer_segformer(image): | |
image = image.convert("RGB") | |
inputs = processor_segformer(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model_segformer(**inputs).logits | |
pred_mask = torch.argmax(logits, dim=1)[0].cpu().numpy() | |
# Colorize | |
colors = np.array([[0,0,0], [255,0,0], [0,255,0], [0,0,255], [255,255,0], [255,0,255], [0,255,255], [128,128,128]]) | |
color_mask = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8) | |
for c in range(8): | |
color_mask[pred_mask == c] = colors[c] | |
return image, Image.fromarray(color_mask) | |
# Handle StarDist prediction | |
def infer_stardist(image): | |
image_gray = rgb2gray(np.array(image)) if image.mode == 'RGB' else np.array(image) | |
labels, _ = model_stardist.predict_instances(normalize(image_gray)) | |
overlay = render_label(labels, img=image_gray) | |
overlay = (overlay[..., :3] * 255).astype(np.uint8) | |
return image, Image.fromarray(overlay) | |
# Handle MEDIAR prediction | |
def infer_mediar(image, temp_dir="temp_mediar"): | |
os.makedirs(temp_dir, exist_ok=True) | |
input_path = os.path.join(temp_dir, "input_image.tiff") | |
output_path = os.path.join(temp_dir, "input_image_label.tiff") | |
image.save(input_path) | |
model_args = { | |
"classes": 3, | |
"decoder_channels": [1024, 512, 256, 128, 64], | |
"decoder_pab_channels": 256, | |
"encoder_name": 'mit_b5', | |
"in_channels": 3 | |
} | |
model = MEDIARFormer(**model_args) | |
weights = torch.load("from_phase1.pth", map_location="cpu") | |
model.load_state_dict(weights, strict=False) | |
model.eval() | |
predictor = Predictor(model, "cpu", temp_dir, temp_dir, algo_params={"use_tta": False}) | |
predictor.img_names = ["input_image.tiff"] | |
_ = predictor.conduct_prediction() | |
pred = imread(output_path) | |
fig, ax = plt.subplots(figsize=(6, 6)) | |
ax.imshow(pred, cmap="cividis") | |
ax.axis("off") | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png") | |
plt.close() | |
buf.seek(0) | |
return image, Image.open(buf) | |
# Handle Cellpose prediction | |
def infer_cellpose(image, temp_dir="temp_cellpose"): | |
os.makedirs(temp_dir, exist_ok=True) | |
input_path = os.path.join(temp_dir, "input_image.tif") | |
output_overlay = os.path.join(temp_dir, "overlay.png") | |
# Save image | |
image.save(input_path) | |
img = cellpose_io.imread(input_path) | |
masks, flows, styles = model_cellpose.eval(img, batch_size=1) | |
fig = plt.figure(figsize=(12,5)) | |
cellpose_plot.show_segmentation(fig, img, masks, flows[0]) | |
plt.tight_layout() | |
fig.savefig(output_overlay) | |
plt.close(fig) | |
return image, Image.open(output_overlay) | |
# Wrapper function | |
def segment(model_name, image): | |
# Gradio passes a PIL.Image without filename attribute | |
# Try to check format if available, else skip check | |
ext = None | |
if hasattr(image, 'format') and image.format is not None: | |
ext = image.format.lower() | |
if model_name == "Cellpose": | |
# Accept only TIFF images for Cellpose | |
if ext not in ["tiff", "tif", None]: | |
return None, f"❌ Cellpose only supports `.tif` or `.tiff` images." | |
# ...existing code... | |
if model_name == "SegFormer": | |
return infer_segformer(image) | |
elif model_name == "StarDist": | |
return infer_stardist(image) | |
elif model_name == "MEDIAR": | |
return infer_mediar(image) | |
elif model_name == "Cellpose": | |
return infer_cellpose(image) | |
else: | |
return None, f"❌ Unknown model: {model_name}" | |
with gr.Blocks(title="Cell Segmentation Explorer") as app: | |
gr.Markdown("## Cell Segmentation Explorer") | |
gr.Markdown("Choose a segmentation model, upload an appropriate image, and view the predicted mask.") | |
with gr.Row(): | |
with gr.Column(): | |
model_dropdown = gr.Dropdown( | |
choices=["SegFormer", "StarDist", "MEDIAR", "Cellpose"], | |
label="Select Segmentation Model", | |
value="SegFormer" | |
) | |
image_input = gr.Image(type="pil", label="Uploaded Image") | |
description_box = gr.Markdown("Accepted formats: `.png`, `.jpg`, `.tif`, `.tiff`.") | |
submit_btn = gr.Button("Submit") | |
clear_btn = gr.Button("Clear") | |
with gr.Column(): | |
output_image = gr.Image(label="Segmentation Result") | |
def handle_submit(model_name, img): | |
if img is None: | |
return None | |
_, result = segment(model_name, img) # Only return the mask (segmentation result) | |
return result | |
submit_btn.click( | |
fn=handle_submit, | |
inputs=[model_dropdown, image_input], | |
outputs=output_image | |
) | |
clear_btn.click( | |
lambda: [None, None], | |
inputs=None, | |
outputs=[image_input, output_image] | |
) | |
# === SAMPLE IMAGES SECTION === | |
gr.Markdown("---") | |
gr.Markdown("### Sample Images (click to use as input)") | |
# Original and resized thumbnails | |
original_sample_paths = [ | |
"img1.png", | |
"img2.png", | |
"img3.png" | |
] | |
resized_sample_paths = [] | |
for idx, p in enumerate(original_sample_paths): | |
img = PILImage.open(p).resize((128, 128)) | |
temp_path = f"/tmp/sample_resized_{idx}.png" | |
img.save(temp_path) | |
resized_sample_paths.append(temp_path) | |
sample_image_components = [] | |
with gr.Row(): | |
for i, img_path in enumerate(resized_sample_paths): | |
def load_full_image(idx=i): # Capture loop index properly | |
return PILImage.open(original_sample_paths[idx]) | |
sample_img = gr.Image(value=img_path, type="pil", interactive=True, show_label=False) | |
sample_img.select( | |
fn=load_full_image, | |
inputs=[], | |
outputs=image_input | |
) | |
sample_image_components.append(sample_img) | |
app.launch() |