FaceSegLite / utils /small_unet_pretained.py
Djulo's picture
first commit
0df579e
raw
history blame contribute delete
No virus
9.06 kB
import torch
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
from io import BytesIO
from torchvision.utils import draw_segmentation_masks
from torchvision.transforms import v2 as T
from PIL import Image
import torch
import numpy as np
import cv2
def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
class UNet(nn.Module):
def __init__(self, num_classes):
super(UNet, self).__init__()
self.block1 = conv_block(3, 16)
self.block2 = conv_block(16, 32)
self.block3 = conv_block(32, 64)
self.block4 = conv_block(64, 128)
self.block5 = conv_block(128, 256)
self.upconv4 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.block6 = conv_block(256, 128)
self.upconv3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.block7 = conv_block(128, 64)
self.upconv2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.block8 = conv_block(64, 32)
self.upconv1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
self.block9 = conv_block(32, 16)
self.output = nn.Conv2d(16, num_classes, kernel_size=1)
def forward(self, x):
block1 = self.block1(x)
pool1 = F.max_pool2d(block1, 2)
block2 = self.block2(pool1)
pool2 = F.max_pool2d(block2, 2)
block3 = self.block3(pool2)
pool3 = F.max_pool2d(block3, 2)
block4 = self.block4(pool3)
pool4 = F.max_pool2d(block4, 2)
block5 = self.block5(pool4)
up6 = self.upconv4(block5)
concat6 = torch.cat([up6, block4], dim=1)
block6 = self.block6(concat6)
up7 = self.upconv3(block6)
concat7 = torch.cat([up7, block3], dim=1)
block7 = self.block7(concat7)
up8 = self.upconv2(block7)
concat8 = torch.cat([up8, block2], dim=1)
block8 = self.block8(concat8)
up9 = self.upconv1(block8)
concat9 = torch.cat([up9, block1], dim=1)
block9 = self.block9(concat9)
out = self.output(block9)
return out
class FaceModel(pl.LightningModule):
def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
super().__init__()
self.model = smp.create_model(
arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
)
# preprocessing parameteres for image
params = smp.encoders.get_preprocessing_params(encoder_name)
self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))
# for image segmentation dice loss could be the best first choice
self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
def forward(self, image):
# normalize image here
image = (image - self.mean) / self.std
mask = self.model(image)
return mask
def shared_step(self, batch, stage):
image = batch["image"]
# Shape of the image should be (batch_size, num_channels, height, width)
# if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
assert image.ndim == 4
# Check that image dimensions are divisible by 32,
# encoder and decoder connected by `skip connections` and usually encoder have 5 stages of
# downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have
# following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
# and we will get an error trying to concat these features
h, w = image.shape[2:]
assert h % 32 == 0 and w % 32 == 0
mask = batch["mask"]
# Shape of the mask should be [batch_size, num_classes, height, width]
# for binary segmentation num_classes = 1
assert mask.ndim == 4
# Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
assert mask.max() <= 1.0 and mask.min() >= 0
logits_mask = self.forward(image)
# Predicted mask contains logits, and loss_fn param `from_logits` is set to True
loss = self.loss_fn(logits_mask, mask)
# Lets compute metrics for some threshold
# first convert mask values to probabilities, then
# apply thresholding
prob_mask = logits_mask.sigmoid()
pred_mask = (prob_mask > 0.5).float()
# We will compute IoU metric by two ways
# 1. dataset-wise
# 2. image-wise
# but for now we just compute true positive, false positive, false negative and
# true negative 'pixels' for each image and class
# these values will be aggregated in the end of an epoch
tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")
return {
"loss": loss,
"tp": tp,
"fp": fp,
"fn": fn,
"tn": tn,
}
def shared_epoch_end(self, outputs, stage):
# aggregate step metics
tp = torch.cat([x["tp"] for x in outputs])
fp = torch.cat([x["fp"] for x in outputs])
fn = torch.cat([x["fn"] for x in outputs])
tn = torch.cat([x["tn"] for x in outputs])
# per image IoU means that we first calculate IoU score for each image
# and then compute mean over these scores
per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
# dataset IoU means that we aggregate intersection and union over whole dataset
# and then compute IoU score. The difference between dataset_iou and per_image_iou scores
# in this particular case will not be much, however for dataset
# with "empty" images (images without target class) a large gap could be observed.
# Empty images influence a lot on per_image_iou and much less on dataset_iou.
dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
metrics = {
f"{stage}_per_image_iou": per_image_iou,
f"{stage}_dataset_iou": dataset_iou,
}
self.log_dict(metrics, prog_bar=True)
def training_step(self, batch, batch_idx):
return self.shared_step(batch, "train")
def training_epoch_end(self, outputs):
return self.shared_epoch_end(outputs, "train")
def validation_step(self, batch, batch_idx):
return self.shared_step(batch, "valid")
def validation_epoch_end(self, outputs):
return self.shared_epoch_end(outputs, "valid")
def test_step(self, batch, batch_idx):
return self.shared_step(batch, "test")
def test_epoch_end(self, outputs):
return self.shared_epoch_end(outputs, "test")
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.0001)
def eval_transform(image):
target_size = (256, 256)
transforms = T.Compose([T.ToImage(), T.ToDtype(torch.float32, scale=True), T.Resize(target_size, antialias=True)])
return transforms(image)
def load_model():
unet_model = FaceModel("FPN", "timm-mobilenetv3_small_minimal_100", in_channels=3, out_classes=1)
unet_model.load_state_dict(torch.load('../models/fpn_trained_model_small_v1.pth', map_location=torch.device('cpu')))
unet_model.eval()
return unet_model
unet_model = load_model()
def predict_with_small_unet_pretained(buffer):
"""
Predict the mask with the fpn_resnet34 model
Args:
buffer (bytes): The image bytes
Returns:
bytes: The image bytes with the mask
"""
# Use OpenCV to read the image and get its shape
image = cv2.imdecode(np.frombuffer(buffer, np.uint8), cv2.IMREAD_UNCHANGED)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Sauvegardez les dimensions originales
original_dimensions = image.shape[:2]
image = eval_transform(image)
with torch.no_grad():
predictions = unet_model(image)
pr_masks = predictions.sigmoid()
masks = (pr_masks > 0.5).squeeze(1)
image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(torch.uint8)
image = image[:3, ...]
output_image = draw_segmentation_masks(image, masks, alpha=0.5, colors="blue")
output_image = cv2.resize(output_image.permute(1, 2, 0).numpy(), (original_dimensions[1], original_dimensions[0]))
pil_image = Image.fromarray(output_image)
image_buffer = BytesIO()
pil_image.save(image_buffer, format='JPEG')
img_encoded = image_buffer.getvalue()
return img_encoded