kth-qa / main.py
erseux's picture
changed folder for static
d95db96
import json
import logging
from magic.conversational import question_handler
from schema import Answer
logger = logging.getLogger()
logging.basicConfig(encoding='utf-8', level=logging.INFO)
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from starlette.routing import WebSocketRoute
import uvicorn
from schema import Question
from config import State
import arel
# --- Setup ---
# hot reload
async def reload_data():
print("Reloading server data...")
static_path = "static"
template_path = "templates"
hotreload = arel.HotReload(
paths=[
arel.Path(static_path),
arel.Path(template_path),
],
)
state = State()
app = FastAPI(
routes=[WebSocketRoute("/hot-reload", hotreload, name="hot-reload")],
on_startup=[hotreload.startup],
on_shutdown=[hotreload.shutdown],
)
# templates
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory=template_path)
templates.env.globals["DEBUG"] = True
templates.env.globals["hotreload"] = hotreload
# CORS
origins = [
"http://localhost",
"http://localhost:5001",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# test questions
with open("test_response.json", "r") as f:
test_questions = json.load(f)
# --- Routes ---
@app.get("/", response_class=HTMLResponse)
def index(request: Request):
return templates.TemplateResponse(
"index.html",
{"request": request}
)
@app.post("/api/ask", response_class=JSONResponse)
async def ask(question: Question):
question_str = question.question
if question_str in test_questions:
return test_questions[question_str]
answer = None
try:
answer: Answer = await question_handler(question, state)
except Exception as e:
logger.exception(e)
if not answer:
return JSONResponse(status_code=404, content={"answer": "Something went wrong."})
return answer.dict(include={"answer", "urls"})
if __name__ == "__main__":
uvicorn.run("kth_qa:app", host="localhost", port=5001, reload=True, reload_excludes=['files/', 'logs/'], reload_dirs=['/templates', '/static'])