File size: 2,892 Bytes
c577148
 
 
 
 
b2f519f
c577148
 
 
b2f519f
 
c577148
 
 
 
 
b2f519f
 
c577148
 
 
 
 
 
 
 
 
 
 
b2f519f
c577148
b2f519f
 
 
 
c577148
 
 
b2f519f
c577148
 
 
 
 
b2f519f
 
 
 
c577148
 
 
 
 
 
 
 
 
 
3de5d81
 
c577148
 
 
b2f519f
 
 
 
 
 
 
c577148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2f519f
c577148
 
 
 
 
 
b2f519f
 
 
 
c577148
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# This file is the main thread that handles all gradio calls for major t2i or i2i processing.
# Other gradio calls (like those from extensions) are not influenced.
# By using one single thread to process all major calls, model moving is significantly faster.


import threading
import time
import traceback

# Add global lock for model unloading
unload_lock = threading.Lock()
lock = threading.Lock()
last_id = 0
waiting_list = []
finished_list = []
last_exception = None
active_generations = 0  # Track active generation tasks
generation_lock = threading.Lock()

class Task:
    def __init__(self, task_id, func, args, kwargs):
        self.task_id = task_id
        self.func = func
        self.args = args
        self.kwargs = kwargs
        self.result = None
        self.exception = None

    def work(self):
        global last_exception, active_generations
        try:
            # Increment active generations counter
            with generation_lock:
                active_generations += 1
            
            self.result = self.func(*self.args, **self.kwargs)
            self.exception = None
            last_exception = None
            
        except Exception as e:
            traceback.print_exc()
            print(e)
            self.exception = e
            last_exception = e
        finally:
            # Decrement active generations counter
            with generation_lock:
                active_generations -= 1


def loop():
    global lock, last_id, waiting_list, finished_list
    while True:
        time.sleep(0.01)
        if len(waiting_list) > 0:
            with lock:
                task = waiting_list.pop(0)

            task.work()

            with lock:
                finished_list.append(task)

def wait_for_all_generations():
    """Wait until all generation tasks are complete"""
    while True:
        with generation_lock:
            if active_generations == 0:
                break
        time.sleep(0.1)

def async_run(func, *args, **kwargs):
    global lock, last_id, waiting_list, finished_list
    with lock:
        last_id += 1
        new_task = Task(task_id=last_id, func=func, args=args, kwargs=kwargs)
        waiting_list.append(new_task)
    return new_task.task_id


def run_and_wait_result(func, *args, **kwargs):
    global lock, last_id, waiting_list, finished_list
    current_id = async_run(func, *args, **kwargs)
    while True:
        time.sleep(0.01)
        finished_task = None
        for t in finished_list.copy():
            if t.task_id == current_id:
                finished_task = t
                break
        if finished_task is not None:
            with lock:
                finished_list.remove(finished_task)
            
            # Wait for all generations to complete before returning
            wait_for_all_generations()
            
            return finished_task.result