mlbench123 commited on
Commit
4f02379
Β·
verified Β·
1 Parent(s): 416ee51

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +338 -0
app.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import io
5
+ import os
6
+ import zipfile
7
+ import tempfile
8
+ from PIL import Image
9
+ import matplotlib
10
+ matplotlib.use("Agg")
11
+
12
+ # ─── Cellpose model (lazy) ────────────────────────────────────────────────────
13
+ _model = None
14
+
15
+ def get_model():
16
+ global _model
17
+ if _model is None:
18
+ from cellpose import models
19
+ from huggingface_hub import hf_hub_download
20
+ fpath = hf_hub_download(repo_id="mouseland/cellpose-sam", filename="cpsam")
21
+ _model = models.CellposeModel(gpu=False, pretrained_model=fpath)
22
+ return _model
23
+
24
+ # ─── Image helpers ────────────────────────────────────────────────────────────
25
+ def normalize99(img):
26
+ X = img.copy().astype(np.float32)
27
+ p1, p99 = np.percentile(X, 1), np.percentile(X, 99)
28
+ return (X - p1) / (1e-10 + p99 - p1)
29
+
30
+ def image_resize(img, resize=1000):
31
+ ny, nx = img.shape[:2]
32
+ if max(ny, nx) > resize:
33
+ if ny > nx:
34
+ nx = int(nx / ny * resize); ny = resize
35
+ else:
36
+ ny = int(ny / nx * resize); nx = resize
37
+ img = cv2.resize(img, (nx, ny))
38
+ return img.astype(np.uint8)
39
+
40
+ def run_cellpose(img, model, flow_threshold=0.4, cellprob_threshold=0.0):
41
+ masks, flows, _ = model.eval(
42
+ img, niter=250,
43
+ flow_threshold=flow_threshold,
44
+ cellprob_threshold=cellprob_threshold,
45
+ )
46
+ return masks
47
+
48
+ # ─── YOLO Annotation Exporter ─────────────────────────────────────────────────
49
+ def export_yolo_annotations(masks, img_shape, class_id=0):
50
+ """
51
+ Converts Cellpose masks β†’ YOLO segmentation format.
52
+
53
+ YOLO segmentation line format:
54
+ class_id x1 y1 x2 y2 ... (all normalized 0–1)
55
+
56
+ class_id = 0 β†’ 'grain' (you will split into broken/whole on Roboflow)
57
+ """
58
+ h, w = img_shape[:2]
59
+ lines = []
60
+ num_grains = int(masks.max())
61
+
62
+ for i in range(1, num_grains + 1):
63
+ # Binary mask for this single grain
64
+ single = (masks == i).astype(np.uint8)
65
+ contours, _ = cv2.findContours(single, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
66
+
67
+ if not contours:
68
+ continue
69
+
70
+ # Pick the largest contour (in case of tiny noise)
71
+ c = max(contours, key=cv2.contourArea)
72
+ c = c.squeeze()
73
+
74
+ if c.ndim < 2 or len(c) < 4:
75
+ continue
76
+
77
+ # Normalize each point to [0, 1]
78
+ norm_pts = []
79
+ for x, y in c:
80
+ norm_pts.append(round(float(x) / w, 6))
81
+ norm_pts.append(round(float(y) / h, 6))
82
+
83
+ pts_str = " ".join(map(str, norm_pts))
84
+ lines.append(f"{class_id} {pts_str}")
85
+
86
+ return "\n".join(lines), num_grains
87
+
88
+
89
+ def make_preview(img_np, masks):
90
+ """Draw red outlines of all grain masks on the image for preview."""
91
+ preview = img_np.copy()
92
+ for i in range(1, int(masks.max()) + 1):
93
+ single = (masks == i).astype(np.uint8)
94
+ contours, _ = cv2.findContours(single, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
95
+ cv2.drawContours(preview, contours, -1, (220, 38, 38), 2)
96
+ return Image.fromarray(preview)
97
+
98
+
99
+ # ─── Main batch processor ─────────────────────────────────────────────────────
100
+ def process_batch(image_files, flow_threshold, cellprob_threshold, progress=gr.Progress()):
101
+ """
102
+ Takes a list of uploaded image file paths.
103
+ Returns:
104
+ - Gallery of preview images (with outlines)
105
+ - Summary text
106
+ - Path to downloadable ZIP
107
+ """
108
+ if not image_files:
109
+ return [], "⚠️ No images uploaded.", None
110
+
111
+ model = get_model()
112
+
113
+ previews = [] # (PIL image, caption) for gallery
114
+ log_lines = []
115
+ total_grains = 0
116
+ failed = []
117
+
118
+ # Temp folder to collect annotation files
119
+ tmp_dir = tempfile.mkdtemp()
120
+ images_dir = os.path.join(tmp_dir, "images")
121
+ labels_dir = os.path.join(tmp_dir, "labels")
122
+ os.makedirs(images_dir, exist_ok=True)
123
+ os.makedirs(labels_dir, exist_ok=True)
124
+
125
+ for idx, file_obj in enumerate(progress.tqdm(image_files, desc="Processing images")):
126
+ # file_obj is a filepath string when using gr.File with type="filepath"
127
+ filepath = file_obj if isinstance(file_obj, str) else file_obj.name
128
+ fname = os.path.splitext(os.path.basename(filepath))[0]
129
+
130
+ try:
131
+ pil_img = Image.open(filepath).convert("RGB")
132
+ img_np = np.array(pil_img)
133
+ img_np = image_resize(img_np, resize=1000)
134
+
135
+ masks = run_cellpose(img_np, model,
136
+ flow_threshold=float(flow_threshold),
137
+ cellprob_threshold=float(cellprob_threshold))
138
+
139
+ num_grains = int(masks.max())
140
+
141
+ if num_grains == 0:
142
+ log_lines.append(f"⚠️ [{idx+1}] {fname} β€” No grains detected, skipped.")
143
+ failed.append(fname)
144
+ continue
145
+
146
+ # Export YOLO annotation txt
147
+ annotation_txt, _ = export_yolo_annotations(masks, img_np.shape, class_id=0)
148
+ txt_path = os.path.join(labels_dir, f"{fname}.txt")
149
+ with open(txt_path, "w") as f:
150
+ f.write(annotation_txt)
151
+
152
+ # Save image to images/
153
+ img_save_path = os.path.join(images_dir, f"{fname}.jpg")
154
+ Image.fromarray(img_np).save(img_save_path, quality=95)
155
+
156
+ # Make preview
157
+ preview_pil = make_preview(img_np, masks)
158
+ previews.append((preview_pil, f"{fname} β€” {num_grains} grains"))
159
+
160
+ total_grains += num_grains
161
+ log_lines.append(f"βœ… [{idx+1}] {fname} β€” {num_grains} grains annotated.")
162
+
163
+ except Exception as e:
164
+ log_lines.append(f"❌ [{idx+1}] {fname} β€” Error: {str(e)}")
165
+ failed.append(fname)
166
+
167
+ # ── Write data.yaml ───────────────────────────────────────────────────────
168
+ yaml_content = (
169
+ "# YOLO Dataset β€” Rice Grain Segmentation\n"
170
+ "# Generated by MLBench Annotation Tool\n\n"
171
+ "path: ./dataset\n"
172
+ "train: images/train\n"
173
+ "val: images/val\n\n"
174
+ "nc: 2\n"
175
+ "names:\n"
176
+ " 0: whole_grain\n"
177
+ " 1: broken_grain\n\n"
178
+ "# NOTE: All grains are currently class 0 (whole_grain).\n"
179
+ "# Upload to Roboflow and re-label broken grains as class 1.\n"
180
+ )
181
+ with open(os.path.join(tmp_dir, "data.yaml"), "w") as f:
182
+ f.write(yaml_content)
183
+
184
+ # ── Write README ──────────────────────────────────────────────────────────
185
+ readme = (
186
+ "# Rice Grain YOLO Dataset\n\n"
187
+ "## Folder Structure\n"
188
+ "```\n"
189
+ "dataset/\n"
190
+ " images/ ← your rice photos (.jpg)\n"
191
+ " labels/ ← YOLO polygon annotations (.txt)\n"
192
+ " data.yaml ← class config for YOLO training\n"
193
+ "```\n\n"
194
+ "## Label Format (YOLO Segmentation)\n"
195
+ "Each .txt file has one line per grain:\n"
196
+ "```\n"
197
+ "class_id x1 y1 x2 y2 x3 y3 ... (normalized 0–1)\n"
198
+ "```\n\n"
199
+ "## Classes\n"
200
+ "| ID | Name |\n"
201
+ "|----|-------------|\n"
202
+ "| 0 | whole_grain |\n"
203
+ "| 1 | broken_grain |\n\n"
204
+ "## Next Steps\n"
205
+ "1. Upload this zip to **Roboflow** (Import > YOLOv8 Segmentation format)\n"
206
+ "2. Re-label broken grains as class `1` in Roboflow\n"
207
+ "3. Export from Roboflow as YOLOv8 format\n"
208
+ "4. Train: `yolo segment train data=data.yaml model=yolov8n-seg.pt epochs=100`\n"
209
+ )
210
+ with open(os.path.join(tmp_dir, "README.md"), "w") as f:
211
+ f.write(readme)
212
+
213
+ # ── Package as ZIP ────────────────────────────────────────────────────────
214
+ zip_path = os.path.join(tempfile.mkdtemp(), "rice_yolo_dataset.zip")
215
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
216
+ for root, _, files in os.walk(tmp_dir):
217
+ for file in files:
218
+ full_path = os.path.join(root, file)
219
+ arcname = os.path.relpath(full_path, tmp_dir)
220
+ zf.write(full_path, arcname)
221
+
222
+ # ── Summary ───────────────────────────────────────────────────────────────
223
+ ok_count = len(image_files) - len(failed)
224
+ summary = (
225
+ f"### βœ… Done!\n"
226
+ f"- **{ok_count} / {len(image_files)}** images processed\n"
227
+ f"- **{total_grains}** total grains annotated\n"
228
+ f"- **{len(failed)}** failed: {', '.join(failed) if failed else 'none'}\n\n"
229
+ "**Download the ZIP below β†’ upload to Roboflow β†’ label broken grains β†’ train YOLO!**\n\n"
230
+ "---\n" + "\n".join(log_lines)
231
+ )
232
+
233
+ return previews, summary, zip_path
234
+
235
+
236
+ # ─── UI ───────────────────────────────────────────────────────────────────────
237
+ CSS = """
238
+ body { font-family: 'IBM Plex Mono', monospace; }
239
+ #header {
240
+ background: #0F172A;
241
+ padding: 20px 24px 14px;
242
+ border-radius: 10px;
243
+ margin-bottom: 12px;
244
+ }
245
+ #run-btn { margin-top: 8px; background: #7C3AED !important; }
246
+ #dl-btn { margin-top: 6px; }
247
+ .gr-gallery-item img { border-radius: 6px; }
248
+ """
249
+
250
+ THEME = gr.themes.Soft(
251
+ primary_hue="violet",
252
+ secondary_hue="indigo",
253
+ neutral_hue="slate",
254
+ )
255
+
256
+ with gr.Blocks(theme=THEME, css=CSS, title="Rice YOLO Annotator") as demo:
257
+
258
+ gr.HTML("""
259
+ <div id="header">
260
+ <span style="font-size:1.9rem;font-weight:900;color:#F1F5F9;font-family:monospace;">
261
+ ML<span style="color:#EF4444;">Bench</span>
262
+ <span style="font-size:1rem;font-weight:400;color:#94A3B8;margin-left:12px;">
263
+ Rice Grain β†’ YOLO Annotation Exporter
264
+ </span>
265
+ </span>
266
+ <p style="color:#64748B;font-size:0.85rem;margin-top:6px;font-family:monospace;">
267
+ Upload up to 50 images Β· Cellpose segments each grain Β·
268
+ Download ZIP with YOLO labels ready for Roboflow
269
+ </p>
270
+ </div>
271
+ """)
272
+
273
+ with gr.Row():
274
+ # ── LEFT ──────────────────────────────────────────────────────────────
275
+ with gr.Column(scale=1):
276
+ gr.Markdown("### πŸ“‚ Upload Images")
277
+ image_input = gr.File(
278
+ file_count="multiple",
279
+ file_types=["image"],
280
+ label="Drop up to 50 rice images here",
281
+ height=180,
282
+ )
283
+
284
+ with gr.Accordion("βš™οΈ Cellpose Settings", open=False):
285
+ flow_thresh = gr.Slider(
286
+ 0.0, 1.0, value=0.4, step=0.05,
287
+ label="Flow Threshold",
288
+ info="Higher = stricter (fewer false grains)"
289
+ )
290
+ cellprob_thresh = gr.Slider(
291
+ -4.0, 4.0, value=0.0, step=0.5,
292
+ label="Cell Probability Threshold",
293
+ info="Lower = detect more grains"
294
+ )
295
+
296
+ run_btn = gr.Button(
297
+ "πŸš€ Run Cellpose & Export Annotations",
298
+ variant="primary", size="lg", elem_id="run-btn"
299
+ )
300
+
301
+ gr.Markdown("""
302
+ ### πŸ“‹ Workflow
303
+ 1. Upload 50 images here
304
+ 2. Click **Run** β€” Cellpose segments every grain
305
+ 3. Download the ZIP
306
+ 4. Upload ZIP to **Roboflow** (format: YOLOv8 Segmentation)
307
+ 5. Re-label broken grains as `broken_grain` class
308
+ 6. Export & train YOLOv8!
309
+ """)
310
+
311
+ download_btn = gr.File(
312
+ label="⬇️ Download YOLO Dataset ZIP",
313
+ interactive=False,
314
+ elem_id="dl-btn",
315
+ )
316
+
317
+ # ── RIGHT ─────────────────────────────────────────────────────────────
318
+ with gr.Column(scale=2):
319
+ gr.Markdown("### πŸ” Segmentation Previews")
320
+ gallery = gr.Gallery(
321
+ label="",
322
+ show_label=False,
323
+ columns=3,
324
+ height=460,
325
+ object_fit="contain",
326
+ )
327
+ summary_box = gr.Markdown(
328
+ value="_Results will appear here after processing..._"
329
+ )
330
+
331
+ run_btn.click(
332
+ fn=process_batch,
333
+ inputs=[image_input, flow_thresh, cellprob_thresh],
334
+ outputs=[gallery, summary_box, download_btn],
335
+ )
336
+
337
+ if __name__ == "__main__":
338
+ demo.launch(share=True)