die_demo / die_model.py
gabar92's picture
add implementation scripts
a9d81c5
"""
U-Net based DIE model for cleaning document.
"""
import os
from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
x = self.conv(x)
x = torch.sigmoid(x)
return x
class UNet(nn.Module):
def __init__(self, n_channels, output_channel_dim=1, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = output_channel_dim
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, output_channel_dim)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
def add_gaussian_noise(
data: torch.Tensor
) -> torch.Tensor:
"""
Adding gaussian noise to torch tensor.
:param data: torch tensor
:return: noise perturbed tensor
"""
data_with_noise = data.clone()
data_with_noise += torch.normal(mean=0, std=0.05, size=data_with_noise.shape).to(data_with_noise.device)
data_with_noise = data_with_noise.clip(min=0, max=1)
return data_with_noise
def inference_model(
model: Callable,
model_input: torch.Tensor,
device: str | torch.device,
num_of_iterations: int = 1
) -> list[torch.Tensor, ...]:
"""
Performing model inference.
:param model: image pre-processing model
:param model_input: data to model
:param device: cuda device
:param num_of_iterations: defines how many times feed the network (recursively)
:return: predictions
"""
# inference model
with torch.no_grad():
prediction_list = []
model_input = model_input.to(device)
if len(model_input.shape) == 3:
model_input = model_input.unsqueeze(dim=0)
model_input_original_part = model_input[:, 0:3, ...]
for i in range(num_of_iterations):
if i == 0:
model_input = add_gaussian_noise(model_input)
prediction = model(model_input)
prediction_list.append(prediction)
model_input_new = torch.cat((model_input_original_part, prediction.detach()), dim=1)
else:
model_input_perturbed = add_gaussian_noise(model_input_new)
prediction = model(model_input_perturbed)
prediction_list.append(prediction)
model_input_new = torch.cat((model_input_original_part, prediction.detach()), dim=1)
return prediction_list
def load_unet(
model_path: str,
device: str = 'cpu',
eval_mode: bool = False,
n_channels: int = 4,
bilinear: bool = False,
output_channel_dim: int = 1
):
print("Loading UNet model...")
# image preprocessing model
model = UNet(
n_channels=n_channels,
bilinear=bilinear,
output_channel_dim=output_channel_dim
)
# this hack is required due to distributed data parallel training
state_dict = torch.load(os.path.join(model_path), map_location=device)
new_state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
model.load_state_dict(new_state_dict)
model.to(device)
if eval_mode:
model.eval()
return model
class UNetDIEModel:
"""
Class for Document Image Enhancement with U-Net.
"""
def __init__(
self,
*args,
**kwargs
):
"""
Initialization.
"""
self.args = kwargs['args']
# loading text detector model
self.die = load_unet(
model_path=self.args.die_model_path,
device=self.args.device,
eval_mode=True,
)
def enhance_document_image(
self,
image_raw_list: list[Image.Image],
num_of_die_iterations: int = 1,
) -> list[Image.Image]:
""""
Enhance document image by removing noise.
:param image_raw_list: original document page to process
:param num_of_die_iterations: number of DIE iterations
:return: cleaned document page to process
"""
with torch.no_grad():
# image_die = torch.stack(image_die_list, dim=0)
image_die = torch.stack(image_raw_list, dim=0)
# document image enhancement
prediction_list = inference_model(
model=self.die,
model_input=image_die,
num_of_iterations=num_of_die_iterations,
device=self.args.device
)
# transform DIE model output to image and apply post-processing
last_prediction = prediction_list[-1]
batch_size = last_prediction.size(0)
image_die_list = [T.ToPILImage()(last_prediction[idx, ...]).convert('RGB') for idx in range(batch_size)]
return image_die_list