Spaces:
Runtime error
Runtime error
from vietocr.model.backbone.cnn import CNN | |
from vietocr.model.seqmodel.transformer import LanguageTransformer | |
from vietocr.model.seqmodel.seq2seq import Seq2Seq | |
from vietocr.model.seqmodel.convseq2seq import ConvSeq2Seq | |
from torch import nn | |
class VietOCR(nn.Module): | |
def __init__(self, vocab_size, | |
backbone, | |
cnn_args, | |
transformer_args, seq_modeling='transformer'): | |
super(VietOCR, self).__init__() | |
self.cnn = CNN(backbone, **cnn_args) | |
self.seq_modeling = seq_modeling | |
if seq_modeling == 'transformer': | |
self.transformer = LanguageTransformer(vocab_size, **transformer_args) | |
elif seq_modeling == 'seq2seq': | |
self.transformer = Seq2Seq(vocab_size, **transformer_args) | |
elif seq_modeling == 'convseq2seq': | |
self.transformer = ConvSeq2Seq(vocab_size, **transformer_args) | |
else: | |
raise('Not Support Seq Model') | |
def forward(self, img, tgt_input, tgt_key_padding_mask): | |
""" | |
Shape: | |
- img: (N, C, H, W) | |
- tgt_input: (T, N) | |
- tgt_key_padding_mask: (N, T) | |
- output: b t v | |
""" | |
src = self.cnn(img) | |
if self.seq_modeling == 'transformer': | |
outputs = self.transformer(src, tgt_input, tgt_key_padding_mask=tgt_key_padding_mask) | |
elif self.seq_modeling == 'seq2seq': | |
outputs = self.transformer(src, tgt_input) | |
elif self.seq_modeling == 'convseq2seq': | |
outputs = self.transformer(src, tgt_input) | |
return outputs | |