Laksh recursionaut kiankaydee commited on
Commit
6877289
·
unverified ·
1 Parent(s): b3e68bb

add mae model (#15)

Browse files

* add mae model

* push to hugging face

* minor cleanup

* more changes

* set mask ratio to 0

* successful validation

* rearrange

* remove testing code

* remove unused funtion

* add a predict method

* update to correct version of phenom-beta

* add Kian's PR suggestion

* add test

* add comment to download model

* add multiple channel test

* allow channelwise embs

* clean some dead code

* add reconstruction notebook with example. can run on CPU no prob

* fix up

* udpate notebook

* remove the need for hydra

---------

Co-authored-by: Laksh <laksh.arumugam@recursionpharma.com>
Co-authored-by: kian-kd <kian.kd@recursionpharma.com>

.gitignore ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # model artifacts
30
+ *.pickle
31
+ *.ckpt
32
+ *.safetensors
generate_reconstructions.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
huggingface_mae.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import PretrainedConfig, PreTrainedModel
7
+
8
+ from loss import FourierLoss
9
+ from normalizer import Normalizer
10
+ from mae_modules import CAMAEDecoder, MAEDecoder, MAEEncoder
11
+ from mae_utils import flatten_images
12
+ from vit import (
13
+ generate_2d_sincos_pos_embeddings,
14
+ sincos_positional_encoding_vit,
15
+ vit_small_patch16_256,
16
+ )
17
+
18
+ TensorDict = Dict[str, torch.Tensor]
19
+
20
+
21
+ class MAEConfig(PretrainedConfig):
22
+ model_type = "MAE"
23
+
24
+ def __init__(
25
+ self,
26
+ mask_ratio=0.75,
27
+ encoder=None,
28
+ decoder=None,
29
+ loss=None,
30
+ optimizer=None,
31
+ input_norm=None,
32
+ fourier_loss=None,
33
+ fourier_loss_weight=0.0,
34
+ lr_scheduler=None,
35
+ use_MAE_weight_init=False,
36
+ crop_size=-1,
37
+ mask_fourier_loss=True,
38
+ return_channelwise_embeddings=False,
39
+ **kwargs,
40
+ ):
41
+ super().__init__(**kwargs)
42
+ self.mask_ratio = mask_ratio
43
+ self.encoder = encoder
44
+ self.decoder = decoder
45
+ self.loss = loss
46
+ self.optimizer = optimizer
47
+ self.input_norm = input_norm
48
+ self.fourier_loss = fourier_loss
49
+ self.fourier_loss_weight = fourier_loss_weight
50
+ self.lr_scheduler = lr_scheduler
51
+ self.use_MAE_weight_init = use_MAE_weight_init
52
+ self.crop_size = crop_size
53
+ self.mask_fourier_loss = mask_fourier_loss
54
+ self.return_channelwise_embeddings = return_channelwise_embeddings
55
+
56
+
57
+ class MAEModel(PreTrainedModel):
58
+ config_class = MAEConfig
59
+
60
+ # Loss metrics
61
+ TOTAL_LOSS = "loss"
62
+ RECON_LOSS = "reconstruction_loss"
63
+ FOURIER_LOSS = "fourier_loss"
64
+
65
+ def __init__(self, config: MAEConfig):
66
+ super().__init__(config)
67
+
68
+ self.mask_ratio = config.mask_ratio
69
+
70
+ # Could use Hydra to instantiate instead
71
+ self.encoder = MAEEncoder(
72
+ vit_backbone=sincos_positional_encoding_vit(
73
+ vit_backbone=vit_small_patch16_256(global_pool="avg")
74
+ ),
75
+ max_in_chans=11, # upper limit on number of input channels
76
+ channel_agnostic=True,
77
+ )
78
+ self.decoder = CAMAEDecoder(
79
+ depth=8,
80
+ embed_dim=512,
81
+ mlp_ratio=4,
82
+ norm_layer=nn.LayerNorm,
83
+ num_heads=16,
84
+ num_modalities=6,
85
+ qkv_bias=True,
86
+ tokens_per_modality=256,
87
+ )
88
+ self.input_norm = torch.nn.Sequential(
89
+ Normalizer(),
90
+ nn.InstanceNorm2d(None, affine=False, track_running_stats=False),
91
+ )
92
+
93
+ self.fourier_loss_weight = config.fourier_loss_weight
94
+ self.mask_fourier_loss = config.mask_fourier_loss
95
+ self.return_channelwise_embeddings = config.return_channelwise_embeddings
96
+ self.tokens_per_channel = 256 # hardcode the number of tokens per channel since we are patch16 crop 256
97
+
98
+ # loss stuff
99
+ self.loss = torch.nn.MSELoss(reduction="none")
100
+
101
+ self.fourier_loss = FourierLoss(num_multimodal_modalities=6)
102
+ if self.fourier_loss_weight > 0 and self.fourier_loss is None:
103
+ raise ValueError(
104
+ "FourierLoss weight is activated but no fourier_loss was defined in constructor"
105
+ )
106
+ elif self.fourier_loss_weight >= 1:
107
+ raise ValueError(
108
+ "FourierLoss weight is too large to do mixing factor, weight should be < 1"
109
+ )
110
+
111
+ self.patch_size = int(self.encoder.vit_backbone.patch_embed.patch_size[0])
112
+
113
+ # projection layer between the encoder and decoder
114
+ self.encoder_decoder_proj = nn.Linear(
115
+ self.encoder.embed_dim, self.decoder.embed_dim, bias=True
116
+ )
117
+
118
+ self.decoder_pred = nn.Linear(
119
+ self.decoder.embed_dim,
120
+ self.patch_size**2
121
+ * (1 if self.encoder.channel_agnostic else self.in_chans),
122
+ bias=True,
123
+ ) # linear layer from decoder embedding to input dims
124
+
125
+ # overwrite decoder pos embeddings based on encoder params
126
+ self.decoder.pos_embeddings = generate_2d_sincos_pos_embeddings( # type: ignore[assignment]
127
+ self.decoder.embed_dim,
128
+ length=self.encoder.vit_backbone.patch_embed.grid_size[0],
129
+ use_class_token=self.encoder.vit_backbone.cls_token is not None,
130
+ num_modality=(
131
+ self.decoder.num_modalities if self.encoder.channel_agnostic else 1
132
+ ),
133
+ )
134
+
135
+ if config.use_MAE_weight_init:
136
+ w = self.encoder.vit_backbone.patch_embed.proj.weight.data
137
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
138
+
139
+ torch.nn.init.normal_(self.encoder.vit_backbone.cls_token, std=0.02)
140
+ torch.nn.init.normal_(self.decoder.mask_token, std=0.02)
141
+
142
+ self.apply(self._MAE_init_weights)
143
+
144
+ def setup(self, stage: str) -> None:
145
+ super().setup(stage)
146
+
147
+ def _MAE_init_weights(self, m):
148
+ if isinstance(m, nn.Linear):
149
+ torch.nn.init.xavier_uniform_(m.weight)
150
+ if isinstance(m, nn.Linear) and m.bias is not None:
151
+ nn.init.constant_(m.bias, 0)
152
+ elif isinstance(m, nn.LayerNorm):
153
+ nn.init.constant_(m.bias, 0)
154
+ nn.init.constant_(m.weight, 1.0)
155
+
156
+ @staticmethod
157
+ def decode_to_reconstruction(
158
+ encoder_latent: torch.Tensor,
159
+ ind_restore: torch.Tensor,
160
+ proj: torch.nn.Module,
161
+ decoder: MAEDecoder | CAMAEDecoder,
162
+ pred: torch.nn.Module,
163
+ ) -> torch.Tensor:
164
+ """Feed forward the encoder latent through the decoders necessary projections and transformations."""
165
+ decoder_latent_projection = proj(
166
+ encoder_latent
167
+ ) # projection from encoder.embed_dim to decoder.embed_dim
168
+ decoder_tokens = decoder.forward_masked(
169
+ decoder_latent_projection, ind_restore
170
+ ) # decoder.embed_dim output
171
+ predicted_reconstruction = pred(
172
+ decoder_tokens
173
+ ) # linear projection to input dim
174
+ return predicted_reconstruction[:, 1:, :] # drop class token
175
+
176
+ def forward(
177
+ self, imgs: torch.Tensor, constant_noise: Union[torch.Tensor, None] = None
178
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
179
+ imgs = self.input_norm(imgs)
180
+ latent, mask, ind_restore = self.encoder.forward_masked(
181
+ imgs, self.mask_ratio, constant_noise
182
+ ) # encoder blocks
183
+ reconstruction = self.decode_to_reconstruction(
184
+ latent,
185
+ ind_restore,
186
+ self.encoder_decoder_proj,
187
+ self.decoder,
188
+ self.decoder_pred,
189
+ )
190
+ return latent, reconstruction, mask
191
+
192
+ def compute_MAE_loss(
193
+ self,
194
+ reconstruction: torch.Tensor,
195
+ img: torch.Tensor,
196
+ mask: torch.Tensor,
197
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
198
+ """Computes final loss and returns specific values of component losses for metric reporting."""
199
+ loss_dict = {}
200
+ img = self.input_norm(img)
201
+ target_flattened = flatten_images(
202
+ img,
203
+ patch_size=self.patch_size,
204
+ channel_agnostic=self.encoder.channel_agnostic,
205
+ )
206
+
207
+ loss: torch.Tensor = self.loss(
208
+ reconstruction, target_flattened
209
+ ) # should be with MSE or MAE (L1) with reduction='none'
210
+ loss = loss.mean(
211
+ dim=-1
212
+ ) # average over embedding dim -> mean loss per patch (N,L)
213
+ loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches only
214
+ loss_dict[self.RECON_LOSS] = loss.item()
215
+
216
+ # compute fourier loss
217
+ if self.fourier_loss_weight > 0:
218
+ floss: torch.Tensor = self.fourier_loss(reconstruction, target_flattened)
219
+ if not self.mask_fourier_loss:
220
+ floss = floss.mean()
221
+ else:
222
+ floss = floss.mean(dim=-1)
223
+ floss = (floss * mask).sum() / mask.sum()
224
+
225
+ loss_dict[self.FOURIER_LOSS] = floss.item()
226
+
227
+ # here we use a mixing factor to keep the loss magnitude appropriate with fourier
228
+ if self.fourier_loss_weight > 0:
229
+ loss = (1 - self.fourier_loss_weight) * loss + (
230
+ self.fourier_loss_weight * floss
231
+ )
232
+ return loss, loss_dict
233
+
234
+ def training_step(self, batch: TensorDict, batch_idx: int) -> TensorDict:
235
+ img = batch["pixels"]
236
+ latent, reconstruction, mask = self(img.clone())
237
+ full_loss, loss_dict = self.compute_MAE_loss(reconstruction, img.float(), mask)
238
+ return {
239
+ "loss": full_loss,
240
+ **loss_dict, # type: ignore[dict-item]
241
+ }
242
+
243
+ def validation_step(self, batch: TensorDict, batch_idx: int) -> TensorDict:
244
+ return self.training_step(batch, batch_idx)
245
+
246
+ def update_metrics(self, outputs: TensorDict, batch: TensorDict) -> None:
247
+ self.metrics["lr"].update(value=self.lr_scheduler.get_last_lr())
248
+ for key, value in outputs.items():
249
+ if key.endswith("loss"):
250
+ self.metrics[key].update(value)
251
+
252
+ def on_validation_batch_end( # type: ignore[override]
253
+ self,
254
+ outputs: TensorDict,
255
+ batch: TensorDict,
256
+ batch_idx: int,
257
+ dataloader_idx: int = 0,
258
+ ) -> None:
259
+ super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)
260
+
261
+ def predict(self, imgs: torch.Tensor) -> torch.Tensor:
262
+ imgs = self.input_norm(imgs)
263
+ X = self.encoder.vit_backbone.forward_features(
264
+ imgs
265
+ ) # 3d tensor N x num_tokens x dim
266
+ if self.return_channelwise_embeddings:
267
+ N, _, d = X.shape
268
+ num_channels = imgs.shape[1]
269
+ X_reshaped = X[:, 1:, :].view(N, num_channels, self.tokens_per_channel, d)
270
+ pooled_segments = X_reshaped.mean(
271
+ dim=2
272
+ ) # Resulting shape: (N, num_channels, d)
273
+ latent = pooled_segments.view(N, num_channels * d).contiguous()
274
+ else:
275
+ latent = X[:, 1:, :].mean(dim=1) # 1 + 256 * C tokens
276
+ return latent
277
+
278
+ def save_pretrained(self, save_directory: str, **kwargs):
279
+ filename = kwargs.pop("filename", "model.safetensors")
280
+ modelpath = f"{save_directory}/{filename}"
281
+ self.config.save_pretrained(save_directory)
282
+ torch.save({"state_dict": self.state_dict()}, modelpath)
283
+
284
+ @classmethod
285
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
286
+ filename = kwargs.pop("filename", "model.safetensors")
287
+
288
+ modelpath = f"{pretrained_model_name_or_path}/{filename}"
289
+ config = MAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
290
+ state_dict = torch.load(modelpath, map_location="cpu")
291
+ model = cls(config, *model_args, **kwargs)
292
+ model.load_state_dict(state_dict["state_dict"])
293
+ return model
mae_modules.py CHANGED
@@ -7,8 +7,8 @@ import torch.nn as nn
7
  from timm.models.helpers import checkpoint_seq
8
  from timm.models.vision_transformer import Block, Mlp, VisionTransformer
9
 
10
- from .masking import transformer_random_masking
11
- from .vit import channel_agnostic_vit
12
 
13
  # If interested in training new MAEs, combine an encoder and decoder into a new module, and you should
14
  # leverage the flattening and unflattening utilities as needed from mae_utils.py.
 
7
  from timm.models.helpers import checkpoint_seq
8
  from timm.models.vision_transformer import Block, Mlp, VisionTransformer
9
 
10
+ from masking import transformer_random_masking
11
+ from vit import channel_agnostic_vit
12
 
13
  # If interested in training new MAEs, combine an encoder and decoder into a new module, and you should
14
  # leverage the flattening and unflattening utilities as needed from mae_utils.py.
models/phenom_beta_huggingface/config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "apply_loss_unmasked": false,
4
+ "architectures": [
5
+ "MAEModel"
6
+ ],
7
+ "crop_size": -1,
8
+ "decoder": {
9
+ "_target_": "mae_modules.CAMAEDecoder",
10
+ "depth": 8,
11
+ "embed_dim": 512,
12
+ "mlp_ratio": 4,
13
+ "norm_layer": {
14
+ "_partial_": true,
15
+ "_target_": "torch.nn.LayerNorm",
16
+ "eps": 1e-06
17
+ },
18
+ "num_heads": 16,
19
+ "num_modalities": 6,
20
+ "qkv_bias": true,
21
+ "tokens_per_modality": 256
22
+ },
23
+ "encoder": {
24
+ "_target_": "mae_modules.MAEEncoder",
25
+ "channel_agnostic": true,
26
+ "max_in_chans": 11,
27
+ "vit_backbone": {
28
+ "_target_": "vit.sincos_positional_encoding_vit",
29
+ "vit_backbone": {
30
+ "_target_": "vit.vit_small_patch16_256",
31
+ "global_pool": "avg"
32
+ }
33
+ }
34
+ },
35
+ "fourier_loss": {
36
+ "_target_": "loss.FourierLoss",
37
+ "num_multimodal_modalities": 6
38
+ },
39
+ "fourier_loss_weight": 0.0,
40
+ "input_norm": {
41
+ "_args_": [
42
+ {
43
+ "_target_": "normalizer.Normalizer"
44
+ },
45
+ {
46
+ "_target_": "torch.nn.InstanceNorm2d",
47
+ "affine": false,
48
+ "num_features": null,
49
+ "track_running_stats": false
50
+ }
51
+ ],
52
+ "_target_": "torch.nn.Sequential"
53
+ },
54
+ "layernorm_unfreeze": true,
55
+ "loss": {
56
+ "_target_": "torch.nn.MSELoss",
57
+ "reduction": "none"
58
+ },
59
+ "lr_scheduler": {
60
+ "_partial_": true,
61
+ "_target_": "torch.optim.lr_scheduler.OneCycleLR",
62
+ "anneal_strategy": "cos",
63
+ "max_lr": 0.0001,
64
+ "pct_start": 0.1
65
+ },
66
+ "mask_fourier_loss": true,
67
+ "mask_ratio": 0.0,
68
+ "model_type": "MAE",
69
+ "norm_pix_loss": false,
70
+ "num_blocks_to_freeze": 0,
71
+ "optimizer": {
72
+ "_partial_": true,
73
+ "_target_": "timm.optim.lion.Lion",
74
+ "betas": [
75
+ 0.9,
76
+ 0.95
77
+ ],
78
+ "lr": 0.0001,
79
+ "weight_decay": 0.05
80
+ },
81
+ "torch_dtype": "float32",
82
+ "transformers_version": "4.46.1",
83
+ "trim_encoder_blocks": null,
84
+ "use_MAE_weight_init": false
85
+ }
normalizer.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class Normalizer(torch.nn.Module):
5
+ def forward(self, pixels: torch.Tensor) -> torch.Tensor:
6
+ pixels = pixels.float()
7
+ return pixels / 255.0
requirements.in ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface-hub
2
+ timm
3
+ torch>=2.3
4
+ torchmetrics
5
+ torchvision
6
+ tqdm
7
+ transformers
8
+ xformers
9
+ zarr
10
+ hydra-core
11
+ pytorch-lightning>=2.1
12
+ isort
13
+ ruff
14
+ pytest
requirements.txt CHANGED
@@ -1,9 +1,213 @@
1
- huggingface-hub==0.18.0
2
- timm==0.9.7
3
- torch==2.1.0+cu121
4
- torchmetrics==1.2.0
5
- torchvision==0.16.0+cu121
6
- tqdm==4.66.1
7
- transformers==4.35.2
8
- xformers==0.0.22.post7
9
- zarr==2.16.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.10
3
+ # by the following command:
4
+ #
5
+ # pip-compile --no-emit-index-url --output-file=requirements.txt requirements.in
6
+ #
7
+ --trusted-host pypi.ngc.nvidia.com
8
+
9
+ aiohappyeyeballs==2.4.3
10
+ # via aiohttp
11
+ aiohttp==3.10.10
12
+ # via fsspec
13
+ aiosignal==1.3.1
14
+ # via aiohttp
15
+ antlr4-python3-runtime==4.9.3
16
+ # via
17
+ # hydra-core
18
+ # omegaconf
19
+ asciitree==0.3.3
20
+ # via zarr
21
+ async-timeout==4.0.3
22
+ # via aiohttp
23
+ attrs==24.2.0
24
+ # via aiohttp
25
+ certifi==2024.8.30
26
+ # via requests
27
+ charset-normalizer==3.4.0
28
+ # via requests
29
+ exceptiongroup==1.2.2
30
+ # via pytest
31
+ fasteners==0.19
32
+ # via zarr
33
+ filelock==3.16.1
34
+ # via
35
+ # huggingface-hub
36
+ # torch
37
+ # transformers
38
+ # triton
39
+ frozenlist==1.4.1
40
+ # via
41
+ # aiohttp
42
+ # aiosignal
43
+ fsspec[http]==2024.10.0
44
+ # via
45
+ # huggingface-hub
46
+ # pytorch-lightning
47
+ # torch
48
+ huggingface-hub==0.26.1
49
+ # via
50
+ # -r requirements.in
51
+ # timm
52
+ # tokenizers
53
+ # transformers
54
+ hydra-core==1.3.2
55
+ # via -r requirements.in
56
+ idna==3.10
57
+ # via
58
+ # requests
59
+ # yarl
60
+ iniconfig==2.0.0
61
+ # via pytest
62
+ isort==5.13.2
63
+ # via -r requirements.in
64
+ jinja2==3.1.4
65
+ # via torch
66
+ lightning-utilities==0.11.8
67
+ # via
68
+ # pytorch-lightning
69
+ # torchmetrics
70
+ markupsafe==3.0.2
71
+ # via jinja2
72
+ mpmath==1.3.0
73
+ # via sympy
74
+ multidict==6.1.0
75
+ # via
76
+ # aiohttp
77
+ # yarl
78
+ networkx==3.4.2
79
+ # via torch
80
+ numcodecs==0.13.1
81
+ # via zarr
82
+ numpy==1.26.4
83
+ # via
84
+ # numcodecs
85
+ # torchmetrics
86
+ # torchvision
87
+ # transformers
88
+ # xformers
89
+ # zarr
90
+ nvidia-cublas-cu12==12.4.5.8
91
+ # via
92
+ # nvidia-cudnn-cu12
93
+ # nvidia-cusolver-cu12
94
+ # torch
95
+ nvidia-cuda-cupti-cu12==12.4.127
96
+ # via torch
97
+ nvidia-cuda-nvrtc-cu12==12.4.127
98
+ # via torch
99
+ nvidia-cuda-runtime-cu12==12.4.127
100
+ # via torch
101
+ nvidia-cudnn-cu12==9.1.0.70
102
+ # via torch
103
+ nvidia-cufft-cu12==11.2.1.3
104
+ # via torch
105
+ nvidia-curand-cu12==10.3.5.147
106
+ # via torch
107
+ nvidia-cusolver-cu12==11.6.1.9
108
+ # via torch
109
+ nvidia-cusparse-cu12==12.3.1.170
110
+ # via
111
+ # nvidia-cusolver-cu12
112
+ # torch
113
+ nvidia-nccl-cu12==2.21.5
114
+ # via torch
115
+ nvidia-nvjitlink-cu12==12.4.127
116
+ # via
117
+ # nvidia-cusolver-cu12
118
+ # nvidia-cusparse-cu12
119
+ # torch
120
+ nvidia-nvtx-cu12==12.4.127
121
+ # via torch
122
+ omegaconf==2.3.0
123
+ # via hydra-core
124
+ packaging==24.1
125
+ # via
126
+ # huggingface-hub
127
+ # hydra-core
128
+ # lightning-utilities
129
+ # pytest
130
+ # pytorch-lightning
131
+ # torchmetrics
132
+ # transformers
133
+ pillow==11.0.0
134
+ # via torchvision
135
+ pluggy==1.5.0
136
+ # via pytest
137
+ propcache==0.2.0
138
+ # via yarl
139
+ pytest==8.3.3
140
+ # via -r requirements.in
141
+ pytorch-lightning==2.4.0
142
+ # via -r requirements.in
143
+ pyyaml==6.0.2
144
+ # via
145
+ # huggingface-hub
146
+ # omegaconf
147
+ # pytorch-lightning
148
+ # timm
149
+ # transformers
150
+ regex==2024.9.11
151
+ # via transformers
152
+ requests==2.32.3
153
+ # via
154
+ # huggingface-hub
155
+ # transformers
156
+ ruff==0.7.0
157
+ # via -r requirements.in
158
+ safetensors==0.4.5
159
+ # via
160
+ # timm
161
+ # transformers
162
+ sympy==1.13.1
163
+ # via torch
164
+ timm==1.0.11
165
+ # via -r requirements.in
166
+ tokenizers==0.20.1
167
+ # via transformers
168
+ tomli==2.0.2
169
+ # via pytest
170
+ torch==2.5.0
171
+ # via
172
+ # -r requirements.in
173
+ # pytorch-lightning
174
+ # timm
175
+ # torchmetrics
176
+ # torchvision
177
+ # xformers
178
+ torchmetrics==1.5.0
179
+ # via
180
+ # -r requirements.in
181
+ # pytorch-lightning
182
+ torchvision==0.20.0
183
+ # via
184
+ # -r requirements.in
185
+ # timm
186
+ tqdm==4.66.5
187
+ # via
188
+ # -r requirements.in
189
+ # huggingface-hub
190
+ # pytorch-lightning
191
+ # transformers
192
+ transformers==4.45.2
193
+ # via -r requirements.in
194
+ triton==3.1.0
195
+ # via torch
196
+ typing-extensions==4.12.2
197
+ # via
198
+ # huggingface-hub
199
+ # lightning-utilities
200
+ # multidict
201
+ # pytorch-lightning
202
+ # torch
203
+ urllib3==2.2.3
204
+ # via requests
205
+ xformers==0.0.28.post2
206
+ # via -r requirements.in
207
+ yarl==1.16.0
208
+ # via aiohttp
209
+ zarr==2.18.3
210
+ # via -r requirements.in
211
+
212
+ # The following packages are considered to be unsafe in a requirements file:
213
+ # setuptools
sample/AA41_s1_1.jp2 ADDED
sample/AA41_s1_2.jp2 ADDED
sample/AA41_s1_3.jp2 ADDED
sample/AA41_s1_4.jp2 ADDED
sample/AA41_s1_5.jp2 ADDED
sample/AA41_s1_6.jp2 ADDED
test_huggingface_mae.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import torch
3
+
4
+ from huggingface_mae import MAEModel
5
+
6
+ huggingface_phenombeta_model_dir = "models/phenom_beta_huggingface"
7
+ # huggingface_modelpath = "recursionpharma/test-pb-model"
8
+
9
+
10
+ @pytest.fixture
11
+ def huggingface_model():
12
+ # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
13
+ # huggingface-cli download recursionpharma/test-pb-model --local-dir=models/phenom_beta_huggingface
14
+ huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir)
15
+ huggingface_model.eval()
16
+ return huggingface_model
17
+
18
+
19
+ @pytest.mark.parametrize("C", [1, 4, 6, 11])
20
+ @pytest.mark.parametrize("return_channelwise_embeddings", [True, False])
21
+ def test_model_predict(huggingface_model, C, return_channelwise_embeddings):
22
+ example_input_array = torch.randint(
23
+ low=0,
24
+ high=255,
25
+ size=(2, C, 256, 256),
26
+ dtype=torch.uint8,
27
+ device=huggingface_model.device,
28
+ )
29
+ huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings
30
+ embeddings = huggingface_model.predict(example_input_array)
31
+ expected_output_dim = 384 * C if return_channelwise_embeddings else 384
32
+ assert embeddings.shape == (2, expected_output_dim)