|
from unittest.mock import AsyncMock, MagicMock, patch |
|
|
|
import pytest |
|
from app.auth.jwt import get_current_user |
|
from fastapi import UploadFile |
|
from httpx import AsyncClient |
|
from main import app |
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio |
|
async def test_predict(): |
|
mock_file = AsyncMock(spec=UploadFile) |
|
mock_file.filename = "test_image.png" |
|
mock_file.read = AsyncMock(return_value=b"fake-image-data") |
|
|
|
mock_user = MagicMock() |
|
mock_user.id = 1 |
|
|
|
mock_current_user = MagicMock() |
|
mock_current_user.return_value = "testtoken" |
|
|
|
app.dependency_overrides[get_current_user] = lambda: mock_current_user |
|
|
|
with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"): |
|
with patch( |
|
"app.model.router.model_predict", new_callable=AsyncMock |
|
) as mock_model_predict: |
|
with patch("app.model.router.os.path.exists", return_value=False): |
|
mock_model_predict.return_value = ("cat", 0.95) |
|
with patch("builtins.open", new_callable=MagicMock): |
|
async with AsyncClient(app=app, base_url="http://test") as ac: |
|
response = await ac.post( |
|
"/model/predict", |
|
files={ |
|
"file": ( |
|
"test_image.png", |
|
mock_file.read.return_value, |
|
"image/png", |
|
) |
|
}, |
|
headers={"Authorization": "Bearer testtoken"}, |
|
) |
|
|
|
assert response.status_code == 200 |
|
|
|
response_data = response.json() |
|
assert response_data["success"] is True |
|
assert response_data["prediction"] == "cat" |
|
assert response_data["score"] == 0.95 |
|
assert response_data["image_file_name"] == "fakehash123" |
|
|
|
|
|
@pytest.mark.asyncio |
|
async def test_predict_fails_bad_extension(): |
|
mock_file = AsyncMock(spec=UploadFile) |
|
mock_file.filename = "test_image.png" |
|
mock_file.read = AsyncMock(return_value=b"fake-image-data") |
|
|
|
mock_user = MagicMock() |
|
mock_user.id = 1 |
|
|
|
mock_current_user = MagicMock() |
|
mock_current_user.return_value = "testtoken" |
|
|
|
app.dependency_overrides[get_current_user] = lambda: mock_current_user |
|
|
|
with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"): |
|
with patch( |
|
"app.model.router.model_predict", new_callable=AsyncMock |
|
) as mock_model_predict: |
|
with patch("app.model.router.os.path.exists", return_value=False): |
|
mock_model_predict.return_value = ("cat", 0.95) |
|
with patch("builtins.open", new_callable=MagicMock): |
|
async with AsyncClient(app=app, base_url="http://test") as ac: |
|
response = await ac.post( |
|
"/model/predict", |
|
files={ |
|
"file": ( |
|
"test_image.pdf", |
|
mock_file.read.return_value, |
|
"image/png", |
|
) |
|
}, |
|
headers={"Authorization": "Bearer testtoken"}, |
|
) |
|
|
|
assert response.status_code == 400 |
|
assert response.json() == { |
|
"detail": "File type is not supported." |
|
} |
|
|