Spaces:
Runtime error
Runtime error
Ved Gupta
commited on
Commit
·
4e4ad14
1
Parent(s):
d67439e
error removed and code improved
Browse files- README.md +6 -0
- app/__init__.py +1 -1
- app/api/__init__.py +1 -1
- app/api/endpoints/__init__.py +1 -1
- app/api/endpoints/items.py +1 -1
- app/api/endpoints/users.py +38 -31
- app/api/models/__init__.py +1 -1
- app/api/models/item.py +1 -1
- app/api/models/user.py +36 -6
- app/core/__init__.py +8 -1
- app/core/config.py +20 -7
- app/core/database.py +6 -3
- app/core/errors.py +13 -5
- app/core/security.py +2 -1
- app/main.py +10 -3
- app/tests/__init__.py +5 -5
- app/tests/conftest.py +2 -1
- app/tests/test_api/__init__.py +3 -2
- app/tests/test_api/test_items.py +6 -1
- app/tests/test_api/test_users.py +7 -4
- app/tests/test_core/__init__.py +2 -3
- app/tests/test_core/test_config.py +1 -1
- app/tests/test_core/test_database.py +3 -3
- app/tests/test_core/test_security.py +4 -1
- app/tests/utils/utils.py +24 -0
- app/utils/__init__.py +1 -0
- app/utils/utils.py +19 -0
- requirements.txt +2 -1
README.md
CHANGED
@@ -74,4 +74,10 @@ The project structure is organized as follows:
|
|
74 |
|
75 |
```bash
|
76 |
uvicorn app.main:app --reload
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
```
|
|
|
74 |
|
75 |
```bash
|
76 |
uvicorn app.main:app --reload
|
77 |
+
```
|
78 |
+
|
79 |
+
## Run Test
|
80 |
+
|
81 |
+
```bash
|
82 |
+
python -m unittest
|
83 |
```
|
app/__init__.py
CHANGED
@@ -10,4 +10,4 @@ def create_app() -> FastAPI:
|
|
10 |
return app
|
11 |
|
12 |
|
13 |
-
app = create_app()
|
|
|
10 |
return app
|
11 |
|
12 |
|
13 |
+
app = create_app()
|
app/api/__init__.py
CHANGED
@@ -6,4 +6,4 @@ from .endpoints import items, users
|
|
6 |
api_router = APIRouter()
|
7 |
|
8 |
api_router.include_router(items.router, prefix="/items", tags=["items"])
|
9 |
-
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
|
|
6 |
api_router = APIRouter()
|
7 |
|
8 |
api_router.include_router(items.router, prefix="/items", tags=["items"])
|
9 |
+
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
app/api/endpoints/__init__.py
CHANGED
@@ -7,4 +7,4 @@ from . import items, users
|
|
7 |
router = APIRouter()
|
8 |
|
9 |
router.include_router(items.router, prefix="/items", tags=["items"])
|
10 |
-
router.include_router(users.router, prefix="/users", tags=["users"])
|
|
|
7 |
router = APIRouter()
|
8 |
|
9 |
router.include_router(items.router, prefix="/items", tags=["items"])
|
10 |
+
router.include_router(users.router, prefix="/users", tags=["users"])
|
app/api/endpoints/items.py
CHANGED
@@ -10,4 +10,4 @@ async def read_items():
|
|
10 |
|
11 |
@router.get("/{item_id}")
|
12 |
async def read_item(item_id: int, q: str = None):
|
13 |
-
return {"item_id": item_id, "q": q}
|
|
|
10 |
|
11 |
@router.get("/{item_id}")
|
12 |
async def read_item(item_id: int, q: str = None):
|
13 |
+
return {"item_id": item_id, "q": q}
|
app/api/endpoints/users.py
CHANGED
@@ -1,50 +1,57 @@
|
|
1 |
from fastapi import APIRouter
|
2 |
|
3 |
-
from app.api.models.user import UserInDB
|
4 |
from app.core.database import SessionLocal
|
|
|
5 |
|
6 |
database = SessionLocal()
|
7 |
|
8 |
users_router = router = APIRouter()
|
9 |
|
10 |
|
11 |
-
@router.post("/", response_model=
|
12 |
-
async def create_user(user:
|
13 |
-
|
14 |
username=user.username,
|
15 |
email=user.email,
|
16 |
-
hashed_password=user.
|
17 |
)
|
18 |
-
|
19 |
-
|
|
|
|
|
20 |
|
21 |
|
22 |
-
@router.get("/{id}/", response_model=
|
23 |
async def read_user(id: int):
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
.
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
44 |
|
45 |
|
46 |
@router.delete("/{id}/", response_model=int)
|
47 |
async def delete_user(id: int):
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
1 |
from fastapi import APIRouter
|
2 |
|
3 |
+
from app.api.models.user import UserInDB, User, UpdateUser, UserResponse
|
4 |
from app.core.database import SessionLocal
|
5 |
+
from app.core.security import get_password_hash, verify_password
|
6 |
|
7 |
database = SessionLocal()
|
8 |
|
9 |
users_router = router = APIRouter()
|
10 |
|
11 |
|
12 |
+
@router.post("/", response_model=UserResponse, status_code=201)
|
13 |
+
async def create_user(user: User):
|
14 |
+
db_user = UserInDB(
|
15 |
username=user.username,
|
16 |
email=user.email,
|
17 |
+
hashed_password=get_password_hash(user.password),
|
18 |
)
|
19 |
+
database.add(db_user)
|
20 |
+
database.commit()
|
21 |
+
database.refresh(db_user)
|
22 |
+
return {**user.dict(), "id": db_user.id, "hashed_password": None}
|
23 |
|
24 |
|
25 |
+
@router.get("/{id}/", response_model=UserResponse)
|
26 |
async def read_user(id: int):
|
27 |
+
db_user = database.query(UserInDB).filter(UserInDB.id == id).first()
|
28 |
+
if not db_user:
|
29 |
+
raise HTTPException(status_code=404, detail="User not found")
|
30 |
+
return UserResponse.from_orm(db_user)
|
31 |
+
|
32 |
+
|
33 |
+
@router.put("/{id}/", response_model=User)
|
34 |
+
async def update_user(id: int, user: UpdateUser):
|
35 |
+
db_user = db.query(UserInDB).filter(UserInDB.id == id).first()
|
36 |
+
if not db_user:
|
37 |
+
raise HTTPException(status_code=404, detail="User not found")
|
38 |
+
if not verify_password(user.current_password, db_user.hashed_password):
|
39 |
+
raise HTTPException(status_code=400, detail="Incorrect password")
|
40 |
+
if user.email != db_user.email:
|
41 |
+
raise HTTPException(status_code=400, detail="Email cannot be changed")
|
42 |
+
if user.username is None:
|
43 |
+
user.username = db_user.username
|
44 |
+
else:
|
45 |
+
db_user.username = user.username
|
46 |
+
db_user.hashed_password = get_password_hash(user.password)
|
47 |
+
database.commit()
|
48 |
+
database.refresh(db_user)
|
49 |
+
return {**user.dict(), "id": db_user.id, "hashed_password": None}
|
50 |
|
51 |
|
52 |
@router.delete("/{id}/", response_model=int)
|
53 |
async def delete_user(id: int):
|
54 |
+
db_user = database.query(UserInDB).filter(UserInDB.id == id).first()
|
55 |
+
database.delete(db_user)
|
56 |
+
database.commit()
|
57 |
+
return id
|
app/api/models/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
from .item import Item
|
2 |
-
from .user import User
|
|
|
1 |
from .item import Item
|
2 |
+
from .user import User
|
app/api/models/item.py
CHANGED
@@ -13,4 +13,4 @@ class ItemCreate(ItemBase):
|
|
13 |
|
14 |
class Item(ItemBase):
|
15 |
id: int
|
16 |
-
owner_id: int
|
|
|
13 |
|
14 |
class Item(ItemBase):
|
15 |
id: int
|
16 |
+
owner_id: int
|
app/api/models/user.py
CHANGED
@@ -1,14 +1,44 @@
|
|
1 |
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
class UserBase(BaseModel):
|
4 |
email: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
class UserCreate(UserBase):
|
7 |
-
password: str
|
8 |
|
9 |
class User(UserBase):
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
|
1 |
from pydantic import BaseModel
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from sqlalchemy import Column, Integer, String, Boolean
|
5 |
+
from app.core.database import Base
|
6 |
+
|
7 |
|
8 |
class UserBase(BaseModel):
|
9 |
email: str
|
10 |
+
username: str
|
11 |
+
|
12 |
+
|
13 |
+
class UserResponse(UserBase):
|
14 |
+
is_active: Optional[bool]
|
15 |
+
|
16 |
+
class Config:
|
17 |
+
from_attributes = True
|
18 |
|
|
|
|
|
19 |
|
20 |
class User(UserBase):
|
21 |
+
is_active: Optional[bool]
|
22 |
+
password: str
|
23 |
+
|
24 |
+
class Config:
|
25 |
+
from_attributes = True
|
26 |
+
|
27 |
+
|
28 |
+
class UpdateUser(UserBase):
|
29 |
+
current_password: str
|
30 |
+
|
31 |
+
|
32 |
+
class UserInDB(Base):
|
33 |
+
__tablename__ = "users"
|
34 |
+
|
35 |
+
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
36 |
+
username = Column(String, unique=True, index=True)
|
37 |
+
email = Column(String, unique=True, index=True)
|
38 |
+
hashed_password = Column(String)
|
39 |
+
is_active = Column(Boolean, default=True)
|
40 |
|
41 |
+
def __init__(self, username: str, email: str, hashed_password: str):
|
42 |
+
self.username = username
|
43 |
+
self.email = email
|
44 |
+
self.hashed_password = hashed_password
|
app/core/__init__.py
CHANGED
@@ -2,4 +2,11 @@ from .config import settings
|
|
2 |
from .database import Base, engine, SessionLocal
|
3 |
from .security import get_password_hash, verify_password
|
4 |
|
5 |
-
__all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from .database import Base, engine, SessionLocal
|
3 |
from .security import get_password_hash, verify_password
|
4 |
|
5 |
+
__all__ = [
|
6 |
+
"settings",
|
7 |
+
"Base",
|
8 |
+
"engine",
|
9 |
+
"SessionLocal",
|
10 |
+
"get_password_hash",
|
11 |
+
"verify_password",
|
12 |
+
]
|
app/core/config.py
CHANGED
@@ -5,20 +5,26 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
5 |
|
6 |
from os import environ as env
|
7 |
|
|
|
8 |
class Settings(BaseSettings):
|
9 |
API_V1_STR: str = "/api/v1"
|
|
|
10 |
PROJECT_NAME: str = "Whisper API"
|
11 |
PROJECT_VERSION: str = "0.1.0"
|
12 |
SECRET_KEY: str = env.get("SECRET_KEY")
|
13 |
ACCESS_TOKEN_EXPIRE_MINUTES: int = env.get("ACCESS_TOKEN_EXPIRE_MINUTES")
|
|
|
14 |
SERVER_NAME: str = env.get("SERVER_NAME")
|
15 |
SERVER_HOST: AnyHttpUrl = env.get("SERVER_HOST")
|
|
|
16 |
POSTGRES_SERVER: str = env.get("POSTGRES_SERVER")
|
17 |
POSTGRES_USER: str = env.get("POSTGRES_USER")
|
18 |
POSTGRES_PASSWORD: str = env.get("POSTGRES_PASSWORD")
|
19 |
POSTGRES_DB: str = env.get("POSTGRES_DB")
|
20 |
POSTGRES_DATABASE_URL: str = env.get("POSTGRES_DATABASE_URL")
|
21 |
-
|
|
|
|
|
22 |
|
23 |
@validator("SECRET_KEY", pre=True)
|
24 |
def secret_key_must_be_set(cls, v: Optional[str], values: Dict[str, Any]) -> str:
|
@@ -33,13 +39,17 @@ class Settings(BaseSettings):
|
|
33 |
return v
|
34 |
|
35 |
@validator("SERVER_HOST", pre=True)
|
36 |
-
def server_host_must_be_set(
|
|
|
|
|
37 |
if not v:
|
38 |
raise ValueError("SERVER_HOST must be set")
|
39 |
return v
|
40 |
|
41 |
@validator("POSTGRES_SERVER", pre=True)
|
42 |
-
def postgres_server_must_be_set(
|
|
|
|
|
43 |
if not v:
|
44 |
raise ValueError("POSTGRES_SERVER must be set")
|
45 |
return v
|
@@ -51,7 +61,9 @@ class Settings(BaseSettings):
|
|
51 |
return v
|
52 |
|
53 |
@validator("POSTGRES_PASSWORD", pre=True)
|
54 |
-
def postgres_password_must_be_set(
|
|
|
|
|
55 |
if not v:
|
56 |
raise ValueError("POSTGRES_PASSWORD must be set")
|
57 |
return v
|
@@ -63,11 +75,12 @@ class Settings(BaseSettings):
|
|
63 |
return v
|
64 |
|
65 |
@validator("POSTGRES_DATABASE_URL", pre=True)
|
66 |
-
def postgres_db_url_must_be_set(
|
|
|
|
|
67 |
if not v:
|
68 |
raise ValueError("POSTGRES_DATABASE_URL must be set")
|
69 |
return v
|
70 |
|
71 |
|
72 |
-
|
73 |
-
settings = Settings()
|
|
|
5 |
|
6 |
from os import environ as env
|
7 |
|
8 |
+
|
9 |
class Settings(BaseSettings):
|
10 |
API_V1_STR: str = "/api/v1"
|
11 |
+
|
12 |
PROJECT_NAME: str = "Whisper API"
|
13 |
PROJECT_VERSION: str = "0.1.0"
|
14 |
SECRET_KEY: str = env.get("SECRET_KEY")
|
15 |
ACCESS_TOKEN_EXPIRE_MINUTES: int = env.get("ACCESS_TOKEN_EXPIRE_MINUTES")
|
16 |
+
|
17 |
SERVER_NAME: str = env.get("SERVER_NAME")
|
18 |
SERVER_HOST: AnyHttpUrl = env.get("SERVER_HOST")
|
19 |
+
|
20 |
POSTGRES_SERVER: str = env.get("POSTGRES_SERVER")
|
21 |
POSTGRES_USER: str = env.get("POSTGRES_USER")
|
22 |
POSTGRES_PASSWORD: str = env.get("POSTGRES_PASSWORD")
|
23 |
POSTGRES_DB: str = env.get("POSTGRES_DB")
|
24 |
POSTGRES_DATABASE_URL: str = env.get("POSTGRES_DATABASE_URL")
|
25 |
+
TEST_DATABASE_URL: str = env.get("TEST_DATABASE_URL")
|
26 |
+
|
27 |
+
BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = ["http://localhost:3000"]
|
28 |
|
29 |
@validator("SECRET_KEY", pre=True)
|
30 |
def secret_key_must_be_set(cls, v: Optional[str], values: Dict[str, Any]) -> str:
|
|
|
39 |
return v
|
40 |
|
41 |
@validator("SERVER_HOST", pre=True)
|
42 |
+
def server_host_must_be_set(
|
43 |
+
cls, v: Optional[str], values: Dict[str, Any]
|
44 |
+
) -> AnyHttpUrl:
|
45 |
if not v:
|
46 |
raise ValueError("SERVER_HOST must be set")
|
47 |
return v
|
48 |
|
49 |
@validator("POSTGRES_SERVER", pre=True)
|
50 |
+
def postgres_server_must_be_set(
|
51 |
+
cls, v: Optional[str], values: Dict[str, Any]
|
52 |
+
) -> str:
|
53 |
if not v:
|
54 |
raise ValueError("POSTGRES_SERVER must be set")
|
55 |
return v
|
|
|
61 |
return v
|
62 |
|
63 |
@validator("POSTGRES_PASSWORD", pre=True)
|
64 |
+
def postgres_password_must_be_set(
|
65 |
+
cls, v: Optional[str], values: Dict[str, Any]
|
66 |
+
) -> str:
|
67 |
if not v:
|
68 |
raise ValueError("POSTGRES_PASSWORD must be set")
|
69 |
return v
|
|
|
75 |
return v
|
76 |
|
77 |
@validator("POSTGRES_DATABASE_URL", pre=True)
|
78 |
+
def postgres_db_url_must_be_set(
|
79 |
+
cls, v: Optional[str], values: Dict[str, Any]
|
80 |
+
) -> str:
|
81 |
if not v:
|
82 |
raise ValueError("POSTGRES_DATABASE_URL must be set")
|
83 |
return v
|
84 |
|
85 |
|
86 |
+
settings = Settings()
|
|
app/core/database.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from sqlalchemy import create_engine
|
2 |
from sqlalchemy.ext.declarative import declarative_base
|
3 |
from sqlalchemy.orm import sessionmaker
|
4 |
|
@@ -6,8 +6,11 @@ from app.core.config import settings
|
|
6 |
|
7 |
SQLALCHEMY_DATABASE_URL = settings.POSTGRES_DATABASE_URL
|
8 |
|
|
|
|
|
9 |
engine = create_engine(SQLALCHEMY_DATABASE_URL)
|
|
|
10 |
|
11 |
-
|
12 |
|
13 |
-
|
|
|
1 |
+
from sqlalchemy import create_engine, MetaData
|
2 |
from sqlalchemy.ext.declarative import declarative_base
|
3 |
from sqlalchemy.orm import sessionmaker
|
4 |
|
|
|
6 |
|
7 |
SQLALCHEMY_DATABASE_URL = settings.POSTGRES_DATABASE_URL
|
8 |
|
9 |
+
meta = MetaData()
|
10 |
+
|
11 |
engine = create_engine(SQLALCHEMY_DATABASE_URL)
|
12 |
+
Base = declarative_base(metadata=meta)
|
13 |
|
14 |
+
Base.metadata.create_all(engine)
|
15 |
|
16 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
app/core/errors.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from fastapi import Request
|
2 |
from fastapi.responses import JSONResponse
|
3 |
-
from fastapi.exceptions import HTTPException
|
|
|
4 |
|
5 |
async def http_exception_handler(request: Request, exc: HTTPException):
|
6 |
"""
|
@@ -11,13 +12,20 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
|
11 |
content={"message": exc.detail},
|
12 |
)
|
13 |
|
14 |
-
|
|
|
15 |
"""
|
16 |
Exception handler for HTTP errors
|
17 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
return JSONResponse(
|
19 |
-
status_code=
|
20 |
-
content={"message":
|
21 |
)
|
22 |
|
23 |
|
@@ -28,4 +36,4 @@ def http422_error_handler(request: Request, exc: RequestValidationError):
|
|
28 |
return JSONResponse(
|
29 |
status_code=422,
|
30 |
content={"message": "Validation error", "details": exc.errors()},
|
31 |
-
)
|
|
|
1 |
from fastapi import Request
|
2 |
from fastapi.responses import JSONResponse
|
3 |
+
from fastapi.exceptions import HTTPException, RequestValidationError
|
4 |
+
|
5 |
|
6 |
async def http_exception_handler(request: Request, exc: HTTPException):
|
7 |
"""
|
|
|
12 |
content={"message": exc.detail},
|
13 |
)
|
14 |
|
15 |
+
|
16 |
+
async def http_error_handler(request: Request, exc: Exception):
|
17 |
"""
|
18 |
Exception handler for HTTP errors
|
19 |
"""
|
20 |
+
if isinstance(exc, HTTPException):
|
21 |
+
detail = exc.detail
|
22 |
+
status_code = exc.status_code
|
23 |
+
else:
|
24 |
+
detail = "Internal server error"
|
25 |
+
status_code = 500
|
26 |
return JSONResponse(
|
27 |
+
status_code=status_code,
|
28 |
+
content={"message": detail},
|
29 |
)
|
30 |
|
31 |
|
|
|
36 |
return JSONResponse(
|
37 |
status_code=422,
|
38 |
content={"message": "Validation error", "details": exc.errors()},
|
39 |
+
)
|
app/core/security.py
CHANGED
@@ -13,8 +13,9 @@ def get_password_hash(password: str) -> str:
|
|
13 |
"""
|
14 |
return pwd_context.hash(password)
|
15 |
|
|
|
16 |
def verify_password(password: str, hash: str) -> bool:
|
17 |
"""
|
18 |
Verifies a password against a bcrypt hash
|
19 |
"""
|
20 |
-
return pwd_context.verify(password, hash)
|
|
|
13 |
"""
|
14 |
return pwd_context.hash(password)
|
15 |
|
16 |
+
|
17 |
def verify_password(password: str, hash: str) -> bool:
|
18 |
"""
|
19 |
Verifies a password against a bcrypt hash
|
20 |
"""
|
21 |
+
return pwd_context.verify(password, hash)
|
app/main.py
CHANGED
@@ -5,7 +5,11 @@ from app.core.errors import http_error_handler
|
|
5 |
from app.core.errors import http422_error_handler
|
6 |
from fastapi.middleware.cors import CORSMiddleware
|
7 |
|
8 |
-
app
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Set all CORS enabled origins
|
11 |
if settings.BACKEND_CORS_ORIGINS:
|
@@ -20,6 +24,9 @@ if settings.BACKEND_CORS_ORIGINS:
|
|
20 |
# Include routers
|
21 |
app.include_router(api_router, prefix=settings.API_V1_STR)
|
22 |
|
23 |
-
# Error handlers
|
24 |
app.add_exception_handler(422, http422_error_handler)
|
25 |
-
app.add_exception_handler(500, http_error_handler)
|
|
|
|
|
|
|
|
5 |
from app.core.errors import http422_error_handler
|
6 |
from fastapi.middleware.cors import CORSMiddleware
|
7 |
|
8 |
+
from app.utils import print_routes
|
9 |
+
|
10 |
+
app = FastAPI(
|
11 |
+
title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json"
|
12 |
+
)
|
13 |
|
14 |
# Set all CORS enabled origins
|
15 |
if settings.BACKEND_CORS_ORIGINS:
|
|
|
24 |
# Include routers
|
25 |
app.include_router(api_router, prefix=settings.API_V1_STR)
|
26 |
|
27 |
+
# # Error handlers
|
28 |
app.add_exception_handler(422, http422_error_handler)
|
29 |
+
app.add_exception_handler(500, http_error_handler)
|
30 |
+
|
31 |
+
# Print all routes
|
32 |
+
print_routes(app)
|
app/tests/__init__.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
# File: whisper.api/app/tests/__init__.py
|
2 |
-
|
3 |
-
# Import necessary modules
|
4 |
import pytest
|
5 |
from fastapi.testclient import TestClient
|
6 |
from sqlalchemy import create_engine
|
@@ -8,19 +5,21 @@ from sqlalchemy.orm import sessionmaker
|
|
8 |
|
9 |
from app.core.config import settings
|
10 |
from app.main import app
|
11 |
-
from app.tests.utils.utils import override_get_db
|
12 |
|
13 |
# Create test database
|
14 |
SQLALCHEMY_DATABASE_URL = settings.TEST_DATABASE_URL
|
15 |
engine = create_engine(SQLALCHEMY_DATABASE_URL)
|
16 |
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
17 |
|
|
|
18 |
# Define test client
|
19 |
@pytest.fixture(scope="module")
|
20 |
def test_client():
|
21 |
with TestClient(app) as client:
|
22 |
yield client
|
23 |
|
|
|
24 |
# Define test database
|
25 |
@pytest.fixture(scope="module")
|
26 |
def test_db():
|
@@ -28,7 +27,8 @@ def test_db():
|
|
28 |
yield db
|
29 |
db.close()
|
30 |
|
|
|
31 |
# Override get_db function for testing
|
32 |
@pytest.fixture(autouse=True)
|
33 |
def override_get_db(monkeypatch):
|
34 |
-
monkeypatch.setattr("app.api.dependencies.get_db",
|
|
|
|
|
|
|
|
|
1 |
import pytest
|
2 |
from fastapi.testclient import TestClient
|
3 |
from sqlalchemy import create_engine
|
|
|
5 |
|
6 |
from app.core.config import settings
|
7 |
from app.main import app
|
8 |
+
from app.tests.utils.utils import override_get_db, get_db
|
9 |
|
10 |
# Create test database
|
11 |
SQLALCHEMY_DATABASE_URL = settings.TEST_DATABASE_URL
|
12 |
engine = create_engine(SQLALCHEMY_DATABASE_URL)
|
13 |
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
14 |
|
15 |
+
|
16 |
# Define test client
|
17 |
@pytest.fixture(scope="module")
|
18 |
def test_client():
|
19 |
with TestClient(app) as client:
|
20 |
yield client
|
21 |
|
22 |
+
|
23 |
# Define test database
|
24 |
@pytest.fixture(scope="module")
|
25 |
def test_db():
|
|
|
27 |
yield db
|
28 |
db.close()
|
29 |
|
30 |
+
|
31 |
# Override get_db function for testing
|
32 |
@pytest.fixture(autouse=True)
|
33 |
def override_get_db(monkeypatch):
|
34 |
+
monkeypatch.setattr("app.api.dependencies.get_db", get_db)
|
app/tests/conftest.py
CHANGED
@@ -12,6 +12,7 @@ from app.db.session import SessionLocal
|
|
12 |
def test_app():
|
13 |
# set up test app with client
|
14 |
from app.main import app
|
|
|
15 |
client = TestClient(app)
|
16 |
# set up test database
|
17 |
SQLALCHEMY_DATABASE_URL = settings.SQLALCHEMY_DATABASE_TEST_URL
|
@@ -33,4 +34,4 @@ def db_session():
|
|
33 |
try:
|
34 |
yield session
|
35 |
finally:
|
36 |
-
session.close()
|
|
|
12 |
def test_app():
|
13 |
# set up test app with client
|
14 |
from app.main import app
|
15 |
+
|
16 |
client = TestClient(app)
|
17 |
# set up test database
|
18 |
SQLALCHEMY_DATABASE_URL = settings.SQLALCHEMY_DATABASE_TEST_URL
|
|
|
34 |
try:
|
35 |
yield session
|
36 |
finally:
|
37 |
+
session.close()
|
app/tests/test_api/__init__.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
# File: whisper.api/app/tests/test_api/__init__.py
|
2 |
|
3 |
from fastapi.testclient import TestClient
|
4 |
-
from
|
5 |
|
6 |
client = TestClient(app)
|
7 |
|
|
|
8 |
def test_read_main():
|
9 |
response = client.get("/")
|
10 |
assert response.status_code == 200
|
11 |
-
assert response.json() == {"msg": "Hello World"}
|
|
|
1 |
# File: whisper.api/app/tests/test_api/__init__.py
|
2 |
|
3 |
from fastapi.testclient import TestClient
|
4 |
+
from app.main import app
|
5 |
|
6 |
client = TestClient(app)
|
7 |
|
8 |
+
|
9 |
def test_read_main():
|
10 |
response = client.get("/")
|
11 |
assert response.status_code == 200
|
12 |
+
assert response.json() == {"msg": "Hello World"}
|
app/tests/test_api/test_items.py
CHANGED
@@ -3,6 +3,7 @@ from app.main import app
|
|
3 |
|
4 |
client = TestClient(app)
|
5 |
|
|
|
6 |
def test_create_item():
|
7 |
data = {"name": "test", "description": "test description"}
|
8 |
response = client.post("/items/", json=data)
|
@@ -10,17 +11,20 @@ def test_create_item():
|
|
10 |
assert response.json()["name"] == "test"
|
11 |
assert response.json()["description"] == "test description"
|
12 |
|
|
|
13 |
def test_read_item():
|
14 |
response = client.get("/items/1")
|
15 |
assert response.status_code == 200
|
16 |
assert response.json()["name"] == "test"
|
17 |
assert response.json()["description"] == "test description"
|
18 |
|
|
|
19 |
def test_read_all_items():
|
20 |
response = client.get("/items/")
|
21 |
assert response.status_code == 200
|
22 |
assert len(response.json()) == 1
|
23 |
|
|
|
24 |
def test_update_item():
|
25 |
data = {"name": "updated test", "description": "updated test description"}
|
26 |
response = client.put("/items/1", json=data)
|
@@ -28,7 +32,8 @@ def test_update_item():
|
|
28 |
assert response.json()["name"] == "updated test"
|
29 |
assert response.json()["description"] == "updated test description"
|
30 |
|
|
|
31 |
def test_delete_item():
|
32 |
response = client.delete("/items/1")
|
33 |
assert response.status_code == 200
|
34 |
-
assert response.json() == {"detail": "Item deleted successfully"}
|
|
|
3 |
|
4 |
client = TestClient(app)
|
5 |
|
6 |
+
|
7 |
def test_create_item():
|
8 |
data = {"name": "test", "description": "test description"}
|
9 |
response = client.post("/items/", json=data)
|
|
|
11 |
assert response.json()["name"] == "test"
|
12 |
assert response.json()["description"] == "test description"
|
13 |
|
14 |
+
|
15 |
def test_read_item():
|
16 |
response = client.get("/items/1")
|
17 |
assert response.status_code == 200
|
18 |
assert response.json()["name"] == "test"
|
19 |
assert response.json()["description"] == "test description"
|
20 |
|
21 |
+
|
22 |
def test_read_all_items():
|
23 |
response = client.get("/items/")
|
24 |
assert response.status_code == 200
|
25 |
assert len(response.json()) == 1
|
26 |
|
27 |
+
|
28 |
def test_update_item():
|
29 |
data = {"name": "updated test", "description": "updated test description"}
|
30 |
response = client.put("/items/1", json=data)
|
|
|
32 |
assert response.json()["name"] == "updated test"
|
33 |
assert response.json()["description"] == "updated test description"
|
34 |
|
35 |
+
|
36 |
def test_delete_item():
|
37 |
response = client.delete("/items/1")
|
38 |
assert response.status_code == 200
|
39 |
+
assert response.json() == {"detail": "Item deleted successfully"}
|
app/tests/test_api/test_users.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from fastapi.testclient import TestClient
|
2 |
-
from app.
|
3 |
|
4 |
-
client = TestClient(
|
5 |
|
6 |
|
7 |
def test_create_user():
|
@@ -22,7 +22,10 @@ def test_create_user_invalid_password():
|
|
22 |
data = {"email": "test@example.com", "password": "short"}
|
23 |
response = client.post("/users/", json=data)
|
24 |
assert response.status_code == 422
|
25 |
-
assert
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
def test_read_user():
|
@@ -60,4 +63,4 @@ def test_delete_user():
|
|
60 |
def test_delete_user_not_found():
|
61 |
response = client.delete("/users/999")
|
62 |
assert response.status_code == 404
|
63 |
-
assert response.json()["detail"] == "User not found"
|
|
|
1 |
from fastapi.testclient import TestClient
|
2 |
+
from app.main import app
|
3 |
|
4 |
+
client = TestClient(app)
|
5 |
|
6 |
|
7 |
def test_create_user():
|
|
|
22 |
data = {"email": "test@example.com", "password": "short"}
|
23 |
response = client.post("/users/", json=data)
|
24 |
assert response.status_code == 422
|
25 |
+
assert (
|
26 |
+
"ensure this value has at least 6 characters"
|
27 |
+
in response.json()["detail"][0]["msg"]
|
28 |
+
)
|
29 |
|
30 |
|
31 |
def test_read_user():
|
|
|
63 |
def test_delete_user_not_found():
|
64 |
response = client.delete("/users/999")
|
65 |
assert response.status_code == 404
|
66 |
+
assert response.json()["detail"] == "User not found"
|
app/tests/test_core/__init__.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
from fastapi.testclient import TestClient
|
2 |
-
from app.
|
3 |
|
4 |
|
5 |
def test_app():
|
6 |
-
from app.main import app
|
7 |
client = TestClient(app)
|
8 |
response = client.get("/")
|
9 |
assert response.status_code == 200
|
@@ -13,4 +12,4 @@ def test_app():
|
|
13 |
def test_config():
|
14 |
assert settings.app_name == "My FastAPI Project"
|
15 |
assert settings.log_level == "debug"
|
16 |
-
assert settings.max_connection_count == 10
|
|
|
1 |
from fastapi.testclient import TestClient
|
2 |
+
from app.main import app
|
3 |
|
4 |
|
5 |
def test_app():
|
|
|
6 |
client = TestClient(app)
|
7 |
response = client.get("/")
|
8 |
assert response.status_code == 200
|
|
|
12 |
def test_config():
|
13 |
assert settings.app_name == "My FastAPI Project"
|
14 |
assert settings.log_level == "debug"
|
15 |
+
assert settings.max_connection_count == 10
|
app/tests/test_core/test_config.py
CHANGED
@@ -5,4 +5,4 @@ def test_settings():
|
|
5 |
assert settings.API_V1_STR == "/api/v1"
|
6 |
assert settings.PROJECT_NAME == "My FastAPI Project"
|
7 |
assert settings.SQLALCHEMY_DATABASE_URI == "sqlite:///./test.db"
|
8 |
-
assert settings.ACCESS_TOKEN_EXPIRE_MINUTES == 30
|
|
|
5 |
assert settings.API_V1_STR == "/api/v1"
|
6 |
assert settings.PROJECT_NAME == "My FastAPI Project"
|
7 |
assert settings.SQLALCHEMY_DATABASE_URI == "sqlite:///./test.db"
|
8 |
+
assert settings.ACCESS_TOKEN_EXPIRE_MINUTES == 30
|
app/tests/test_core/test_database.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from sqlalchemy.orm import Session
|
2 |
from app.core.database import Base
|
3 |
-
from app.models.user import User
|
4 |
-
from app.models.item import Item
|
5 |
|
6 |
|
7 |
def test_create_user(db: Session):
|
@@ -21,4 +21,4 @@ def test_create_item(db: Session):
|
|
21 |
db.refresh(item)
|
22 |
assert item.id is not None
|
23 |
assert item.name == "test item"
|
24 |
-
assert item.description == "test description"
|
|
|
1 |
from sqlalchemy.orm import Session
|
2 |
from app.core.database import Base
|
3 |
+
from app.api.models.user import User
|
4 |
+
from app.api.models.item import Item
|
5 |
|
6 |
|
7 |
def test_create_user(db: Session):
|
|
|
21 |
db.refresh(item)
|
22 |
assert item.id is not None
|
23 |
assert item.name == "test item"
|
24 |
+
assert item.description == "test description"
|
app/tests/test_core/test_security.py
CHANGED
@@ -1,17 +1,20 @@
|
|
1 |
from fastapi import HTTPException
|
2 |
from app.core.security import verify_password, get_password_hash
|
3 |
|
|
|
4 |
def test_password_hashing():
|
5 |
password = "testpassword"
|
6 |
hashed_password = get_password_hash(password)
|
7 |
assert hashed_password != password
|
8 |
|
|
|
9 |
def test_password_verification():
|
10 |
password = "testpassword"
|
11 |
hashed_password = get_password_hash(password)
|
12 |
assert verify_password(password, hashed_password)
|
13 |
assert not verify_password("wrongpassword", hashed_password)
|
14 |
|
|
|
15 |
def test_password_verification_exception():
|
16 |
password = "testpassword"
|
17 |
hashed_password = get_password_hash(password)
|
@@ -19,4 +22,4 @@ def test_password_verification_exception():
|
|
19 |
verify_password("wrongpassword", hashed_password)
|
20 |
except HTTPException as e:
|
21 |
assert e.status_code == 401
|
22 |
-
assert e.detail == "Incorrect email or password"
|
|
|
1 |
from fastapi import HTTPException
|
2 |
from app.core.security import verify_password, get_password_hash
|
3 |
|
4 |
+
|
5 |
def test_password_hashing():
|
6 |
password = "testpassword"
|
7 |
hashed_password = get_password_hash(password)
|
8 |
assert hashed_password != password
|
9 |
|
10 |
+
|
11 |
def test_password_verification():
|
12 |
password = "testpassword"
|
13 |
hashed_password = get_password_hash(password)
|
14 |
assert verify_password(password, hashed_password)
|
15 |
assert not verify_password("wrongpassword", hashed_password)
|
16 |
|
17 |
+
|
18 |
def test_password_verification_exception():
|
19 |
password = "testpassword"
|
20 |
hashed_password = get_password_hash(password)
|
|
|
22 |
verify_password("wrongpassword", hashed_password)
|
23 |
except HTTPException as e:
|
24 |
assert e.status_code == 401
|
25 |
+
assert e.detail == "Incorrect email or password"
|
app/tests/utils/utils.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy.orm import Session
|
2 |
+
from app.core.database import SessionLocal
|
3 |
+
|
4 |
+
|
5 |
+
def override_get_db():
|
6 |
+
"""
|
7 |
+
Override the get_db function for testing
|
8 |
+
"""
|
9 |
+
try:
|
10 |
+
db = SessionLocal()
|
11 |
+
yield db
|
12 |
+
finally:
|
13 |
+
db.close()
|
14 |
+
|
15 |
+
|
16 |
+
def get_db():
|
17 |
+
"""
|
18 |
+
Get a new database session
|
19 |
+
"""
|
20 |
+
try:
|
21 |
+
db = SessionLocal()
|
22 |
+
yield db
|
23 |
+
finally:
|
24 |
+
db.close()
|
app/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .utils import get_all_routes, print_routes
|
app/utils/utils.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
|
4 |
+
def get_all_routes(app):
|
5 |
+
routes = []
|
6 |
+
for route in app.routes:
|
7 |
+
routes.append(
|
8 |
+
{
|
9 |
+
"path": route.path,
|
10 |
+
"name": route.name,
|
11 |
+
"methods": list(route.methods),
|
12 |
+
}
|
13 |
+
)
|
14 |
+
return routes
|
15 |
+
|
16 |
+
|
17 |
+
def print_routes(app):
|
18 |
+
routes = get_all_routes(app)
|
19 |
+
print(json.dumps(routes, indent=4))
|
requirements.txt
CHANGED
@@ -7,4 +7,5 @@ pytest
|
|
7 |
pytest-cov
|
8 |
faker
|
9 |
requests-mock
|
10 |
-
passlib
|
|
|
|
7 |
pytest-cov
|
8 |
faker
|
9 |
requests-mock
|
10 |
+
passlib
|
11 |
+
httpx
|