Spaces:
Paused
Paused
| import pytest | |
| from fastapi.testclient import TestClient | |
| from fastapi import HTTPException | |
| from typing import Union, Any | |
| from app import ask_gpt4, app, http_exception_handler | |
| from models.query_model import QueryModel | |
| client = TestClient(app) | |
| async def test_ask_gpt4(query_params, model, expected_output): | |
| response = client.post( | |
| "/ask_gpt4/", | |
| json={"user_input": query_params["user_input"], "model": model}, | |
| ) | |
| assert response.status_code == 200 | |
| json_response = response.json() | |
| # Check if the response contains a valid answer | |
| assert "response" in json_response or "error" in json_response | |
| # If there's an error, check if it's a known error | |
| if "error" in json_response: | |
| assert json_response["error"] in [ | |
| "ChatGPT response does not contain text attribute.", | |
| # Add other known errors here | |
| ] | |
| else: | |
| assert json_response["response"] == expected_output | |
| async def test_http_exception_handler(status_code: int, detail: Union[str, dict], expected_result: Any) -> None: | |
| """ | |
| Function to test http exception handler | |
| :param status_code: int ,status code e.g. 400, 404 etc. | |
| :param detail: str or Dict , detail message | |
| :param expected_result: | |
| :return: None | |
| """ | |
| exc = HTTPException(status_code=status_code, detail=detail) | |
| result = await http_exception_handler(exc) | |
| assert result == expected_result | |
| # @pytest.mark.parametrize( | |
| # "user_input,expected_status_code", | |
| # [ | |
| # ("What is the capital of France?", 200), | |
| # ("", 400), # Invalid query | |
| # ], | |
| # ) | |
| # def test_ask_gpt4_route(user_input: str, expected_status_code: int): | |
| # query = QueryModel(user_input=user_input) | |
| # response = client.post("/ask_gpt4/", json=query.dict()) | |
| # | |
| # assert response.status_code == expected_status_code | |
| # | |
| # if expected_status_code == 200: | |
| # assert "response" in response.json() | |
| # else: | |
| # assert "error" in response.json() |