Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import uvicorn | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, Form | |
| from fastapi.requests import Request | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from backend.classes.embedding_model import EmbeddingModelConfig, EmbeddingModel | |
| from backend.classes.galileo_platform import GalileoPlatformConfig, GalileoPlatform | |
| from backend.classes.generative_model import GeminiModelConfig, GeminiModel, OpenAIModelConfig, OpenAIModel | |
| from backend.classes.rag_application import RAGApplicationConfig, RAGApplication | |
| from backend.classes.vector_database.milvus_vector_database import ( | |
| MilvusVectorDatabaseConfig, | |
| MilvusVectorDatabase, | |
| ) | |
| from backend.utils.utils import get_embedding_model | |
| from backend.utils.utils import ( | |
| initialize_logger, | |
| read_config, | |
| set_env_variables, | |
| create_vector_database, | |
| get_generative_model, | |
| ) | |
| app = FastAPI() | |
| app.mount("/static", StaticFiles(directory="backend/api/static"), name="static") | |
| templates = Jinja2Templates(directory="backend/api/templates") | |
| load_dotenv() | |
| logger = initialize_logger() | |
| # get current file path using Path | |
| config = read_config(str(Path(Path(__file__).parent.parent, "conf/config.yaml"))) | |
| # check if environment variables are set | |
| env_variables = set_env_variables(config["env_variables"]) | |
| app_config = config[env_variables["APP_ENV"]] | |
| app_config["env_vars"] = env_variables | |
| # Create embedding model object | |
| embedding_model_config = EmbeddingModelConfig( | |
| model_name=app_config["embedding_model"]["model_name"], | |
| batch_size=app_config["embedding_model"]["batch_size"], | |
| ) | |
| embedding_model = get_embedding_model(EmbeddingModel, embedding_model_config) | |
| # Create vector db model object | |
| vector_db_config = MilvusVectorDatabaseConfig( | |
| db_path=app_config["vector_database"]["db_path"] + env_variables["MILVUS_DB"] + "_milvus.db", | |
| collection_name=env_variables["MILVUS_DB"], | |
| vector_dimensions=app_config["vector_database"]["dimensions"], | |
| drop_if_exists=False, | |
| ) | |
| vector_db = create_vector_database(MilvusVectorDatabase, vector_db_config) | |
| # Create generative model object | |
| gemini_generative_model_config = GeminiModelConfig( | |
| model_name=env_variables["GOOGLE_GEMINI_MODEL"], | |
| api_keys=[env_variables["GOOGLE_GEMINI_API_KEY"], env_variables["GOOGLE_GEMINI_BACKUP_API_KEY"]], | |
| temperature=float(env_variables["MODEL_TEMPERATURE"]), | |
| ) | |
| gemini_generative_model = get_generative_model(GeminiModel, gemini_generative_model_config) | |
| # openai_generative_model_config = OpenAIModelConfig( | |
| # model_name=env_variables["OPENAI_MODEL"], | |
| # api_key=env_variables["OPENAI_API_KEY"], | |
| # temperature=float(env_variables["MODEL_TEMPERATURE"]), | |
| # ) | |
| # openai_generative_model = get_generative_model(OpenAIModel, openai_generative_model_config) | |
| default_project_name = env_variables["GALILEO_PROJECT_NAME"] | |
| default_logstream_name = env_variables["GALILEO_LOGSTREAM_NAME"] | |
| default_protect_stage_name = env_variables["GALILEO_PROTECT_STAGE_NAME"] | |
| default_dataset_name = env_variables["GALILEO_DATASET_NAME"] | |
| # Create Galileo platform object | |
| galileo_platform_config = GalileoPlatformConfig( | |
| protect_project_name=env_variables["GALILEO_PROJECT_NAME"], | |
| protect_stage_name=env_variables["GALILEO_PROTECT_STAGE_NAME"], | |
| ) | |
| galileo_platform = GalileoPlatform(galileo_platform_config) | |
| # Initialize RAG application | |
| rag_application_config = RAGApplicationConfig( | |
| embedding_model=embedding_model, | |
| vector_db=vector_db, | |
| generative_model=gemini_generative_model, | |
| # generative_model=openai_generative_model, | |
| galileo_platform=galileo_platform, | |
| ) | |
| rag_app = RAGApplication(rag_application_config) | |
| async def read_root(request: Request): | |
| # Get default project name from environment variables | |
| return templates.TemplateResponse("index.html", { | |
| "request": request, | |
| "default_project_name": default_project_name, | |
| "default_logstream_name": default_logstream_name, | |
| "default_dataset_name": default_dataset_name | |
| }) | |
| async def search( | |
| query: str = Form(...), | |
| top_k: int = Form(5), | |
| add_to_dataset: bool = Form(False), | |
| pii_detection: bool = Form(False), | |
| hallucination_detection: bool = Form(False), | |
| induce_hallucination: bool = Form(False), | |
| project_name: str = Form(...), | |
| logstream_name: str = Form(...), | |
| dataset_name: str = Form(...), | |
| ) -> JSONResponse: | |
| logger.info("=" * 80) | |
| logger.info("SEARCH REQUEST RECEIVED") | |
| logger.info(f"Query: {query}") | |
| logger.info(f"Top K: {top_k}") | |
| logger.info(f"Add to Dataset: {add_to_dataset}") | |
| logger.info(f"PII Detection: {pii_detection}") | |
| logger.info(f"Hallucination Detection: {hallucination_detection}") | |
| logger.info(f"Induce Hallucination: {induce_hallucination}") | |
| logger.info("=" * 80) | |
| response, redacted_response, original_response, context_adherence_score, pii_flag = rag_app.run( | |
| query, | |
| pii_detection=pii_detection, | |
| top_k=top_k, | |
| hallucination_detection=hallucination_detection, | |
| induce_hallucination=induce_hallucination, | |
| project_name=project_name, | |
| logstream_name=logstream_name, | |
| dataset_name=dataset_name if add_to_dataset else None, | |
| ) | |
| # Simulate processing | |
| return JSONResponse( | |
| { | |
| "message": response, | |
| "redacted_message": redacted_response, | |
| "original_message": original_response, | |
| "metrics": { | |
| "context_adherence": context_adherence_score, | |
| "pii_flag": pii_flag, | |
| }, | |
| } | |
| ) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |