Spaces:
Runtime error
Runtime error
| import os | |
| import unittest | |
| from unittest.mock import patch, MagicMock | |
| # Create stubs for FastAPI if it's not available | |
| class MockFastAPI: | |
| def __init__(self, *args, **kwargs): | |
| self.middleware_stack = [] | |
| def add_middleware(self, middleware_class, **kwargs): | |
| self.middleware_stack.append((middleware_class, kwargs)) | |
| def on_event(self, event_type): | |
| return lambda f: f | |
| def get(self, path): | |
| return lambda f: f | |
| def post(self, path): | |
| return lambda f: f | |
| def websocket(self, path): | |
| return lambda f: f | |
| # Mock the runtime and fastapi modules | |
| with patch.dict("sys.modules", { | |
| "runtime.camera_capture": MagicMock(), | |
| "runtime.mic_capture": MagicMock(), | |
| "runtime.gemma_prompt_engine": MagicMock(), | |
| "runtime.magenta_generation": MagicMock(), | |
| "runtime.clip_scheduler": MagicMock(), | |
| "fastapi": MagicMock(), | |
| "fastapi.middleware.cors": MagicMock(), | |
| "uvicorn": MagicMock(), | |
| }): | |
| import fastapi | |
| import fastapi.middleware.cors | |
| fastapi.FastAPI = MockFastAPI | |
| import sys | |
| sys.path.append(os.getcwd()) | |
| from api.server import app | |
| class TestAPISecurity(unittest.TestCase): | |
| def test_cors_configuration(self): | |
| # Find CORSMiddleware in the stack | |
| cors_config = None | |
| for middleware_class, kwargs in app.middleware_stack: | |
| # We can't check middleware_class directly because it's a Mock | |
| # But we can check the kwargs | |
| if "allow_origins" in kwargs: | |
| cors_config = kwargs | |
| break | |
| self.assertIsNotNone(cors_config, "CORSMiddleware not found in app") | |
| # Default should be http://localhost:1420 | |
| self.assertEqual(cors_config["allow_origins"], ["http://localhost:1420"]) | |
| self.assertTrue(cors_config["allow_credentials"]) | |
| def test_cors_env_override(self): | |
| # We need to re-import or manually trigger the logic because api.server was already imported | |
| from ML_Pipeline.shared.env import apply_defaults | |
| env_config = apply_defaults() | |
| allowed_origins = [o.strip() for o in env_config.get("ALLOWED_ORIGINS").split(",")] | |
| allow_credentials = env_config.get("ALLOW_CREDENTIALS").lower() == "true" | |
| self.assertEqual(allowed_origins, ["http://myapp.com", "http://another.com"]) | |
| self.assertFalse(allow_credentials) | |
| if __name__ == "__main__": | |
| unittest.main() | |