Spaces:
Sleeping
Sleeping
Merge branch 'pr/12' into pr/13
Browse files- app.py +18 -1
- ndvi_predictor.py +93 -50
- resize_image.py +62 -0
app.py
CHANGED
|
@@ -13,6 +13,7 @@ from rasterio.transform import from_bounds
|
|
| 13 |
import tempfile
|
| 14 |
import os
|
| 15 |
import logging
|
|
|
|
| 16 |
|
| 17 |
# Configure logging
|
| 18 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -39,6 +40,7 @@ except Exception as e:
|
|
| 39 |
async def root():
|
| 40 |
return {"message": "Welcome to the NDVI and YOLO prediction API!"}
|
| 41 |
|
|
|
|
| 42 |
@app.post("/predict_ndvi/")
|
| 43 |
async def predict_ndvi_api(file: UploadFile = File(...)):
|
| 44 |
"""Predict NDVI from RGB image"""
|
|
@@ -46,11 +48,25 @@ async def predict_ndvi_api(file: UploadFile = File(...)):
|
|
| 46 |
return JSONResponse(status_code=500, content={"error": "NDVI model not loaded"})
|
| 47 |
|
| 48 |
try:
|
|
|
|
|
|
|
|
|
|
| 49 |
contents = await file.read()
|
| 50 |
img = Image.open(BytesIO(contents)).convert("RGB")
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
pred_ndvi = predict_ndvi(ndvi_model, norm_img)
|
| 53 |
|
|
|
|
| 54 |
# Visualization image as PNG
|
| 55 |
vis_img_bytes = create_visualization(norm_img, pred_ndvi)
|
| 56 |
vis_img_bytes.seek(0)
|
|
@@ -77,6 +93,7 @@ async def predict_ndvi_api(file: UploadFile = File(...)):
|
|
| 77 |
logger.error(f"Error in predict_ndvi_api: {e}")
|
| 78 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
| 79 |
|
|
|
|
| 80 |
@app.post("/predict_yolo/")
|
| 81 |
async def predict_yolo_api(file: UploadFile = File(...)):
|
| 82 |
"""Predict YOLO results from 4-channel TIFF image"""
|
|
|
|
| 13 |
import tempfile
|
| 14 |
import os
|
| 15 |
import logging
|
| 16 |
+
from resize_image import resize_image_optimized, resize_image_simple
|
| 17 |
|
| 18 |
# Configure logging
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 40 |
async def root():
|
| 41 |
return {"message": "Welcome to the NDVI and YOLO prediction API!"}
|
| 42 |
|
| 43 |
+
# Example usage in your predict_ndvi endpoint:
|
| 44 |
@app.post("/predict_ndvi/")
|
| 45 |
async def predict_ndvi_api(file: UploadFile = File(...)):
|
| 46 |
"""Predict NDVI from RGB image"""
|
|
|
|
| 48 |
return JSONResponse(status_code=500, content={"error": "NDVI model not loaded"})
|
| 49 |
|
| 50 |
try:
|
| 51 |
+
# Define target size (height, width)
|
| 52 |
+
target_size = (640, 640)
|
| 53 |
+
|
| 54 |
contents = await file.read()
|
| 55 |
img = Image.open(BytesIO(contents)).convert("RGB")
|
| 56 |
+
|
| 57 |
+
# Convert to numpy array
|
| 58 |
+
rgb_array = np.array(img)
|
| 59 |
+
|
| 60 |
+
# Resize image to target size
|
| 61 |
+
rgb_resized = resize_image_optimized(rgb_array, target_size)
|
| 62 |
+
|
| 63 |
+
# Normalize the resized image
|
| 64 |
+
norm_img = normalize_rgb(rgb_resized)
|
| 65 |
+
|
| 66 |
+
# Predict NDVI
|
| 67 |
pred_ndvi = predict_ndvi(ndvi_model, norm_img)
|
| 68 |
|
| 69 |
+
# Rest of the endpoint remains the same...
|
| 70 |
# Visualization image as PNG
|
| 71 |
vis_img_bytes = create_visualization(norm_img, pred_ndvi)
|
| 72 |
vis_img_bytes.seek(0)
|
|
|
|
| 93 |
logger.error(f"Error in predict_ndvi_api: {e}")
|
| 94 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
| 95 |
|
| 96 |
+
|
| 97 |
@app.post("/predict_yolo/")
|
| 98 |
async def predict_yolo_api(file: UploadFile = File(...)):
|
| 99 |
"""Predict YOLO results from 4-channel TIFF image"""
|
ndvi_predictor.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# ndvi_predictor.py
|
| 2 |
import os
|
| 3 |
-
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 4 |
os.environ["SM_FRAMEWORK"] = "tf.keras"
|
| 5 |
import segmentation_models as sm
|
| 6 |
import tensorflow as tf
|
|
@@ -11,6 +11,9 @@ import rasterio
|
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
from PIL import Image
|
| 13 |
import io
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Custom loss functions and activation functions
|
| 16 |
def balanced_mse_loss(y_true, y_pred):
|
|
@@ -27,31 +30,36 @@ def custom_mae(y_true, y_pred):
|
|
| 27 |
|
| 28 |
def load_model(models_dir):
|
| 29 |
"""Load NDVI prediction model with custom objects"""
|
|
|
|
| 30 |
# Define custom objects dictionary
|
| 31 |
custom_objects = {
|
| 32 |
'balanced_mse_loss': balanced_mse_loss,
|
| 33 |
'custom_mae': custom_mae
|
| 34 |
}
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def normalize_rgb(rgb):
|
| 57 |
"""Normalize RGB image to [0, 1] range using percentile normalization"""
|
|
@@ -71,7 +79,7 @@ def normalize_rgb(rgb):
|
|
| 71 |
|
| 72 |
def predict_ndvi(model, rgb_np):
|
| 73 |
"""
|
| 74 |
-
|
| 75 |
|
| 76 |
Args:
|
| 77 |
model: Loaded NDVI prediction model
|
|
@@ -81,46 +89,81 @@ def predict_ndvi(model, rgb_np):
|
|
| 81 |
ndvi_pred: Predicted NDVI as numpy array (H, W) in range [-1, 1]
|
| 82 |
"""
|
| 83 |
height, width = rgb_np.shape[:2]
|
|
|
|
|
|
|
| 84 |
tile_size = 512
|
| 85 |
-
stride = int(tile_size * 0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
# Initialize output arrays
|
| 88 |
ndvi_pred = np.zeros((height, width), dtype=np.float32)
|
| 89 |
weight_map = np.zeros((height, width), dtype=np.float32)
|
| 90 |
|
| 91 |
-
#
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
# Extract tile
|
| 105 |
-
tile =
|
| 106 |
|
| 107 |
-
#
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
-
#
|
| 116 |
-
|
| 117 |
-
valid_width = min(tile_size, width - j)
|
| 118 |
|
| 119 |
-
#
|
| 120 |
-
ndvi_pred[i:i+
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
| 124 |
|
| 125 |
# Normalize by weights
|
| 126 |
mask = weight_map > 0
|
|
|
|
| 1 |
# ndvi_predictor.py
|
| 2 |
import os
|
| 3 |
+
# os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 4 |
os.environ["SM_FRAMEWORK"] = "tf.keras"
|
| 5 |
import segmentation_models as sm
|
| 6 |
import tensorflow as tf
|
|
|
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
from PIL import Image
|
| 13 |
import io
|
| 14 |
+
from tensorflow.keras.models import model_from_json
|
| 15 |
+
import traceback
|
| 16 |
+
import gc
|
| 17 |
|
| 18 |
# Custom loss functions and activation functions
|
| 19 |
def balanced_mse_loss(y_true, y_pred):
|
|
|
|
| 30 |
|
| 31 |
def load_model(models_dir):
|
| 32 |
"""Load NDVI prediction model with custom objects"""
|
| 33 |
+
|
| 34 |
# Define custom objects dictionary
|
| 35 |
custom_objects = {
|
| 36 |
'balanced_mse_loss': balanced_mse_loss,
|
| 37 |
'custom_mae': custom_mae
|
| 38 |
}
|
| 39 |
|
| 40 |
+
try:
|
| 41 |
+
# Load model architecture
|
| 42 |
+
with open(os.path.join(models_dir, "model_architecture.json"), "r") as json_file:
|
| 43 |
+
model_json = json_file.read()
|
| 44 |
+
|
| 45 |
+
model = model_from_json(model_json, custom_objects=custom_objects)
|
| 46 |
+
|
| 47 |
+
# Load weights
|
| 48 |
+
model.load_weights(os.path.join(models_dir, "best_model_weights.weights.h5"))
|
| 49 |
+
|
| 50 |
+
# Compile model with custom functions
|
| 51 |
+
optimizer = tf.keras.optimizers.AdamW(learning_rate=0.0005, weight_decay=1e-4)
|
| 52 |
+
|
| 53 |
+
model.compile(
|
| 54 |
+
optimizer=optimizer,
|
| 55 |
+
loss=balanced_mse_loss,
|
| 56 |
+
metrics=[custom_mae, 'mse']
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return model
|
| 60 |
+
except Exception as e:
|
| 61 |
+
traceback.print_exc()
|
| 62 |
+
return None
|
| 63 |
|
| 64 |
def normalize_rgb(rgb):
|
| 65 |
"""Normalize RGB image to [0, 1] range using percentile normalization"""
|
|
|
|
| 79 |
|
| 80 |
def predict_ndvi(model, rgb_np):
|
| 81 |
"""
|
| 82 |
+
Faster NDVI prediction with larger tiles and more efficient processing
|
| 83 |
|
| 84 |
Args:
|
| 85 |
model: Loaded NDVI prediction model
|
|
|
|
| 89 |
ndvi_pred: Predicted NDVI as numpy array (H, W) in range [-1, 1]
|
| 90 |
"""
|
| 91 |
height, width = rgb_np.shape[:2]
|
| 92 |
+
|
| 93 |
+
# Larger tiles for faster processing
|
| 94 |
tile_size = 512
|
| 95 |
+
stride = int(tile_size * 0.75) # 25% overlap
|
| 96 |
+
|
| 97 |
+
# For smaller images, process whole image at once
|
| 98 |
+
if height <= tile_size and width <= tile_size:
|
| 99 |
+
# Pad to tile size if needed
|
| 100 |
+
pad_height = max(0, tile_size - height)
|
| 101 |
+
pad_width = max(0, tile_size - width)
|
| 102 |
+
if pad_height > 0 or pad_width > 0:
|
| 103 |
+
rgb_padded = np.pad(rgb_np, ((0, pad_height), (0, pad_width), (0, 0)), mode='reflect')
|
| 104 |
+
else:
|
| 105 |
+
rgb_padded = rgb_np
|
| 106 |
+
|
| 107 |
+
# Single prediction
|
| 108 |
+
pred = model.predict(np.expand_dims(rgb_padded, axis=0), verbose=0, batch_size=1)[0, :, :, 0]
|
| 109 |
+
return pred[:height, :width]
|
| 110 |
|
| 111 |
# Initialize output arrays
|
| 112 |
ndvi_pred = np.zeros((height, width), dtype=np.float32)
|
| 113 |
weight_map = np.zeros((height, width), dtype=np.float32)
|
| 114 |
|
| 115 |
+
# Pre-compute weights for efficiency
|
| 116 |
+
y, x = np.mgrid[0:tile_size, 0:tile_size]
|
| 117 |
+
base_weights = np.minimum(np.minimum(x, tile_size - x - 1), np.minimum(y, tile_size - y - 1))
|
| 118 |
+
base_weights = np.clip(base_weights, 0, 64) / 64
|
| 119 |
+
|
| 120 |
+
# Collect all tiles first
|
| 121 |
+
tiles = []
|
| 122 |
+
positions = []
|
| 123 |
+
|
| 124 |
+
for i in range(0, height, stride):
|
| 125 |
+
for j in range(0, width, stride):
|
| 126 |
+
# Calculate actual tile bounds
|
| 127 |
+
end_i = min(i + tile_size, height)
|
| 128 |
+
end_j = min(j + tile_size, width)
|
| 129 |
+
actual_height = end_i - i
|
| 130 |
+
actual_width = end_j - j
|
| 131 |
+
|
| 132 |
# Extract tile
|
| 133 |
+
tile = rgb_np[i:end_i, j:end_j, :]
|
| 134 |
|
| 135 |
+
# Pad if necessary
|
| 136 |
+
if actual_height < tile_size or actual_width < tile_size:
|
| 137 |
+
pad_height = tile_size - actual_height
|
| 138 |
+
pad_width = tile_size - actual_width
|
| 139 |
+
tile = np.pad(tile, ((0, pad_height), (0, pad_width), (0, 0)), mode='reflect')
|
| 140 |
|
| 141 |
+
tiles.append(tile)
|
| 142 |
+
positions.append((i, j, actual_height, actual_width))
|
| 143 |
+
|
| 144 |
+
# Process all tiles in larger batches
|
| 145 |
+
batch_size = 8 # Process 8 tiles at once
|
| 146 |
+
for batch_start in range(0, len(tiles), batch_size):
|
| 147 |
+
batch_end = min(batch_start + batch_size, len(tiles))
|
| 148 |
+
batch_tiles = np.array(tiles[batch_start:batch_end])
|
| 149 |
+
|
| 150 |
+
# Predict batch
|
| 151 |
+
batch_preds = model.predict(batch_tiles, verbose=0, batch_size=batch_size)
|
| 152 |
+
|
| 153 |
+
# Apply predictions
|
| 154 |
+
for k in range(batch_end - batch_start):
|
| 155 |
+
pred = batch_preds[k, :, :, 0]
|
| 156 |
+
i, j, actual_height, actual_width = positions[batch_start + k]
|
| 157 |
|
| 158 |
+
# Use appropriate weights
|
| 159 |
+
weights = base_weights[:actual_height, :actual_width]
|
|
|
|
| 160 |
|
| 161 |
+
# Add to output
|
| 162 |
+
ndvi_pred[i:i+actual_height, j:j+actual_width] += pred[:actual_height, :actual_width] * weights
|
| 163 |
+
weight_map[i:i+actual_height, j:j+actual_width] += weights
|
| 164 |
+
|
| 165 |
+
# Clean up batch
|
| 166 |
+
del batch_tiles, batch_preds
|
| 167 |
|
| 168 |
# Normalize by weights
|
| 169 |
mask = weight_map > 0
|
resize_image.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
# Alternative: Simple resize function using PIL directly
|
| 5 |
+
def resize_image_simple(image_array, target_size):
|
| 6 |
+
"""
|
| 7 |
+
Simple resize function using PIL
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
image_array: Input image array (H, W, C)
|
| 11 |
+
target_size: Tuple (height, width)
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
Resized image array
|
| 15 |
+
"""
|
| 16 |
+
# Ensure image is in correct format
|
| 17 |
+
if image_array.max() <= 1:
|
| 18 |
+
image_array = (image_array * 255).astype(np.uint8)
|
| 19 |
+
|
| 20 |
+
# Convert to PIL Image
|
| 21 |
+
image_pil = Image.fromarray(image_array)
|
| 22 |
+
|
| 23 |
+
# Resize (PIL uses width, height format)
|
| 24 |
+
resized_pil = image_pil.resize((target_size[1], target_size[0]), Image.LANCZOS)
|
| 25 |
+
|
| 26 |
+
# Convert back to numpy array and normalize back to [0, 1]
|
| 27 |
+
resized_array = np.array(resized_pil).astype(np.float32) / 255.0
|
| 28 |
+
|
| 29 |
+
return resized_array
|
| 30 |
+
|
| 31 |
+
def resize_image_optimized(image_array, target_size):
|
| 32 |
+
"""
|
| 33 |
+
Resize image to target size with memory optimization
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
image_array: Input image array (H, W, C)
|
| 37 |
+
target_size: Tuple (height, width) representing target dimensions
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Resized image array
|
| 41 |
+
"""
|
| 42 |
+
# Convert numpy array to PIL Image
|
| 43 |
+
if image_array.dtype != np.uint8:
|
| 44 |
+
# Convert to uint8 if not already
|
| 45 |
+
if image_array.max() <= 1:
|
| 46 |
+
image_array = (image_array * 255).astype(np.uint8)
|
| 47 |
+
else:
|
| 48 |
+
image_array = image_array.astype(np.uint8)
|
| 49 |
+
|
| 50 |
+
image_pil = Image.fromarray(image_array)
|
| 51 |
+
|
| 52 |
+
# Resize image (PIL uses (width, height) format)
|
| 53 |
+
resized_pil = image_pil.resize((target_size[1], target_size[0]), Image.LANCZOS)
|
| 54 |
+
|
| 55 |
+
# Convert back to numpy array
|
| 56 |
+
result = np.array(resized_pil)
|
| 57 |
+
|
| 58 |
+
# Clean up
|
| 59 |
+
image_pil.close()
|
| 60 |
+
resized_pil.close()
|
| 61 |
+
|
| 62 |
+
return result
|