Spaces:
Sleeping
Sleeping
🐛 [Fix] params mismatch with origin v9
Browse files- yolo/model/module.py +11 -5
yolo/model/module.py
CHANGED
|
@@ -49,14 +49,16 @@ class Pool(nn.Module):
|
|
| 49 |
class Detection(nn.Module):
|
| 50 |
"""A single YOLO Detection head for detection models"""
|
| 51 |
|
| 52 |
-
def __init__(self, in_channels: int, num_classes: int, *, reg_max: int = 16, use_group: bool = True):
|
| 53 |
super().__init__()
|
| 54 |
|
| 55 |
groups = 4 if use_group else 1
|
| 56 |
anchor_channels = 4 * reg_max
|
|
|
|
|
|
|
| 57 |
# TODO: round up head[0] channels or each head?
|
| 58 |
-
anchor_neck = max(round_up(
|
| 59 |
-
class_neck = max(
|
| 60 |
|
| 61 |
self.anchor_conv = nn.Sequential(
|
| 62 |
Conv(in_channels, anchor_neck, 3),
|
|
@@ -78,8 +80,12 @@ class MultiheadDetection(nn.Module):
|
|
| 78 |
|
| 79 |
def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
|
| 80 |
super().__init__()
|
|
|
|
| 81 |
self.heads = nn.ModuleList(
|
| 82 |
-
[
|
|
|
|
|
|
|
|
|
|
| 83 |
)
|
| 84 |
|
| 85 |
def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
@@ -118,7 +124,7 @@ class RepNBottleneck(nn.Module):
|
|
| 118 |
*,
|
| 119 |
kernel_size: Tuple[int, int] = (3, 3),
|
| 120 |
residual: bool = True,
|
| 121 |
-
expand: float = 0
|
| 122 |
**kwargs
|
| 123 |
):
|
| 124 |
super().__init__()
|
|
|
|
| 49 |
class Detection(nn.Module):
|
| 50 |
"""A single YOLO Detection head for detection models"""
|
| 51 |
|
| 52 |
+
def __init__(self, in_channels: Tuple[int], num_classes: int, *, reg_max: int = 16, use_group: bool = True):
|
| 53 |
super().__init__()
|
| 54 |
|
| 55 |
groups = 4 if use_group else 1
|
| 56 |
anchor_channels = 4 * reg_max
|
| 57 |
+
|
| 58 |
+
first_neck, in_channels = in_channels
|
| 59 |
# TODO: round up head[0] channels or each head?
|
| 60 |
+
anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, 16)
|
| 61 |
+
class_neck = max(first_neck, min(num_classes * 2, 128))
|
| 62 |
|
| 63 |
self.anchor_conv = nn.Sequential(
|
| 64 |
Conv(in_channels, anchor_neck, 3),
|
|
|
|
| 80 |
|
| 81 |
def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
|
| 82 |
super().__init__()
|
| 83 |
+
# TODO: Refactor these parts
|
| 84 |
self.heads = nn.ModuleList(
|
| 85 |
+
[
|
| 86 |
+
Detection((in_channels[3 * (idx // 3)], in_channel), num_classes, **head_kwargs)
|
| 87 |
+
for idx, in_channel in enumerate(in_channels)
|
| 88 |
+
]
|
| 89 |
)
|
| 90 |
|
| 91 |
def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
|
|
| 124 |
*,
|
| 125 |
kernel_size: Tuple[int, int] = (3, 3),
|
| 126 |
residual: bool = True,
|
| 127 |
+
expand: float = 1.0,
|
| 128 |
**kwargs
|
| 129 |
):
|
| 130 |
super().__init__()
|