EasyOCR-VietOCR / vietocr /model /transformerocr.py
hantech's picture
Duplicate from hantech/VietOCR
33c0fae
raw
history blame
No virus
1.66 kB
from vietocr.model.backbone.cnn import CNN
from vietocr.model.seqmodel.transformer import LanguageTransformer
from vietocr.model.seqmodel.seq2seq import Seq2Seq
from vietocr.model.seqmodel.convseq2seq import ConvSeq2Seq
from torch import nn
class VietOCR(nn.Module):
def __init__(self, vocab_size,
backbone,
cnn_args,
transformer_args, seq_modeling='transformer'):
super(VietOCR, self).__init__()
self.cnn = CNN(backbone, **cnn_args)
self.seq_modeling = seq_modeling
if seq_modeling == 'transformer':
self.transformer = LanguageTransformer(vocab_size, **transformer_args)
elif seq_modeling == 'seq2seq':
self.transformer = Seq2Seq(vocab_size, **transformer_args)
elif seq_modeling == 'convseq2seq':
self.transformer = ConvSeq2Seq(vocab_size, **transformer_args)
else:
raise('Not Support Seq Model')
def forward(self, img, tgt_input, tgt_key_padding_mask):
"""
Shape:
- img: (N, C, H, W)
- tgt_input: (T, N)
- tgt_key_padding_mask: (N, T)
- output: b t v
"""
src = self.cnn(img)
if self.seq_modeling == 'transformer':
outputs = self.transformer(src, tgt_input, tgt_key_padding_mask=tgt_key_padding_mask)
elif self.seq_modeling == 'seq2seq':
outputs = self.transformer(src, tgt_input)
elif self.seq_modeling == 'convseq2seq':
outputs = self.transformer(src, tgt_input)
return outputs