ia_back / main.py
Ilyas KHIAT
test
05971d9
raw
history blame
2.67 kB
from fastapi import FastAPI, HTTPException, UploadFile, File,Request,Depends,status,BackgroundTasks
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from typing import Optional
from uuid import uuid4
import os
from dotenv import load_dotenv
from rag import *
from fastapi.responses import StreamingResponse
import json
from prompt import *
from fastapi.middleware.cors import CORSMiddleware
import requests
load_dotenv()
## setup authorization
api_keys = [os.environ.get("FASTAPI_API_KEY")]
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # use token authentication
def api_key_auth(api_key: str = Depends(oauth2_scheme)):
if api_key not in api_keys:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Forbidden"
)
dev_mode = os.environ.get("DEV")
if dev_mode == "True":
app = FastAPI()
else:
app = FastAPI(dependencies=[Depends(api_key_auth)])
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
# Pydantic model for the form data
class verify_response_model(BaseModel):
response: str = Field(description="The response from the user to the question")
answers: list[str] = Field(description="The possible answers to the question to test if the user read the entire book")
question: str = Field(description="The question asked to the user to test if they read the entire book")
class UserInput(BaseModel):
query: str
stream: Optional[bool] = False
messages: Optional[list[dict]] = []
#endpoinds
@app.post("/generate_sphinx")
async def generate_sphinx():
try:
sphinx : sphinx_output = generate_sphinx_response()
return {"question": sphinx.question, "answers": sphinx.answers}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/verify_sphinx")
async def verify_sphinx(response: verify_response_model):
try:
score : bool = verify_response(response.response, response.answers, response.question)
return {"score": score}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate")
async def generate(user_input: UserInput):
try:
print(user_input.stream,user_input.query)
if user_input.stream:
return StreamingResponse(generate_stream(user_input.query,user_input.messages,stream=True),media_type="application/json")
else:
return generate_stream(user_input.query,user_input.messages,stream=False)
except Exception as e:
return {"message": str(e)}