Spaces:
Build error
Build error
Add CORS for prod & checking if CUDA is available before loading model
Browse files
backend/backend/app/utils/index.py
CHANGED
@@ -34,6 +34,10 @@ DATA_DIR = str(
|
|
34 |
current_directory / "data"
|
35 |
) # directory containing the documents to index
|
36 |
|
|
|
|
|
|
|
|
|
37 |
llm = LlamaCPP(
|
38 |
model_url="https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf",
|
39 |
temperature=0.1,
|
@@ -43,8 +47,7 @@ llm = LlamaCPP(
|
|
43 |
# kwargs to pass to __call__()
|
44 |
# generate_kwargs={},
|
45 |
# kwargs to pass to __init__()
|
46 |
-
|
47 |
-
model_kwargs={"n_gpu_layers": 100},
|
48 |
# transform inputs into Llama2 format
|
49 |
messages_to_prompt=messages_to_prompt,
|
50 |
completion_to_prompt=completion_to_prompt,
|
|
|
34 |
current_directory / "data"
|
35 |
) # directory containing the documents to index
|
36 |
|
37 |
+
|
38 |
+
# set to at least 1 to use GPU, adjust according to your GPU memory, but must be able to fit the model
|
39 |
+
model_kwargs = {"n_gpu_layers": 100} if DEVICE_TYPE == "cuda" else {}
|
40 |
+
|
41 |
llm = LlamaCPP(
|
42 |
model_url="https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf",
|
43 |
temperature=0.1,
|
|
|
47 |
# kwargs to pass to __call__()
|
48 |
# generate_kwargs={},
|
49 |
# kwargs to pass to __init__()
|
50 |
+
model_kwargs=model_kwargs,
|
|
|
51 |
# transform inputs into Llama2 format
|
52 |
messages_to_prompt=messages_to_prompt,
|
53 |
completion_to_prompt=completion_to_prompt,
|
backend/backend/main.py
CHANGED
@@ -9,6 +9,7 @@ from app.utils.index import create_index
|
|
9 |
from dotenv import load_dotenv
|
10 |
from fastapi import FastAPI
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
12 |
|
13 |
load_dotenv()
|
14 |
|
@@ -16,6 +17,7 @@ app = FastAPI()
|
|
16 |
|
17 |
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
|
18 |
|
|
|
19 |
|
20 |
if environment == "dev":
|
21 |
logger = logging.getLogger("uvicorn")
|
@@ -28,10 +30,30 @@ if environment == "dev":
|
|
28 |
allow_headers=["*"],
|
29 |
)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
app.include_router(chat_router, prefix="/api/chat")
|
32 |
app.include_router(query_router, prefix="/api/query")
|
33 |
app.include_router(search_router, prefix="/api/search")
|
34 |
app.include_router(healthcheck_router, prefix="/api/healthcheck")
|
35 |
|
36 |
-
#
|
37 |
create_index()
|
|
|
9 |
from dotenv import load_dotenv
|
10 |
from fastapi import FastAPI
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
12 |
+
from torch.cuda import is_available as is_cuda_available
|
13 |
|
14 |
load_dotenv()
|
15 |
|
|
|
17 |
|
18 |
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
|
19 |
|
20 |
+
# TODO: Add reading allowed origins from environment variables
|
21 |
|
22 |
if environment == "dev":
|
23 |
logger = logging.getLogger("uvicorn")
|
|
|
30 |
allow_headers=["*"],
|
31 |
)
|
32 |
|
33 |
+
if environment == "prod":
|
34 |
+
# In production, specify the allowed origins
|
35 |
+
allowed_origins = [
|
36 |
+
"https://your-production-domain.com",
|
37 |
+
"https://another-production-domain.com",
|
38 |
+
# Add more allowed origins as needed
|
39 |
+
]
|
40 |
+
|
41 |
+
logger = logging.getLogger("uvicorn")
|
42 |
+
logger.info(f"Running in production mode - allowing CORS for {allowed_origins}")
|
43 |
+
app.add_middleware(
|
44 |
+
CORSMiddleware,
|
45 |
+
allow_origins=allowed_origins,
|
46 |
+
allow_credentials=True,
|
47 |
+
allow_methods=["GET", "POST", "PUT", "DELETE"],
|
48 |
+
allow_headers=["*"],
|
49 |
+
)
|
50 |
+
|
51 |
+
logger.info(f"CUDA available: {is_cuda_available()}")
|
52 |
+
|
53 |
app.include_router(chat_router, prefix="/api/chat")
|
54 |
app.include_router(query_router, prefix="/api/query")
|
55 |
app.include_router(search_router, prefix="/api/search")
|
56 |
app.include_router(healthcheck_router, prefix="/api/healthcheck")
|
57 |
|
58 |
+
# Try to create the index first on startup
|
59 |
create_index()
|