import pytest | |
import torch | |
from huggingface_mae import MAEModel | |
huggingface_phenombeta_model_dir = "." | |
# huggingface_modelpath = "recursionpharma/test-pb-model" | |
def huggingface_model(): | |
# Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory | |
# huggingface-cli download recursionpharma/test-pb-model --local-dir=. | |
huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir) | |
huggingface_model.eval() | |
return huggingface_model | |
def test_model_predict(huggingface_model, C, return_channelwise_embeddings): | |
example_input_array = torch.randint( | |
low=0, | |
high=255, | |
size=(2, C, 256, 256), | |
dtype=torch.uint8, | |
device=huggingface_model.device, | |
) | |
huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings | |
embeddings = huggingface_model.predict(example_input_array) | |
expected_output_dim = 384 * C if return_channelwise_embeddings else 384 | |
assert embeddings.shape == (2, expected_output_dim) | |