File size: 2,236 Bytes
6ab520d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from unittest import mock
import pytest
from app import db
from app.auth.jwt import get_current_user
from app.feedback.schema import DisplayFeedback, Feedback
from app.user.schema import User
from fastapi.testclient import TestClient
from main import app
from sqlalchemy.orm import Session
client = TestClient(app)
sample_user = User(
id=1,
username="testuser",
email="testuser@example.com",
name="Test User",
password="password",
)
sample_feedback = Feedback(
feedback="Great service!",
image_file_name="testimage.jpg",
predicted_class="dog",
score=0.95,
)
@pytest.fixture
def mock_db_session():
return mock.create_autospec(Session, instance=True)
@pytest.fixture
def mock_get_current_user():
return sample_user
@mock.patch("app.feedback.router.services.new_feedback")
def test_create_feedback(mock_new_feedback, mock_db_session, mock_get_current_user):
mock_new_feedback.return_value = sample_feedback
payload = {
"feedback": "Great service!",
"image_file_name": "testimage.jpg",
"predicted_class": "dog",
"score": 0.95,
}
app.dependency_overrides[db.get_db] = lambda: mock_db_session
app.dependency_overrides[get_current_user] = lambda: mock_get_current_user
response = client.post(
"/feedback/",
json=payload,
)
assert response.status_code == 201
mock_new_feedback.assert_called_once_with(payload, sample_user, mock_db_session)
@mock.patch("app.feedback.router.services.all_feedback")
def test_get_all_feedback(mock_all_feedback, mock_db_session, mock_get_current_user):
# Setup the mock service to return a list of feedback
mock_all_feedback.return_value = [
DisplayFeedback(
id=1,
feedback="Great service!",
score=0.95,
predicted_class="dog",
image_file_name="testimage.jpg",
)
]
app.dependency_overrides[db.get_db] = lambda: mock_db_session
app.dependency_overrides[get_current_user] = lambda: mock_get_current_user
response = client.get(
"/feedback/",
)
assert response.status_code == 200
mock_all_feedback.assert_called_once_with(mock_db_session, sample_user)
|