akore commited on
Commit
54e05fc
·
verified ·
1 Parent(s): ccc5362

feat: add original_size param to forward() — boxes auto-scaled to image space

Browse files
Files changed (1) hide show
  1. modeling_rtmdet.py +30 -8
modeling_rtmdet.py CHANGED
@@ -28,7 +28,9 @@ class DetectionOutput(ModelOutput):
28
 
29
  Args:
30
  boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
31
- Detection boxes in format [x1, y1, x2, y2].
 
 
32
  scores (`torch.FloatTensor` of shape `(batch_size, num_boxes)`):
33
  Detection confidence scores.
34
  labels (`torch.LongTensor` of shape `(batch_size, num_boxes)`):
@@ -1817,6 +1819,7 @@ class RTMDetModel(PreTrainedModel):
1817
  def forward(
1818
  self,
1819
  pixel_values=None,
 
1820
  labels=None,
1821
  output_hidden_states=None,
1822
  return_dict=None,
@@ -1826,11 +1829,15 @@ class RTMDetModel(PreTrainedModel):
1826
 
1827
  Args:
1828
  pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
1829
- Pixel values. Pixel values can be obtained using
1830
- RTMDetImageProcessor.
 
 
 
 
 
1831
  labels (`List[Dict]`, *optional*):
1832
- Labels for computing the detection loss. Expected format:
1833
- List of dicts with 'boxes' and 'labels' keys.
1834
  output_hidden_states (`bool`, *optional*):
1835
  Whether or not to return the hidden states of all layers.
1836
  return_dict (`bool`, *optional*):
@@ -1838,9 +1845,8 @@ class RTMDetModel(PreTrainedModel):
1838
 
1839
  Returns:
1840
  `DetectionOutput` or `tuple`:
1841
- If return_dict=True, `DetectionOutput` is returned.
1842
- If return_dict=False, a tuple is returned where the first element
1843
- is the detection output tensor.
1844
  """
1845
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1846
 
@@ -1886,6 +1892,22 @@ class RTMDetModel(PreTrainedModel):
1886
  max_per_img=self.config.max_detections
1887
  )
1888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1889
  if return_dict:
1890
  return results
1891
  else:
 
28
 
29
  Args:
30
  boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
31
+ Detection boxes in format [x1, y1, x2, y2]. Coordinates are in
32
+ model-input space (640×640) by default, or in original image pixel
33
+ space when ``original_size`` was passed to ``forward()``.
34
  scores (`torch.FloatTensor` of shape `(batch_size, num_boxes)`):
35
  Detection confidence scores.
36
  labels (`torch.LongTensor` of shape `(batch_size, num_boxes)`):
 
1819
  def forward(
1820
  self,
1821
  pixel_values=None,
1822
+ original_size=None,
1823
  labels=None,
1824
  output_hidden_states=None,
1825
  return_dict=None,
 
1829
 
1830
  Args:
1831
  pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
1832
+ Pixel values resized to 640×640 by the image processor.
1833
+ original_size (`Tuple[int, int]`, *optional*):
1834
+ ``(height, width)`` of the **original** image before preprocessing.
1835
+ When supplied, the returned boxes are automatically scaled from
1836
+ 640×640 model-input space to original image pixel coordinates so
1837
+ the caller never needs to compute ``sx = orig_w / 640`` manually.
1838
+ All images in the batch are assumed to share the same original size.
1839
  labels (`List[Dict]`, *optional*):
1840
+ Labels for computing the detection loss.
 
1841
  output_hidden_states (`bool`, *optional*):
1842
  Whether or not to return the hidden states of all layers.
1843
  return_dict (`bool`, *optional*):
 
1845
 
1846
  Returns:
1847
  `DetectionOutput` or `tuple`:
1848
+ Boxes are in 640×640 space by default, or in original image space
1849
+ when ``original_size`` is provided.
 
1850
  """
1851
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1852
 
 
1892
  max_per_img=self.config.max_detections
1893
  )
1894
 
1895
+ # Scale boxes from 640×640 model space → original image space if requested
1896
+ if original_size is not None:
1897
+ orig_h, orig_w = original_size
1898
+ sx = orig_w / width # width == 640
1899
+ sy = orig_h / height # height == 640
1900
+ scaled_boxes = results.boxes.clone()
1901
+ scaled_boxes[..., 0] *= sx # x1
1902
+ scaled_boxes[..., 2] *= sx # x2
1903
+ scaled_boxes[..., 1] *= sy # y1
1904
+ scaled_boxes[..., 3] *= sy # y2
1905
+ results = DetectionOutput(
1906
+ boxes=scaled_boxes,
1907
+ scores=results.scores,
1908
+ labels=results.labels,
1909
+ )
1910
+
1911
  if return_dict:
1912
  return results
1913
  else: