Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class UNet(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(UNet, self).__init__() | |
def conv_block(in_channels, out_channels): | |
return nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True) | |
) | |
self.encoder1 = conv_block(in_channels, 64) | |
self.encoder2 = conv_block(64, 128) | |
self.encoder3 = conv_block(128, 256) | |
self.encoder4 = conv_block(256, 512) | |
self.bottleneck = conv_block(512, 1024) | |
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) | |
self.decoder4 = conv_block(1024, 512) | |
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) | |
self.decoder3 = conv_block(512, 256) | |
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | |
self.decoder2 = conv_block(256, 128) | |
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | |
self.decoder1 = conv_block(128, 64) | |
self.final = nn.Conv2d(64, out_channels, kernel_size=1) | |
def forward(self, x): | |
enc1 = self.encoder1(x) | |
enc2 = self.encoder2(F.max_pool2d(enc1, kernel_size=2, stride=2)) | |
enc3 = self.encoder3(F.max_pool2d(enc2, kernel_size=2, stride=2)) | |
enc4 = self.encoder4(F.max_pool2d(enc3, kernel_size=2, stride=2)) | |
bottleneck = self.bottleneck(F.max_pool2d(enc4, kernel_size=2, stride=2)) | |
dec4 = self.upconv4(bottleneck) | |
dec4 = torch.cat((dec4, enc4), dim=1) | |
dec4 = self.decoder4(dec4) | |
dec3 = self.upconv3(dec4) | |
dec3 = torch.cat((dec3, enc3), dim=1) | |
dec3 = self.decoder3(dec3) | |
dec2 = self.upconv2(dec3) | |
dec2 = torch.cat((dec2, enc2), dim=1) | |
dec2 = self.decoder2(dec2) | |
dec1 = self.upconv1(dec2) | |
dec1 = torch.cat((dec1, enc1), dim=1) | |
dec1 = self.decoder1(dec1) | |
return self.final(dec1) | |
if __name__ == "__main__": | |
model = UNet(in_channels=3,out_channels=7) | |
fake_img = torch.rand(size=(2,3,224,224)) | |
print(fake_img.shape) | |
# torch.Size([2, 3, 224, 224]) | |
out = model(fake_img) | |
print(out.shape) | |
# torch.Size([2, 7, 224, 224]) |