captcha_pixelplanet / modules /model_iternet.py
nigger game
gfgf
166850f
raw history blame
No virus
1.52 kB
import torch
import torch.nn as nn
from fastai.vision import *
from .model_vision import BaseIterVision
from .model_language import BCNLanguage
from .model_alignment import BaseAlignment
class IterNet(nn.Module):
def __init__(self, config):
super().__init__()
self.iter_size = ifnone(config.model_iter_size, 1)
self.max_length = config.dataset_max_length + 1 # additional stop token
self.vision = BaseIterVision(config)
self.language = BCNLanguage(config)
self.alignment = BaseAlignment(config)
self.deep_supervision = ifnone(config.model_deep_supervision, True)
def forward(self, images, *args):
list_v_res = self.vision(images)
if not isinstance(list_v_res, (list, tuple)):
list_v_res = [list_v_res]
all_l_res, all_a_res = [], []
for v_res in list_v_res:
a_res = v_res
for _ in range(self.iter_size):
tokens = torch.softmax(a_res['logits'], dim=-1)
lengths = a_res['pt_lengths']
lengths.clamp_(2, self.max_length) # TODO:move to langauge model
l_res = self.language(tokens, lengths)
all_l_res.append(l_res)
a_res = self.alignment(l_res['feature'], v_res['feature'])
all_a_res.append(a_res)
if self.training and self.deep_supervision:
return all_a_res, all_l_res, list_v_res
else:
return a_res, all_l_res[-1], list_v_res[-1]