Spaces:
Running
Running
from fastapi import FastAPI, HTTPException, UploadFile, File,Request,Depends,status,BackgroundTasks | |
from fastapi.security import OAuth2PasswordBearer | |
from pydantic import BaseModel | |
from typing import Optional | |
from uuid import uuid4 | |
import os | |
from dotenv import load_dotenv | |
from rag import * | |
from fastapi.responses import StreamingResponse | |
import json | |
from prompt import * | |
from fastapi.middleware.cors import CORSMiddleware | |
import requests | |
load_dotenv() | |
## setup authorization | |
api_keys = [os.environ.get("FASTAPI_API_KEY")] | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # use token authentication | |
def api_key_auth(api_key: str = Depends(oauth2_scheme)): | |
if api_key not in api_keys: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Forbidden" | |
) | |
dev_mode = os.environ.get("DEV") | |
if dev_mode == "True": | |
app = FastAPI() | |
else: | |
app = FastAPI(dependencies=[Depends(api_key_auth)]) | |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) | |
# Pydantic model for the form data | |
class verify_response_model(BaseModel): | |
response: str = Field(description="The response from the user to the question") | |
answers: list[str] = Field(description="The possible answers to the question to test if the user read the entire book") | |
question: str = Field(description="The question asked to the user to test if they read the entire book") | |
class UserInput(BaseModel): | |
query: str | |
stream: Optional[bool] = False | |
messages: Optional[list[dict]] = [] | |
#endpoinds | |
async def generate_sphinx(): | |
try: | |
sphinx : sphinx_output = generate_sphinx_response() | |
return {"question": sphinx.question, "answers": sphinx.answers} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def verify_sphinx(response: verify_response_model): | |
try: | |
score : bool = verify_response(response.response, response.answers, response.question) | |
return {"score": score} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate(user_input: UserInput): | |
try: | |
print(user_input.stream,user_input.query) | |
if user_input.stream: | |
return StreamingResponse(generate_stream(user_input.query,user_input.messages,stream=True),media_type="application/json") | |
else: | |
return generate_stream(user_input.query,user_input.messages,stream=False) | |
except Exception as e: | |
return {"message": str(e)} |