PokeGen / app.py
Ron Au
feat(error): Tweaks to logging and error handling
0bae5c5
raw
history blame
3.59 kB
from time import time
from statistics import mean
from fastapi import BackgroundTasks, FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from modules.details import rand_details
from modules.inference import generate_image
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
tasks = {}
class NewTask(BaseModel):
prompt = "покемон"
def get_place_in_queue(task_id):
queued_tasks = list(task for task in tasks.values()
if task["status"] == "queued" or task["status"] == "processing")
queued_tasks.sort(key=lambda task: task["created_at"])
queued_task_ids = list(task["task_id"] for task in queued_tasks)
try:
return queued_task_ids.index(task_id) + 1
except:
return 0
def calculate_eta(task_id):
total_durations = list(task["completed_at"] - task["started_at"]
for task in tasks.values() if "completed_at" in task and task["status"] == "completed")
initial_place_in_queue = tasks[task_id]["initial_place_in_queue"]
if len(total_durations):
eta = initial_place_in_queue * mean(total_durations)
else:
eta = initial_place_in_queue * 35
return round(eta, 1)
def next_task(task_id):
tasks[task_id]["completed_at"] = time()
queued_tasks = list(task for task in tasks.values() if task["status"] == "queued")
if queued_tasks:
print(f"{task_id} {tasks[task_id]['status']}. Task/s remaining: {len(queued_tasks)}")
process_task(queued_tasks[0]["task_id"])
def process_task(task_id):
if 'processing' in list(task['status'] for task in tasks.values()):
return
if tasks[task_id]["last_poll"] and time() - tasks[task_id]["last_poll"] > 30:
tasks[task_id]["status"] = "abandoned"
next_task(task_id)
tasks[task_id]["status"] = "processing"
tasks[task_id]["started_at"] = time()
print(f"Processing {task_id}")
try:
tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
except Exception as ex:
tasks[task_id]["status"] = "failed"
tasks[task_id]["error"] = repr(ex)
else:
tasks[task_id]["status"] = "completed"
finally:
next_task(task_id)
@app.head('/')
@app.get('/')
def index():
return FileResponse(path="static/index.html", media_type="text/html")
@app.get('/details')
def generate_details():
return rand_details()
@app.post('/task/create')
def create_task(background_tasks: BackgroundTasks, new_task: NewTask):
created_at = time()
task_id = f"{str(created_at)}_{new_task.prompt}"
tasks[task_id] = {
"task_id": task_id,
"status": "queued",
"eta": None,
"created_at": created_at,
"started_at": None,
"completed_at": None,
"last_poll": None,
"poll_count": 0,
"initial_place_in_queue": None,
"place_in_queue": None,
"prompt": new_task.prompt,
"value": None,
}
tasks[task_id]["initial_place_in_queue"] = get_place_in_queue(task_id)
tasks[task_id]["eta"] = calculate_eta(task_id)
background_tasks.add_task(process_task, task_id)
return tasks[task_id]
@app.get('/task/poll')
def poll_task(task_id: str):
tasks[task_id]["place_in_queue"] = get_place_in_queue(task_id)
tasks[task_id]["eta"] = calculate_eta(task_id)
tasks[task_id]["last_poll"] = time()
tasks[task_id]["poll_count"] += 1
return tasks[task_id]