ai-pokemon-card / app.py
Ron Au
refactor(FastAPI): Flask -> FastAPI
c6fcf99
raw
history blame
3.09 kB
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from time import time
from statistics import mean
from modules.details import rand_details
from modules.inference import generate_image
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
tasks = {}
def place_in_queue(task_id):
pending_tasks = list(task for task in tasks.values()
if task["status"] == "pending")
try:
return pending_tasks.index(task_id) + 1
except:
return 0
def calculate_eta(task_id):
total_durations = list(task["completed_at"] - task["created_at"]
for task in tasks.values() if "completed_at" in task)
place = tasks[task_id]["initial_place_in_queue"] or 1
if len(total_durations):
return mean(total_durations) * place
else:
return 40 * place
@app.get('/')
def index():
return FileResponse(path="static/index.html", media_type="text/html")
@app.get('/task/create')
def create_task(prompt: str = "покемон"):
created_at = time()
task_id = f"{str(created_at)}_{prompt}"
tasks[task_id] = {
"task_id": task_id,
"created_at": created_at,
"prompt": prompt,
"initial_place_in_queue": place_in_queue(task_id),
"status": "pending",
"poll_count": 0,
}
print("Place in queue: ", place_in_queue(task_id))
print("ETA: ", calculate_eta(task_id))
return tasks[task_id]
@app.get('/task/queue')
def queue_task(task_id: str):
try:
tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
except Exception as ex:
tasks[task_id]["status"] = "failed"
tasks[task_id]["error"] = repr(ex)
else:
tasks[task_id]["status"] = "completed"
finally:
tasks[task_id]["completed_at"] = time()
return tasks[task_id]
@app.get('/task/poll')
def poll_task(task_id: str):
pending_tasks = []
completed_durations = []
for task in tasks.values():
if task["status"] == "pending":
pending_tasks.append(task["task_id"])
elif task["status"] == "completed":
completed_durations.append(
task["completed_at"] - task["created_at"])
try:
place_in_queue = pending_tasks.index(task_id) + 1
except:
place_in_queue = 0
if (len(completed_durations)):
eta = sum(completed_durations) / \
len(completed_durations) * (place_in_queue or 1)
else:
eta = 40 * (place_in_queue or 1)
tasks[task_id]["place_in_queue"] = place_in_queue
tasks[task_id]["eta"] = round(eta, 1)
tasks[task_id]["poll_count"] += 1
return tasks[task_id]
# @app.route('/details')
@app.get('/details')
async def generate_details():
return rand_details()
@app.get('/duck/quack')
async def test(query: str = "quack"):
print(query)
return {"duck": query}
@app.get('/test')
async def test(query: str = "test"):
print(query)
return {"query": query}