VietOCR / vietocr /model /transformerocr.py
nhay103's picture
add app
113bcef
from vietocr.model.backbone.cnn import CNN
from vietocr.model.seqmodel.transformer import LanguageTransformer
from vietocr.model.seqmodel.seq2seq import Seq2Seq
from vietocr.model.seqmodel.convseq2seq import ConvSeq2Seq
from torch import nn
class VietOCR(nn.Module):
def __init__(self, vocab_size,
backbone,
cnn_args,
transformer_args, seq_modeling='transformer'):
super(VietOCR, self).__init__()
self.cnn = CNN(backbone, **cnn_args)
self.seq_modeling = seq_modeling
if seq_modeling == 'transformer':
self.transformer = LanguageTransformer(vocab_size, **transformer_args)
elif seq_modeling == 'seq2seq':
self.transformer = Seq2Seq(vocab_size, **transformer_args)
elif seq_modeling == 'convseq2seq':
self.transformer = ConvSeq2Seq(vocab_size, **transformer_args)
else:
raise('Not Support Seq Model')
def forward(self, img, tgt_input, tgt_key_padding_mask):
"""
Shape:
- img: (N, C, H, W)
- tgt_input: (T, N)
- tgt_key_padding_mask: (N, T)
- output: b t v
"""
src = self.cnn(img)
if self.seq_modeling == 'transformer':
outputs = self.transformer(src, tgt_input, tgt_key_padding_mask=tgt_key_padding_mask)
elif self.seq_modeling == 'seq2seq':
outputs = self.transformer(src, tgt_input)
elif self.seq_modeling == 'convseq2seq':
outputs = self.transformer(src, tgt_input)
return outputs