mlbench123's picture
Create app.py
4f02379 verified
import gradio as gr
import cv2
import numpy as np
import io
import os
import zipfile
import tempfile
from PIL import Image
import matplotlib
matplotlib.use("Agg")
# ─── Cellpose model (lazy) ────────────────────────────────────────────────────
_model = None
def get_model():
global _model
if _model is None:
from cellpose import models
from huggingface_hub import hf_hub_download
fpath = hf_hub_download(repo_id="mouseland/cellpose-sam", filename="cpsam")
_model = models.CellposeModel(gpu=False, pretrained_model=fpath)
return _model
# ─── Image helpers ────────────────────────────────────────────────────────────
def normalize99(img):
X = img.copy().astype(np.float32)
p1, p99 = np.percentile(X, 1), np.percentile(X, 99)
return (X - p1) / (1e-10 + p99 - p1)
def image_resize(img, resize=1000):
ny, nx = img.shape[:2]
if max(ny, nx) > resize:
if ny > nx:
nx = int(nx / ny * resize); ny = resize
else:
ny = int(ny / nx * resize); nx = resize
img = cv2.resize(img, (nx, ny))
return img.astype(np.uint8)
def run_cellpose(img, model, flow_threshold=0.4, cellprob_threshold=0.0):
masks, flows, _ = model.eval(
img, niter=250,
flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold,
)
return masks
# ─── YOLO Annotation Exporter ─────────────────────────────────────────────────
def export_yolo_annotations(masks, img_shape, class_id=0):
"""
Converts Cellpose masks β†’ YOLO segmentation format.
YOLO segmentation line format:
class_id x1 y1 x2 y2 ... (all normalized 0–1)
class_id = 0 β†’ 'grain' (you will split into broken/whole on Roboflow)
"""
h, w = img_shape[:2]
lines = []
num_grains = int(masks.max())
for i in range(1, num_grains + 1):
# Binary mask for this single grain
single = (masks == i).astype(np.uint8)
contours, _ = cv2.findContours(single, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
continue
# Pick the largest contour (in case of tiny noise)
c = max(contours, key=cv2.contourArea)
c = c.squeeze()
if c.ndim < 2 or len(c) < 4:
continue
# Normalize each point to [0, 1]
norm_pts = []
for x, y in c:
norm_pts.append(round(float(x) / w, 6))
norm_pts.append(round(float(y) / h, 6))
pts_str = " ".join(map(str, norm_pts))
lines.append(f"{class_id} {pts_str}")
return "\n".join(lines), num_grains
def make_preview(img_np, masks):
"""Draw red outlines of all grain masks on the image for preview."""
preview = img_np.copy()
for i in range(1, int(masks.max()) + 1):
single = (masks == i).astype(np.uint8)
contours, _ = cv2.findContours(single, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(preview, contours, -1, (220, 38, 38), 2)
return Image.fromarray(preview)
# ─── Main batch processor ─────────────────────────────────────────────────────
def process_batch(image_files, flow_threshold, cellprob_threshold, progress=gr.Progress()):
"""
Takes a list of uploaded image file paths.
Returns:
- Gallery of preview images (with outlines)
- Summary text
- Path to downloadable ZIP
"""
if not image_files:
return [], "⚠️ No images uploaded.", None
model = get_model()
previews = [] # (PIL image, caption) for gallery
log_lines = []
total_grains = 0
failed = []
# Temp folder to collect annotation files
tmp_dir = tempfile.mkdtemp()
images_dir = os.path.join(tmp_dir, "images")
labels_dir = os.path.join(tmp_dir, "labels")
os.makedirs(images_dir, exist_ok=True)
os.makedirs(labels_dir, exist_ok=True)
for idx, file_obj in enumerate(progress.tqdm(image_files, desc="Processing images")):
# file_obj is a filepath string when using gr.File with type="filepath"
filepath = file_obj if isinstance(file_obj, str) else file_obj.name
fname = os.path.splitext(os.path.basename(filepath))[0]
try:
pil_img = Image.open(filepath).convert("RGB")
img_np = np.array(pil_img)
img_np = image_resize(img_np, resize=1000)
masks = run_cellpose(img_np, model,
flow_threshold=float(flow_threshold),
cellprob_threshold=float(cellprob_threshold))
num_grains = int(masks.max())
if num_grains == 0:
log_lines.append(f"⚠️ [{idx+1}] {fname} β€” No grains detected, skipped.")
failed.append(fname)
continue
# Export YOLO annotation txt
annotation_txt, _ = export_yolo_annotations(masks, img_np.shape, class_id=0)
txt_path = os.path.join(labels_dir, f"{fname}.txt")
with open(txt_path, "w") as f:
f.write(annotation_txt)
# Save image to images/
img_save_path = os.path.join(images_dir, f"{fname}.jpg")
Image.fromarray(img_np).save(img_save_path, quality=95)
# Make preview
preview_pil = make_preview(img_np, masks)
previews.append((preview_pil, f"{fname} β€” {num_grains} grains"))
total_grains += num_grains
log_lines.append(f"βœ… [{idx+1}] {fname} β€” {num_grains} grains annotated.")
except Exception as e:
log_lines.append(f"❌ [{idx+1}] {fname} β€” Error: {str(e)}")
failed.append(fname)
# ── Write data.yaml ───────────────────────────────────────────────────────
yaml_content = (
"# YOLO Dataset β€” Rice Grain Segmentation\n"
"# Generated by MLBench Annotation Tool\n\n"
"path: ./dataset\n"
"train: images/train\n"
"val: images/val\n\n"
"nc: 2\n"
"names:\n"
" 0: whole_grain\n"
" 1: broken_grain\n\n"
"# NOTE: All grains are currently class 0 (whole_grain).\n"
"# Upload to Roboflow and re-label broken grains as class 1.\n"
)
with open(os.path.join(tmp_dir, "data.yaml"), "w") as f:
f.write(yaml_content)
# ── Write README ──────────────────────────────────────────────────────────
readme = (
"# Rice Grain YOLO Dataset\n\n"
"## Folder Structure\n"
"```\n"
"dataset/\n"
" images/ ← your rice photos (.jpg)\n"
" labels/ ← YOLO polygon annotations (.txt)\n"
" data.yaml ← class config for YOLO training\n"
"```\n\n"
"## Label Format (YOLO Segmentation)\n"
"Each .txt file has one line per grain:\n"
"```\n"
"class_id x1 y1 x2 y2 x3 y3 ... (normalized 0–1)\n"
"```\n\n"
"## Classes\n"
"| ID | Name |\n"
"|----|-------------|\n"
"| 0 | whole_grain |\n"
"| 1 | broken_grain |\n\n"
"## Next Steps\n"
"1. Upload this zip to **Roboflow** (Import > YOLOv8 Segmentation format)\n"
"2. Re-label broken grains as class `1` in Roboflow\n"
"3. Export from Roboflow as YOLOv8 format\n"
"4. Train: `yolo segment train data=data.yaml model=yolov8n-seg.pt epochs=100`\n"
)
with open(os.path.join(tmp_dir, "README.md"), "w") as f:
f.write(readme)
# ── Package as ZIP ────────────────────────────────────────────────────────
zip_path = os.path.join(tempfile.mkdtemp(), "rice_yolo_dataset.zip")
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(tmp_dir):
for file in files:
full_path = os.path.join(root, file)
arcname = os.path.relpath(full_path, tmp_dir)
zf.write(full_path, arcname)
# ── Summary ───────────────────────────────────────────────────────────────
ok_count = len(image_files) - len(failed)
summary = (
f"### βœ… Done!\n"
f"- **{ok_count} / {len(image_files)}** images processed\n"
f"- **{total_grains}** total grains annotated\n"
f"- **{len(failed)}** failed: {', '.join(failed) if failed else 'none'}\n\n"
"**Download the ZIP below β†’ upload to Roboflow β†’ label broken grains β†’ train YOLO!**\n\n"
"---\n" + "\n".join(log_lines)
)
return previews, summary, zip_path
# ─── UI ───────────────────────────────────────────────────────────────────────
CSS = """
body { font-family: 'IBM Plex Mono', monospace; }
#header {
background: #0F172A;
padding: 20px 24px 14px;
border-radius: 10px;
margin-bottom: 12px;
}
#run-btn { margin-top: 8px; background: #7C3AED !important; }
#dl-btn { margin-top: 6px; }
.gr-gallery-item img { border-radius: 6px; }
"""
THEME = gr.themes.Soft(
primary_hue="violet",
secondary_hue="indigo",
neutral_hue="slate",
)
with gr.Blocks(theme=THEME, css=CSS, title="Rice YOLO Annotator") as demo:
gr.HTML("""
<div id="header">
<span style="font-size:1.9rem;font-weight:900;color:#F1F5F9;font-family:monospace;">
ML<span style="color:#EF4444;">Bench</span>
<span style="font-size:1rem;font-weight:400;color:#94A3B8;margin-left:12px;">
Rice Grain β†’ YOLO Annotation Exporter
</span>
</span>
<p style="color:#64748B;font-size:0.85rem;margin-top:6px;font-family:monospace;">
Upload up to 50 images Β· Cellpose segments each grain Β·
Download ZIP with YOLO labels ready for Roboflow
</p>
</div>
""")
with gr.Row():
# ── LEFT ──────────────────────────────────────────────────────────────
with gr.Column(scale=1):
gr.Markdown("### πŸ“‚ Upload Images")
image_input = gr.File(
file_count="multiple",
file_types=["image"],
label="Drop up to 50 rice images here",
height=180,
)
with gr.Accordion("βš™οΈ Cellpose Settings", open=False):
flow_thresh = gr.Slider(
0.0, 1.0, value=0.4, step=0.05,
label="Flow Threshold",
info="Higher = stricter (fewer false grains)"
)
cellprob_thresh = gr.Slider(
-4.0, 4.0, value=0.0, step=0.5,
label="Cell Probability Threshold",
info="Lower = detect more grains"
)
run_btn = gr.Button(
"πŸš€ Run Cellpose & Export Annotations",
variant="primary", size="lg", elem_id="run-btn"
)
gr.Markdown("""
### πŸ“‹ Workflow
1. Upload 50 images here
2. Click **Run** β€” Cellpose segments every grain
3. Download the ZIP
4. Upload ZIP to **Roboflow** (format: YOLOv8 Segmentation)
5. Re-label broken grains as `broken_grain` class
6. Export & train YOLOv8!
""")
download_btn = gr.File(
label="⬇️ Download YOLO Dataset ZIP",
interactive=False,
elem_id="dl-btn",
)
# ── RIGHT ─────────────────────────────────────────────────────────────
with gr.Column(scale=2):
gr.Markdown("### πŸ” Segmentation Previews")
gallery = gr.Gallery(
label="",
show_label=False,
columns=3,
height=460,
object_fit="contain",
)
summary_box = gr.Markdown(
value="_Results will appear here after processing..._"
)
run_btn.click(
fn=process_batch,
inputs=[image_input, flow_thresh, cellprob_thresh],
outputs=[gallery, summary_box, download_btn],
)
if __name__ == "__main__":
demo.launch(share=True)