|
""" Backend server for the frontend app |
|
This file contains the endpoints that can be called via HTTP |
|
""" |
|
|
|
from fastapi import FastAPI |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import Response |
|
import duckdb |
|
import pyarrow as pa |
|
from uvicorn import run |
|
from fastapi.staticfiles import StaticFiles |
|
|
|
from fire import Fire |
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
con = duckdb.connect() |
|
con.query("""CREATE TABLE flights AS FROM 'flights-10m.parquet'""") |
|
|
|
|
|
@app.get("/query/{sql_query:path}") |
|
async def query(sql_query: str): |
|
global con |
|
sql_query = sql_query.replace("count(*)", "count(*)::INT") |
|
result = con.query(sql_query).arrow() |
|
return Response(arrow_to_bytes(result), media_type="application/octet-stream") |
|
|
|
|
|
app.mount("/", StaticFiles(directory="dist", html=True), name="static") |
|
|
|
|
|
def arrow_to_bytes(table: pa.Table): |
|
sink = pa.BufferOutputStream() |
|
with pa.RecordBatchStreamWriter(sink, table.schema) as writer: |
|
writer.write_table(table) |
|
bytes = sink.getvalue().to_pybytes() |
|
return bytes |
|
|
|
|
|
def serve(port=8000, host="localhost"): |
|
run(app, port=port, host=host) |
|
|
|
|
|
if __name__ == "__main__": |
|
Fire(serve) |
|
|