Sunday01's picture
up
9dce458
import cv2
import copy
import torch
import torch.nn as nn
from .utils.yolov5_utils import fuse_conv_and_bn
from .utils.weight_init import init_weights
from .yolov5.yolo import load_yolov5_ckpt
from .yolov5.common import C3, Conv
TEXTDET_MASK = 0
TEXTDET_DET = 1
TEXTDET_INFERENCE = 2
class double_conv_up_c3(nn.Module):
def __init__(self, in_ch, mid_ch, out_ch, act=True):
super(double_conv_up_c3, self).__init__()
self.conv = nn.Sequential(
C3(in_ch+mid_ch, mid_ch, act=act),
nn.ConvTranspose2d(mid_ch, out_ch, kernel_size=4, stride = 2, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
class double_conv_c3(nn.Module):
def __init__(self, in_ch, out_ch, stride=1, act=True):
super(double_conv_c3, self).__init__()
if stride > 1:
self.down = nn.AvgPool2d(2,stride=2) if stride > 1 else None
self.conv = C3(in_ch, out_ch, act=act)
def forward(self, x):
if self.down is not None:
x = self.down(x)
x = self.conv(x)
return x
class UnetHead(nn.Module):
def __init__(self, act=True) -> None:
super(UnetHead, self).__init__()
self.down_conv1 = double_conv_c3(512, 512, 2, act=act)
self.upconv0 = double_conv_up_c3(0, 512, 256, act=act)
self.upconv2 = double_conv_up_c3(256, 512, 256, act=act)
self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
self.upconv5 = double_conv_up_c3(64, 128, 64, act=act)
self.upconv6 = nn.Sequential(
nn.ConvTranspose2d(64, 1, kernel_size=4, stride = 2, padding=1, bias=False),
nn.Sigmoid()
)
def forward(self, f160, f80, f40, f20, f3, forward_mode=TEXTDET_MASK):
# input: 640@3
d10 = self.down_conv1(f3) # 512@10
u20 = self.upconv0(d10) # 256@10
u40 = self.upconv2(torch.cat([f20, u20], dim = 1)) # 256@40
if forward_mode == TEXTDET_DET:
return f80, f40, u40
else:
u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80
u160 = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160
u320 = self.upconv5(torch.cat([f160, u160], dim = 1)) # 64@320
mask = self.upconv6(u320)
if forward_mode == TEXTDET_MASK:
return mask
else:
return mask, [f80, f40, u40]
def init_weight(self, init_func):
self.apply(init_func)
class DBHead(nn.Module):
def __init__(self, in_channels, k = 50, shrink_with_sigmoid=True, act=True):
super().__init__()
self.k = k
self.shrink_with_sigmoid = shrink_with_sigmoid
self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
self.conv = nn.Sequential(
nn.Conv2d(128, in_channels, 1),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
self.binarize = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2)
)
self.thresh = self._init_thresh(in_channels)
def forward(self, f80, f40, u40, shrink_with_sigmoid=True, step_eval=False):
shrink_with_sigmoid = self.shrink_with_sigmoid
u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80
x = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160
x = self.conv(x)
threshold_maps = self.thresh(x)
x = self.binarize(x)
shrink_maps = torch.sigmoid(x)
if self.training:
binary_maps = self.step_function(shrink_maps, threshold_maps)
if shrink_with_sigmoid:
return torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
else:
return torch.cat((shrink_maps, threshold_maps, binary_maps, x), dim=1)
else:
if step_eval:
return self.step_function(shrink_maps, threshold_maps)
else:
return torch.cat((shrink_maps, threshold_maps), dim=1)
def init_weight(self, init_func):
self.apply(init_func)
def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
in_channels = inner_channels
if serial:
in_channels += 1
self.thresh = nn.Sequential(
nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.BatchNorm2d(inner_channels // 4),
nn.ReLU(inplace=True),
self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),
nn.BatchNorm2d(inner_channels // 4),
nn.ReLU(inplace=True),
self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
nn.Sigmoid())
return self.thresh
def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
if smooth:
inter_out_channels = out_channels
if out_channels == 1:
inter_out_channels = in_channels
module_list = [
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]
if out_channels == 1:
module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))
return nn.Sequential(module_list)
else:
return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
def step_function(self, x, y):
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
class TextDetector(nn.Module):
def __init__(self, weights, map_location='cpu', forward_mode=TEXTDET_MASK, act=True):
super(TextDetector, self).__init__()
yolov5s_backbone = load_yolov5_ckpt(weights=weights, map_location=map_location)
yolov5s_backbone.eval()
out_indices = [1, 3, 5, 7, 9]
yolov5s_backbone.out_indices = out_indices
yolov5s_backbone.model = yolov5s_backbone.model[:max(out_indices)+1]
self.act = act
self.seg_net = UnetHead(act=act)
self.backbone = yolov5s_backbone
self.dbnet = None
self.forward_mode = forward_mode
def train_mask(self):
self.forward_mode = TEXTDET_MASK
self.backbone.eval()
self.seg_net.train()
def initialize_db(self, unet_weights):
self.dbnet = DBHead(64, act=self.act)
self.seg_net.load_state_dict(torch.load(unet_weights, map_location='cpu')['weights'])
self.dbnet.init_weight(init_weights)
self.dbnet.upconv3 = copy.deepcopy(self.seg_net.upconv3)
self.dbnet.upconv4 = copy.deepcopy(self.seg_net.upconv4)
del self.seg_net.upconv3
del self.seg_net.upconv4
del self.seg_net.upconv5
del self.seg_net.upconv6
# del self.seg_net.conv_mask
def train_db(self):
self.forward_mode = TEXTDET_DET
self.backbone.eval()
self.seg_net.eval()
self.dbnet.train()
def forward(self, x):
forward_mode = self.forward_mode
with torch.no_grad():
outs = self.backbone(x)
if forward_mode == TEXTDET_MASK:
return self.seg_net(*outs, forward_mode=forward_mode)
elif forward_mode == TEXTDET_DET:
with torch.no_grad():
outs = self.seg_net(*outs, forward_mode=forward_mode)
return self.dbnet(*outs)
def get_base_det_models(model_path, device='cpu', half=False, act='leaky'):
textdetector_dict = torch.load(model_path, map_location=device)
blk_det = load_yolov5_ckpt(textdetector_dict['blk_det'], map_location=device)
text_seg = UnetHead(act=act)
text_seg.load_state_dict(textdetector_dict['text_seg'])
text_det = DBHead(64, act=act)
text_det.load_state_dict(textdetector_dict['text_det'])
if half:
return blk_det.eval().half(), text_seg.eval().half(), text_det.eval().half()
return blk_det.eval().to(device), text_seg.eval().to(device), text_det.eval().to(device)
class TextDetBase(nn.Module):
def __init__(self, model_path, device='cpu', half=False, fuse=False, act='leaky'):
super(TextDetBase, self).__init__()
self.blk_det, self.text_seg, self.text_det = get_base_det_models(model_path, device, half, act=act)
if fuse:
self.fuse()
def fuse(self):
def _fuse(model):
for m in model.modules():
if isinstance(m, (Conv)) and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm
m.forward = m.forward_fuse # update forward
return model
self.text_seg = _fuse(self.text_seg)
self.text_det = _fuse(self.text_det)
def forward(self, features):
blks, features = self.blk_det(features, detect=True)
mask, features = self.text_seg(*features, forward_mode=TEXTDET_INFERENCE)
lines = self.text_det(*features, step_eval=False)
return blks[0], mask, lines
class TextDetBaseDNN:
def __init__(self, input_size, model_path):
self.input_size = input_size
self.model = cv2.dnn.readNetFromONNX(model_path)
self.uoln = self.model.getUnconnectedOutLayersNames()
def __call__(self, im_in):
blob = cv2.dnn.blobFromImage(im_in, scalefactor=1 / 255.0, size=(self.input_size, self.input_size))
self.model.setInput(blob)
blks, mask, lines_map = self.model.forward(self.uoln)
return blks, mask, lines_map