fancyfeast commited on
Commit
6982e15
1 Parent(s): 69eecf7

Initial commit

Browse files
Files changed (3) hide show
  1. Models.py +1159 -0
  2. app.py +41 -0
  3. requirements.txt +5 -0
Models.py ADDED
@@ -0,0 +1,1159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Optional
4
+ import torch
5
+ import torch.backends.cuda
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+ from transformers.activations import QuickGELUActivation
11
+ import math
12
+ from einops.layers.torch import Rearrange
13
+ import einops
14
+
15
+
16
+ MODEL_CONFIGS = {
17
+ # Custom models trained from scratch
18
+ # "Standard" definitions:
19
+ # name | layers | width | heads
20
+ # B | 12 | 768 | 12
21
+ # L | 24 | 1024 | 16
22
+ # H | 32 | 1280 | 16
23
+ # G | 48 | 1664 | 16
24
+ # e | 56 | 1792 | 16
25
+ # 22 | 48 | 6144 | 48
26
+
27
+ # B/16, 224, PaLM, GELU
28
+ 'CustomTest6': {
29
+ 'class': 'CLIPLikeModel',
30
+ 'embedding_dim': 768,
31
+ 'num_attention_heads': 12,
32
+ 'activation_cls': nn.GELU,
33
+ 'num_channels': 3,
34
+ 'patch_size': 16,
35
+ 'use_palm_alt': True,
36
+ 'num_layers': 12,
37
+ 'use_mha_alt': False,
38
+ 'good_dropout': False,
39
+ },
40
+
41
+ # GAP head + Sinusoidal positional embeddings + 448 image size
42
+ 'CustomTest18': {
43
+ 'class': 'CLIPLikeModel',
44
+ 'embedding_dim': 768,
45
+ 'num_attention_heads': 12,
46
+ 'activation_cls': nn.GELU,
47
+ 'num_channels': 3,
48
+ 'patch_size': 16,
49
+ 'use_palm_alt': True,
50
+ 'num_layers': 12,
51
+ 'use_mha_alt': False,
52
+ 'good_dropout': False,
53
+ 'use_gap_head': True,
54
+ 'sine_positional_embeddings': True,
55
+ },
56
+
57
+ # SW Model + B/16 + ASL + 448 image size
58
+ # cutout_max_pct = 0
59
+ # mixup_alpha = 0.8
60
+ # noise_level = 2
61
+ # random_resize_method = true
62
+ # total_labels = 6549
63
+ 'SWModel1': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': False},
64
+
65
+ # Sinusoidal positional embeddings
66
+ 'SWModel2': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
67
+
68
+ # Sinusoidal positional embeddings + 224 image size + L/14
69
+ 'SWModel3': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 14, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.05, 'layerscale_init': 1e-1, 'use_sine': True},
70
+
71
+ # Sinusoidal positional embeddings + 224 image size + G/14
72
+ 'SWModel4': {'class': 'ViT', 'num_blocks': 48, 'patch_size': 14, 'd_model': 1664, 'mlp_dim': 1664*4, 'num_heads': 16, 'stochdepth_rate': 0.05, 'layerscale_init': 1e-1, 'use_sine': True},
73
+
74
+ # Sinusoidal positional embeddings + focal loss
75
+ 'SWModel5': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
76
+
77
+ 'SWModel6': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
78
+
79
+ 'SWModel7': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
80
+ 'SWModel8': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
81
+ 'SWModel9': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
82
+ 'SWModel10': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
83
+ 'SWModel11': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0, 'use_sine': True},
84
+
85
+ # Trying head_mean_after
86
+ 'SWModel12': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'head_mean_after': True},
87
+
88
+ # Fat boy
89
+ 'SWModel13': {'class': 'ViT', 'num_blocks': 6, 'patch_size': 16, 'd_model': 1536, 'mlp_dim': 1536*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
90
+
91
+ # L/14
92
+ 'SWModel14': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 14, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.05, 'layerscale_init': 1e-1, 'use_sine': True},
93
+ 'SWModel15': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 14, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.05, 'layerscale_init': 1e-5, 'use_sine': True},
94
+ 'SWModel16': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 14, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.10, 'layerscale_init': 1e-1, 'use_sine': True},
95
+ 'SWModel16f': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 14, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.10, 'layerscale_init': 1e-1, 'use_sine': True},
96
+ 'SWModel22': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 14, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.20, 'layerscale_init': 1e-1, 'use_sine': True},
97
+ 'SWModel25': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 16, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.15, 'layerscale_init': 1e-1, 'use_sine': True, 'cnn_stem': 'conv:c=128;ln;relu;conv:c=256;ln;relu;conv:c=512;ln;relu;conv:c=1024;ln;relu;conv:c=1024,s=1,k=1,p=0'},
98
+
99
+ # CNN stem
100
+ 'SWModel18': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'cnn_stem': 'conv:c=64;bn;relu;conv:c=128;bn;relu;conv:c=256;bn;relu;conv:c=512;bn;relu;conv:c=768,s=1,k=1'},
101
+ 'SWModel19': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'cnn_stem': 'conv:c=64;bn;relu;conv:c=128;bn;relu;conv:c=128,s=1;bn;relu;conv:c=256;bn;relu;conv:c=256,s=1;bn;relu;conv:c=512;bn;relu;conv:c=768,s=1,k=1,p=0'},
102
+ 'SWModel20': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'cnn_stem': 'conv:c=64;ln;relu;conv:c=128;ln;relu;conv:c=256;ln;relu;conv:c=512;ln;relu;conv:c=768,s=1,k=1,p=0'},
103
+ 'SWModel21': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'cnn_stem': 'conv:c=64;ln;gelu;conv:c=128;ln;gelu;conv:c=256;ln;gelu;conv:c=512;ln;gelu;conv:c=768,s=1,k=1,p=0'},
104
+ 'SWModel23': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'cnn_stem': 'conv:c=64;ln;relu;conv:c=128;ln;relu;conv:c=256;ln;relu;conv:c=512;ln;relu;conv:c=768,s=1,k=1,p=0'},
105
+ 'SWModel24': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'cnn_stem': 'conv:c=64;ln;relu;conv:c=128;ln;relu;conv:c=256;ln;relu;conv:c=512;ln;relu;conv:c=768,s=1,k=1,p=0'},
106
+
107
+ # H/14
108
+ 'SWModel17': {'class': 'ViT', 'num_blocks': 32, 'patch_size': 14, 'd_model': 1280, 'mlp_dim': 1280*4, 'num_heads': 16, 'stochdepth_rate': 0.05, 'layerscale_init': 1e-1, 'use_sine': True},
109
+ 'SWModel26': {'class': 'ViT', 'num_blocks': 32, 'patch_size': 14, 'd_model': 1280, 'mlp_dim': 1280*4, 'num_heads': 16, 'stochdepth_rate': 0.15, 'layerscale_init': 1e-1, 'use_sine': True},
110
+ }
111
+
112
+
113
+ class VisionModel(nn.Module):
114
+ image_size: int
115
+ n_tags: int
116
+
117
+ def __init__(self, image_size: int, n_tags: int):
118
+ super().__init__()
119
+
120
+ self.image_size = image_size
121
+ self.n_tags = n_tags
122
+
123
+ @staticmethod
124
+ def load_model(path: Path | str, device: str | None = None) -> 'VisionModel':
125
+ """
126
+ Load a model from a directory.
127
+ :param path: The directory containing the model.
128
+ :return: The model, the image size, and the number of tags.
129
+ """
130
+ with open(Path(path) / 'config.json', 'r') as f:
131
+ config = json.load(f)
132
+
133
+ if (Path(path) / 'model.safetensors').exists():
134
+ from safetensors.torch import load_file
135
+ resume = load_file(Path(path) / 'model.safetensors', device='cpu')
136
+ else:
137
+ resume = torch.load(Path(path) / 'model.pt', map_location=torch.device('cpu'))
138
+
139
+ model_classes = VisionModel.__subclasses__()
140
+ model_cls = next(cls for cls in model_classes if cls.__name__ == config['class'])
141
+
142
+ model = model_cls(**{k: v for k, v in config.items() if k != 'class'})
143
+ model.load(resume['model'])
144
+ if device is not None:
145
+ model = model.to(device)
146
+
147
+ return model
148
+
149
+ @staticmethod
150
+ def from_config(config: dict) -> 'VisionModel':
151
+ model_classes = VisionModel.__subclasses__()
152
+ model_cls = next(cls for cls in model_classes if cls.__name__ == config['class'])
153
+ return model_cls(**{k: v for k, v in config.items() if k != 'class'})
154
+
155
+ def get_optimized_parameters(self, lr: float):
156
+ raise NotImplementedError
157
+
158
+ def save(self):
159
+ raise NotImplementedError
160
+
161
+ def load(self, state_dict):
162
+ raise NotImplementedError
163
+
164
+
165
+ def basic_calculate_loss(preds: dict[str, torch.Tensor], batch: dict, pos_weight: torch.Tensor | None, loss_type: str):
166
+ def asl_helper(preds, target):
167
+ p = F.softmax(preds, dim=1)
168
+ xs_pos = p.clamp(min=1e-6)
169
+ xs_neg = (1 - p).clamp(min=1e-6)
170
+
171
+ los_pos = torch.log(torch.gather(xs_pos, 1, target.unsqueeze(1))).sum()
172
+ los_neg = torch.log(xs_neg)
173
+ los_neg = los_neg.sum() - torch.gather(los_neg, 1, target.unsqueeze(1)).sum()
174
+ loss = los_pos + los_neg
175
+
176
+ return -loss
177
+
178
+ if loss_type == "ce":
179
+ loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'])
180
+ elif loss_type == "weighted":
181
+ loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'], pos_weight=pos_weight)
182
+ elif loss_type == "focal":
183
+ gamma = 2
184
+ p = torch.sigmoid(preds['tags'])
185
+ ce_loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'], reduction='none')
186
+ p_t = p * batch['tags'] + (1 - p) * (1 - batch['tags'])
187
+ loss = ce_loss * ((1 - p_t) ** gamma)
188
+ loss = loss.mean()
189
+ elif loss_type == "focal2":
190
+ gamma = 2
191
+ p = torch.sigmoid(preds['tags'])
192
+ ce_loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'], reduction='none')
193
+ p_t = p * batch['tags'] + (1 - p) * (1 - batch['tags'])
194
+ loss = ce_loss * ((1 - p_t) ** gamma) * 256
195
+ loss = loss.mean()
196
+ elif loss_type == "asl":
197
+ p = torch.sigmoid(preds['tags'])
198
+ xs_pos = p
199
+ xs_neg = 1 - p
200
+
201
+ los_pos = batch['tags'] * torch.log(xs_pos.clamp(min=1e-6))
202
+ los_neg = (1 - batch['tags']) * torch.log(xs_neg.clamp(min=1e-6))
203
+ loss = los_pos + los_neg
204
+ loss = -loss.sum()
205
+
206
+ # Rating
207
+ loss = loss + asl_helper(preds['rating'], batch['rating'])
208
+
209
+ # Score
210
+ loss = loss + asl_helper(preds['score'], batch['score'])
211
+ elif loss_type == "asl2":
212
+ p = torch.sigmoid(preds['tags'])
213
+ xs_pos = p
214
+ xs_neg = 1 - p
215
+
216
+ los_pos = batch['tags'] * torch.log(xs_pos.clamp(min=1e-6))
217
+ los_neg = (1 - batch['tags']) * torch.log(xs_neg.clamp(min=1e-6))
218
+ loss = -los_pos - los_neg
219
+ loss = loss.sum()
220
+ elif loss_type == "asl3":
221
+ p = torch.sigmoid(preds['tags'])
222
+ xs_pos = p
223
+ xs_neg = 1 - p
224
+
225
+ los_pos = batch['tags'] * torch.log(xs_pos.clamp(min=1e-6))
226
+ los_neg = (1 - batch['tags']) * torch.log(xs_neg.clamp(min=1e-6))
227
+ loss = -los_pos - los_neg
228
+ loss = loss.mean()
229
+ elif loss_type == "asl4":
230
+ p = torch.sigmoid(preds['tags'])
231
+ xs_pos = p
232
+ xs_neg = 1 - p
233
+
234
+ los_pos = batch['tags'] * torch.log(xs_pos.clamp(min=1e-6))
235
+ los_neg = (1 - batch['tags']) * torch.log(xs_neg.clamp(min=1e-6))
236
+ loss = -los_pos - los_neg
237
+ loss = loss.mean() * 128
238
+ elif loss_type == "asl5":
239
+ loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'], pos_weight=pos_weight) * 128
240
+ elif loss_type == "asl6":
241
+ loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'], pos_weight=pos_weight) * 256
242
+ elif loss_type == "asl7":
243
+ loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'], pos_weight=pos_weight) * 2
244
+ else:
245
+ raise ValueError(f"Invalid loss type: {loss_type}")
246
+
247
+ return loss
248
+
249
+
250
+ class CLIPMlp(nn.Module):
251
+ def __init__(self, hidden_size: int, intermediate_size: int, activation_cls):
252
+ super().__init__()
253
+ self.activation_fn = activation_cls()
254
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
255
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
256
+
257
+ def forward(self, hidden_states: torch.Tensor):
258
+ hidden_states = self.fc1(hidden_states)
259
+ hidden_states = self.activation_fn(hidden_states)
260
+ hidden_states = self.fc2(hidden_states)
261
+ return hidden_states
262
+
263
+
264
+ class FastCLIPAttention2(nn.Module):
265
+ """Fast Attention module for CLIP-like. This is NOT a drop-in replacement for CLIPAttention, since it adds additional flexibility. Mainly uses xformers."""
266
+ def __init__(self, hidden_size: int, out_dim: int, num_attention_heads: int, out_seq_len: Optional[int] = None, norm_qk: bool = False):
267
+ super().__init__()
268
+ self.out_seq_len = out_seq_len
269
+ self.embed_dim = hidden_size
270
+ self.out_dim = out_dim
271
+ self.norm_qk = norm_qk
272
+ self.num_heads = num_attention_heads
273
+ self.head_dim = hidden_size // num_attention_heads
274
+ assert self.head_dim * num_attention_heads == self.embed_dim, "embed_dim must be divisible by num_attention_heads"
275
+
276
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
277
+ self.kv_proj = nn.Linear(self.embed_dim, self.embed_dim * 2)
278
+ self.out_proj = nn.Linear(self.embed_dim, self.out_dim)
279
+
280
+ if self.norm_qk:
281
+ self.query_norm = nn.LayerNorm(self.embed_dim)
282
+ self.key_norm = nn.LayerNorm(self.embed_dim)
283
+
284
+ #def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
285
+ # return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous()
286
+
287
+ def forward(self, query_states: torch.Tensor, kv_states: torch.Tensor) -> torch.Tensor:
288
+ bsz, src_len, embed_dim = kv_states.size()
289
+ if self.out_seq_len is not None:
290
+ tgt_len = self.out_seq_len
291
+ else:
292
+ tgt_len = src_len
293
+
294
+ kv_states = self.kv_proj(kv_states) # (bsz, src_len, embed_dim * 2)
295
+ q_states = self.q_proj(query_states[:, :tgt_len]) # (bsz, tgt_len, embed_dim)
296
+
297
+ # NOTE: It is not clear if LayerNorm should be applied to the embed_dim, or to the head_dim
298
+ if self.norm_qk:
299
+ q_states = self.query_norm(q_states).type(q_states.dtype)
300
+ k_states = self.key_norm(kv_states[:, :, :embed_dim]).type(kv_states.dtype)
301
+ v_states = kv_states[:, :, embed_dim:]
302
+ else:
303
+ k_states = kv_states[:, :, :embed_dim]
304
+ v_states = kv_states[:, :, embed_dim:]
305
+
306
+ q_states = q_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) # (bsz, num_heads, tgt_len, head_dim)
307
+ k_states = k_states.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) # (bsz, num_heads, src_len, head_dim)
308
+ v_states = v_states.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) # (bsz, num_heads, src_len, head_dim)
309
+
310
+ # Performs scale of query_states, attention, and softmax
311
+ with torch.backends.cuda.sdp_kernel(enable_math=False):
312
+ x = F.scaled_dot_product_attention(q_states, k_states, v_states) # (bsz, num_heads, tgt_len, head_dim)
313
+ x = x.transpose(1, 2).contiguous().view(bsz, tgt_len, embed_dim) # (bsz, tgt_len, embed_dim)
314
+
315
+ # Projection
316
+ x = self.out_proj(x) # (bsz, tgt_len, out_dim)
317
+
318
+ return x
319
+
320
+
321
+ class SkipInit(nn.Module):
322
+ def __init__(self, hidden_size: int, channel_wise: bool, init_scale: float):
323
+ super().__init__()
324
+ self.hidden_size = hidden_size
325
+ self.channel_wise = channel_wise
326
+ self.init_scale = init_scale
327
+
328
+ if self.channel_wise:
329
+ self.scale = nn.Parameter(torch.ones(hidden_size) * init_scale)
330
+ else:
331
+ self.scale = nn.Parameter(torch.tensor(init_scale))
332
+
333
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
334
+ return x * self.scale
335
+
336
+
337
+ class FastCLIPEncoderLayer(nn.Module):
338
+ def __init__(
339
+ self,
340
+ hidden_size: int,
341
+ num_attention_heads: int,
342
+ out_seq_len: Optional[int],
343
+ activation_cls = QuickGELUActivation,
344
+ use_palm_alt: bool = False,
345
+ norm_qk: bool = False,
346
+ skip_init: Optional[float] = None,
347
+ stochastic_depth: Optional[float] = None,
348
+ ):
349
+ super().__init__()
350
+
351
+ self.use_palm_alt = use_palm_alt
352
+ self.stochastic_depth = stochastic_depth
353
+
354
+ self.self_attn = FastCLIPAttention2(
355
+ hidden_size=hidden_size,
356
+ out_dim=hidden_size,
357
+ num_attention_heads=num_attention_heads,
358
+ out_seq_len=out_seq_len,
359
+ norm_qk=norm_qk,
360
+ )
361
+ self.mlp = CLIPMlp(hidden_size, 4 * hidden_size, activation_cls)
362
+ self.layer_norm1 = nn.LayerNorm(hidden_size)
363
+ if not use_palm_alt:
364
+ self.layer_norm2 = nn.LayerNorm(hidden_size)
365
+
366
+ if skip_init is not None:
367
+ self.attn_skip_init = SkipInit(hidden_size, channel_wise=True, init_scale=skip_init)
368
+ self.mlp_skip_init = SkipInit(hidden_size, channel_wise=True, init_scale=skip_init)
369
+ else:
370
+ self.attn_skip_init = nn.Identity()
371
+ self.mlp_skip_init = nn.Identity()
372
+
373
+ def forward(self, hidden_states: torch.Tensor):
374
+ residual = hidden_states
375
+ hidden_states = self.layer_norm1(hidden_states)
376
+
377
+ if not self.use_palm_alt:
378
+ hidden_states = self.self_attn(query_states=hidden_states, kv_states=hidden_states)
379
+ hidden_states = self.attn_skip_init(hidden_states)
380
+ hidden_states = hidden_states + residual[:, :hidden_states.size(1)]
381
+
382
+ residual = hidden_states
383
+ hidden_states = self.layer_norm2(hidden_states)
384
+ hidden_states = self.mlp(hidden_states)
385
+ hidden_states = self.mlp_skip_init(hidden_states)
386
+ hidden_states = hidden_states + residual
387
+ else:
388
+ # An alternative implementation inspired by the PALM paper
389
+ # By performing the attention and MLP in parallel it's possible to fuse the linear projections of the attention and MLP layers
390
+ # We don't do that here yet, but that supposedly improves efficiency without hurting performance
391
+ attn = self.self_attn(query_states=hidden_states, kv_states=hidden_states)
392
+ attn = self.attn_skip_init(attn)
393
+ mlp = self.mlp(hidden_states[:, :attn.size(1)])
394
+ mlp = self.mlp_skip_init(mlp)
395
+
396
+ if self.stochastic_depth is not None:
397
+ attn = torchvision.ops.stochastic_depth(attn, self.stochastic_depth, mode='row', training=self.training)
398
+ mlp = torchvision.ops.stochastic_depth(mlp, self.stochastic_depth, mode='row', training=self.training)
399
+
400
+ hidden_states = residual[:, :attn.size(1)] + attn + mlp
401
+
402
+ return hidden_states
403
+
404
+
405
+ def sinusoidal_position_embedding(width: int, height: int, depth: int, dtype, device, temperature = 10000):
406
+ """
407
+ Sinusoidal position embedding. Returns a flat tensor of shape (h * w, d).
408
+ """
409
+ assert depth % 4 == 0, "Embedding dimension must be divisible by 4."
410
+
411
+ y, x = torch.meshgrid(torch.arange(height, device=device), torch.arange(width, device=device), indexing="ij")
412
+ omega = torch.arange(depth // 4, device=device) / (depth // 4 - 1)
413
+ omega = 1. / (temperature ** omega)
414
+
415
+ y = y.flatten()[:, None] * omega[None, :]
416
+ x = x.flatten()[:, None] * omega[None, :]
417
+ embedding = torch.cat([x.sin(), x.cos(), y.sin(), y.cos()], dim=1)
418
+
419
+ return embedding.type(dtype)
420
+
421
+
422
+ class CLIPEmbeddingLayer(nn.Module):
423
+ def __init__(self, hidden_size: int, num_channels: int, image_size: int, patch_size: int, patch_dropout: float = 0.0, good_dropout: bool = False, dpn: bool = False, sine_positional_embeddings: bool = False):
424
+ super().__init__()
425
+
426
+ assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size."
427
+
428
+ seq_len = (image_size // patch_size) ** 2
429
+ self.patch_dropout = patch_dropout
430
+ self.hidden_size = hidden_size
431
+ self.good_dropout = good_dropout
432
+ self.dpn = dpn
433
+ self.sine_positional_embeddings = sine_positional_embeddings
434
+ self.patch_size = patch_size
435
+
436
+ self.patch_embeddings = nn.Conv2d(
437
+ in_channels=num_channels,
438
+ out_channels=hidden_size,
439
+ kernel_size=patch_size,
440
+ stride=patch_size,
441
+ bias=False,
442
+ )
443
+ if not self.sine_positional_embeddings:
444
+ self.positional_embeddings = nn.Embedding(seq_len, hidden_size)
445
+ self.register_buffer("position_ids", torch.arange(seq_len))
446
+
447
+ if self.dpn:
448
+ self.to_patch_embeddings = nn.Sequential(
449
+ Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
450
+ nn.LayerNorm(3 * patch_size * patch_size),
451
+ nn.Linear(3 * patch_size * patch_size, hidden_size),
452
+ nn.LayerNorm(hidden_size),
453
+ )
454
+ else:
455
+ self.to_patch_embeddings = nn.Conv2d(
456
+ in_channels=num_channels,
457
+ out_channels=hidden_size,
458
+ kernel_size=patch_size,
459
+ stride=patch_size,
460
+ bias=False,
461
+ )
462
+
463
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
464
+ B, C, H, W = pixel_values.shape
465
+ assert H % self.patch_size == 0, f"Input image height ({H}) needs to be divisible by the patch size ({self.patch_size})."
466
+ assert W % self.patch_size == 0, f"Input image width ({W}) needs to be divisible by the patch size ({self.patch_size})."
467
+
468
+ if self.dpn:
469
+ patches = self.to_patch_embeddings(pixel_values)
470
+ else:
471
+ patches = self.to_patch_embeddings(pixel_values)
472
+ patches = patches.flatten(2).transpose(1, 2)
473
+
474
+ seq_len = patches.shape[1]
475
+ patch_dropout = int(math.ceil((1.0 - self.patch_dropout) * seq_len))
476
+
477
+ if self.sine_positional_embeddings:
478
+ position_embeddings = sinusoidal_position_embedding(W // self.patch_size, H // self.patch_size, self.hidden_size, pixel_values.dtype, pixel_values.device)
479
+ else:
480
+ position_embeddings = self.positional_embeddings(self.position_ids)
481
+
482
+ if patch_dropout == seq_len or not self.training:
483
+ embeddings = patches + position_embeddings
484
+ elif self.good_dropout:
485
+ # Pick random patches to drop out
486
+ # The "good_dropout" variant uses random permutations for each batch item, but is slightly slower and involves more code
487
+
488
+ # The below method is a nice trick to generate a batch of random permutations.
489
+ # Torch (as of 1.13) doesn't have a built-in function to do this, and a for loop of torch.randperm is slow.
490
+ # Based on some benchmarks I measured the generation of the mask and the fetching to be only 50% slower than the non-"good_dropout" variant.
491
+ # And the time taken here is only a fraction of the time spent performing the embedding convolution.
492
+ # Generate a matrix of random numbers between 0 and 1 of shape (B, seq_len)
493
+ patch_mask = torch.rand(B, seq_len, device=patches.device)
494
+ # For each batch tensor, use argsort to convert the random numbers into a permutation of the patch indices
495
+ patch_mask = torch.argsort(patch_mask, dim=1)
496
+ # Truncate
497
+ patch_mask = patch_mask[:, :patch_dropout]
498
+
499
+ embeddings = patches.gather(1, patch_mask.unsqueeze(-1).expand(-1, -1, self.hidden_size)) + position_embeddings[patch_mask]
500
+ else:
501
+ # The non-"good_dropout" variant uses a single random permutation for all batch items, but is faster and uses less code
502
+ indices = torch.randperm(seq_len, device=pixel_values.device)[:patch_dropout]
503
+ embeddings = patches[:, indices, :] + position_embeddings[indices.expand(1, -1)]
504
+
505
+ return embeddings
506
+
507
+
508
+ class MHAPoolingHead(nn.Module):
509
+ def __init__(self, hidden_size: int, num_attention_heads: int, activation_cls, out_dim: int, alt_style: bool, norm_qk: bool):
510
+ super().__init__()
511
+
512
+ self.out_dim = out_dim if not alt_style else hidden_size
513
+
514
+ self.probe = nn.Parameter(torch.randn(hidden_size))
515
+
516
+ self.mlp = CLIPMlp(hidden_size, 4 * hidden_size, activation_cls)
517
+ self.layer_norm = nn.LayerNorm(hidden_size)
518
+ self.pooling_head = nn.Linear(hidden_size, 1)
519
+
520
+ self.self_attn = FastCLIPAttention2(
521
+ hidden_size=hidden_size,
522
+ out_dim=self.out_dim,
523
+ num_attention_heads=num_attention_heads,
524
+ out_seq_len=1,
525
+ norm_qk=norm_qk,
526
+ )
527
+ self.mlp = CLIPMlp(self.out_dim, 4 * self.out_dim, activation_cls)
528
+ self.layer_norm1 = nn.LayerNorm(hidden_size)
529
+ self.layer_norm2 = nn.LayerNorm(self.out_dim)
530
+
531
+ if alt_style:
532
+ self.final_proj = nn.Linear(hidden_size, out_dim)
533
+ else:
534
+ self.final_proj = nn.Identity()
535
+
536
+ def forward(self, hidden_states: torch.Tensor):
537
+ hidden_states = self.layer_norm1(hidden_states)
538
+ query_states = self.probe.unsqueeze(0).unsqueeze(0).expand(hidden_states.size(0), 1, -1)
539
+
540
+ hidden_states = self.self_attn(query_states=query_states, kv_states=hidden_states)
541
+ # We don't use a residual connection here because the out_dim is different from the hidden_size
542
+
543
+ residual = hidden_states
544
+ hidden_states = self.layer_norm2(hidden_states)
545
+ hidden_states = self.mlp(hidden_states)
546
+ hidden_states = hidden_states + residual
547
+ hidden_states = self.final_proj(hidden_states)
548
+
549
+ return hidden_states.squeeze(1)
550
+
551
+
552
+ class GAPHead(nn.Module):
553
+ def __init__(self, hidden_size: int, out_dim: int):
554
+ super().__init__()
555
+
556
+ self.norm = nn.LayerNorm(hidden_size)
557
+ self.proj = nn.Linear(hidden_size, out_dim)
558
+
559
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
560
+ x = x.mean(dim=1)
561
+ x = self.norm(x)
562
+ x = self.proj(x)
563
+ return x
564
+
565
+
566
+ class CLIPLikeModel(VisionModel):
567
+ def __init__(
568
+ self,
569
+ n_tags: int,
570
+ embedding_dim: int,
571
+ num_attention_heads: int,
572
+ activation_cls,
573
+ num_channels: int,
574
+ image_size: int,
575
+ patch_size: int,
576
+ patch_dropout: float,
577
+ use_palm_alt: bool,
578
+ num_layers: int,
579
+ use_mha_alt: bool,
580
+ loss_type: str,
581
+ good_dropout: bool=False,
582
+ dpn: bool=False,
583
+ sine_positional_embeddings: bool=False,
584
+ norm_qk: bool = False,
585
+ no_wd_bias: bool = False,
586
+ use_gap_head: bool = False,
587
+ skip_init: Optional[float] = None,
588
+ stochastic_depth: Optional[float] = None,
589
+ ):
590
+ super().__init__(image_size, n_tags)
591
+
592
+ out_dim = n_tags
593
+ self.n_tags = n_tags
594
+ self.loss_type = loss_type
595
+ self.no_wd_bias = no_wd_bias
596
+
597
+ stochastic_depth_space = torch.linspace(0, stochastic_depth, num_layers) if stochastic_depth is not None else None
598
+
599
+ self.embedding_layer = CLIPEmbeddingLayer(embedding_dim, num_channels, image_size, patch_size, patch_dropout, good_dropout, dpn, sine_positional_embeddings)
600
+ self.pre_layer_norm = nn.LayerNorm(embedding_dim)
601
+ self.encoder_layers = nn.ModuleList([FastCLIPEncoderLayer(
602
+ hidden_size=embedding_dim,
603
+ num_attention_heads=num_attention_heads,
604
+ out_seq_len=None,
605
+ activation_cls=activation_cls,
606
+ use_palm_alt=use_palm_alt,
607
+ norm_qk=norm_qk,
608
+ skip_init=skip_init,
609
+ stochastic_depth=stochastic_depth_space[i].item() if stochastic_depth_space is not None else None,
610
+ ) for i in range(num_layers)])
611
+
612
+ if use_gap_head:
613
+ self.pooling_head = GAPHead(embedding_dim, out_dim)
614
+ else:
615
+ self.pooling_head = MHAPoolingHead(embedding_dim, num_attention_heads, activation_cls, out_dim, use_mha_alt, norm_qk=norm_qk)
616
+
617
+ def forward(self, batch):
618
+ hidden_states = self.embedding_layer(batch['image'])
619
+ hidden_states = self.pre_layer_norm(hidden_states)
620
+
621
+ for layer in self.encoder_layers:
622
+ hidden_states = layer(hidden_states)
623
+
624
+ preds = self.pooling_head(hidden_states)
625
+
626
+ result = {
627
+ 'tags': preds,
628
+ }
629
+
630
+ return result
631
+
632
+ def calculate_loss(self, preds, batch, pos_weight):
633
+ return basic_calculate_loss(preds, batch, pos_weight, self.loss_type)
634
+
635
+ def get_optimized_parameters(self, lr: float):
636
+ if self.no_wd_bias:
637
+ return self.get_optimized_parameters_no_wd_bias()
638
+ else:
639
+ return self.parameters()
640
+
641
+ def get_optimized_parameters_no_wd_bias(self):
642
+ decay = []
643
+ no_decay = []
644
+
645
+ for name, param in self.named_parameters():
646
+ if not param.requires_grad:
647
+ continue
648
+
649
+ if len(param.shape) == 1 or name.endswith(".bias"):
650
+ no_decay.append(param)
651
+ print(f'No decay: {name}')
652
+ else:
653
+ decay.append(param)
654
+
655
+ return [
656
+ {'params': decay},
657
+ {'params': no_decay, 'weight_decay': 0.},
658
+ ]
659
+
660
+ def save(self):
661
+ return self.state_dict()
662
+
663
+ def load(self, state_dict):
664
+ self.load_state_dict(state_dict)
665
+
666
+
667
+ class MaskedAutoEncoderViT(nn.Module):
668
+ def __init__(
669
+ self,
670
+ n_tags: int,
671
+
672
+ embedding_dim: int,
673
+ num_attention_heads: int,
674
+ activation_cls,
675
+ num_channels: int,
676
+ image_size: int,
677
+ patch_size: int,
678
+ num_layers: int,
679
+ loss_type: str,
680
+ sine_positional_embeddings: bool=False,
681
+
682
+ decoder_embedding_dim: int = 512,
683
+ decoder_num_attention_heads: int = 8,
684
+ decoder_num_layers: int = 6,
685
+ decoder_force_projection: bool = False,
686
+
687
+ masking_ratio: float = 0.75,
688
+ mae_loss_weight: float = 1.0,
689
+ mae_normalize_targets: bool = False,
690
+ mae_post_norm: bool = False,
691
+ ):
692
+ super().__init__()
693
+
694
+ self.n_tags = n_tags
695
+ self.seq_len = (image_size // patch_size) ** 2
696
+ self.embedding_dim = embedding_dim
697
+ self.decoder_embedding_dim = decoder_embedding_dim
698
+ self.sine_positional_embeddings = sine_positional_embeddings
699
+ self.image_size = image_size
700
+ self.patch_size = patch_size
701
+ self.masking_ratio = masking_ratio
702
+ self.loss_type = loss_type
703
+ self.mae_loss_weight = mae_loss_weight
704
+ self.mae_normalize_targets = mae_normalize_targets
705
+
706
+ if not self.sine_positional_embeddings:
707
+ self.positional_embeddings = nn.Embedding(self.seq_len, embedding_dim)
708
+ self.decoder_positional_embeddings = nn.Embedding(self.seq_len, decoder_embedding_dim)
709
+ self.register_buffer("position_ids", torch.arange(self.seq_len))
710
+
711
+ self.to_patches = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
712
+ self.patch_embedder = nn.Linear(num_channels * patch_size * patch_size, embedding_dim)
713
+
714
+ # Encoder
715
+ self.pre_layer_norm = nn.LayerNorm(embedding_dim)
716
+ self.encoder_layers = nn.ModuleList([FastCLIPEncoderLayer(
717
+ hidden_size=embedding_dim,
718
+ num_attention_heads=num_attention_heads,
719
+ out_seq_len=None,
720
+ activation_cls=activation_cls,
721
+ use_palm_alt=True,
722
+ norm_qk=False,
723
+ skip_init=None,
724
+ ) for _ in range(num_layers)])
725
+
726
+ # Head for classification
727
+ self.pooling_head = GAPHead(embedding_dim, n_tags)
728
+
729
+ # Decoder
730
+ if embedding_dim != decoder_embedding_dim or decoder_force_projection:
731
+ self.encoder_to_decoder_proj = nn.Linear(embedding_dim, decoder_embedding_dim)
732
+ else:
733
+ self.encoder_to_decoder_proj = nn.Identity()
734
+ self.decoder_pre_layer_norm = nn.LayerNorm(decoder_embedding_dim)
735
+ self.decoder_layers = nn.ModuleList([FastCLIPEncoderLayer(
736
+ hidden_size=decoder_embedding_dim,
737
+ num_attention_heads=decoder_num_attention_heads,
738
+ out_seq_len=None,
739
+ activation_cls=activation_cls,
740
+ use_palm_alt=True,
741
+ norm_qk=False,
742
+ skip_init=None,
743
+ ) for _ in range(decoder_num_layers)])
744
+
745
+ if mae_post_norm:
746
+ self.decoder_to_pixel_values = nn.Sequential(
747
+ nn.LayerNorm(decoder_embedding_dim),
748
+ nn.Linear(decoder_embedding_dim, num_channels * patch_size * patch_size)
749
+ )
750
+ else:
751
+ self.decoder_to_pixel_values = nn.Linear(decoder_embedding_dim, num_channels * patch_size * patch_size)
752
+ self.mask_token = nn.Parameter(torch.zeros(decoder_embedding_dim))
753
+ torch.nn.init.normal_(self.mask_token, std=0.02)
754
+
755
+ def forward(self, batch):
756
+ pixel_values = batch['image']
757
+ device = pixel_values.device
758
+ B, C, H, W = pixel_values.shape
759
+ assert H % self.patch_size == 0, f"Input image height ({H}) needs to be divisible by the patch size ({self.patch_size})."
760
+ assert W % self.patch_size == 0, f"Input image width ({W}) needs to be divisible by the patch size ({self.patch_size})."
761
+
762
+ # Convert image to patches (B, seq_len, C * patch_size * patch_size)
763
+ patches = self.to_patches(pixel_values)
764
+ seq_len = patches.shape[1]
765
+ num_masked = int(self.masking_ratio * seq_len)
766
+
767
+ # For each batch tensor, use argsort to convert the random numbers into a permutation of the patch indices
768
+ # From this we can get the masked and unmasked indices
769
+ patch_mask = torch.rand(B, seq_len, device=device)
770
+ patch_mask = torch.argsort(patch_mask, dim=1)
771
+ masked_indices, unmasked_indices = patch_mask[:, :num_masked], patch_mask[:, num_masked:]
772
+ batch_range = torch.arange(B, device=device)[:, None]
773
+
774
+ # Masked and unmasked patches
775
+ unmasked_patches = patches[batch_range, unmasked_indices]
776
+ masked_patches = patches[batch_range, masked_indices]
777
+
778
+ # Embed unmasked patches for the encoder (B, seq_len, embedding_dim)
779
+ tokens = self.patch_embedder(unmasked_patches)
780
+
781
+ if self.sine_positional_embeddings:
782
+ position_embeddings = sinusoidal_position_embedding(W // self.patch_size, H // self.patch_size, self.embedding_dim, pixel_values.dtype, device)
783
+ decoder_position_embeddings = sinusoidal_position_embedding(W // self.patch_size, H // self.patch_size, self.decoder_embedding_dim, pixel_values.dtype, device)
784
+ else:
785
+ position_embeddings = self.positional_embeddings(self.position_ids)
786
+ decoder_position_embeddings = self.decoder_positional_embeddings(self.position_ids)
787
+
788
+ # Add position embeddings
789
+ tokens = tokens + position_embeddings[unmasked_indices]
790
+
791
+ # Run the encoder
792
+ encoded_tokens = self.pre_layer_norm(tokens)
793
+
794
+ for layer in self.encoder_layers:
795
+ encoded_tokens = layer(encoded_tokens)
796
+
797
+ # Label predictions
798
+ if self.training:
799
+ preds = self.pooling_head(encoded_tokens)
800
+ else:
801
+ # During inference, classify using the entire image
802
+ # But we'll do the usual for the MAE part, just so we can see how MAE is performing during validation
803
+ tokens = self.patch_embedder(patches)
804
+ tokens = tokens + position_embeddings
805
+ tokens = self.pre_layer_norm(tokens)
806
+ for layer in self.encoder_layers:
807
+ tokens = layer(tokens)
808
+ preds = self.pooling_head(tokens)
809
+
810
+ # Projection for the decoder and position embeddings
811
+ decoder_tokens = self.encoder_to_decoder_proj(encoded_tokens)
812
+ decoder_tokens = decoder_tokens + decoder_position_embeddings[unmasked_indices]
813
+
814
+ # Fill in the masked patches
815
+ mask_tokens = einops.repeat(self.mask_token, 'd -> b n d', b = B, n = num_masked)
816
+ mask_tokens = mask_tokens + decoder_position_embeddings[masked_indices]
817
+ decoder_tokens = torch.cat([decoder_tokens, mask_tokens], dim=1)
818
+
819
+ # Run the decoder
820
+ decoded_tokens = self.decoder_pre_layer_norm(decoder_tokens)
821
+
822
+ for layer in self.decoder_layers:
823
+ decoded_tokens = layer(decoded_tokens)
824
+
825
+ # Only predict the masked patches
826
+ # All the masked patches are at the end of the sequence
827
+ decoded_tokens = decoded_tokens[:, -num_masked:]
828
+ pred_pixel_values = self.decoder_to_pixel_values(decoded_tokens)
829
+
830
+ # Calculate the mae loss
831
+ if self.mae_normalize_targets:
832
+ # Normalize each patch by its mean and variance. The ViCHA paper says this provides better results
833
+ means = masked_patches.mean(dim=-1, keepdim=True)
834
+ vars = masked_patches.var(dim=-1, keepdim=True)
835
+ target = (masked_patches - means) / (vars + 1e-6)**0.5
836
+ mae_loss = F.mse_loss(pred_pixel_values, target)
837
+ else:
838
+ mae_loss = F.mse_loss(pred_pixel_values, masked_patches)
839
+ mae_loss = mae_loss * self.mae_loss_weight
840
+
841
+ return {
842
+ 'tags': preds,
843
+ 'mae_loss': mae_loss,
844
+ }
845
+
846
+ def calculate_loss(self, preds, batch, pos_weight):
847
+ return basic_calculate_loss(preds, batch, pos_weight, self.loss_type) + preds['mae_loss']
848
+
849
+ def get_optimized_parameters(self, lr: float):
850
+ return self.parameters()
851
+
852
+ def save(self):
853
+ return self.state_dict()
854
+
855
+ def load(self, state_dict):
856
+ self.load_state_dict(state_dict)
857
+
858
+
859
+ class StochDepth(nn.Module):
860
+ def __init__(self, drop_rate: float, scale_by_keep: bool = False):
861
+ super().__init__()
862
+ self.drop_rate = drop_rate
863
+ self.scale_by_keep = scale_by_keep
864
+
865
+ def forward(self, x):
866
+ if not self.training:
867
+ return x
868
+
869
+ batch_size = x.shape[0]
870
+ r = torch.rand((batch_size, 1, 1), device=x.device)
871
+ keep_prob = 1 - self.drop_rate
872
+ binary_tensor = torch.floor(keep_prob + r)
873
+ if self.scale_by_keep:
874
+ x = x / keep_prob
875
+
876
+ return x * binary_tensor
877
+
878
+
879
+ class SkipInitChannelwise(nn.Module):
880
+ def __init__(self, channels, init_val=1e-6):
881
+ super().__init__()
882
+ self.channels = channels
883
+ self.init_val = init_val
884
+ self.skip = nn.Parameter(torch.ones(channels) * init_val)
885
+
886
+ def forward(self, x):
887
+ return x * self.skip
888
+
889
+
890
+ class PosEmbedding(nn.Module):
891
+ def __init__(self, d_model: int, max_len: int, use_sine: bool, patch_size: int):
892
+ super().__init__()
893
+ self.d_model = d_model
894
+ self.max_len = max_len
895
+ self.use_sine = use_sine
896
+ self.patch_size = patch_size
897
+
898
+ if not self.use_sine:
899
+ self.embedding = nn.Embedding(max_len, d_model)
900
+ nn.init.trunc_normal_(self.embedding.weight, std=0.02)
901
+ self.register_buffer("position_ids", torch.arange(max_len))
902
+
903
+ def forward(self, x, width: int, height: int):
904
+ if self.use_sine:
905
+ position_embeddings = sinusoidal_position_embedding(width // self.patch_size, height // self.patch_size, self.d_model, x.dtype, x.device)
906
+ else:
907
+ position_embeddings = self.embedding(self.position_ids)
908
+
909
+ return x + position_embeddings
910
+
911
+
912
+ class MLPBlock(nn.Module):
913
+ def __init__(self, d_model: int, d_ff: int, stochdepth_rate: float):
914
+ super().__init__()
915
+ self.linear1 = nn.Linear(d_model, d_ff)
916
+ self.linear2 = nn.Linear(d_ff, d_model)
917
+ self.activation = nn.GELU()
918
+ if stochdepth_rate > 0:
919
+ self.stochdepth = StochDepth(stochdepth_rate, scale_by_keep=True)
920
+ else:
921
+ self.stochdepth = None
922
+
923
+ def forward(self, x):
924
+ x = self.linear1(x)
925
+ x = self.activation(x)
926
+ if self.stochdepth is not None:
927
+ x = self.stochdepth(x)
928
+ x = self.linear2(x)
929
+ return x
930
+
931
+
932
+ class ViTBlock(nn.Module):
933
+ def __init__(self, num_heads: int, d_model: int, d_ff: int, layerscale_init: float, stochdepth_rate: float):
934
+ super().__init__()
935
+ self.num_heads = num_heads
936
+ self.d_model = d_model
937
+
938
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
939
+
940
+ # MHA
941
+ self.norm1 = nn.LayerNorm(d_model)
942
+ self.qkv_proj = nn.Linear(d_model, d_model * 3)
943
+ self.out_proj = nn.Linear(d_model, d_model)
944
+ self.skip_init1 = SkipInitChannelwise(channels=d_model, init_val=layerscale_init)
945
+ self.stochdepth1 = StochDepth(stochdepth_rate, scale_by_keep=True) if stochdepth_rate > 0 else None
946
+
947
+ # MLP
948
+ self.norm2 = nn.LayerNorm(d_model)
949
+ self.mlp = MLPBlock(d_model, d_ff, stochdepth_rate)
950
+ self.skip_init2 = SkipInitChannelwise(channels=d_model, init_val=layerscale_init)
951
+ self.stochdepth2 = StochDepth(stochdepth_rate, scale_by_keep=True) if stochdepth_rate > 0 else None
952
+
953
+ def forward(self, x):
954
+ bsz, src_len, embed_dim = x.shape
955
+
956
+ out = x
957
+ out = self.norm1(out)
958
+
959
+ # MHA
960
+ qkv_states = self.qkv_proj(out).split(self.d_model, dim=-1)
961
+ q_states = qkv_states[0].view(bsz, src_len, self.num_heads, embed_dim // self.num_heads).transpose(1, 2) # (bsz, num_heads, src_len, embed_dim // num_heads)
962
+ k_states = qkv_states[1].view(bsz, src_len, self.num_heads, embed_dim // self.num_heads).transpose(1, 2) # (bsz, num_heads, src_len, embed_dim // num_heads)
963
+ v_states = qkv_states[2].view(bsz, src_len, self.num_heads, embed_dim // self.num_heads).transpose(1, 2) # (bsz, num_heads, src_len, embed_dim // num_heads)
964
+
965
+ with torch.backends.cuda.sdp_kernel(enable_math=False):
966
+ out = F.scaled_dot_product_attention(q_states, k_states, v_states) # (bsz, num_heads, tgt_len, head_dim)
967
+ out = out.transpose(1, 2).contiguous().view(bsz, src_len, embed_dim) # (bsz, tgt_len, embed_dim)
968
+
969
+ out = self.out_proj(out)
970
+
971
+ out = self.skip_init1(out)
972
+ if self.stochdepth1 is not None:
973
+ out = self.stochdepth1(out)
974
+ x = out + x
975
+
976
+ out = self.norm2(x)
977
+ out = self.mlp(out)
978
+ out = self.skip_init2(out)
979
+ if self.stochdepth2 is not None:
980
+ out = self.stochdepth2(out)
981
+
982
+ out = out + x
983
+
984
+ return out
985
+
986
+
987
+ def CaiT_LayerScale_init(network_depth):
988
+ if network_depth <= 18:
989
+ return 1e-1
990
+ elif network_depth <= 24:
991
+ return 1e-5
992
+ else:
993
+ return 1e-6
994
+
995
+
996
+ class CNNLayerNorm(nn.Module):
997
+ def __init__(self, d_model: int):
998
+ super().__init__()
999
+ self.norm = nn.LayerNorm(d_model)
1000
+
1001
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1002
+ x = x.transpose(1, 3)
1003
+ x = self.norm(x)
1004
+ x = x.transpose(1, 3)
1005
+ return x
1006
+
1007
+
1008
+ class CNNStem(nn.Module):
1009
+ def __init__(self, config: str):
1010
+ super().__init__()
1011
+ self.config = config
1012
+
1013
+ layers = []
1014
+ channels = 3
1015
+
1016
+ for line in config.split(";"):
1017
+ ty, line = line.split(":") if ":" in line else (line, "")
1018
+ options = line.split(",")
1019
+ options = [o.split("=") for o in options] if line else []
1020
+ options = {k: v for k, v in options}
1021
+
1022
+ if ty == 'conv':
1023
+ layers.append(nn.Conv2d(
1024
+ in_channels=channels,
1025
+ out_channels=int(options['c']),
1026
+ kernel_size=int(options['k'] if 'k' in options else 3),
1027
+ stride=int(options['s'] if 's' in options else 2),
1028
+ bias=True,
1029
+ padding=int(options['p'] if 'p' in options else 1),
1030
+ ))
1031
+ channels = int(options['c'])
1032
+ elif ty == 'bn':
1033
+ layers.append(nn.BatchNorm2d(channels))
1034
+ elif ty == 'ln':
1035
+ layers.append(CNNLayerNorm(channels))
1036
+ elif ty == 'relu':
1037
+ layers.append(nn.ReLU())
1038
+ elif ty == 'gelu':
1039
+ layers.append(nn.GELU())
1040
+
1041
+ self.conv = nn.Sequential(*layers)
1042
+
1043
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1044
+ return self.conv(x)
1045
+
1046
+
1047
+ class ViT(VisionModel):
1048
+ def __init__(self,
1049
+ n_tags: int,
1050
+ image_size: int,
1051
+ num_blocks: int,
1052
+ patch_size: int,
1053
+ d_model: int,
1054
+ mlp_dim: int,
1055
+ num_heads: int,
1056
+ stochdepth_rate: float,
1057
+ use_sine: bool,
1058
+ loss_type: str,
1059
+ layerscale_init: Optional[float] = None,
1060
+ head_mean_after: bool = False,
1061
+ cnn_stem: str | None = None,
1062
+ patch_dropout: float = 0.0,
1063
+ ):
1064
+ super().__init__(image_size, n_tags)
1065
+
1066
+ #assert image_size % patch_size == 0, "image_size must be divisible by patch_size"
1067
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
1068
+
1069
+ out_dim = n_tags
1070
+ self.n_tags = n_tags
1071
+ self.loss_type = loss_type
1072
+ self.patch_size = patch_size
1073
+ self.head_mean_after = head_mean_after
1074
+ self.patch_dropout = patch_dropout
1075
+
1076
+ layerscale_init = CaiT_LayerScale_init(num_blocks) if layerscale_init is None else layerscale_init
1077
+ self.patch_embeddings = nn.Conv2d(
1078
+ in_channels=3,
1079
+ out_channels=d_model,
1080
+ kernel_size=patch_size,
1081
+ stride=patch_size,
1082
+ bias=True,
1083
+ ) if cnn_stem is None else CNNStem(cnn_stem)
1084
+ self.pos_embedding = PosEmbedding(d_model, (image_size // patch_size) ** 2, use_sine=use_sine, patch_size=patch_size)
1085
+
1086
+ self.blocks = nn.ModuleList([
1087
+ ViTBlock(num_heads, d_model, mlp_dim, layerscale_init, stochdepth_rate)
1088
+ for _ in range(num_blocks)
1089
+ ])
1090
+
1091
+ self.norm = nn.LayerNorm(d_model)
1092
+ self.head = nn.Linear(d_model, out_dim)
1093
+
1094
+ def forward(self, batch, return_embeddings=False, return_loss: bool = False, pos_weight = None):
1095
+ B, C, H, W = batch['image'].shape
1096
+ assert H % self.patch_size == 0, f"Input image height ({H}) needs to be divisible by the patch size ({self.patch_size})."
1097
+ assert W % self.patch_size == 0, f"Input image width ({W}) needs to be divisible by the patch size ({self.patch_size})."
1098
+
1099
+ x = self.patch_embeddings(batch['image']) # (bsz, d_model, patch_num, patch_num)
1100
+ x = x.flatten(2).transpose(1, 2) # (bsz, patch_num ** 2, d_model)
1101
+ x = self.pos_embedding(x, W, H) # (bsz, patch_num ** 2, d_model)
1102
+
1103
+ # Patch dropout
1104
+ seq_len = x.shape[1]
1105
+ patch_dropout = int(math.ceil((1.0 - self.patch_dropout) * seq_len))
1106
+
1107
+ if patch_dropout != seq_len:
1108
+ # Generate a matrix of random numbers between 0 and 1 of shape (B, seq_len)
1109
+ patch_mask = torch.rand(B, seq_len, device=x.device)
1110
+ # For each batch tensor, use argsort to convert the random numbers into a permutation of the patch indices
1111
+ patch_mask = torch.argsort(patch_mask, dim=1)
1112
+ # Truncate
1113
+ patch_mask = patch_mask[:, :patch_dropout]
1114
+
1115
+ x = x.gather(1, patch_mask.unsqueeze(-1).expand(-1, -1, x.shape[-1]))
1116
+
1117
+ #indices = torch.randperm(seq_len, device=x.device)[:patch_dropout]
1118
+ #x = x[:, indices, :]
1119
+
1120
+ # Transformer
1121
+ for block in self.blocks:
1122
+ x = block(x)
1123
+
1124
+ # Head
1125
+ result = {}
1126
+
1127
+ x = self.norm(x)
1128
+ if self.head_mean_after:
1129
+ x = self.head(x)
1130
+ x = x.mean(dim=1)
1131
+ else:
1132
+ x = x.mean(dim=1)
1133
+ if return_embeddings:
1134
+ result['embeddings'] = x
1135
+ x = self.head(x)
1136
+
1137
+ result['tags'] = x
1138
+
1139
+ if return_loss:
1140
+ result['loss'] = self.calculate_loss(result, batch, pos_weight)
1141
+
1142
+ return result
1143
+
1144
+ def calculate_loss(self, preds, batch, pos_weight):
1145
+ return basic_calculate_loss(preds, batch, pos_weight, self.loss_type)
1146
+
1147
+ def get_optimized_parameters(self, lr: float):
1148
+ return self.parameters()
1149
+
1150
+ def save(self):
1151
+ return self.state_dict()
1152
+
1153
+ def load(self, state_dict):
1154
+ if 'head.weight' in state_dict and 'head.bias' in state_dict and state_dict['head.weight'].shape[0] == (self.n_tags + 9):
1155
+ # Support old models which included 3 rating and 6 score dimensions
1156
+ state_dict['head.weight'] = state_dict['head.weight'][:self.n_tags]
1157
+ state_dict['head.bias'] = state_dict['head.bias'][:self.n_tags]
1158
+
1159
+ self.load_state_dict(state_dict)
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from Models import VisionModel
3
+ import huggingface_hub
4
+ from PIL import Image
5
+ import torch.amp.autocast_mode
6
+ from pathlib import Path
7
+
8
+
9
+ MODEL_REPO = "fancyfeast/joytag"
10
+
11
+
12
+ @torch.no_grad()
13
+ def predict(image: Image.Image):
14
+ with torch.amp.autocast_mode.autocast('cuda', enabled=True):
15
+ preds = model(image)
16
+ tag_preds = preds['tags'].sigmoid().cpu()
17
+
18
+ return {top_tags[i]: tag_preds[i] for i in range(len(top_tags))}
19
+
20
+
21
+ print("Downloading model...")
22
+ path = huggingface_hub.snapshot_download(MODEL_REPO)
23
+ print("Loading model...")
24
+ model = VisionModel.load_model(path)
25
+ model.eval()
26
+
27
+ with open(Path(path) / 'top_tags.txt', 'r') as f:
28
+ top_tags = [line.strip() for line in f.readlines() if line.strip()]
29
+
30
+ print("Starting server...")
31
+
32
+ gradio_app = gr.Interface(
33
+ predict,
34
+ inputs=gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'),
35
+ outputs=[gr.Label(label="Result", num_top_classes=5)],
36
+ title="JoyTag",
37
+ )
38
+
39
+
40
+ if __name__ == '__main__':
41
+ gradio_app.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==2.1.2
2
+ transformers==4.36.2
3
+ torchvision==0.16.2
4
+ einops==0.7.0
5
+ safetensors==0.4.1