ai-pokemon-card / app.py
Ron Au
feat(endpoint): Change `create/task` from GET to POST
42422ba
from time import time
from statistics import mean
from fastapi import BackgroundTasks, FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from modules.details import rand_details
from modules.inference import generate_image
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
tasks = {}
class NewTask(BaseModel):
prompt = "покемон"
def get_place_in_queue(task_id):
queued_tasks = list(task for task in tasks.values()
if task["status"] == "queued" or task["status"] == "processing")
queued_tasks.sort(key=lambda task: task["created_at"])
queued_task_ids = list(task["task_id"] for task in queued_tasks)
try:
return queued_task_ids.index(task_id) + 1
except:
return 0
def calculate_eta(task_id):
total_durations = list(task["completed_at"] - task["started_at"]
for task in tasks.values() if "completed_at" in task)
initial_place_in_queue = tasks[task_id]["initial_place_in_queue"]
if len(total_durations):
eta = initial_place_in_queue * mean(total_durations)
else:
eta = initial_place_in_queue * 40
return round(eta, 1)
def process_task(task_id):
if 'processing' in list(task['status'] for task in tasks.values()):
return
tasks[task_id]["status"] = "processing"
tasks[task_id]["started_at"] = time()
try:
tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
except Exception as ex:
tasks[task_id]["status"] = "failed"
tasks[task_id]["error"] = repr(ex)
else:
tasks[task_id]["status"] = "completed"
finally:
tasks[task_id]["completed_at"] = time()
queued_tasks = list(task for task in tasks.values() if task["status"] == "queued")
if queued_tasks:
print(f"Tasks remaining: {len(queued_tasks)}")
process_task(queued_tasks[0]["task_id"])
@app.head('/')
@app.get('/')
def index():
return FileResponse(path="static/index.html", media_type="text/html")
@app.get('/details')
def generate_details():
return rand_details()
@app.post('/task/create')
def create_task(background_tasks: BackgroundTasks, new_task: NewTask):
created_at = time()
task_id = f"{str(created_at)}_{new_task.prompt}"
tasks[task_id] = {
"task_id": task_id,
"created_at": created_at,
"prompt": new_task.prompt,
"status": "queued",
"poll_count": 0,
}
tasks[task_id]["initial_place_in_queue"] = get_place_in_queue(task_id)
background_tasks.add_task(process_task, task_id)
return tasks[task_id]
@app.get('/task/poll')
def poll_task(task_id: str):
tasks[task_id]["place_in_queue"] = get_place_in_queue(task_id)
tasks[task_id]["eta"] = calculate_eta(task_id)
tasks[task_id]["poll_count"] += 1
return tasks[task_id]