| 
							 | 
						""" timm model adapter | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						import logging | 
					
					
						
						| 
							 | 
						from collections import OrderedDict | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn as nn | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						try: | 
					
					
						
						| 
							 | 
						    import timm | 
					
					
						
						| 
							 | 
						    from timm.models.layers import Mlp, to_2tuple | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        from timm.models.layers.attention_pool2d import RotAttentionPool2d | 
					
					
						
						| 
							 | 
						        from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d | 
					
					
						
						| 
							 | 
						    except ImportError: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        from timm.layers import RotAttentionPool2d | 
					
					
						
						| 
							 | 
						        from timm.layers import AttentionPool2d as AbsAttentionPool2d | 
					
					
						
						| 
							 | 
						except ImportError: | 
					
					
						
						| 
							 | 
						    timm = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from misc import freeze_batch_norm_2d | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class TimmModel(nn.Module): | 
					
					
						
						| 
							 | 
						    """ timm model adapter | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						            self, | 
					
					
						
						| 
							 | 
						            model_name, | 
					
					
						
						| 
							 | 
						            embed_dim, | 
					
					
						
						| 
							 | 
						            image_size=224, | 
					
					
						
						| 
							 | 
						            pool='avg', | 
					
					
						
						| 
							 | 
						            proj='linear', | 
					
					
						
						| 
							 | 
						            proj_bias=False, | 
					
					
						
						| 
							 | 
						            drop=0., | 
					
					
						
						| 
							 | 
						            drop_path=None, | 
					
					
						
						| 
							 | 
						            patch_drop=None, | 
					
					
						
						| 
							 | 
						            pretrained=False, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        if timm is None: | 
					
					
						
						| 
							 | 
						            raise RuntimeError("Please `pip install timm` to use timm models.") | 
					
					
						
						| 
							 | 
						        self.image_size = to_2tuple(image_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        timm_kwargs = {} | 
					
					
						
						| 
							 | 
						        if drop_path is not None: | 
					
					
						
						| 
							 | 
						            timm_kwargs['drop_path_rate'] = drop_path | 
					
					
						
						| 
							 | 
						        if patch_drop is not None: | 
					
					
						
						| 
							 | 
						            timm_kwargs['patch_drop_rate'] = patch_drop | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        custom_pool = pool in ('abs_attn', 'rot_attn') | 
					
					
						
						| 
							 | 
						        if not proj and not custom_pool: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            self.trunk = timm.create_model( | 
					
					
						
						| 
							 | 
						                model_name, | 
					
					
						
						| 
							 | 
						                num_classes=embed_dim, | 
					
					
						
						| 
							 | 
						                global_pool=pool, | 
					
					
						
						| 
							 | 
						                pretrained=pretrained, | 
					
					
						
						| 
							 | 
						                **timm_kwargs, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            prev_chs = embed_dim | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            self.trunk = timm.create_model( | 
					
					
						
						| 
							 | 
						                model_name, | 
					
					
						
						| 
							 | 
						                pretrained=pretrained, | 
					
					
						
						| 
							 | 
						                **timm_kwargs, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            feat_size = self.trunk.default_cfg.get('pool_size', None) | 
					
					
						
						| 
							 | 
						            feature_ndim = 1 if not feat_size else 2 | 
					
					
						
						| 
							 | 
						            if custom_pool: | 
					
					
						
						| 
							 | 
						                assert feature_ndim == 2 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                self.trunk.reset_classifier(0, global_pool='') | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                reset_kwargs = dict(global_pool=pool) if pool else {} | 
					
					
						
						| 
							 | 
						                self.trunk.reset_classifier(0, **reset_kwargs) | 
					
					
						
						| 
							 | 
						            prev_chs = self.trunk.num_features | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        head_layers = OrderedDict() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if pool == 'abs_attn': | 
					
					
						
						| 
							 | 
						            head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) | 
					
					
						
						| 
							 | 
						            prev_chs = embed_dim | 
					
					
						
						| 
							 | 
						        elif pool == 'rot_attn': | 
					
					
						
						| 
							 | 
						            head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) | 
					
					
						
						| 
							 | 
						            prev_chs = embed_dim | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if proj == 'linear': | 
					
					
						
						| 
							 | 
						            head_layers['drop'] = nn.Dropout(drop) | 
					
					
						
						| 
							 | 
						            head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) | 
					
					
						
						| 
							 | 
						        elif proj == 'mlp': | 
					
					
						
						| 
							 | 
						            head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            assert not proj, f'Unknown projection type {proj}.' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.head = nn.Sequential(head_layers) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def lock(self, unlocked_groups=0, freeze_bn_stats=False): | 
					
					
						
						| 
							 | 
						        """ lock modules | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            unlocked_groups (int): leave last n layer groups unlocked (default: 0) | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        if not unlocked_groups: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            for param in self.trunk.parameters(): | 
					
					
						
						| 
							 | 
						                param.requires_grad = False | 
					
					
						
						| 
							 | 
						            if freeze_bn_stats: | 
					
					
						
						| 
							 | 
						                freeze_batch_norm_2d(self.trunk) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            try: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                from timm.models.helpers import group_parameters, group_modules | 
					
					
						
						| 
							 | 
						            except ImportError: | 
					
					
						
						| 
							 | 
						                raise RuntimeError( | 
					
					
						
						| 
							 | 
						                    'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') | 
					
					
						
						| 
							 | 
						            matcher = self.trunk.group_matcher() | 
					
					
						
						| 
							 | 
						            gparams = group_parameters(self.trunk, matcher) | 
					
					
						
						| 
							 | 
						            max_layer_id = max(gparams.keys()) | 
					
					
						
						| 
							 | 
						            max_layer_id = max_layer_id - unlocked_groups | 
					
					
						
						| 
							 | 
						            for group_idx in range(max_layer_id + 1): | 
					
					
						
						| 
							 | 
						                group = gparams[group_idx] | 
					
					
						
						| 
							 | 
						                for param in group: | 
					
					
						
						| 
							 | 
						                    self.trunk.get_parameter(param).requires_grad = False | 
					
					
						
						| 
							 | 
						            if freeze_bn_stats: | 
					
					
						
						| 
							 | 
						                gmodules = group_modules(self.trunk, matcher, reverse=True) | 
					
					
						
						| 
							 | 
						                gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} | 
					
					
						
						| 
							 | 
						                freeze_batch_norm_2d(self.trunk, gmodules) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @torch.jit.ignore | 
					
					
						
						| 
							 | 
						    def set_grad_checkpointing(self, enable=True): | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						            self.trunk.set_grad_checkpointing(enable) | 
					
					
						
						| 
							 | 
						        except Exception as e: | 
					
					
						
						| 
							 | 
						            logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x): | 
					
					
						
						| 
							 | 
						        x = self.trunk(x) | 
					
					
						
						| 
							 | 
						        x = self.head(x) | 
					
					
						
						| 
							 | 
						        return x |