File size: 2,999 Bytes
4c519fd
 
5c239ba
9fbb486
47ab990
 
42422ba
9fbb486
8f246ac
5c239ba
 
27660a3
5c239ba
47ab990
5c239ba
 
 
 
42422ba
 
 
 
67f60f6
9fbb486
 
 
 
3750ff9
9fbb486
3750ff9
 
9fbb486
3750ff9
 
 
 
 
9fbb486
3750ff9
 
9fbb486
3750ff9
 
9fbb486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3750ff9
9fbb486
 
 
 
 
 
 
 
 
3750ff9
 
24eb369
c6fcf99
 
47ab990
c6fcf99
5c239ba
9fbb486
 
 
 
 
42422ba
 
4c519fd
 
42422ba
5c239ba
 
 
4c519fd
42422ba
9fbb486
5c239ba
 
 
9fbb486
5c239ba
9fbb486
5c239ba
c6fcf99
5c239ba
 
c6fcf99
 
9fbb486
 
5c239ba
 
c6fcf99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from time import time
from statistics import mean

from fastapi import BackgroundTasks, FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel

from modules.details import rand_details
from modules.inference import generate_image

app = FastAPI(docs_url=None, redoc_url=None)

app.mount("/static", StaticFiles(directory="static"), name="static")

tasks = {}


class NewTask(BaseModel):
    prompt = "покемон"


def get_place_in_queue(task_id):
    queued_tasks = list(task for task in tasks.values()
                        if task["status"] == "queued" or task["status"] == "processing")

    queued_tasks.sort(key=lambda task: task["created_at"])

    queued_task_ids = list(task["task_id"] for task in queued_tasks)

    try:
        return queued_task_ids.index(task_id) + 1
    except:
        return 0


def calculate_eta(task_id):
    total_durations = list(task["completed_at"] - task["started_at"]
                           for task in tasks.values() if "completed_at" in task)

    initial_place_in_queue = tasks[task_id]["initial_place_in_queue"]

    if len(total_durations):
        eta = initial_place_in_queue * mean(total_durations)
    else:
        eta = initial_place_in_queue * 40

    return round(eta, 1)


def process_task(task_id):
    if 'processing' in list(task['status'] for task in tasks.values()):
        return

    tasks[task_id]["status"] = "processing"
    tasks[task_id]["started_at"] = time()

    try:
        tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
    except Exception as ex:
        tasks[task_id]["status"] = "failed"
        tasks[task_id]["error"] = repr(ex)
    else:
        tasks[task_id]["status"] = "completed"
    finally:
        tasks[task_id]["completed_at"] = time()

        queued_tasks = list(task for task in tasks.values() if task["status"] == "queued")

        if queued_tasks:
            print(f"Tasks remaining: {len(queued_tasks)}")
            process_task(queued_tasks[0]["task_id"])


@app.head('/')
@app.get('/')
def index():
    return FileResponse(path="static/index.html", media_type="text/html")


@app.get('/details')
def generate_details():
    return rand_details()


@app.post('/task/create')
def create_task(background_tasks: BackgroundTasks, new_task: NewTask):
    created_at = time()

    task_id = f"{str(created_at)}_{new_task.prompt}"

    tasks[task_id] = {
        "task_id": task_id,
        "created_at": created_at,
        "prompt": new_task.prompt,
        "status": "queued",
        "poll_count": 0,
    }

    tasks[task_id]["initial_place_in_queue"] = get_place_in_queue(task_id)

    background_tasks.add_task(process_task, task_id)

    return tasks[task_id]


@app.get('/task/poll')
def poll_task(task_id: str):
    tasks[task_id]["place_in_queue"] = get_place_in_queue(task_id)
    tasks[task_id]["eta"] = calculate_eta(task_id)
    tasks[task_id]["poll_count"] += 1

    return tasks[task_id]