Spaces:
Sleeping
Sleeping
File size: 3,590 Bytes
4c519fd 5c239ba 9fbb486 47ab990 42422ba 9fbb486 8f246ac 5c239ba 27660a3 5c239ba 47ab990 5c239ba 42422ba 67f60f6 9fbb486 3750ff9 9fbb486 3750ff9 9fbb486 3750ff9 9fbb486 da30f9b 3750ff9 9fbb486 3750ff9 9fbb486 568065a 9fbb486 568065a 9fbb486 568065a 9fbb486 f7bfaff 9fbb486 3750ff9 9fbb486 568065a 3750ff9 24eb369 c6fcf99 47ab990 c6fcf99 5c239ba 9fbb486 42422ba 4c519fd 42422ba 5c239ba 9fbb486 0bae5c5 da30f9b 0bae5c5 da30f9b 5c239ba 9fbb486 7cd74ca 5c239ba 9fbb486 5c239ba c6fcf99 5c239ba c6fcf99 9fbb486 568065a 5c239ba c6fcf99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from time import time
from statistics import mean
from fastapi import BackgroundTasks, FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from modules.details import rand_details
from modules.inference import generate_image
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
tasks = {}
class NewTask(BaseModel):
prompt = "покемон"
def get_place_in_queue(task_id):
queued_tasks = list(task for task in tasks.values()
if task["status"] == "queued" or task["status"] == "processing")
queued_tasks.sort(key=lambda task: task["created_at"])
queued_task_ids = list(task["task_id"] for task in queued_tasks)
try:
return queued_task_ids.index(task_id) + 1
except:
return 0
def calculate_eta(task_id):
total_durations = list(task["completed_at"] - task["started_at"]
for task in tasks.values() if "completed_at" in task and task["status"] == "completed")
initial_place_in_queue = tasks[task_id]["initial_place_in_queue"]
if len(total_durations):
eta = initial_place_in_queue * mean(total_durations)
else:
eta = initial_place_in_queue * 35
return round(eta, 1)
def next_task(task_id):
tasks[task_id]["completed_at"] = time()
queued_tasks = list(task for task in tasks.values() if task["status"] == "queued")
if queued_tasks:
print(f"{task_id} {tasks[task_id]['status']}. Task/s remaining: {len(queued_tasks)}")
process_task(queued_tasks[0]["task_id"])
def process_task(task_id):
if 'processing' in list(task['status'] for task in tasks.values()):
return
if tasks[task_id]["last_poll"] and time() - tasks[task_id]["last_poll"] > 30:
tasks[task_id]["status"] = "abandoned"
next_task(task_id)
tasks[task_id]["status"] = "processing"
tasks[task_id]["started_at"] = time()
print(f"Processing {task_id}")
try:
tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
except Exception as ex:
tasks[task_id]["status"] = "failed"
tasks[task_id]["error"] = repr(ex)
else:
tasks[task_id]["status"] = "completed"
finally:
next_task(task_id)
@app.head('/')
@app.get('/')
def index():
return FileResponse(path="static/index.html", media_type="text/html")
@app.get('/details')
def generate_details():
return rand_details()
@app.post('/task/create')
def create_task(background_tasks: BackgroundTasks, new_task: NewTask):
created_at = time()
task_id = f"{str(created_at)}_{new_task.prompt}"
tasks[task_id] = {
"task_id": task_id,
"status": "queued",
"eta": None,
"created_at": created_at,
"started_at": None,
"completed_at": None,
"last_poll": None,
"poll_count": 0,
"initial_place_in_queue": None,
"place_in_queue": None,
"prompt": new_task.prompt,
"value": None,
}
tasks[task_id]["initial_place_in_queue"] = get_place_in_queue(task_id)
tasks[task_id]["eta"] = calculate_eta(task_id)
background_tasks.add_task(process_task, task_id)
return tasks[task_id]
@app.get('/task/poll')
def poll_task(task_id: str):
tasks[task_id]["place_in_queue"] = get_place_in_queue(task_id)
tasks[task_id]["eta"] = calculate_eta(task_id)
tasks[task_id]["last_poll"] = time()
tasks[task_id]["poll_count"] += 1
return tasks[task_id]
|