|
import cv2
|
|
import copy
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .utils.yolov5_utils import fuse_conv_and_bn
|
|
from .utils.weight_init import init_weights
|
|
from .yolov5.yolo import load_yolov5_ckpt
|
|
from .yolov5.common import C3, Conv
|
|
|
|
TEXTDET_MASK = 0
|
|
TEXTDET_DET = 1
|
|
TEXTDET_INFERENCE = 2
|
|
|
|
class double_conv_up_c3(nn.Module):
|
|
def __init__(self, in_ch, mid_ch, out_ch, act=True):
|
|
super(double_conv_up_c3, self).__init__()
|
|
self.conv = nn.Sequential(
|
|
C3(in_ch+mid_ch, mid_ch, act=act),
|
|
nn.ConvTranspose2d(mid_ch, out_ch, kernel_size=4, stride = 2, padding=1, bias=False),
|
|
nn.BatchNorm2d(out_ch),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
class double_conv_c3(nn.Module):
|
|
def __init__(self, in_ch, out_ch, stride=1, act=True):
|
|
super(double_conv_c3, self).__init__()
|
|
if stride > 1:
|
|
self.down = nn.AvgPool2d(2,stride=2) if stride > 1 else None
|
|
self.conv = C3(in_ch, out_ch, act=act)
|
|
|
|
def forward(self, x):
|
|
if self.down is not None:
|
|
x = self.down(x)
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
class UnetHead(nn.Module):
|
|
def __init__(self, act=True) -> None:
|
|
|
|
super(UnetHead, self).__init__()
|
|
self.down_conv1 = double_conv_c3(512, 512, 2, act=act)
|
|
self.upconv0 = double_conv_up_c3(0, 512, 256, act=act)
|
|
self.upconv2 = double_conv_up_c3(256, 512, 256, act=act)
|
|
self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
|
|
self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
|
|
self.upconv5 = double_conv_up_c3(64, 128, 64, act=act)
|
|
self.upconv6 = nn.Sequential(
|
|
nn.ConvTranspose2d(64, 1, kernel_size=4, stride = 2, padding=1, bias=False),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, f160, f80, f40, f20, f3, forward_mode=TEXTDET_MASK):
|
|
|
|
d10 = self.down_conv1(f3)
|
|
u20 = self.upconv0(d10)
|
|
u40 = self.upconv2(torch.cat([f20, u20], dim = 1))
|
|
|
|
if forward_mode == TEXTDET_DET:
|
|
return f80, f40, u40
|
|
else:
|
|
u80 = self.upconv3(torch.cat([f40, u40], dim = 1))
|
|
u160 = self.upconv4(torch.cat([f80, u80], dim = 1))
|
|
u320 = self.upconv5(torch.cat([f160, u160], dim = 1))
|
|
mask = self.upconv6(u320)
|
|
if forward_mode == TEXTDET_MASK:
|
|
return mask
|
|
else:
|
|
return mask, [f80, f40, u40]
|
|
|
|
def init_weight(self, init_func):
|
|
self.apply(init_func)
|
|
|
|
class DBHead(nn.Module):
|
|
def __init__(self, in_channels, k = 50, shrink_with_sigmoid=True, act=True):
|
|
super().__init__()
|
|
self.k = k
|
|
self.shrink_with_sigmoid = shrink_with_sigmoid
|
|
self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
|
|
self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(128, in_channels, 1),
|
|
nn.BatchNorm2d(in_channels),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
self.binarize = nn.Sequential(
|
|
nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
|
|
nn.BatchNorm2d(in_channels // 4),
|
|
nn.ReLU(inplace=True),
|
|
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
|
|
nn.BatchNorm2d(in_channels // 4),
|
|
nn.ReLU(inplace=True),
|
|
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2)
|
|
)
|
|
self.thresh = self._init_thresh(in_channels)
|
|
|
|
def forward(self, f80, f40, u40, shrink_with_sigmoid=True, step_eval=False):
|
|
shrink_with_sigmoid = self.shrink_with_sigmoid
|
|
u80 = self.upconv3(torch.cat([f40, u40], dim = 1))
|
|
x = self.upconv4(torch.cat([f80, u80], dim = 1))
|
|
x = self.conv(x)
|
|
threshold_maps = self.thresh(x)
|
|
x = self.binarize(x)
|
|
shrink_maps = torch.sigmoid(x)
|
|
|
|
if self.training:
|
|
binary_maps = self.step_function(shrink_maps, threshold_maps)
|
|
if shrink_with_sigmoid:
|
|
return torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
|
|
else:
|
|
return torch.cat((shrink_maps, threshold_maps, binary_maps, x), dim=1)
|
|
else:
|
|
if step_eval:
|
|
return self.step_function(shrink_maps, threshold_maps)
|
|
else:
|
|
return torch.cat((shrink_maps, threshold_maps), dim=1)
|
|
|
|
def init_weight(self, init_func):
|
|
self.apply(init_func)
|
|
|
|
def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
|
|
in_channels = inner_channels
|
|
if serial:
|
|
in_channels += 1
|
|
self.thresh = nn.Sequential(
|
|
nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
|
|
nn.BatchNorm2d(inner_channels // 4),
|
|
nn.ReLU(inplace=True),
|
|
self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),
|
|
nn.BatchNorm2d(inner_channels // 4),
|
|
nn.ReLU(inplace=True),
|
|
self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
|
|
nn.Sigmoid())
|
|
return self.thresh
|
|
|
|
def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
|
|
if smooth:
|
|
inter_out_channels = out_channels
|
|
if out_channels == 1:
|
|
inter_out_channels = in_channels
|
|
module_list = [
|
|
nn.Upsample(scale_factor=2, mode='nearest'),
|
|
nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]
|
|
if out_channels == 1:
|
|
module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))
|
|
return nn.Sequential(module_list)
|
|
else:
|
|
return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
|
|
|
|
def step_function(self, x, y):
|
|
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
|
|
|
|
class TextDetector(nn.Module):
|
|
def __init__(self, weights, map_location='cpu', forward_mode=TEXTDET_MASK, act=True):
|
|
super(TextDetector, self).__init__()
|
|
|
|
yolov5s_backbone = load_yolov5_ckpt(weights=weights, map_location=map_location)
|
|
yolov5s_backbone.eval()
|
|
out_indices = [1, 3, 5, 7, 9]
|
|
yolov5s_backbone.out_indices = out_indices
|
|
yolov5s_backbone.model = yolov5s_backbone.model[:max(out_indices)+1]
|
|
self.act = act
|
|
self.seg_net = UnetHead(act=act)
|
|
self.backbone = yolov5s_backbone
|
|
self.dbnet = None
|
|
self.forward_mode = forward_mode
|
|
|
|
def train_mask(self):
|
|
self.forward_mode = TEXTDET_MASK
|
|
self.backbone.eval()
|
|
self.seg_net.train()
|
|
|
|
def initialize_db(self, unet_weights):
|
|
self.dbnet = DBHead(64, act=self.act)
|
|
self.seg_net.load_state_dict(torch.load(unet_weights, map_location='cpu')['weights'])
|
|
self.dbnet.init_weight(init_weights)
|
|
self.dbnet.upconv3 = copy.deepcopy(self.seg_net.upconv3)
|
|
self.dbnet.upconv4 = copy.deepcopy(self.seg_net.upconv4)
|
|
del self.seg_net.upconv3
|
|
del self.seg_net.upconv4
|
|
del self.seg_net.upconv5
|
|
del self.seg_net.upconv6
|
|
|
|
|
|
def train_db(self):
|
|
self.forward_mode = TEXTDET_DET
|
|
self.backbone.eval()
|
|
self.seg_net.eval()
|
|
self.dbnet.train()
|
|
|
|
def forward(self, x):
|
|
forward_mode = self.forward_mode
|
|
with torch.no_grad():
|
|
outs = self.backbone(x)
|
|
if forward_mode == TEXTDET_MASK:
|
|
return self.seg_net(*outs, forward_mode=forward_mode)
|
|
elif forward_mode == TEXTDET_DET:
|
|
with torch.no_grad():
|
|
outs = self.seg_net(*outs, forward_mode=forward_mode)
|
|
return self.dbnet(*outs)
|
|
|
|
def get_base_det_models(model_path, device='cpu', half=False, act='leaky'):
|
|
textdetector_dict = torch.load(model_path, map_location=device)
|
|
blk_det = load_yolov5_ckpt(textdetector_dict['blk_det'], map_location=device)
|
|
text_seg = UnetHead(act=act)
|
|
text_seg.load_state_dict(textdetector_dict['text_seg'])
|
|
text_det = DBHead(64, act=act)
|
|
text_det.load_state_dict(textdetector_dict['text_det'])
|
|
if half:
|
|
return blk_det.eval().half(), text_seg.eval().half(), text_det.eval().half()
|
|
return blk_det.eval().to(device), text_seg.eval().to(device), text_det.eval().to(device)
|
|
|
|
class TextDetBase(nn.Module):
|
|
def __init__(self, model_path, device='cpu', half=False, fuse=False, act='leaky'):
|
|
super(TextDetBase, self).__init__()
|
|
self.blk_det, self.text_seg, self.text_det = get_base_det_models(model_path, device, half, act=act)
|
|
if fuse:
|
|
self.fuse()
|
|
|
|
def fuse(self):
|
|
def _fuse(model):
|
|
for m in model.modules():
|
|
if isinstance(m, (Conv)) and hasattr(m, 'bn'):
|
|
m.conv = fuse_conv_and_bn(m.conv, m.bn)
|
|
delattr(m, 'bn')
|
|
m.forward = m.forward_fuse
|
|
return model
|
|
self.text_seg = _fuse(self.text_seg)
|
|
self.text_det = _fuse(self.text_det)
|
|
|
|
def forward(self, features):
|
|
blks, features = self.blk_det(features, detect=True)
|
|
mask, features = self.text_seg(*features, forward_mode=TEXTDET_INFERENCE)
|
|
lines = self.text_det(*features, step_eval=False)
|
|
return blks[0], mask, lines
|
|
|
|
class TextDetBaseDNN:
|
|
def __init__(self, input_size, model_path):
|
|
self.input_size = input_size
|
|
self.model = cv2.dnn.readNetFromONNX(model_path)
|
|
self.uoln = self.model.getUnconnectedOutLayersNames()
|
|
|
|
def __call__(self, im_in):
|
|
blob = cv2.dnn.blobFromImage(im_in, scalefactor=1 / 255.0, size=(self.input_size, self.input_size))
|
|
self.model.setInput(blob)
|
|
blks, mask, lines_map = self.model.forward(self.uoln)
|
|
return blks, mask, lines_map |