import logging import torch.nn as nn from fastai.vision import * from modules.attention import * from modules.backbone import ResTranformer from modules.model import Model from modules.resnet import resnet45 class BaseVision(Model): def __init__(self, config): super().__init__(config) self.loss_weight = ifnone(config.model_vision_loss_weight, 1.0) self.out_channels = ifnone(config.model_vision_d_model, 512) if config.model_vision_backbone == 'transformer': self.backbone = ResTranformer(config) else: self.backbone = resnet45() if config.model_vision_attention == 'position': mode = ifnone(config.model_vision_attention_mode, 'nearest') self.attention = PositionAttention( in_channels=self.out_channels, max_length=config.dataset_max_length + 1, # additional stop token mode=mode, ) elif config.model_vision_attention == 'attention': self.attention = Attention( in_channels=self.out_channels, max_length=config.dataset_max_length + 1, # additional stop token n_feature=8*32, ) else: raise Exception(f'{config.model_vision_attention} is not valid.') self.cls = nn.Linear(self.out_channels, self.charset.num_classes) if config.model_vision_checkpoint is not None: logging.info(f'Read vision model from {config.model_vision_checkpoint}.') self.load(config.model_vision_checkpoint) def _forward(self, b_features): attn_vecs, attn_scores = self.attention(b_features) # (N, T, E), (N, T, H, W) logits = self.cls(attn_vecs) # (N, T, C) pt_lengths = self._get_length(logits) return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths, 'attn_scores': attn_scores, 'loss_weight':self.loss_weight, 'name': 'vision', 'b_features':b_features} def forward(self, images, *args, **kwargs): features = self.backbone(images, **kwargs) # (N, E, H, W) return self._forward(features) class BaseIterVision(BaseVision): def __init__(self, config): super().__init__(config) assert config.model_vision_backbone == 'transformer' self.iter_size = ifnone(config.model_vision_iter_size, 1) self.share_weights = ifnone(config.model_vision_share_weights, False) self.share_cnns = ifnone(config.model_vision_share_cnns, False) self.add_transformer = ifnone(config.model_vision_add_transformer, False) self.simple_trans = ifnone(config.model_vision_simple_trans, False) self.deep_supervision = ifnone(config.model_vision_deep_supervision, True) self.backbones = nn.ModuleList() self.trans = nn.ModuleList() for i in range(self.iter_size-1): B = None if self.share_weights else ResTranformer(config) if self.share_cnns: del B.resnet self.backbones.append(B) output_channel = self.out_channels if self.add_transformer: self.split_sizes = [output_channel] elif self.simple_trans: # self.split_sizes=[output_channel//16] + [0] * 5 # self.split_sizes= [output_channel//16, output_channel//16, output_channel//8, output_channel//4, output_channel//2] + [0] self.split_sizes= [output_channel//16, output_channel//16, 0, output_channel//4, output_channel//2, output_channel] else: self.split_sizes=[output_channel//16, output_channel//16, output_channel//8, output_channel//4, output_channel//2, output_channel] self.trans.append(nn.Conv2d(output_channel, sum(self.split_sizes), 1)) torch.nn.init.zeros_(self.trans[-1].weight) if config.model_vision_checkpoint is not None: logging.info(f'Read vision model from {config.model_vision_checkpoint}.') self.load(config.model_vision_checkpoint) cb_init = ifnone(config.model_vision_cb_init, True) if cb_init: self.cb_init() def load(self, source, device=None, strict=False): state = torch.load(source, map_location=device) msg = self.load_state_dict(state['model'], strict=strict) print(msg) def cb_init(self): model_state_dict = self.backbone.state_dict() for m in self.backbones: if m: print('cb_init') msg = m.load_state_dict(model_state_dict, strict=False) print(msg) def forward_test(self, images, *args): l_feats = self.backbone.resnet(images) b_feats = self.backbone.forward_transformer(l_feats) cnt = len(self.backbones) if cnt == 0: v_res = super()._forward(b_feats) for B,T in zip(self.backbones, self.trans): cnt -= 1 extra_feats = T(b_feats).split(self.split_sizes, dim=1) if self.share_weights: v_res = super().forward(images, extra_feats=extra_feats) else: if self.add_transformer: if not self.share_cnns: l_feats = B.resnet(images) b_feats = B.forward_transformer(extra_feats[-1] + l_feats) else: b_feats = B(images, extra_feats=extra_feats) v_res = super()._forward(b_feats) if cnt==0 else None return v_res def forward_train(self, images, *args): l_feats = self.backbone.resnet(images) b_feats = self.backbone.forward_transformer(l_feats) v_res = super()._forward(b_feats) # v_res = super().forward(images) all_v_res = [v_res] for B,T in zip(self.backbones, self.trans): extra_feats = T(v_res['b_features']).split(self.split_sizes, dim=1) if self.share_weights: v_res = super().forward(images, extra_feats=extra_feats) else: if self.add_transformer: if not self.share_cnns: l_feats = B.resnet(images) b_feats = B.forward_transformer(extra_feats[-1] + l_feats) else: b_feats = B(images, extra_feats=extra_feats) v_res = super()._forward(b_feats) all_v_res.append(v_res) return all_v_res def forward(self, images, *args): if self.training and self.deep_supervision: return self.forward_train(images, *args) else: return self.forward_test(images, *args)