Addax-Data-Science commited on
Commit
db268f4
·
verified ·
1 Parent(s): 27da5c6

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +42 -38
inference.py CHANGED
@@ -133,28 +133,39 @@ class ModelInference:
133
  self, image: Image.Image, bbox: tuple[float, float, float, float]
134
  ) -> Image.Image:
135
  """
136
- Return the full image (cropping happens in get_tensor/get_classification).
137
 
138
- The official SpeciesNet pipeline crops on the float32 tensor, not
139
- the PIL image. Cropping is deferred so the bbox can be applied
140
- after pil_to_tensor + convert_image_dtype, matching the official
141
- preprocessing exactly.
142
 
143
- The bbox is stored on the returned image via the info dict so
144
- that get_tensor() can apply it at the right stage.
 
 
 
 
145
  """
146
- img = image.copy()
147
- img.info["_bbox"] = bbox
148
- return img
 
 
 
 
 
 
 
 
 
149
 
150
  def get_classification(
151
  self, crop: Image.Image
152
  ) -> list[list[str | float]]:
153
  """
154
- Run SpeciesNet classification on a single image.
155
 
156
  Args:
157
- crop: PIL Image with optional _bbox in info dict
158
 
159
  Returns:
160
  List of [class_name, confidence] lists for ALL classes.
@@ -166,7 +177,19 @@ class ModelInference:
166
  if self.model is None:
167
  raise RuntimeError("Model not loaded, call load_model() first")
168
 
169
- img_arr = self.get_tensor(crop)
 
 
 
 
 
 
 
 
 
 
 
 
170
  input_batch = torch.from_numpy(img_arr).unsqueeze(0).to(self.device)
171
 
172
  with torch.no_grad():
@@ -193,32 +216,13 @@ class ModelInference:
193
  str(i + 1): name for i, name in enumerate(self.class_names)
194
  }
195
 
196
- def get_tensor(self, image: Image.Image):
197
- """Preprocess an image into a numpy array for batch inference.
 
 
198
 
199
- Matches the official SpeciesNet preprocessing exactly:
200
- PIL -> CHW float32 [0,1] -> crop on tensor -> resize -> uint8 -> HWC /255
201
- """
202
- if image.mode != "RGB":
203
- image = image.convert("RGB")
204
-
205
- img_tensor = TF.pil_to_tensor(image)
206
  img_tensor = TF.convert_image_dtype(img_tensor, torch.float32)
207
-
208
- # Crop on the float32 tensor (matching official API)
209
- bbox = image.info.get("_bbox")
210
- if bbox:
211
- W, H = image.size
212
- x, y, w, h = bbox
213
- crop_top = int(y * H)
214
- crop_left = int(x * W)
215
- crop_h = int(h * H)
216
- crop_w = int(w * W)
217
- if crop_w > 0 and crop_h > 0:
218
- img_tensor = TF.crop(
219
- img_tensor, crop_top, crop_left, crop_h, crop_w
220
- )
221
-
222
  img_tensor = TF.resize(
223
  img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False
224
  )
@@ -239,4 +243,4 @@ class ModelInference:
239
  for i in range(len(self.class_names))
240
  ]
241
  results.append(classifications)
242
- return results
 
133
  self, image: Image.Image, bbox: tuple[float, float, float, float]
134
  ) -> Image.Image:
135
  """
136
+ Crop image using normalized bounding box coordinates.
137
 
138
+ Matches SpeciesNet's preprocessing: crop using int() truncation
139
+ (not rounding) to match torchvision.transforms.functional.crop().
 
 
140
 
141
+ Args:
142
+ image: PIL Image (full resolution)
143
+ bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
144
+
145
+ Returns:
146
+ Cropped PIL Image
147
  """
148
+ W, H = image.size
149
+ x, y, w, h = bbox
150
+
151
+ left = int(x * W)
152
+ top = int(y * H)
153
+ crop_w = int(w * W)
154
+ crop_h = int(h * H)
155
+
156
+ if crop_w <= 0 or crop_h <= 0:
157
+ return image
158
+
159
+ return image.crop((left, top, left + crop_w, top + crop_h))
160
 
161
  def get_classification(
162
  self, crop: Image.Image
163
  ) -> list[list[str | float]]:
164
  """
165
+ Run SpeciesNet classification on a cropped image.
166
 
167
  Args:
168
+ crop: Cropped and preprocessed PIL Image
169
 
170
  Returns:
171
  List of [class_name, confidence] lists for ALL classes.
 
177
  if self.model is None:
178
  raise RuntimeError("Model not loaded, call load_model() first")
179
 
180
+ if crop.mode != "RGB":
181
+ crop = crop.convert("RGB")
182
+
183
+ # Match SpeciesNet's exact preprocessing pipeline:
184
+ # PIL -> CHW float32 [0,1] -> resize -> uint8 -> /255 -> HWC
185
+ img_tensor = TF.pil_to_tensor(crop)
186
+ img_tensor = TF.convert_image_dtype(img_tensor, torch.float32)
187
+ img_tensor = TF.resize(
188
+ img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False
189
+ )
190
+ img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8)
191
+ # HWC float32 [0, 1] (matching speciesnet's img.arr / 255)
192
+ img_arr = img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0
193
  input_batch = torch.from_numpy(img_arr).unsqueeze(0).to(self.device)
194
 
195
  with torch.no_grad():
 
216
  str(i + 1): name for i, name in enumerate(self.class_names)
217
  }
218
 
219
+ def get_tensor(self, crop: Image.Image):
220
+ """Preprocess a crop into a numpy array for batch inference."""
221
+ if crop.mode != "RGB":
222
+ crop = crop.convert("RGB")
223
 
224
+ img_tensor = TF.pil_to_tensor(crop)
 
 
 
 
 
 
225
  img_tensor = TF.convert_image_dtype(img_tensor, torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  img_tensor = TF.resize(
227
  img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False
228
  )
 
243
  for i in range(len(self.class_names))
244
  ]
245
  results.append(classifications)
246
+ return results