khronoz commited on
Commit
afc8144
·
1 Parent(s): 8500091

Add CORS for prod & checking if CUDA is available before loading model

Browse files
backend/backend/app/utils/index.py CHANGED
@@ -34,6 +34,10 @@ DATA_DIR = str(
34
  current_directory / "data"
35
  ) # directory containing the documents to index
36
 
 
 
 
 
37
  llm = LlamaCPP(
38
  model_url="https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf",
39
  temperature=0.1,
@@ -43,8 +47,7 @@ llm = LlamaCPP(
43
  # kwargs to pass to __call__()
44
  # generate_kwargs={},
45
  # kwargs to pass to __init__()
46
- # set to at least 1 to use GPU, adjust according to your GPU memory, but must be able to fit the model
47
- model_kwargs={"n_gpu_layers": 100},
48
  # transform inputs into Llama2 format
49
  messages_to_prompt=messages_to_prompt,
50
  completion_to_prompt=completion_to_prompt,
 
34
  current_directory / "data"
35
  ) # directory containing the documents to index
36
 
37
+
38
+ # set to at least 1 to use GPU, adjust according to your GPU memory, but must be able to fit the model
39
+ model_kwargs = {"n_gpu_layers": 100} if DEVICE_TYPE == "cuda" else {}
40
+
41
  llm = LlamaCPP(
42
  model_url="https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf",
43
  temperature=0.1,
 
47
  # kwargs to pass to __call__()
48
  # generate_kwargs={},
49
  # kwargs to pass to __init__()
50
+ model_kwargs=model_kwargs,
 
51
  # transform inputs into Llama2 format
52
  messages_to_prompt=messages_to_prompt,
53
  completion_to_prompt=completion_to_prompt,
backend/backend/main.py CHANGED
@@ -9,6 +9,7 @@ from app.utils.index import create_index
9
  from dotenv import load_dotenv
10
  from fastapi import FastAPI
11
  from fastapi.middleware.cors import CORSMiddleware
 
12
 
13
  load_dotenv()
14
 
@@ -16,6 +17,7 @@ app = FastAPI()
16
 
17
  environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
18
 
 
19
 
20
  if environment == "dev":
21
  logger = logging.getLogger("uvicorn")
@@ -28,10 +30,30 @@ if environment == "dev":
28
  allow_headers=["*"],
29
  )
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  app.include_router(chat_router, prefix="/api/chat")
32
  app.include_router(query_router, prefix="/api/query")
33
  app.include_router(search_router, prefix="/api/search")
34
  app.include_router(healthcheck_router, prefix="/api/healthcheck")
35
 
36
- # try to create the index first on startup
37
  create_index()
 
9
  from dotenv import load_dotenv
10
  from fastapi import FastAPI
11
  from fastapi.middleware.cors import CORSMiddleware
12
+ from torch.cuda import is_available as is_cuda_available
13
 
14
  load_dotenv()
15
 
 
17
 
18
  environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
19
 
20
+ # TODO: Add reading allowed origins from environment variables
21
 
22
  if environment == "dev":
23
  logger = logging.getLogger("uvicorn")
 
30
  allow_headers=["*"],
31
  )
32
 
33
+ if environment == "prod":
34
+ # In production, specify the allowed origins
35
+ allowed_origins = [
36
+ "https://your-production-domain.com",
37
+ "https://another-production-domain.com",
38
+ # Add more allowed origins as needed
39
+ ]
40
+
41
+ logger = logging.getLogger("uvicorn")
42
+ logger.info(f"Running in production mode - allowing CORS for {allowed_origins}")
43
+ app.add_middleware(
44
+ CORSMiddleware,
45
+ allow_origins=allowed_origins,
46
+ allow_credentials=True,
47
+ allow_methods=["GET", "POST", "PUT", "DELETE"],
48
+ allow_headers=["*"],
49
+ )
50
+
51
+ logger.info(f"CUDA available: {is_cuda_available()}")
52
+
53
  app.include_router(chat_router, prefix="/api/chat")
54
  app.include_router(query_router, prefix="/api/query")
55
  app.include_router(search_router, prefix="/api/search")
56
  app.include_router(healthcheck_router, prefix="/api/healthcheck")
57
 
58
+ # Try to create the index first on startup
59
  create_index()