from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from time import time from statistics import mean from modules.details import rand_details from modules.inference import generate_image app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") tasks = {} def get_place_in_queue(task_id): pending_tasks = list(task for task in tasks.values() if task["status"] == "pending") try: return pending_tasks.index(task_id) + 1 except: return 0 def calculate_eta(task_id): total_durations = list(task["completed_at"] - task["created_at"] for task in tasks.values() if "completed_at" in task) place = tasks[task_id]["initial_place_in_queue"] or 1 if len(total_durations): return mean(total_durations) * place else: return 40 * place @app.get('/') def index(): return FileResponse(path="static/index.html", media_type="text/html") @app.get('/task/create') def create_task(prompt: str = "покемон"): created_at = time() task_id = f"{str(created_at)}_{prompt}" tasks[task_id] = { "task_id": task_id, "created_at": created_at, "prompt": prompt, "initial_place_in_queue": get_place_in_queue(task_id), "status": "pending", "poll_count": 0, } print("Place in queue: ", get_place_in_queue(task_id)) print("ETA: ", calculate_eta(task_id)) return tasks[task_id] @app.get('/task/queue') def queue_task(task_id: str): try: tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"]) except Exception as ex: tasks[task_id]["status"] = "failed" tasks[task_id]["error"] = repr(ex) else: tasks[task_id]["status"] = "completed" finally: tasks[task_id]["completed_at"] = time() return tasks[task_id] @app.get('/task/poll') def poll_task(task_id: str): pending_tasks = [] completed_durations = [] for task in tasks.values(): if task["status"] == "pending": pending_tasks.append(task["task_id"]) elif task["status"] == "completed": completed_durations.append( task["completed_at"] - task["created_at"]) try: place_in_queue = pending_tasks.index(task_id) + 1 except: place_in_queue = 0 if (len(completed_durations)): eta = sum(completed_durations) / \ len(completed_durations) * (place_in_queue or 1) else: eta = 40 * (place_in_queue or 1) tasks[task_id]["place_in_queue"] = place_in_queue tasks[task_id]["eta"] = round(eta, 1) tasks[task_id]["poll_count"] += 1 return tasks[task_id] # @app.route('/details') @app.get('/details') async def generate_details(): return rand_details() @app.get('/duck/quack') async def test(query: str = "quack"): print(query) return {"duck": query} @app.get('/test') async def test(query: str = "test"): print(query) return {"query": query}