Image_Classifier_with_CNN / api /tests /test_router_model.py
iBrokeTheCode's picture
chore: Add model UI files
459b8f5
raw
history blame
3.72 kB
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.auth.jwt import get_current_user
from fastapi import UploadFile
from httpx import AsyncClient
from main import app
# πŸ’‘ NOTE Run tests with: pytest tests/test_router_model.py -v
@pytest.mark.asyncio
async def test_predict():
mock_file = AsyncMock(spec=UploadFile)
mock_file.filename = "test_image.png"
mock_file.read = AsyncMock(return_value=b"fake-image-data")
mock_user = MagicMock()
mock_user.id = 1
mock_current_user = MagicMock()
mock_current_user.return_value = "testtoken"
app.dependency_overrides[get_current_user] = lambda: mock_current_user
with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"):
with patch(
"app.model.router.model_predict", new_callable=AsyncMock
) as mock_model_predict:
with patch("app.model.router.os.path.exists", return_value=False):
mock_model_predict.return_value = ("cat", 0.95)
with patch("builtins.open", new_callable=MagicMock):
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.post(
"/model/predict",
files={
"file": (
"test_image.png",
mock_file.read.return_value,
"image/png",
)
},
headers={"Authorization": "Bearer testtoken"},
)
assert response.status_code == 200
response_data = response.json()
assert response_data["success"] is True
assert response_data["prediction"] == "cat"
assert response_data["score"] == 0.95
assert response_data["image_file_name"] == "fakehash123"
@pytest.mark.asyncio
async def test_predict_fails_bad_extension():
mock_file = AsyncMock(spec=UploadFile)
mock_file.filename = "test_image.png"
mock_file.read = AsyncMock(return_value=b"fake-image-data")
mock_user = MagicMock()
mock_user.id = 1
mock_current_user = MagicMock()
mock_current_user.return_value = "testtoken"
app.dependency_overrides[get_current_user] = lambda: mock_current_user
with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"):
with patch(
"app.model.router.model_predict", new_callable=AsyncMock
) as mock_model_predict:
with patch("app.model.router.os.path.exists", return_value=False):
mock_model_predict.return_value = ("cat", 0.95)
with patch("builtins.open", new_callable=MagicMock):
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.post(
"/model/predict",
files={
"file": (
"test_image.pdf",
mock_file.read.return_value,
"image/png",
)
},
headers={"Authorization": "Bearer testtoken"},
)
assert response.status_code == 400
assert response.json() == {
"detail": "File type is not supported."
}