|
""" |
|
This file consolidates parameters for logging, database connections, model paths, API settings, and security. |
|
""" |
|
|
|
|
|
import os |
|
import logging |
|
from datetime import timedelta |
|
from typing import Callable, List, Optional |
|
|
|
|
|
import torch |
|
from dotenv import load_dotenv |
|
from pathlib import Path |
|
from pydantic import BaseModel, Field, computed_field |
|
from pydantic_settings import BaseSettings |
|
|
|
load_dotenv() |
|
|
|
BASE_DIR = Path(__file__).resolve().parent.parent |
|
|
|
|
|
class QdrantSettings(BaseModel): |
|
host: str = Field("localhost", validation_alias="LOCAL_HOST") |
|
port: int = Field(6333, validation_alias="LOCAL_PORT") |
|
|
|
|
|
class ModelsSettings(BaseModel): |
|
embedder_model: str = "all-MiniLM-L6-v2" |
|
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L6-v2" |
|
|
|
|
|
class LocalLLMSettings(BaseModel): |
|
model_path_or_repo_id: str = "TheBloke/Mistral-7B-v0.1-GGUF" |
|
model_file: str = "mistral-7b-v0.1.Q5_K_S.gguf" |
|
model_type: str = "mistral" |
|
|
|
gpu_layers: Optional[int] = None |
|
threads: int = 8 |
|
context_length: int = 4096 |
|
mlock: bool = True |
|
|
|
|
|
class GenerationSettings(BaseModel): |
|
last_n_tokens: int = ( |
|
128 |
|
) |
|
temperature: float = ( |
|
0.3 |
|
) |
|
repetition_penalty: float = 1.2 |
|
|
|
|
|
class TextSplitterSettings(BaseModel): |
|
chunk_size: int = 1000 |
|
chunk_overlap: int = 100 |
|
length_function: Callable = len |
|
is_separator_regex: bool = False |
|
add_start_index: bool = True |
|
|
|
|
|
class APISettings(BaseModel): |
|
app: str = "app.api.api:api" |
|
host: str = "127.0.0.1" |
|
port: int = 5050 |
|
reload: bool = True |
|
|
|
|
|
class GeminiSettings(BaseModel): |
|
temperature: float = 0.0 |
|
top_p: float = 0.95 |
|
top_k: int = 20 |
|
candidate_count: int = 1 |
|
seed: int = 5 |
|
max_output_tokens: int = 1001 |
|
stop_sequences: List[str] = Field(default_factory=lambda: ["STOP!"]) |
|
presence_penalty: float = 0.0 |
|
frequency_penalty: float = 0.0 |
|
|
|
|
|
class GeminiEmbeddingSettings(BaseModel): |
|
output_dimensionality: int = 382 |
|
task_type: str = "retrieval_document" |
|
|
|
|
|
class GeminiWrapperSettings(BaseModel): |
|
temperature: float = 0.0 |
|
top_p: float = 0.95 |
|
top_k: int = 20 |
|
candidate_count: int = 1 |
|
seed: int = 5 |
|
max_output_tokens: int = 100 |
|
stop_sequences: List[str] = Field(default_factory=lambda: ["STOP!"]) |
|
presence_penalty: float = 0.0 |
|
frequency_penalty: float = 0.0 |
|
|
|
|
|
class PostgresSettings(BaseModel): |
|
url: str = os.environ["DATABASE_URL"] |
|
echo: bool = False |
|
|
|
|
|
class Settings(BaseSettings): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qdrant: QdrantSettings = Field(default_factory=QdrantSettings) |
|
local_llm: LocalLLMSettings = Field(default_factory=LocalLLMSettings) |
|
models: ModelsSettings = Field(default_factory=ModelsSettings) |
|
local_generation: GenerationSettings = Field(default_factory=GenerationSettings) |
|
text_splitter: TextSplitterSettings = Field(default_factory=TextSplitterSettings) |
|
api: APISettings = Field(default_factory=APISettings) |
|
gemini_generation: GeminiSettings = Field(default_factory=GeminiSettings) |
|
gemini_embedding: GeminiEmbeddingSettings = Field( |
|
default_factory=GeminiEmbeddingSettings |
|
) |
|
gemini_wrapper: GeminiWrapperSettings = Field( |
|
default_factory=GeminiWrapperSettings |
|
) |
|
postgres: PostgresSettings = Field(default_factory=PostgresSettings) |
|
|
|
use_gemini: bool = True |
|
max_delta: float = ( |
|
0.15 |
|
) |
|
max_cookie_lifetime: timedelta = timedelta(seconds=3000) |
|
password_reset_token_lifetime: timedelta = timedelta(seconds=3000) |
|
|
|
device: str = Field( |
|
default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu" |
|
) |
|
base_dir: Path = BASE_DIR |
|
|
|
stream: bool = True |
|
|
|
secret_pepper: str = os.environ["SECRET_PEPPER"] |
|
jwt_algorithm: str = os.environ["JWT_ALGORITHM"].replace("\r", "") |
|
api_key: str = os.environ["GEMINI_API_KEY"] |
|
|
|
@computed_field |
|
@property |
|
def get_gpu_layers(self) -> int: |
|
return 20 if self.device == "cuda" else 0 |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
settings = Settings() |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(levelname)s: %(message)s", |
|
handlers=[logging.StreamHandler()], |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
def bold_text(text: str): |
|
return "\033[1m" + text + "\033[0m" |
|
|
|
print(bold_text("--- Successfully loaded settings ---")) |
|
print(f"{bold_text("Base Directory:")} {settings.base_dir}") |
|
print(f"{bold_text("Running on device:")} {settings.device}") |
|
print(f"{bold_text("Qdrant Host:")} {settings.qdrant.host}") |
|
print(f"{bold_text("LLM GPU Layers:")} {settings.local_llm.gpu_layers}") |
|
|
|
|
|
|
|
print(bold_text("\n--- Full settings model dump (secrets masked) ---")) |
|
print(settings.model_dump()) |
|
|
|
print(bold_text("\n--- Secret fields (from .env file) ---")) |
|
print(f"{bold_text("Postgres URL:")} {settings.postgres.url}") |
|
print(f"{bold_text("JWT Algorithm:")} {settings.jwt_algorithm}") |
|
print(f"{bold_text("Secret Pepper:")} {settings.secret_pepper}") |
|
|
|
print(f"{bold_text("Gemini API Key:")} {settings.api_key}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
url_user_not_required = ["login", "", "viewer", "message_with_docs", "new_user", "health"] |
|
|