Chakshu123's picture
Duplicate from Chakshu123/sketch-colorization-with-hint
443d045
from __future__ import annotations
import asyncio
import copy
import sys
import time
from collections import deque
from typing import Any, Deque, Dict, List, Tuple
import fastapi
from gradio.data_classes import Estimation, PredictBody, Progress, ProgressUnit
from gradio.helpers import TrackedIterable
from gradio.utils import AsyncRequest, run_coro_in_background, set_task_name
class Event:
def __init__(
self,
websocket: fastapi.WebSocket,
session_hash: str,
fn_index: int,
):
self.websocket = websocket
self.session_hash: str = session_hash
self.fn_index: int = fn_index
self._id = f"{self.session_hash}_{self.fn_index}"
self.data: PredictBody | None = None
self.lost_connection_time: float | None = None
self.token: str | None = None
self.progress: Progress | None = None
self.progress_pending: bool = False
async def disconnect(self, code: int = 1000):
await self.websocket.close(code=code)
class Queue:
def __init__(
self,
live_updates: bool,
concurrency_count: int,
update_intervals: float,
max_size: int | None,
blocks_dependencies: List,
):
self.event_queue: Deque[Event] = deque()
self.events_pending_reconnection = []
self.stopped = False
self.max_thread_count = concurrency_count
self.update_intervals = update_intervals
self.active_jobs: List[None | List[Event]] = [None] * concurrency_count
self.delete_lock = asyncio.Lock()
self.server_path = None
self.duration_history_total = 0
self.duration_history_count = 0
self.avg_process_time = 0
self.avg_concurrent_process_time = None
self.queue_duration = 1
self.live_updates = live_updates
self.sleep_when_free = 0.05
self.progress_update_sleep_when_free = 0.1
self.max_size = max_size
self.blocks_dependencies = blocks_dependencies
self.access_token = ""
async def start(self, progress_tracking=False):
run_coro_in_background(self.start_processing)
if progress_tracking:
run_coro_in_background(self.start_progress_tracking)
if not self.live_updates:
run_coro_in_background(self.notify_clients)
def close(self):
self.stopped = True
def resume(self):
self.stopped = False
def set_url(self, url: str):
self.server_path = url
def set_access_token(self, token: str):
self.access_token = token
def get_active_worker_count(self) -> int:
count = 0
for worker in self.active_jobs:
if worker is not None:
count += 1
return count
def get_events_in_batch(self) -> Tuple[List[Event] | None, bool]:
if not (self.event_queue):
return None, False
first_event = self.event_queue.popleft()
events = [first_event]
event_fn_index = first_event.fn_index
batch = self.blocks_dependencies[event_fn_index]["batch"]
if batch:
batch_size = self.blocks_dependencies[event_fn_index]["max_batch_size"]
rest_of_batch = [
event for event in self.event_queue if event.fn_index == event_fn_index
][: batch_size - 1]
events.extend(rest_of_batch)
[self.event_queue.remove(event) for event in rest_of_batch]
return events, batch
async def start_processing(self) -> None:
while not self.stopped:
if not self.event_queue:
await asyncio.sleep(self.sleep_when_free)
continue
if not (None in self.active_jobs):
await asyncio.sleep(self.sleep_when_free)
continue
# Using mutex to avoid editing a list in use
async with self.delete_lock:
events, batch = self.get_events_in_batch()
if events:
self.active_jobs[self.active_jobs.index(None)] = events
task = run_coro_in_background(self.process_events, events, batch)
run_coro_in_background(self.broadcast_live_estimations)
set_task_name(task, events[0].session_hash, events[0].fn_index, batch)
async def start_progress_tracking(self) -> None:
while not self.stopped:
if not any(self.active_jobs):
await asyncio.sleep(self.progress_update_sleep_when_free)
continue
for job in self.active_jobs:
if job is None:
continue
for event in job:
if event.progress_pending and event.progress:
event.progress_pending = False
client_awake = await self.send_message(
event, event.progress.dict()
)
if not client_awake:
await self.clean_event(event)
await asyncio.sleep(self.progress_update_sleep_when_free)
def set_progress(
self,
event_id: str,
iterables: List[TrackedIterable] | None,
):
if iterables is None:
return
for job in self.active_jobs:
if job is None:
continue
for evt in job:
if evt._id == event_id:
progress_data: List[ProgressUnit] = []
for iterable in iterables:
progress_unit = ProgressUnit(
index=iterable.index,
length=iterable.length,
unit=iterable.unit,
progress=iterable.progress,
desc=iterable.desc,
)
progress_data.append(progress_unit)
evt.progress = Progress(progress_data=progress_data)
evt.progress_pending = True
def push(self, event: Event) -> int | None:
"""
Add event to queue, or return None if Queue is full
Parameters:
event: Event to add to Queue
Returns:
rank of submitted Event
"""
queue_len = len(self.event_queue)
if self.max_size is not None and queue_len >= self.max_size:
return None
self.event_queue.append(event)
return queue_len
async def clean_event(self, event: Event) -> None:
if event in self.event_queue:
async with self.delete_lock:
self.event_queue.remove(event)
async def broadcast_live_estimations(self) -> None:
"""
Runs 2 functions sequentially instead of concurrently. Otherwise dced clients are tried to get deleted twice.
"""
if self.live_updates:
await self.broadcast_estimations()
async def gather_event_data(self, event: Event) -> bool:
"""
Gather data for the event
Parameters:
event:
"""
if not event.data:
client_awake = await self.send_message(event, {"msg": "send_data"})
if not client_awake:
return False
event.data = await self.get_message(event)
return True
async def notify_clients(self) -> None:
"""
Notify clients about events statuses in the queue periodically.
"""
while not self.stopped:
await asyncio.sleep(self.update_intervals)
if self.event_queue:
await self.broadcast_estimations()
async def broadcast_estimations(self) -> None:
estimation = self.get_estimation()
# Send all messages concurrently
await asyncio.gather(
*[
self.send_estimation(event, estimation, rank)
for rank, event in enumerate(self.event_queue)
]
)
async def send_estimation(
self, event: Event, estimation: Estimation, rank: int
) -> Estimation:
"""
Send estimation about ETA to the client.
Parameters:
event:
estimation:
rank:
"""
estimation.rank = rank
if self.avg_concurrent_process_time is not None:
estimation.rank_eta = (
estimation.rank * self.avg_concurrent_process_time
+ self.avg_process_time
)
if None not in self.active_jobs:
# Add estimated amount of time for a thread to get empty
estimation.rank_eta += self.avg_concurrent_process_time
client_awake = await self.send_message(event, estimation.dict())
if not client_awake:
await self.clean_event(event)
return estimation
def update_estimation(self, duration: float) -> None:
"""
Update estimation by last x element's average duration.
Parameters:
duration:
"""
self.duration_history_total += duration
self.duration_history_count += 1
self.avg_process_time = (
self.duration_history_total / self.duration_history_count
)
self.avg_concurrent_process_time = self.avg_process_time / min(
self.max_thread_count, self.duration_history_count
)
self.queue_duration = self.avg_concurrent_process_time * len(self.event_queue)
def get_estimation(self) -> Estimation:
return Estimation(
queue_size=len(self.event_queue),
avg_event_process_time=self.avg_process_time,
avg_event_concurrent_process_time=self.avg_concurrent_process_time,
queue_eta=self.queue_duration,
)
def get_request_params(self, websocket: fastapi.WebSocket) -> Dict[str, Any]:
return {
"url": str(websocket.url),
"headers": dict(websocket.headers),
"query_params": dict(websocket.query_params),
"path_params": dict(websocket.path_params),
"client": dict(host=websocket.client.host, port=websocket.client.port), # type: ignore
}
async def call_prediction(self, events: List[Event], batch: bool):
data = events[0].data
assert data is not None, "No event data"
token = events[0].token
data.event_id = events[0]._id if not batch else None
try:
data.request = self.get_request_params(events[0].websocket)
except ValueError:
pass
if batch:
data.data = list(zip(*[event.data.data for event in events if event.data]))
data.request = [
self.get_request_params(event.websocket)
for event in events
if event.data
]
data.batched = True
response = await AsyncRequest(
method=AsyncRequest.Method.POST,
url=f"{self.server_path}api/predict",
json=dict(data),
headers={"Authorization": f"Bearer {self.access_token}"},
cookies={"access-token": token} if token is not None else None,
)
return response
async def process_events(self, events: List[Event], batch: bool) -> None:
awake_events: List[Event] = []
try:
for event in events:
client_awake = await self.gather_event_data(event)
if client_awake:
client_awake = await self.send_message(
event, {"msg": "process_starts"}
)
if client_awake:
awake_events.append(event)
if not awake_events:
return
begin_time = time.time()
response = await self.call_prediction(awake_events, batch)
if response.has_exception:
for event in awake_events:
await self.send_message(
event,
{
"msg": "process_completed",
"output": {"error": str(response.exception)},
"success": False,
},
)
elif response.json.get("is_generating", False):
old_response = response
while response.json.get("is_generating", False):
# Python 3.7 doesn't have named tasks.
# In order to determine if a task was cancelled, we
# ping the websocket to see if it was closed mid-iteration.
if sys.version_info < (3, 8):
is_alive = await self.send_message(event, {"msg": "alive?"})
if not is_alive:
return
old_response = response
open_ws = []
for event in awake_events:
open = await self.send_message(
event,
{
"msg": "process_generating",
"output": old_response.json,
"success": old_response.status == 200,
},
)
open_ws.append(open)
awake_events = [
e for e, is_open in zip(awake_events, open_ws) if is_open
]
if not awake_events:
return
response = await self.call_prediction(awake_events, batch)
for event in awake_events:
if response.status != 200:
relevant_response = response
else:
relevant_response = old_response
await self.send_message(
event,
{
"msg": "process_completed",
"output": relevant_response.json,
"success": relevant_response.status == 200,
},
)
else:
output = copy.deepcopy(response.json)
for e, event in enumerate(awake_events):
if batch and "data" in output:
output["data"] = list(zip(*response.json.get("data")))[e]
await self.send_message(
event,
{
"msg": "process_completed",
"output": output,
"success": response.status == 200,
},
)
end_time = time.time()
if response.status == 200:
self.update_estimation(end_time - begin_time)
finally:
for event in awake_events:
try:
await event.disconnect()
except Exception:
pass
self.active_jobs[self.active_jobs.index(events)] = None
for event in awake_events:
await self.clean_event(event)
# Always reset the state of the iterator
# If the job finished successfully, this has no effect
# If the job is cancelled, this will enable future runs
# to start "from scratch"
await self.reset_iterators(event.session_hash, event.fn_index)
async def send_message(self, event, data: Dict) -> bool:
try:
await event.websocket.send_json(data=data)
return True
except:
await self.clean_event(event)
return False
async def get_message(self, event) -> PredictBody | None:
try:
data = await event.websocket.receive_json()
return PredictBody(**data)
except:
await self.clean_event(event)
return None
async def reset_iterators(self, session_hash: str, fn_index: int):
await AsyncRequest(
method=AsyncRequest.Method.POST,
url=f"{self.server_path}reset",
json={
"session_hash": session_hash,
"fn_index": fn_index,
},
)