Excitometer / tests /test_model.py
szili2011's picture
Upload 17 files
c248b9b verified
import pytest
import torch
from model import ExciteMeterModel # Adjust the import based on your actual file structure
from torch.testing import assert_allclose
def test_model_architecture():
"""
Test the architecture of the ExciteMeterModel to ensure it has the correct layers and dimensions.
"""
model = ExciteMeterModel()
# Check the type of the model
assert isinstance(model, ExciteMeterModel), "Model is not of type ExciteMeterModel"
# Check if model contains expected layers
assert hasattr(model, 'conv1'), "Model does not have the expected conv1 layer"
assert hasattr(model, 'rnn'), "Model does not have the expected rnn layer"
assert hasattr(model, 'fc'), "Model does not have the expected fc layer"
def test_model_forward_pass():
"""
Test the model's forward pass with a dummy input to ensure it produces output without errors.
"""
model = ExciteMeterModel()
# Create a dummy input tensor (e.g., batch size of 1, 1 channel, 16000 samples)
dummy_input = torch.randn(1, 1, 16000) # Adjust shape based on your model input
# Perform a forward pass
output = model(dummy_input)
# Check if the output is a tensor with expected shape
assert isinstance(output, torch.Tensor), "Output should be a torch tensor"
assert output.shape == torch.Size([1]), "Output shape is incorrect"
def test_model_loading():
"""
Test loading of a pre-trained model to ensure it loads without errors.
"""
model_path = 'path/to/your/model.pth'
model = ExciteMeterModel()
# Load model weights
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval() # Set to evaluation mode
# Create a dummy input tensor for inference
dummy_input = torch.randn(1, 1, 16000)
output = model(dummy_input)
# Validate the output
assert isinstance(output, torch.Tensor), "Output should be a torch tensor"
assert output.shape == torch.Size([1]), "Output shape is incorrect"
def test_model_prediction():
"""
Test that the model's prediction is within a reasonable range.
"""
model = ExciteMeterModel()
# Create a dummy input tensor (e.g., batch size of 1, 1 channel, 16000 samples)
dummy_input = torch.randn(1, 1, 16000)
# Perform a forward pass
with torch.no_grad():
output = model(dummy_input)
# Check if the output is within the expected range
assert output.item() >= 0.0, "Output should be non-negative"
assert output.item() <= 1.0, "Output should be less than or equal to 1"
if __name__ == "__main__":
pytest.main()