py-learn-backend / auth /models.py
Oviya
fix
da8ce94
"""
Database models and schemas for authentication system
Contains:
- User model with role-based access
- Token blacklist model
- Refresh token model
- Support for both SQL and JSON backends
"""
import os
from typing import Optional, Dict, Any
# Determine which backend to use
USE_JSON_DB = os.getenv("USE_JSON_DB", "1") == "1"
if USE_JSON_DB:
# JSON backend
from .json_database import JSONUsers as UserBackend
from .json_database import JSONBlacklistedTokens as BlacklistedTokenBackend
from .json_database import JSONRefreshTokens as RefreshTokenBackend
else:
# SQL backend (kept for backward compatibility)
import pyodbc
class UserBackend:
@staticmethod
def find_by_username(username: str) -> Optional[Dict[str, Any]]:
from .database import get_db_connection
conn = get_db_connection()
cur = conn.cursor()
cur.execute("SELECT id, username, password_hash, role FROM Users WHERE username = ?", (username,))
row = cur.fetchone()
conn.close()
if row:
return {
'id': row[0],
'username': row[1],
'password_hash': row[2],
'role': row[3]
}
return None
@staticmethod
def create_user(username: str, password_hash: str, role: str = 'user') -> bool:
from .database import get_db_connection
try:
conn = get_db_connection()
cur = conn.cursor()
cur.execute(
"INSERT INTO Users (username, password_hash, role) VALUES (?, ?, ?)",
(username, password_hash, role)
)
conn.commit()
conn.close()
return True
except pyodbc.IntegrityError:
return False
@staticmethod
def get_all_users() -> list:
from .database import get_db_connection
conn = get_db_connection()
cur = conn.cursor()
cur.execute("SELECT id, username, role FROM Users ORDER BY id")
users = []
for row in cur.fetchall():
users.append({
"id": row[0],
"username": row[1],
"role": row[2]
})
conn.close()
return users
@staticmethod
def promote_to_admin(username: str) -> bool:
from .database import get_db_connection
conn = get_db_connection()
cur = conn.cursor()
cur.execute("UPDATE Users SET role = 'admin' WHERE username = ?", (username,))
conn.commit()
result = cur.rowcount > 0
conn.close()
return result
@staticmethod
def user_count() -> int:
from .database import get_db_connection
conn = get_db_connection()
cur = conn.cursor()
cur.execute("SELECT COUNT(*) FROM Users")
count = cur.fetchone()[0]
conn.close()
return count
class BlacklistedTokenBackend:
@staticmethod
def is_blacklisted(token: str) -> bool:
from .database import get_db_connection
conn = get_db_connection()
cur = conn.cursor()
cur.execute("SELECT token FROM BlacklistedTokens WHERE token = ?", (token,))
result = cur.fetchone() is not None
conn.close()
return result
@staticmethod
def add_to_blacklist(token: str) -> bool:
from .database import get_db_connection
conn = get_db_connection()
cur = conn.cursor()
cur.execute("SELECT token FROM BlacklistedTokens WHERE token = ?", (token,))
if cur.fetchone():
conn.close()
return True
cur.execute("INSERT INTO BlacklistedTokens (token) VALUES (?)", (token,))
conn.commit()
conn.close()
return True
class RefreshTokenBackend:
@staticmethod
def find_by_token(token: str) -> Optional[str]:
from .database import get_db_connection
conn = get_db_connection()
cur = conn.cursor()
cur.execute("SELECT username FROM RefreshTokens WHERE token = ?", (token,))
row = cur.fetchone()
conn.close()
return row[0] if row else None
@staticmethod
def create_token(username: str, token: str) -> bool:
from .database import get_db_connection
conn = get_db_connection()
cur = conn.cursor()
cur.execute("INSERT INTO RefreshTokens (username, token) VALUES (?, ?)", (username, token))
conn.commit()
conn.close()
return True
@staticmethod
def delete_user_tokens(username: str) -> bool:
from .database import get_db_connection
conn = get_db_connection()
cur = conn.cursor()
cur.execute("DELETE FROM RefreshTokens WHERE username = ?", (username,))
conn.commit()
conn.close()
return True
class User:
"""User model for authentication and authorization"""
def __init__(self, username: str, password_hash: str, role: str = 'user', user_id: int = None):
self.id = user_id
self.username = username
self.password_hash = password_hash
self.role = role
@staticmethod
def find_by_username(username: str) -> Optional['User']:
"""Find user by username"""
user_data = UserBackend.find_by_username(username)
if user_data:
return User(
user_id=user_data.get('id'),
username=user_data.get('username'),
password_hash=user_data.get('password_hash'),
role=user_data.get('role')
)
return None
@staticmethod
def create_user(username: str, password_hash: str, role: str = 'user') -> bool:
"""Create a new user"""
return UserBackend.create_user(username, password_hash, role)
@staticmethod
def get_all_users() -> list:
"""Get all users (admin only)"""
return UserBackend.get_all_users()
@staticmethod
def promote_to_admin(username: str) -> bool:
"""Promote user to admin role"""
return UserBackend.promote_to_admin(username)
@staticmethod
def user_count() -> int:
"""Get total user count"""
return UserBackend.user_count()
def to_dict(self) -> Dict[str, Any]:
"""Convert user to dictionary (safe for JSON)"""
return {
"id": self.id,
"username": self.username,
"role": self.role
}
class BlacklistedToken:
"""Model for blacklisted JWT tokens"""
@staticmethod
def is_blacklisted(token: str) -> bool:
"""Check if token is blacklisted"""
return BlacklistedTokenBackend.is_blacklisted(token)
@staticmethod
def add_to_blacklist(token: str) -> bool:
"""Add token to blacklist"""
return BlacklistedTokenBackend.add_to_blacklist(token)
class RefreshToken:
"""Model for refresh token management"""
@staticmethod
def find_by_token(token: str) -> Optional[str]:
"""Find username by refresh token"""
return RefreshTokenBackend.find_by_token(token)
@staticmethod
def create_token(username: str, token: str) -> bool:
"""Store refresh token"""
return RefreshTokenBackend.create_token(username, token)
@staticmethod
def delete_user_tokens(username: str) -> bool:
"""Delete all refresh tokens for user"""
return RefreshTokenBackend.delete_user_tokens(username)
# Database table creation SQL (for SQL backend only)
def get_table_definitions():
"""Get SQL statements for creating authentication tables"""
return {
'users': """
IF OBJECT_ID('Users', 'U') IS NULL
CREATE TABLE Users (
id INT IDENTITY(1,1) PRIMARY KEY,
username NVARCHAR(100) UNIQUE NOT NULL,
password_hash NVARCHAR(500) NOT NULL,
role NVARCHAR(50) DEFAULT 'user'
)
""",
'blacklisted_tokens': """
IF OBJECT_ID('BlacklistedTokens', 'U') IS NULL
CREATE TABLE BlacklistedTokens (
id INT IDENTITY(1,1) PRIMARY KEY,
token NVARCHAR(1000) UNIQUE NOT NULL,
created_at DATETIME DEFAULT GETDATE()
)
""",
'refresh_tokens': """
IF OBJECT_ID('RefreshTokens', 'U') IS NULL
CREATE TABLE RefreshTokens (
id INT IDENTITY(1,1) PRIMARY KEY,
username NVARCHAR(100) NOT NULL,
token NVARCHAR(1000) UNIQUE NOT NULL,
created_at DATETIME DEFAULT GETDATE(),
FOREIGN KEY (username) REFERENCES Users(username) ON DELETE CASCADE
)
"""
}