ahadhassan commited on
Commit
b565402
·
verified ·
1 Parent(s): a72887b

Update yolo_predictor.py

Browse files
Files changed (1) hide show
  1. yolo_predictor.py +36 -42
yolo_predictor.py CHANGED
@@ -32,71 +32,65 @@ def predict_ndvi_from_rgb(ndvi_model, rgb_array):
32
  return ndvi_pred
33
 
34
  def predict_yolo(yolo_model, image_path, conf=0.001):
35
- """
36
- Predict using YOLO model on 4-channel TIFF image
37
-
38
- Args:
39
- yolo_model: Loaded YOLO model
40
- image_path: Path to 4-channel TIFF image
41
- conf: Confidence threshold
42
-
43
- Returns:
44
- results: YOLO results object
45
- """
46
- # Verify the image has 4 channels before prediction
47
  try:
48
- # Use tifffile for 32-bit TIFF support
49
  img_array = tifffile.imread(image_path)
50
-
51
- # Handle different array shapes
52
  if len(img_array.shape) == 3:
53
  if img_array.shape[0] == 4:
54
- # Shape is (4, H, W) - transpose to (H, W, 4)
55
  img_array = np.transpose(img_array, (1, 2, 0))
 
56
  elif img_array.shape[2] != 4:
57
- raise ValueError(f"YOLO model expects 4-channel images, but got {img_array.shape[2]} channels")
58
  else:
59
- raise ValueError(f"Unexpected image shape: {img_array.shape}")
60
-
61
- # Convert 32-bit float to uint8 if necessary
 
 
 
 
 
62
  if img_array.dtype != np.uint8:
63
- # Normalize to [0, 255] for RGB channels (first 3)
64
  rgb_array = img_array[:, :, :3]
 
 
 
65
  if rgb_array.max() > 1.0:
66
  rgb_array = np.clip(rgb_array / rgb_array.max() * 255, 0, 255).astype(np.uint8)
67
  else:
68
  rgb_array = np.clip(rgb_array * 255, 0, 255).astype(np.uint8)
69
-
70
- # Normalize NDVI (4th channel) from [-1, 1] to [0, 255]
71
- ndvi_array = img_array[:, :, 3]
72
  ndvi_normalized = ((ndvi_array + 1) * 127.5).astype(np.uint8)
73
-
74
- # Recombine into 4-channel uint8 array
75
  img_array = np.zeros((img_array.shape[0], img_array.shape[1], 4), dtype=np.uint8)
76
  img_array[:, :, :3] = rgb_array
77
  img_array[:, :, 3] = ndvi_normalized
78
-
79
- # Save normalized image to temporary file
 
 
80
  with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
81
  temp_path = tmp_file.name
82
- tifffile.imwrite(
83
- temp_path,
84
- img_array,
85
- photometric='rgb',
86
- compress='lzw',
87
- metadata={'axes': 'YXC', 'resolution': (1, 1)} # DPI=1
88
- )
89
  image_path = temp_path
90
-
91
- # Run YOLO prediction
 
 
 
 
92
  results = yolo_model([image_path], conf=conf)
93
-
94
- # Clean up temporary file if created
95
  if 'temp_path' in locals() and os.path.exists(temp_path):
96
  os.unlink(temp_path)
97
-
98
- return results[0] # Return first result
99
-
100
  except Exception as e:
101
  raise ValueError(f"Error processing image: {str(e)}")
102
 
 
32
  return ndvi_pred
33
 
34
  def predict_yolo(yolo_model, image_path, conf=0.001):
35
+ import tifffile
36
+
 
 
 
 
 
 
 
 
 
 
37
  try:
 
38
  img_array = tifffile.imread(image_path)
39
+ print(f"[DEBUG] Loaded image shape: {img_array.shape}, dtype: {img_array.dtype}")
40
+
41
  if len(img_array.shape) == 3:
42
  if img_array.shape[0] == 4:
 
43
  img_array = np.transpose(img_array, (1, 2, 0))
44
+ print(f"[DEBUG] Transposed image shape to (H,W,C): {img_array.shape}")
45
  elif img_array.shape[2] != 4:
46
+ raise ValueError(f"[ERROR] Expected 4 channels, got {img_array.shape[2]}")
47
  else:
48
+ raise ValueError(f"[ERROR] Unexpected image shape: {img_array.shape}")
49
+
50
+ # Confirm channel count
51
+ if img_array.shape[2] != 4:
52
+ raise ValueError(f"[ERROR] After transpose, still not 4 channels: got {img_array.shape[2]}")
53
+
54
+ print(f"[DEBUG] Image dtype before normalization: {img_array.dtype}")
55
+
56
  if img_array.dtype != np.uint8:
57
+ print(f"[DEBUG] Converting image to uint8")
58
  rgb_array = img_array[:, :, :3]
59
+ ndvi_array = img_array[:, :, 3]
60
+
61
+ # Normalize RGB
62
  if rgb_array.max() > 1.0:
63
  rgb_array = np.clip(rgb_array / rgb_array.max() * 255, 0, 255).astype(np.uint8)
64
  else:
65
  rgb_array = np.clip(rgb_array * 255, 0, 255).astype(np.uint8)
66
+
67
+ # Normalize NDVI
 
68
  ndvi_normalized = ((ndvi_array + 1) * 127.5).astype(np.uint8)
69
+
 
70
  img_array = np.zeros((img_array.shape[0], img_array.shape[1], 4), dtype=np.uint8)
71
  img_array[:, :, :3] = rgb_array
72
  img_array[:, :, 3] = ndvi_normalized
73
+
74
+ print(f"[DEBUG] Image converted to uint8 with shape: {img_array.shape}")
75
+
76
+ # Save normalized version to temp file
77
  with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
78
  temp_path = tmp_file.name
79
+ tifffile.imwrite(temp_path, img_array)
 
 
 
 
 
 
80
  image_path = temp_path
81
+
82
+ print(f"[DEBUG] Final image ready for YOLO, path: {image_path}")
83
+
84
+ # Final safety check
85
+ assert img_array.shape[2] == 4, "[FATAL] Final image does not have 4 channels."
86
+
87
  results = yolo_model([image_path], conf=conf)
88
+
 
89
  if 'temp_path' in locals() and os.path.exists(temp_path):
90
  os.unlink(temp_path)
91
+
92
+ return results[0]
93
+
94
  except Exception as e:
95
  raise ValueError(f"Error processing image: {str(e)}")
96