Nick
Restore full project files after history cleanup
1c1a0c5
"""
UNet model definition for lane segmentation.
Includes DoubleConv block and UNet architecture.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# --- DoubleConv Block ---
class DoubleConv(nn.Module):
"""
(Conv => BN => ReLU) * 2 block used in UNet encoder/decoder.
Args:
in_channels: Number of input channels
out_channels: Number of output channels
"""
def __init__(self, in_channels: int, out_channels: int):
"""
Args:
in_channels (int): Number of input channels
out_channels (int): Number of output channels
"""
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for DoubleConv block.
Args:
x (torch.Tensor): Input tensor
Returns:
torch.Tensor: Output tensor
"""
return self.double_conv(x)
# --- UNet Model ---
class UNet(nn.Module):
"""
U-Net: Convolutional Networks for Biomedical Image Segmentation
Args:
in_channels: Number of input channels
out_channels: Number of output channels
"""
def __init__(self, in_channels: int = 3, out_channels: int = 1):
"""
Args:
in_channels (int): Number of input channels
out_channels (int): Number of output channels
"""
super().__init__()
self.encoder1 = DoubleConv(in_channels, 64)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.encoder2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.encoder3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(kernel_size=2)
self.encoder4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(kernel_size=2)
self.bottleneck = DoubleConv(512, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.decoder4 = DoubleConv(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.decoder3 = DoubleConv(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.decoder2 = DoubleConv(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder1 = DoubleConv(128, 64)
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for UNet model.
Args:
x (torch.Tensor): Input tensor
Returns:
torch.Tensor: Output tensor
"""
"""
Forward pass of UNet.
Args:
x: Input tensor of shape (B, C, H, W)
Returns:
Output tensor of shape (B, out_channels, H, W)
"""
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat([dec4, enc4], dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat([dec3, enc3], dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat([dec2, enc2], dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat([dec1, enc1], dim=1)
dec1 = self.decoder1(dec1)
out = self.final_conv(dec1)
return torch.sigmoid(out)
# --- Model Summary and FLOPs (Optional) ---
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=3, out_channels=1).to(device)
dummy_input = torch.randn(1, 3, 256, 256).to(device)
output = model(dummy_input)
print(f"Output shape: {output.shape}")
# Model summary
from torchinfo import summary
print(summary(model, input_size=(1, 3, 256, 256), device=device))
# FLOPs and parameters
from thop import profile
flops, params = profile(model, inputs=(dummy_input,))
print(f"FLOPs: {flops:,}")
print(f"Parameters: {params:,}")