ahadhassan commited on
Commit
e8ddde1
·
verified ·
1 Parent(s): 06137a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -19
app.py CHANGED
@@ -60,13 +60,20 @@ async def predict_ndvi_api(file: UploadFile = File(...)):
60
  async def predict_yolo_api(file: UploadFile = File(...)):
61
  """Predict YOLO results from 4-channel TIFF image"""
62
  try:
63
- # Save uploaded file temporarily
64
- with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file:
 
 
65
  contents = await file.read()
66
  tmp_file.write(contents)
 
67
  tmp_file_path = tmp_file.name
68
 
69
  try:
 
 
 
 
70
  # Predict using YOLO model
71
  results = predict_yolo(yolo_model, tmp_file_path)
72
 
@@ -79,23 +86,24 @@ async def predict_yolo_api(file: UploadFile = File(...)):
79
  },
80
  "classes": results.boxes.cls.tolist() if results.boxes is not None else None,
81
  "names": results.names,
82
- "growth_stages": getattr(results, 'growth_stages', None),
83
  "orig_shape": results.orig_shape,
84
  "speed": results.speed
85
  }
86
 
87
- # Handle growth stages if present
88
- if hasattr(results, 'boxes') and hasattr(results.boxes, 'data'):
89
- # Extract growth stages from the results if available
90
- if len(results.boxes.data[0]) > 6: # Assuming growth stages are in the data
91
- growth_stages = results.boxes.data[:, 6].tolist()
92
- results_dict["growth_stages"] = growth_stages
 
93
 
94
  return JSONResponse(content=results_dict)
95
 
96
  finally:
97
  # Clean up temporary file
98
- os.unlink(tmp_file_path)
 
99
 
100
  except Exception as e:
101
  return JSONResponse(status_code=500, content={"error": str(e)})
@@ -104,13 +112,20 @@ async def predict_yolo_api(file: UploadFile = File(...)):
104
  async def predict_pipeline_api(file: UploadFile = File(...)):
105
  """Full pipeline: RGB -> NDVI -> 4-channel -> YOLO prediction"""
106
  try:
107
- # Save uploaded file temporarily
108
- with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file:
 
 
109
  contents = await file.read()
110
  tmp_file.write(contents)
 
111
  tmp_file_path = tmp_file.name
112
 
113
  try:
 
 
 
 
114
  # Run the full pipeline
115
  results = predict_pipeline(ndvi_model, yolo_model, tmp_file_path)
116
 
@@ -123,22 +138,24 @@ async def predict_pipeline_api(file: UploadFile = File(...)):
123
  },
124
  "classes": results.boxes.cls.tolist() if results.boxes is not None else None,
125
  "names": results.names,
126
- "growth_stages": getattr(results, 'growth_stages', None),
127
  "orig_shape": results.orig_shape,
128
  "speed": results.speed
129
  }
130
 
131
- # Handle growth stages if present
132
- if hasattr(results, 'boxes') and hasattr(results.boxes, 'data'):
133
- if len(results.boxes.data[0]) > 6:
134
- growth_stages = results.boxes.data[:, 6].tolist()
135
- results_dict["growth_stages"] = growth_stages
 
 
136
 
137
  return JSONResponse(content=results_dict)
138
 
139
  finally:
140
  # Clean up temporary file
141
- os.unlink(tmp_file_path)
 
142
 
143
  except Exception as e:
144
  return JSONResponse(status_code=500, content={"error": str(e)})
 
60
  async def predict_yolo_api(file: UploadFile = File(...)):
61
  """Predict YOLO results from 4-channel TIFF image"""
62
  try:
63
+ # Save uploaded file temporarily with proper extension
64
+ file_extension = '.tiff' if file.filename.lower().endswith(('.tif', '.tiff')) else '.tiff'
65
+
66
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
67
  contents = await file.read()
68
  tmp_file.write(contents)
69
+ tmp_file.flush() # Ensure data is written
70
  tmp_file_path = tmp_file.name
71
 
72
  try:
73
+ # Verify the file was written correctly
74
+ if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0:
75
+ raise ValueError("Failed to create temporary file")
76
+
77
  # Predict using YOLO model
78
  results = predict_yolo(yolo_model, tmp_file_path)
79
 
 
86
  },
87
  "classes": results.boxes.cls.tolist() if results.boxes is not None else None,
88
  "names": results.names,
 
89
  "orig_shape": results.orig_shape,
90
  "speed": results.speed
91
  }
92
 
93
+ # Handle growth stages if present in the results
94
+ if hasattr(results, 'boxes') and results.boxes is not None:
95
+ if hasattr(results.boxes, 'data') and len(results.boxes.data) > 0:
96
+ # Check if there are additional columns for growth stages
97
+ if results.boxes.data.shape[1] > 6:
98
+ growth_stages = results.boxes.data[:, 6:].tolist()
99
+ results_dict["growth_stages"] = growth_stages
100
 
101
  return JSONResponse(content=results_dict)
102
 
103
  finally:
104
  # Clean up temporary file
105
+ if os.path.exists(tmp_file_path):
106
+ os.unlink(tmp_file_path)
107
 
108
  except Exception as e:
109
  return JSONResponse(status_code=500, content={"error": str(e)})
 
112
  async def predict_pipeline_api(file: UploadFile = File(...)):
113
  """Full pipeline: RGB -> NDVI -> 4-channel -> YOLO prediction"""
114
  try:
115
+ # Save uploaded file temporarily with proper extension
116
+ file_extension = '.tiff' if file.filename.lower().endswith(('.tif', '.tiff')) else '.jpg'
117
+
118
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
119
  contents = await file.read()
120
  tmp_file.write(contents)
121
+ tmp_file.flush() # Ensure data is written
122
  tmp_file_path = tmp_file.name
123
 
124
  try:
125
+ # Verify the file was written correctly
126
+ if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0:
127
+ raise ValueError("Failed to create temporary file")
128
+
129
  # Run the full pipeline
130
  results = predict_pipeline(ndvi_model, yolo_model, tmp_file_path)
131
 
 
138
  },
139
  "classes": results.boxes.cls.tolist() if results.boxes is not None else None,
140
  "names": results.names,
 
141
  "orig_shape": results.orig_shape,
142
  "speed": results.speed
143
  }
144
 
145
+ # Handle growth stages if present in the results
146
+ if hasattr(results, 'boxes') and results.boxes is not None:
147
+ if hasattr(results.boxes, 'data') and len(results.boxes.data) > 0:
148
+ # Check if there are additional columns for growth stages
149
+ if results.boxes.data.shape[1] > 6:
150
+ growth_stages = results.boxes.data[:, 6:].tolist()
151
+ results_dict["growth_stages"] = growth_stages
152
 
153
  return JSONResponse(content=results_dict)
154
 
155
  finally:
156
  # Clean up temporary file
157
+ if os.path.exists(tmp_file_path):
158
+ os.unlink(tmp_file_path)
159
 
160
  except Exception as e:
161
  return JSONResponse(status_code=500, content={"error": str(e)})