from time import time from statistics import mean from fastapi import BackgroundTasks, FastAPI from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from pydantic import BaseModel from modules.details import rand_details from modules.inference import generate_image app = FastAPI(docs_url=None, redoc_url=None) app.mount("/static", StaticFiles(directory="static"), name="static") tasks = {} class NewTask(BaseModel): prompt = "покемон" def get_place_in_queue(task_id): queued_tasks = list(task for task in tasks.values() if task["status"] == "queued" or task["status"] == "processing") queued_tasks.sort(key=lambda task: task["created_at"]) queued_task_ids = list(task["task_id"] for task in queued_tasks) try: return queued_task_ids.index(task_id) + 1 except: return 0 def calculate_eta(task_id): total_durations = list(task["completed_at"] - task["started_at"] for task in tasks.values() if "completed_at" in task and task["status"] == "completed") initial_place_in_queue = tasks[task_id]["initial_place_in_queue"] if len(total_durations): eta = initial_place_in_queue * mean(total_durations) else: eta = initial_place_in_queue * 35 return round(eta, 1) def next_task(task_id): tasks[task_id]["completed_at"] = time() queued_tasks = list(task for task in tasks.values() if task["status"] == "queued") if queued_tasks: print(f"{task_id} {tasks[task_id]['status']}. Task/s remaining: {len(queued_tasks)}") process_task(queued_tasks[0]["task_id"]) def process_task(task_id): if 'processing' in list(task['status'] for task in tasks.values()): return if tasks[task_id]["last_poll"] and time() - tasks[task_id]["last_poll"] > 30: tasks[task_id]["status"] = "abandoned" next_task(task_id) tasks[task_id]["status"] = "processing" tasks[task_id]["started_at"] = time() print(f"Processing {task_id}") try: tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"]) except Exception as ex: tasks[task_id]["status"] = "failed" tasks[task_id]["error"] = repr(ex) else: tasks[task_id]["status"] = "completed" finally: next_task(task_id) @app.head('/') @app.get('/') def index(): return FileResponse(path="static/index.html", media_type="text/html") @app.get('/details') def generate_details(): return rand_details() @app.post('/task/create') def create_task(background_tasks: BackgroundTasks, new_task: NewTask): created_at = time() task_id = f"{str(created_at)}_{new_task.prompt}" tasks[task_id] = { "task_id": task_id, "status": "queued", "eta": None, "created_at": created_at, "started_at": None, "completed_at": None, "last_poll": None, "poll_count": 0, "initial_place_in_queue": None, "place_in_queue": None, "prompt": new_task.prompt, "value": None, } tasks[task_id]["initial_place_in_queue"] = get_place_in_queue(task_id) tasks[task_id]["eta"] = calculate_eta(task_id) background_tasks.add_task(process_task, task_id) return tasks[task_id] @app.get('/task/poll') def poll_task(task_id: str): tasks[task_id]["place_in_queue"] = get_place_in_queue(task_id) tasks[task_id]["eta"] = calculate_eta(task_id) tasks[task_id]["last_poll"] = time() tasks[task_id]["poll_count"] += 1 return tasks[task_id]