File size: 3,590 Bytes
4c519fd
 
5c239ba
9fbb486
47ab990
 
42422ba
9fbb486
8f246ac
5c239ba
 
27660a3
5c239ba
47ab990
5c239ba
 
 
 
42422ba
 
 
 
67f60f6
9fbb486
 
 
 
3750ff9
9fbb486
3750ff9
 
9fbb486
3750ff9
 
 
 
 
9fbb486
da30f9b
3750ff9
9fbb486
3750ff9
 
9fbb486
 
568065a
9fbb486
 
 
 
568065a
 
 
 
 
 
 
 
 
 
9fbb486
 
 
 
568065a
 
 
 
9fbb486
 
f7bfaff
9fbb486
 
 
 
 
 
3750ff9
9fbb486
 
568065a
3750ff9
 
24eb369
c6fcf99
 
47ab990
c6fcf99
5c239ba
9fbb486
 
 
 
 
42422ba
 
4c519fd
 
42422ba
5c239ba
 
 
9fbb486
0bae5c5
da30f9b
 
 
0bae5c5
 
da30f9b
 
 
 
5c239ba
 
9fbb486
7cd74ca
5c239ba
9fbb486
5c239ba
c6fcf99
5c239ba
 
c6fcf99
 
9fbb486
 
568065a
5c239ba
 
c6fcf99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from time import time
from statistics import mean

from fastapi import BackgroundTasks, FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel

from modules.details import rand_details
from modules.inference import generate_image

app = FastAPI(docs_url=None, redoc_url=None)

app.mount("/static", StaticFiles(directory="static"), name="static")

tasks = {}


class NewTask(BaseModel):
    prompt = "покемон"


def get_place_in_queue(task_id):
    queued_tasks = list(task for task in tasks.values()
                        if task["status"] == "queued" or task["status"] == "processing")

    queued_tasks.sort(key=lambda task: task["created_at"])

    queued_task_ids = list(task["task_id"] for task in queued_tasks)

    try:
        return queued_task_ids.index(task_id) + 1
    except:
        return 0


def calculate_eta(task_id):
    total_durations = list(task["completed_at"] - task["started_at"]
                           for task in tasks.values() if "completed_at" in task and task["status"] == "completed")

    initial_place_in_queue = tasks[task_id]["initial_place_in_queue"]

    if len(total_durations):
        eta = initial_place_in_queue * mean(total_durations)
    else:
        eta = initial_place_in_queue * 35

    return round(eta, 1)


def next_task(task_id):
    tasks[task_id]["completed_at"] = time()

    queued_tasks = list(task for task in tasks.values() if task["status"] == "queued")

    if queued_tasks:
        print(f"{task_id} {tasks[task_id]['status']}. Task/s remaining: {len(queued_tasks)}")
        process_task(queued_tasks[0]["task_id"])


def process_task(task_id):
    if 'processing' in list(task['status'] for task in tasks.values()):
        return

    if tasks[task_id]["last_poll"] and time() - tasks[task_id]["last_poll"] > 30:
        tasks[task_id]["status"] = "abandoned"
        next_task(task_id)

    tasks[task_id]["status"] = "processing"
    tasks[task_id]["started_at"] = time()
    print(f"Processing {task_id}")

    try:
        tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
    except Exception as ex:
        tasks[task_id]["status"] = "failed"
        tasks[task_id]["error"] = repr(ex)
    else:
        tasks[task_id]["status"] = "completed"
    finally:
        next_task(task_id)


@app.head('/')
@app.get('/')
def index():
    return FileResponse(path="static/index.html", media_type="text/html")


@app.get('/details')
def generate_details():
    return rand_details()


@app.post('/task/create')
def create_task(background_tasks: BackgroundTasks, new_task: NewTask):
    created_at = time()

    task_id = f"{str(created_at)}_{new_task.prompt}"

    tasks[task_id] = {
        "task_id": task_id,
        "status": "queued",
        "eta": None,
        "created_at": created_at,
        "started_at": None,
        "completed_at": None,
        "last_poll": None,
        "poll_count": 0,
        "initial_place_in_queue": None,
        "place_in_queue": None,
        "prompt": new_task.prompt,
        "value": None,
    }

    tasks[task_id]["initial_place_in_queue"] = get_place_in_queue(task_id)
    tasks[task_id]["eta"] = calculate_eta(task_id)

    background_tasks.add_task(process_task, task_id)

    return tasks[task_id]


@app.get('/task/poll')
def poll_task(task_id: str):
    tasks[task_id]["place_in_queue"] = get_place_in_queue(task_id)
    tasks[task_id]["eta"] = calculate_eta(task_id)
    tasks[task_id]["last_poll"] = time()
    tasks[task_id]["poll_count"] += 1

    return tasks[task_id]