yolov6 / yolov6 /models /efficientrep.py
yourusername's picture
:beers: cheers
2a27594
raw history blame
No virus
2.79 kB
from torch import nn
from yolov6.layers.common import RepVGGBlock, RepBlock, SimSPPF
class EfficientRep(nn.Module):
'''EfficientRep Backbone
EfficientRep is handcrafted by hardware-aware neural network design.
With rep-style struct, EfficientRep is friendly to high-computation hardware(e.g. GPU).
'''
def __init__(
self,
in_channels=3,
channels_list=None,
num_repeats=None,
):
super().__init__()
assert channels_list is not None
assert num_repeats is not None
self.stem = RepVGGBlock(
in_channels=in_channels,
out_channels=channels_list[0],
kernel_size=3,
stride=2
)
self.ERBlock_2 = nn.Sequential(
RepVGGBlock(
in_channels=channels_list[0],
out_channels=channels_list[1],
kernel_size=3,
stride=2
),
RepBlock(
in_channels=channels_list[1],
out_channels=channels_list[1],
n=num_repeats[1]
)
)
self.ERBlock_3 = nn.Sequential(
RepVGGBlock(
in_channels=channels_list[1],
out_channels=channels_list[2],
kernel_size=3,
stride=2
),
RepBlock(
in_channels=channels_list[2],
out_channels=channels_list[2],
n=num_repeats[2]
)
)
self.ERBlock_4 = nn.Sequential(
RepVGGBlock(
in_channels=channels_list[2],
out_channels=channels_list[3],
kernel_size=3,
stride=2
),
RepBlock(
in_channels=channels_list[3],
out_channels=channels_list[3],
n=num_repeats[3]
)
)
self.ERBlock_5 = nn.Sequential(
RepVGGBlock(
in_channels=channels_list[3],
out_channels=channels_list[4],
kernel_size=3,
stride=2,
),
RepBlock(
in_channels=channels_list[4],
out_channels=channels_list[4],
n=num_repeats[4]
),
SimSPPF(
in_channels=channels_list[4],
out_channels=channels_list[4],
kernel_size=5
)
)
def forward(self, x):
outputs = []
x = self.stem(x)
x = self.ERBlock_2(x)
x = self.ERBlock_3(x)
outputs.append(x)
x = self.ERBlock_4(x)
outputs.append(x)
x = self.ERBlock_5(x)
outputs.append(x)
return tuple(outputs)