| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| def autopad(k, p=None): |
| if p is None: p = k // 2 if isinstance(k, int) else [x // 2 for x in k] |
| return p |
|
|
| class Conv(nn.Module): |
| def __init__( |
| self, |
| c1, |
| c2, |
| k=1, |
| s=1, |
| p=None, |
| g=1, |
| act=True |
| ): |
| super().__init__() |
| self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) |
| self.bn = nn.BatchNorm2d(c2) |
| self.act = nn.SiLU() if act else nn.Identity() |
|
|
| def forward(self, x): |
| return self.act( |
| self.bn( |
| self.conv(x) |
| ) |
| ) |
|
|
| class DSConv(nn.Module): |
| def __init__( |
| self, |
| c1, |
| c2, |
| k=3, |
| s=1, |
| p=None, |
| act=True |
| ): |
| super().__init__() |
| self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False) |
| self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False) |
| self.bn = nn.BatchNorm2d(c2) |
| self.act = nn.SiLU() if act else nn.Identity() |
|
|
| def forward(self, x): |
| return self.act( |
| self.bn( |
| self.pwconv( |
| self.dwconv(x) |
| ) |
| ) |
| ) |
|
|
| class DS_Bottleneck(nn.Module): |
| def __init__( |
| self, |
| c1, |
| c2, |
| k=3, |
| shortcut=True |
| ): |
| super().__init__() |
| self.dsconv1 = DSConv(c1, c1, k=3, s=1) |
| self.dsconv2 = DSConv(c1, c2, k=k, s=1) |
| self.shortcut = shortcut and c1 == c2 |
|
|
| def forward(self, x): |
| return x + self.dsconv2(self.dsconv1(x)) if self.shortcut else self.dsconv2(self.dsconv1(x)) |
|
|
| class DS_C3k(nn.Module): |
| def __init__( |
| self, |
| c1, |
| c2, |
| n=1, |
| k=3, |
| e=0.5 |
| ): |
| super().__init__() |
| self.cv1 = Conv(c1, int(c2 * e), 1, 1) |
| self.cv2 = Conv(c1, int(c2 * e), 1, 1) |
| self.cv3 = Conv(2 * int(c2 * e), c2, 1, 1) |
| self.m = nn.Sequential( |
| *[ |
| DS_Bottleneck( |
| int(c2 * e), |
| int(c2 * e), |
| k=k, |
| shortcut=True |
| ) |
| for _ in range(n) |
| ] |
| ) |
|
|
| def forward(self, x): |
| return self.cv3( |
| torch.cat( |
| (self.m(self.cv1(x)), self.cv2(x)), |
| dim=1 |
| ) |
| ) |
|
|
| class DS_C3k2(nn.Module): |
| def __init__( |
| self, |
| c1, |
| c2, |
| n=1, |
| k=3, |
| e=0.5 |
| ): |
| super().__init__() |
| self.cv1 = Conv(c1, int(c2 * e), 1, 1) |
| self.m = DS_C3k(int(c2 * e), int(c2 * e), n=n, k=k, e=1.0) |
| self.cv2 = Conv(int(c2 * e), c2, 1, 1) |
|
|
| def forward(self, x): |
| return self.cv2( |
| self.m( |
| self.cv1(x) |
| ) |
| ) |
|
|
| class AdaptiveHyperedgeGeneration(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| num_hyperedges, |
| num_heads |
| ): |
| super().__init__() |
| self.num_hyperedges = num_hyperedges |
| self.num_heads = num_heads |
| self.head_dim = max(1, in_channels // num_heads) |
| self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels)) |
| self.context_mapper = nn.Linear(2 * in_channels, num_hyperedges * in_channels, bias=False) |
| self.query_proj = nn.Linear(in_channels, in_channels, bias=False) |
| self.scale = self.head_dim ** -0.5 |
|
|
| def forward(self, x): |
| B, N, C = x.shape |
| P = ( |
| self.global_proto.unsqueeze(0) + |
| self.context_mapper( |
| torch.cat( |
| ( |
| F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1), |
| F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1) |
| ), |
| dim=1 |
| ) |
| ).view(B, self.num_hyperedges, C)) |
|
|
| return F.softmax(( |
| (self.query_proj(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) @ P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(0, 2, 3, 1)) * self.scale |
| ).mean(dim=1).permute(0, 2, 1), dim=-1) |
|
|
| class HypergraphConvolution(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| out_channels |
| ): |
| super().__init__() |
| self.W_e = nn.Linear(in_channels, in_channels, bias=False) |
| self.W_v = nn.Linear(in_channels, out_channels, bias=False) |
| self.act = nn.SiLU() |
|
|
| def forward(self, x, A): |
| return x + self.act(self.W_v(A.transpose(1, 2).bmm(self.act(self.W_e(A.bmm(x)))))) |
|
|
| class AdaptiveHypergraphComputation(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| num_hyperedges, |
| num_heads |
| ): |
| super().__init__() |
| self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(in_channels, num_hyperedges, num_heads) |
| self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels) |
|
|
| def forward(self, x): |
| B, _, H, W = x.shape |
| x_flat = x.flatten(2).permute(0, 2, 1) |
|
|
| return self.hypergraph_conv(x_flat, self.adaptive_hyperedge_gen(x_flat)).permute(0, 2, 1).view(B, -1, H, W) |
|
|
| class C3AH(nn.Module): |
| def __init__( |
| self, |
| c1, |
| c2, |
| num_hyperedges, |
| num_heads, |
| e=0.5 |
| ): |
| super().__init__() |
| self.cv1 = Conv(c1, int(c1 * e), 1, 1) |
| self.cv2 = Conv(c1, int(c1 * e), 1, 1) |
| self.ahc = AdaptiveHypergraphComputation(int(c1 * e), int(c1 * e), num_hyperedges, num_heads) |
| self.cv3 = Conv(2 * int(c1 * e), c2, 1, 1) |
|
|
| def forward(self, x): |
| return self.cv3( |
| torch.cat( |
| (self.ahc(self.cv2(x)), self.cv1(x)), |
| dim=1 |
| ) |
| ) |
|
|
| class HyperACE(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| num_hyperedges=16, |
| num_heads=8, |
| k=2, |
| l=1, |
| c_h=0.5, |
| c_l=0.25 |
| ): |
| super().__init__() |
| c2, c3, c4, c5 = in_channels |
| c_mid = c4 |
| self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1) |
| self.c_h = int(c_mid * c_h) |
| self.c_l = int(c_mid * c_l) |
| self.c_s = c_mid - self.c_h - self.c_l |
| self.high_order_branch = nn.ModuleList([ |
| C3AH( |
| self.c_h, |
| self.c_h, |
| num_hyperedges=num_hyperedges, |
| num_heads=num_heads, e=1.0 |
| ) |
| for _ in range(k) |
| ]) |
| self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1) |
| self.low_order_branch = nn.Sequential( |
| *[ |
| DS_C3k( |
| self.c_l, |
| self.c_l, |
| n=1, |
| k=3, |
| e=1.0 |
| ) |
| for _ in range(l) |
| ] |
| ) |
| self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1) |
|
|
| def forward(self, x): |
| B2, B3, B4, B5 = x |
| _, _, H4, W4 = B4.shape |
|
|
| x_h, x_l, x_s = self.fuse_conv( |
| torch.cat( |
| ( |
| F.interpolate( |
| B2, |
| size=(H4, W4), |
| mode='bilinear', |
| align_corners=False |
| ), |
| F.interpolate( |
| B3, |
| size=(H4, W4), |
| mode='bilinear', |
| align_corners=False |
| ), |
| B4, |
| F.interpolate( |
| B5, |
| size=(H4, W4), |
| mode='bilinear', |
| align_corners=False |
| ) |
| ), |
| dim=1 |
| ) |
| ).split([self.c_h, self.c_l, self.c_s], dim=1) |
|
|
| return self.final_fuse( |
| torch.cat( |
| ( |
| self.high_order_fuse(torch.cat([m(x_h) for m in self.high_order_branch], dim=1)), |
| self.low_order_branch(x_l), |
| x_s |
| ), |
| dim=1 |
| ) |
| ) |
|
|
| class GatedFusion(nn.Module): |
| def __init__( |
| self, |
| in_channels |
| ): |
| super().__init__() |
| self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1)) |
|
|
| def forward(self, f_in, h): |
| return f_in + self.gamma * h |
|
|
| class YOLO13Encoder(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| base_channels=32 |
| ): |
| super().__init__() |
| self.stem = DSConv( |
| in_channels, |
| base_channels, |
| k=3, |
| s=1 |
| ) |
| self.p2 = nn.Sequential( |
| DSConv( |
| base_channels, |
| base_channels*2, k=3, s=(2, 2)), |
| DS_C3k2( |
| base_channels*2, |
| base_channels*2, |
| n=1 |
| ) |
| ) |
| self.p3 = nn.Sequential( |
| DSConv( |
| base_channels*2, |
| base_channels*4, |
| k=3, |
| s=(2, 2) |
| ), |
| DS_C3k2( |
| base_channels*4, |
| base_channels*4, |
| n=2 |
| ) |
| ) |
| self.p4 = nn.Sequential( |
| DSConv( |
| base_channels*4, |
| base_channels*8, |
| k=3, |
| s=(2, 2) |
| ), |
| DS_C3k2( |
| base_channels*8, |
| base_channels*8, |
| n=2 |
| ) |
| ) |
| self.p5 = nn.Sequential( |
| DSConv( |
| base_channels*8, |
| base_channels*16, |
| k=3, |
| s=(2, 2) |
| ), |
| DS_C3k2( |
| base_channels*16, |
| base_channels*16, |
| n=1 |
| ) |
| ) |
| |
| self.out_channels = [base_channels*2, base_channels*4, base_channels*8, base_channels*16] |
|
|
| def forward(self, x): |
| p2 = self.p2(self.stem(x)) |
| p3 = self.p3(p2) |
| p4 = self.p4(p3) |
| p5 = self.p5(p4) |
|
|
| return [p2, p3, p4, p5] |
|
|
| class YOLO13FullPADDecoder(nn.Module): |
| def __init__(self, encoder_channels, hyperace_out_c, out_channels_final): |
| super().__init__() |
| c_p2, c_p3, c_p4, c_p5 = encoder_channels |
| c_d5, c_d4, c_d3, c_d2 = c_p5, c_p4, c_p3, c_p2 |
| |
| self.h_to_d5 = Conv( |
| hyperace_out_c, |
| c_d5, |
| 1, |
| 1 |
| ) |
| self.h_to_d4 = Conv( |
| hyperace_out_c, |
| c_d4, |
| 1, |
| 1 |
| ) |
| self.h_to_d3 = Conv( |
| hyperace_out_c, |
| c_d3, |
| 1, |
| 1 |
| ) |
| self.h_to_d2 = Conv( |
| hyperace_out_c, |
| c_d2, |
| 1, |
| 1 |
| ) |
|
|
| self.fusion_d5 = GatedFusion(c_d5) |
| self.fusion_d4 = GatedFusion(c_d4) |
| self.fusion_d3 = GatedFusion(c_d3) |
| self.fusion_d2 = GatedFusion(c_d2) |
|
|
| self.skip_p5 = Conv( |
| c_p5, |
| c_d5, |
| 1, |
| 1 |
| ) |
| self.skip_p4 = Conv( |
| c_p4, |
| c_d4, |
| 1, |
| 1 |
| ) |
| self.skip_p3 = Conv( |
| c_p3, |
| c_d3, |
| 1, |
| 1 |
| ) |
| self.skip_p2 = Conv( |
| c_p2, |
| c_d2, |
| 1, |
| 1 |
| ) |
|
|
| self.up_d5 = DS_C3k2( |
| c_d5, |
| c_d4, |
| n=1 |
| ) |
| self.up_d4 = DS_C3k2( |
| c_d4, |
| c_d3, |
| n=1 |
| ) |
| self.up_d3 = DS_C3k2( |
| c_d3, |
| c_d2, |
| n=1 |
| ) |
| |
| self.final_d2 = DS_C3k2( |
| c_d2, |
| c_d2, |
| n=1 |
| ) |
| self.final_conv = Conv( |
| c_d2, |
| out_channels_final, |
| 1, |
| 1 |
| ) |
|
|
| def forward(self, enc_feats, h_ace): |
| p2, p3, p4, p5 = enc_feats |
| d5 = self.skip_p5(p5) |
|
|
| d4 = self.up_d5( |
| F.interpolate( |
| self.fusion_d5(d5, self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode='bilinear', align_corners=False))), |
| size=p4.shape[2:], |
| mode='bilinear', |
| align_corners=False |
| ) |
| ) + self.skip_p4(p4) |
|
|
| d3 = self.up_d4( |
| F.interpolate( |
| self.fusion_d4(d4, self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode='bilinear', align_corners=False))), |
| size=p3.shape[2:], |
| mode='bilinear', |
| align_corners=False |
| ) |
| ) + self.skip_p3(p3) |
|
|
| d2 = self.up_d3( |
| F.interpolate( |
| self.fusion_d3(d3, self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode='bilinear', align_corners=False))), |
| size=p2.shape[2:], |
| mode='bilinear', |
| align_corners=False |
| ) |
| ) + self.skip_p2(p2) |
|
|
| return self.final_conv( |
| self.final_d2( |
| self.fusion_d2( |
| d2, |
| self.h_to_d2( |
| F.interpolate(h_ace, size=d2.shape[2:], mode='bilinear', align_corners=False) |
| ) |
| ) |
| ) |
| ) |