|
import logging |
|
import torch.nn as nn |
|
from fastai.vision import * |
|
|
|
from modules.attention import * |
|
from modules.backbone import ResTranformer |
|
from modules.model import Model |
|
from modules.resnet import resnet45 |
|
|
|
|
|
class BaseVision(Model): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.loss_weight = ifnone(config.model_vision_loss_weight, 1.0) |
|
self.out_channels = ifnone(config.model_vision_d_model, 512) |
|
|
|
if config.model_vision_backbone == 'transformer': |
|
self.backbone = ResTranformer(config) |
|
else: self.backbone = resnet45() |
|
|
|
if config.model_vision_attention == 'position': |
|
mode = ifnone(config.model_vision_attention_mode, 'nearest') |
|
self.attention = PositionAttention( |
|
in_channels=self.out_channels, |
|
max_length=config.dataset_max_length + 1, |
|
mode=mode, |
|
) |
|
elif config.model_vision_attention == 'attention': |
|
self.attention = Attention( |
|
in_channels=self.out_channels, |
|
max_length=config.dataset_max_length + 1, |
|
n_feature=8*32, |
|
) |
|
else: |
|
raise Exception(f'{config.model_vision_attention} is not valid.') |
|
self.cls = nn.Linear(self.out_channels, self.charset.num_classes) |
|
|
|
if config.model_vision_checkpoint is not None: |
|
logging.info(f'Read vision model from {config.model_vision_checkpoint}.') |
|
self.load(config.model_vision_checkpoint) |
|
|
|
def _forward(self, b_features): |
|
attn_vecs, attn_scores = self.attention(b_features) |
|
logits = self.cls(attn_vecs) |
|
pt_lengths = self._get_length(logits) |
|
|
|
return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths, |
|
'attn_scores': attn_scores, 'loss_weight':self.loss_weight, 'name': 'vision', 'b_features':b_features} |
|
|
|
def forward(self, images, *args, **kwargs): |
|
features = self.backbone(images, **kwargs) |
|
return self._forward(features) |
|
|
|
|
|
class BaseIterVision(BaseVision): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
assert config.model_vision_backbone == 'transformer' |
|
self.iter_size = ifnone(config.model_vision_iter_size, 1) |
|
self.share_weights = ifnone(config.model_vision_share_weights, False) |
|
self.share_cnns = ifnone(config.model_vision_share_cnns, False) |
|
self.add_transformer = ifnone(config.model_vision_add_transformer, False) |
|
self.simple_trans = ifnone(config.model_vision_simple_trans, False) |
|
self.deep_supervision = ifnone(config.model_vision_deep_supervision, True) |
|
self.backbones = nn.ModuleList() |
|
self.trans = nn.ModuleList() |
|
for i in range(self.iter_size-1): |
|
B = None if self.share_weights else ResTranformer(config) |
|
if self.share_cnns: |
|
del B.resnet |
|
self.backbones.append(B) |
|
output_channel = self.out_channels |
|
if self.add_transformer: |
|
self.split_sizes = [output_channel] |
|
elif self.simple_trans: |
|
|
|
|
|
self.split_sizes= [output_channel//16, output_channel//16, 0, output_channel//4, output_channel//2, output_channel] |
|
else: |
|
self.split_sizes=[output_channel//16, output_channel//16, output_channel//8, output_channel//4, output_channel//2, output_channel] |
|
self.trans.append(nn.Conv2d(output_channel, sum(self.split_sizes), 1)) |
|
torch.nn.init.zeros_(self.trans[-1].weight) |
|
|
|
if config.model_vision_checkpoint is not None: |
|
logging.info(f'Read vision model from {config.model_vision_checkpoint}.') |
|
self.load(config.model_vision_checkpoint) |
|
cb_init = ifnone(config.model_vision_cb_init, True) |
|
if cb_init: |
|
self.cb_init() |
|
|
|
def load(self, source, device=None, strict=False): |
|
state = torch.load(source, map_location=device) |
|
msg = self.load_state_dict(state['model'], strict=strict) |
|
print(msg) |
|
|
|
def cb_init(self): |
|
model_state_dict = self.backbone.state_dict() |
|
|
|
for m in self.backbones: |
|
if m: |
|
print('cb_init') |
|
msg = m.load_state_dict(model_state_dict, strict=False) |
|
print(msg) |
|
|
|
def forward_test(self, images, *args): |
|
l_feats = self.backbone.resnet(images) |
|
b_feats = self.backbone.forward_transformer(l_feats) |
|
cnt = len(self.backbones) |
|
if cnt == 0: |
|
v_res = super()._forward(b_feats) |
|
for B,T in zip(self.backbones, self.trans): |
|
cnt -= 1 |
|
extra_feats = T(b_feats).split(self.split_sizes, dim=1) |
|
if self.share_weights: |
|
v_res = super().forward(images, extra_feats=extra_feats) |
|
else: |
|
if self.add_transformer: |
|
if not self.share_cnns: |
|
l_feats = B.resnet(images) |
|
b_feats = B.forward_transformer(extra_feats[-1] + l_feats) |
|
else: |
|
b_feats = B(images, extra_feats=extra_feats) |
|
v_res = super()._forward(b_feats) if cnt==0 else None |
|
return v_res |
|
|
|
def forward_train(self, images, *args): |
|
l_feats = self.backbone.resnet(images) |
|
b_feats = self.backbone.forward_transformer(l_feats) |
|
v_res = super()._forward(b_feats) |
|
|
|
all_v_res = [v_res] |
|
for B,T in zip(self.backbones, self.trans): |
|
extra_feats = T(v_res['b_features']).split(self.split_sizes, dim=1) |
|
if self.share_weights: |
|
v_res = super().forward(images, extra_feats=extra_feats) |
|
else: |
|
if self.add_transformer: |
|
if not self.share_cnns: |
|
l_feats = B.resnet(images) |
|
b_feats = B.forward_transformer(extra_feats[-1] + l_feats) |
|
else: |
|
b_feats = B(images, extra_feats=extra_feats) |
|
v_res = super()._forward(b_feats) |
|
all_v_res.append(v_res) |
|
return all_v_res |
|
|
|
def forward(self, images, *args): |
|
if self.training and self.deep_supervision: |
|
return self.forward_train(images, *args) |
|
else: |
|
return self.forward_test(images, *args) |