Demos / backend /api /main.py
nikhile-galileo's picture
Updating app with latest changes
4ee29ab
from pathlib import Path
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, Form
from fastapi.requests import Request
from fastapi.responses import HTMLResponse
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from backend.classes.embedding_model import EmbeddingModelConfig, EmbeddingModel
from backend.classes.galileo_platform import GalileoPlatformConfig, GalileoPlatform
from backend.classes.generative_model import GeminiModelConfig, GeminiModel, OpenAIModelConfig, OpenAIModel
from backend.classes.rag_application import RAGApplicationConfig, RAGApplication
from backend.classes.vector_database.milvus_vector_database import (
MilvusVectorDatabaseConfig,
MilvusVectorDatabase,
)
from backend.utils.utils import get_embedding_model
from backend.utils.utils import (
initialize_logger,
read_config,
set_env_variables,
create_vector_database,
get_generative_model,
)
app = FastAPI()
app.mount("/static", StaticFiles(directory="backend/api/static"), name="static")
templates = Jinja2Templates(directory="backend/api/templates")
load_dotenv()
logger = initialize_logger()
# get current file path using Path
config = read_config(str(Path(Path(__file__).parent.parent, "conf/config.yaml")))
# check if environment variables are set
env_variables = set_env_variables(config["env_variables"])
app_config = config[env_variables["APP_ENV"]]
app_config["env_vars"] = env_variables
# Create embedding model object
embedding_model_config = EmbeddingModelConfig(
model_name=app_config["embedding_model"]["model_name"],
batch_size=app_config["embedding_model"]["batch_size"],
)
embedding_model = get_embedding_model(EmbeddingModel, embedding_model_config)
# Create vector db model object
vector_db_config = MilvusVectorDatabaseConfig(
db_path=app_config["vector_database"]["db_path"] + env_variables["MILVUS_DB"] + "_milvus.db",
collection_name=env_variables["MILVUS_DB"],
vector_dimensions=app_config["vector_database"]["dimensions"],
drop_if_exists=False,
)
vector_db = create_vector_database(MilvusVectorDatabase, vector_db_config)
# Create generative model object
gemini_generative_model_config = GeminiModelConfig(
model_name=env_variables["GOOGLE_GEMINI_MODEL"],
api_keys=[env_variables["GOOGLE_GEMINI_API_KEY"], env_variables["GOOGLE_GEMINI_BACKUP_API_KEY"]],
temperature=float(env_variables["MODEL_TEMPERATURE"]),
)
gemini_generative_model = get_generative_model(GeminiModel, gemini_generative_model_config)
# openai_generative_model_config = OpenAIModelConfig(
# model_name=env_variables["OPENAI_MODEL"],
# api_key=env_variables["OPENAI_API_KEY"],
# temperature=float(env_variables["MODEL_TEMPERATURE"]),
# )
# openai_generative_model = get_generative_model(OpenAIModel, openai_generative_model_config)
default_project_name = env_variables["GALILEO_PROJECT_NAME"]
default_logstream_name = env_variables["GALILEO_LOGSTREAM_NAME"]
default_protect_stage_name = env_variables["GALILEO_PROTECT_STAGE_NAME"]
default_dataset_name = env_variables["GALILEO_DATASET_NAME"]
# Create Galileo platform object
galileo_platform_config = GalileoPlatformConfig(
protect_project_name=env_variables["GALILEO_PROJECT_NAME"],
protect_stage_name=env_variables["GALILEO_PROTECT_STAGE_NAME"],
)
galileo_platform = GalileoPlatform(galileo_platform_config)
# Initialize RAG application
rag_application_config = RAGApplicationConfig(
embedding_model=embedding_model,
vector_db=vector_db,
generative_model=gemini_generative_model,
# generative_model=openai_generative_model,
galileo_platform=galileo_platform,
)
rag_app = RAGApplication(rag_application_config)
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
# Get default project name from environment variables
return templates.TemplateResponse("index.html", {
"request": request,
"default_project_name": default_project_name,
"default_logstream_name": default_logstream_name,
"default_dataset_name": default_dataset_name
})
@app.post("/search")
async def search(
query: str = Form(...),
top_k: int = Form(5),
add_to_dataset: bool = Form(False),
pii_detection: bool = Form(False),
hallucination_detection: bool = Form(False),
induce_hallucination: bool = Form(False),
project_name: str = Form(...),
logstream_name: str = Form(...),
dataset_name: str = Form(...),
) -> JSONResponse:
logger.info("=" * 80)
logger.info("SEARCH REQUEST RECEIVED")
logger.info(f"Query: {query}")
logger.info(f"Top K: {top_k}")
logger.info(f"Add to Dataset: {add_to_dataset}")
logger.info(f"PII Detection: {pii_detection}")
logger.info(f"Hallucination Detection: {hallucination_detection}")
logger.info(f"Induce Hallucination: {induce_hallucination}")
logger.info("=" * 80)
response, redacted_response, original_response, context_adherence_score, pii_flag = rag_app.run(
query,
pii_detection=pii_detection,
top_k=top_k,
hallucination_detection=hallucination_detection,
induce_hallucination=induce_hallucination,
project_name=project_name,
logstream_name=logstream_name,
dataset_name=dataset_name if add_to_dataset else None,
)
# Simulate processing
return JSONResponse(
{
"message": response,
"redacted_message": redacted_response,
"original_message": original_response,
"metrics": {
"context_adherence": context_adherence_score,
"pii_flag": pii_flag,
},
}
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)