|
import pytest |
|
import torch |
|
|
|
|
|
huggingface_modelpath = "recursionpharma/OpenPhenom" |
|
|
|
from .huggingface_mae import MAEModel |
|
|
|
|
|
@pytest.fixture |
|
def huggingface_model(): |
|
|
|
huggingface_model = MAEModel.from_pretrained(huggingface_modelpath) |
|
huggingface_model.eval() |
|
return huggingface_model |
|
|
|
|
|
@pytest.mark.parametrize("C", [1, 4, 6, 11]) |
|
@pytest.mark.parametrize("return_channelwise_embeddings", [True, False]) |
|
def test_model_predict(huggingface_model, C, return_channelwise_embeddings): |
|
example_input_array = torch.randint( |
|
low=0, |
|
high=255, |
|
size=(2, C, 256, 256), |
|
dtype=torch.uint8, |
|
device=huggingface_model.device, |
|
) |
|
huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings |
|
embeddings = huggingface_model.predict(example_input_array) |
|
expected_output_dim = 384 * C if return_channelwise_embeddings else 384 |
|
assert embeddings.shape == (2, expected_output_dim) |
|
|