File size: 3,100 Bytes
c6fcf99
 
 
 
 
4c519fd
 
5c239ba
8f246ac
5c239ba
 
c6fcf99
5c239ba
c6fcf99
5c239ba
 
 
 
 
67f60f6
3750ff9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6fcf99
3750ff9
 
 
 
c6fcf99
 
 
 
5c239ba
c6fcf99
 
4c519fd
 
 
5c239ba
 
 
4c519fd
5c239ba
67f60f6
5c239ba
 
 
 
67f60f6
3750ff9
 
c6fcf99
5c239ba
 
c6fcf99
 
9421891
 
 
 
 
 
 
 
 
5c239ba
c6fcf99
5c239ba
 
c6fcf99
 
4c519fd
 
 
 
 
 
 
8f246ac
 
4c519fd
 
 
 
 
 
 
8f246ac
 
4c519fd
 
 
 
 
5c239ba
 
c6fcf99
 
 
 
 
 
 
5c239ba
 
c6fcf99
 
 
 
5c239ba
 
c6fcf99
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse


from time import time
from statistics import mean

from modules.details import rand_details
from modules.inference import generate_image

app = FastAPI()

app.mount("/static", StaticFiles(directory="static"), name="static")


tasks = {}


def get_place_in_queue(task_id):

    pending_tasks = list(task for task in tasks.values()
                         if task["status"] == "pending")

    try:
        return pending_tasks.index(task_id) + 1
    except:
        return 0


def calculate_eta(task_id):
    total_durations = list(task["completed_at"] - task["created_at"]
                           for task in tasks.values() if "completed_at" in task)

    place = tasks[task_id]["initial_place_in_queue"] or 1

    if len(total_durations):
        return mean(total_durations) * place
    else:
        return 40 * place


@app.get('/')
def index():
    return FileResponse(path="static/index.html", media_type="text/html")


@app.get('/task/create')
def create_task(prompt: str = "покемон"):
    created_at = time()

    task_id = f"{str(created_at)}_{prompt}"

    tasks[task_id] = {
        "task_id": task_id,
        "created_at": created_at,
        "prompt": prompt,
        "initial_place_in_queue": get_place_in_queue(task_id),
        "status": "pending",
        "poll_count": 0,
    }

    print("Place in queue: ", get_place_in_queue(task_id))
    print("ETA: ", calculate_eta(task_id))

    return tasks[task_id]


@app.get('/task/queue')
def queue_task(task_id: str):
    try:
        tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
    except Exception as ex:
        tasks[task_id]["status"] = "failed"
        tasks[task_id]["error"] = repr(ex)
    else:
        tasks[task_id]["status"] = "completed"
    finally:
        tasks[task_id]["completed_at"] = time()

    return tasks[task_id]


@app.get('/task/poll')
def poll_task(task_id: str):
    pending_tasks = []
    completed_durations = []

    for task in tasks.values():
        if task["status"] == "pending":
            pending_tasks.append(task["task_id"])
        elif task["status"] == "completed":
            completed_durations.append(
                task["completed_at"] - task["created_at"])

    try:
        place_in_queue = pending_tasks.index(task_id) + 1
    except:
        place_in_queue = 0

    if (len(completed_durations)):
        eta = sum(completed_durations) / \
            len(completed_durations) * (place_in_queue or 1)
    else:
        eta = 40 * (place_in_queue or 1)

    tasks[task_id]["place_in_queue"] = place_in_queue
    tasks[task_id]["eta"] = round(eta, 1)
    tasks[task_id]["poll_count"] += 1

    return tasks[task_id]


# @app.route('/details')
@app.get('/details')
async def generate_details():
    return rand_details()


@app.get('/duck/quack')
async def test(query: str = "quack"):
    print(query)
    return {"duck": query}


@app.get('/test')
async def test(query: str = "test"):
    print(query)
    return {"query": query}