Subh775 commited on
Commit
3ad00f2
·
verified ·
1 Parent(s): 8eb3166

finalized v1

Browse files
Files changed (1) hide show
  1. app.py +0 -338
app.py CHANGED
@@ -1,341 +1,3 @@
1
- # import os
2
- # import io
3
- # import base64
4
- # import threading
5
- # import traceback
6
- # import gc
7
- # from typing import Optional
8
-
9
- # from flask import Flask, request, jsonify, send_from_directory
10
- # from PIL import Image
11
- # import numpy as np
12
- # import requests
13
- # import torch
14
-
15
- # # Set environment variables for CPU-only operation
16
- # os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib")
17
- # os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig")
18
- # os.environ.setdefault("FONTCONFIG_FILE", "/etc/fonts/fonts.conf")
19
- # os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
20
- # os.environ.setdefault("OMP_NUM_THREADS", "4")
21
- # os.environ.setdefault("MKL_NUM_THREADS", "4")
22
- # os.environ.setdefault("OPENBLAS_NUM_THREADS", "4")
23
-
24
- # # Create writable fontconfig cache
25
- # os.makedirs("/tmp/.fontconfig", exist_ok=True)
26
- # os.makedirs("/tmp/.matplotlib", exist_ok=True)
27
-
28
- # # Limit torch threads
29
- # try:
30
- # torch.set_num_threads(4)
31
- # except Exception:
32
- # pass
33
-
34
- # import supervision as sv
35
- # from rfdetr import RFDETRSegPreview
36
-
37
- # app = Flask(__name__, static_folder="static", static_url_path="/")
38
-
39
- # # Checkpoint URL & local path
40
- # CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs/resolve/main/checkpoint_best_total.pth"
41
- # CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
42
-
43
- # MODEL_LOCK = threading.Lock()
44
- # MODEL = None
45
-
46
-
47
- # def download_file(url: str, dst: str, chunk_size: int = 8192):
48
- # """Download file if not exists"""
49
- # if os.path.exists(dst) and os.path.getsize(dst) > 0:
50
- # print(f"[INFO] Checkpoint already exists at {dst}")
51
- # return dst
52
- # print(f"[INFO] Downloading weights from {url} -> {dst}")
53
- # try:
54
- # r = requests.get(url, stream=True, timeout=180)
55
- # r.raise_for_status()
56
- # with open(dst, "wb") as fh:
57
- # for chunk in r.iter_content(chunk_size=chunk_size):
58
- # if chunk:
59
- # fh.write(chunk)
60
- # print("[INFO] Download complete.")
61
- # return dst
62
- # except Exception as e:
63
- # print(f"[ERROR] Download failed: {e}")
64
- # raise
65
-
66
-
67
- # def init_model():
68
- # """Lazily initialize the RF-DETR model and cache it in global MODEL."""
69
- # global MODEL
70
- # with MODEL_LOCK:
71
- # if MODEL is not None:
72
- # print("[INFO] Model already loaded, returning cached instance")
73
- # return MODEL
74
- # try:
75
- # # Ensure checkpoint present
76
- # if not os.path.exists(CHECKPOINT_PATH):
77
- # print("[INFO] Checkpoint not found, downloading...")
78
- # download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
79
- # else:
80
- # print(f"[INFO] Using existing checkpoint at {CHECKPOINT_PATH}")
81
-
82
- # print("[INFO] Loading RF-DETR model (CPU mode)...")
83
- # MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH)
84
-
85
- # # Try to optimize for inference
86
- # try:
87
- # print("[INFO] Optimizing model for inference...")
88
- # MODEL.optimize_for_inference()
89
- # print("[INFO] Model optimization complete")
90
- # except Exception as e:
91
- # print(f"[WARN] optimize_for_inference() skipped/failed: {e}")
92
-
93
- # print("[INFO] Model ready for inference")
94
- # return MODEL
95
- # except Exception as e:
96
- # print(f"[ERROR] Model initialization failed: {e}")
97
- # traceback.print_exc()
98
- # raise
99
-
100
-
101
- # def decode_data_url(data_url: str) -> Image.Image:
102
- # """Decode data URL to PIL Image"""
103
- # if data_url.startswith("data:"):
104
- # _, b64 = data_url.split(",", 1)
105
- # data = base64.b64decode(b64)
106
- # else:
107
- # try:
108
- # data = base64.b64decode(data_url)
109
- # except Exception:
110
- # raise ValueError("Invalid image data")
111
- # return Image.open(io.BytesIO(data)).convert("RGB")
112
-
113
-
114
- # def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str:
115
- # """Encode PIL Image to data URL"""
116
- # buf = io.BytesIO()
117
- # pil_img.save(buf, format=fmt, optimize=False)
118
- # buf.seek(0)
119
- # return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii")
120
-
121
-
122
- # def annotate_segmentation(image: Image.Image, detections: sv.Detections) -> Image.Image:
123
- # """
124
- # Annotate image with segmentation masks using supervision library.
125
- # This matches the visualization from rfdetr_seg_infer.py script.
126
- # """
127
- # try:
128
- # # Define color palette
129
- # palette = sv.ColorPalette.from_hex([
130
- # "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
131
- # "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00",
132
- # ])
133
-
134
- # # Calculate optimal text scale based on image resolution
135
- # text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
136
-
137
- # print(f"[INFO] Creating annotators with text_scale={text_scale}")
138
-
139
- # # Create annotators
140
- # mask_annotator = sv.MaskAnnotator(color=palette)
141
- # polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE)
142
- # label_annotator = sv.LabelAnnotator(
143
- # color=palette,
144
- # text_color=sv.Color.BLACK,
145
- # text_scale=text_scale,
146
- # text_position=sv.Position.CENTER_OF_MASS
147
- # )
148
-
149
- # # Create labels with confidence scores
150
- # labels = [
151
- # f"Tulsi {float(conf):.2f}"
152
- # for conf in detections.confidence
153
- # ]
154
-
155
- # print(f"[INFO] Annotating {len(labels)} detections")
156
-
157
- # # Apply annotations step by step
158
- # out = image.copy()
159
- # print("[INFO] Applying mask annotation...")
160
- # out = mask_annotator.annotate(out, detections)
161
- # print("[INFO] Applying polygon annotation...")
162
- # out = polygon_annotator.annotate(out, detections)
163
- # print("[INFO] Applying label annotation...")
164
- # out = label_annotator.annotate(out, detections, labels)
165
-
166
- # print("[INFO] Annotation complete")
167
- # return out
168
-
169
- # except Exception as e:
170
- # print(f"[ERROR] Annotation failed: {e}")
171
- # traceback.print_exc()
172
- # # Return original image if annotation fails
173
- # return image
174
-
175
-
176
- # @app.route("/", methods=["GET"])
177
- # def index():
178
- # """Serve the static UI"""
179
- # index_path = os.path.join(app.static_folder or "static", "index.html")
180
- # if os.path.exists(index_path):
181
- # return send_from_directory(app.static_folder, "index.html")
182
- # return jsonify({"message": "RF-DETR Segmentation API is running.", "status": "ready"})
183
-
184
-
185
- # @app.route("/health", methods=["GET"])
186
- # def health():
187
- # """Health check endpoint"""
188
- # model_loaded = MODEL is not None
189
- # return jsonify({
190
- # "status": "healthy",
191
- # "model_loaded": model_loaded,
192
- # "checkpoint_exists": os.path.exists(CHECKPOINT_PATH)
193
- # })
194
-
195
-
196
- # @app.route("/predict", methods=["POST"])
197
- # def predict():
198
- # """
199
- # Accepts:
200
- # - multipart/form-data with file field "file"
201
- # - or JSON {"image": "<data:url...>", "conf": 0.05}
202
- # Returns JSON:
203
- # {"annotated": "<data:image/png;base64,...>", "confidences": [..], "count": N}
204
- # """
205
- # print("\n[INFO] ========== New prediction request ==========")
206
-
207
- # try:
208
- # print("[INFO] Initializing model...")
209
- # model = init_model()
210
- # print("[INFO] Model ready")
211
- # except Exception as e:
212
- # error_msg = f"Model initialization failed: {e}"
213
- # print(f"[ERROR] {error_msg}")
214
- # return jsonify({"error": error_msg}), 500
215
-
216
- # # Parse input
217
- # img: Optional[Image.Image] = None
218
- # conf_threshold = 0.05
219
-
220
- # # Check if file uploaded
221
- # if "file" in request.files:
222
- # file = request.files["file"]
223
- # print(f"[INFO] Processing uploaded file: {file.filename}")
224
- # try:
225
- # img = Image.open(file.stream).convert("RGB")
226
- # except Exception as e:
227
- # error_msg = f"Invalid uploaded image: {e}"
228
- # print(f"[ERROR] {error_msg}")
229
- # return jsonify({"error": error_msg}), 400
230
- # conf_threshold = float(request.form.get("conf", conf_threshold))
231
- # else:
232
- # # Try JSON payload
233
- # payload = request.get_json(silent=True)
234
- # if not payload or "image" not in payload:
235
- # return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400
236
- # try:
237
- # print("[INFO] Decoding image from data URL...")
238
- # img = decode_data_url(payload["image"])
239
- # except Exception as e:
240
- # error_msg = f"Invalid image data: {e}"
241
- # print(f"[ERROR] {error_msg}")
242
- # return jsonify({"error": error_msg}), 400
243
- # conf_threshold = float(payload.get("conf", conf_threshold))
244
-
245
- # print(f"[INFO] Image size: {img.size}, Confidence threshold: {conf_threshold}")
246
-
247
- # # Optionally downscale large images to reduce memory usage
248
- # MAX_SIZE = 1024
249
- # if max(img.size) > MAX_SIZE:
250
- # w, h = img.size
251
- # scale = MAX_SIZE / float(max(w, h))
252
- # new_w, new_h = int(round(w * scale)), int(round(h * scale))
253
- # print(f"[INFO] Resizing image from {w}x{h} to {new_w}x{new_h}")
254
- # img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
255
-
256
- # # Run inference with no_grad for memory efficiency
257
- # try:
258
- # print("[INFO] Running inference...")
259
- # with torch.no_grad():
260
- # detections = model.predict(img, threshold=conf_threshold)
261
-
262
- # print(f"[INFO] Raw detections: {len(detections)} objects")
263
-
264
- # # Check if detections exist
265
- # if len(detections) == 0 or not hasattr(detections, 'confidence') or len(detections.confidence) == 0:
266
- # print("[INFO] No detections above threshold")
267
- # # Return original image
268
- # data_url = encode_pil_to_dataurl(img, fmt="PNG")
269
- # return jsonify({
270
- # "annotated": data_url,
271
- # "confidences": [],
272
- # "count": 0
273
- # })
274
-
275
- # print(f"[INFO] Detections have {len(detections.confidence)} confidence scores")
276
- # print(f"[INFO] Confidence range: {min(detections.confidence):.3f} - {max(detections.confidence):.3f}")
277
-
278
- # # Check if masks exist
279
- # if hasattr(detections, 'masks') and detections.masks is not None:
280
- # print(f"[INFO] Masks present: shape={np.array(detections.masks).shape if hasattr(detections.masks, '__len__') else 'unknown'}")
281
- # else:
282
- # print("[WARN] No masks found in detections!")
283
-
284
- # # Annotate image using supervision library
285
- # print("[INFO] Starting annotation...")
286
- # annotated_pil = annotate_segmentation(img, detections)
287
-
288
- # # Extract confidence scores
289
- # confidences = [float(conf) for conf in detections.confidence]
290
- # print(f"[INFO] Final confidences: {confidences}")
291
-
292
- # # Encode to data URL
293
- # print("[INFO] Encoding annotated image...")
294
- # data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
295
-
296
- # # Clean up
297
- # del detections
298
- # gc.collect()
299
-
300
- # print(f"[INFO] ========== Prediction complete: {len(confidences)} leaves detected ==========\n")
301
-
302
- # return jsonify({
303
- # "annotated": data_url,
304
- # "confidences": confidences,
305
- # "count": len(confidences)
306
- # })
307
-
308
- # except Exception as e:
309
- # error_msg = f"Inference failed: {e}"
310
- # print(f"[ERROR] {error_msg}")
311
- # traceback.print_exc()
312
- # return jsonify({"error": error_msg}), 500
313
-
314
-
315
- # if __name__ == "__main__":
316
- # print("\n" + "="*60)
317
- # print("Starting Tulsi Leaf Segmentation Server")
318
- # print("="*60 + "\n")
319
-
320
- # # Warm model in background thread
321
- # def warm():
322
- # try:
323
- # print("[INFO] Starting model warmup in background...")
324
- # init_model()
325
- # print("[INFO] ✓ Model warmup complete - ready for predictions")
326
- # except Exception as e:
327
- # print(f"[ERROR] ✗ Model warmup failed: {e}")
328
- # traceback.print_exc()
329
-
330
- # threading.Thread(target=warm, daemon=True).start()
331
-
332
- # # Run Flask app
333
- # app.run(
334
- # host="0.0.0.0",
335
- # port=int(os.environ.get("PORT", 7860)),
336
- # debug=False
337
- # )
338
-
339
  import os
340
  import io
341
  import base64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import io
3
  import base64