from __future__ import annotations import asyncio import copy import sys import time from collections import deque from typing import Any, Deque, Dict, List, Tuple import fastapi from gradio.data_classes import Estimation, PredictBody, Progress, ProgressUnit from gradio.helpers import TrackedIterable from gradio.utils import AsyncRequest, run_coro_in_background, set_task_name class Event: def __init__( self, websocket: fastapi.WebSocket, session_hash: str, fn_index: int, ): self.websocket = websocket self.session_hash: str = session_hash self.fn_index: int = fn_index self._id = f"{self.session_hash}_{self.fn_index}" self.data: PredictBody | None = None self.lost_connection_time: float | None = None self.token: str | None = None self.progress: Progress | None = None self.progress_pending: bool = False async def disconnect(self, code: int = 1000): await self.websocket.close(code=code) class Queue: def __init__( self, live_updates: bool, concurrency_count: int, update_intervals: float, max_size: int | None, blocks_dependencies: List, ): self.event_queue: Deque[Event] = deque() self.events_pending_reconnection = [] self.stopped = False self.max_thread_count = concurrency_count self.update_intervals = update_intervals self.active_jobs: List[None | List[Event]] = [None] * concurrency_count self.delete_lock = asyncio.Lock() self.server_path = None self.duration_history_total = 0 self.duration_history_count = 0 self.avg_process_time = 0 self.avg_concurrent_process_time = None self.queue_duration = 1 self.live_updates = live_updates self.sleep_when_free = 0.05 self.progress_update_sleep_when_free = 0.1 self.max_size = max_size self.blocks_dependencies = blocks_dependencies self.access_token = "" async def start(self, progress_tracking=False): run_coro_in_background(self.start_processing) if progress_tracking: run_coro_in_background(self.start_progress_tracking) if not self.live_updates: run_coro_in_background(self.notify_clients) def close(self): self.stopped = True def resume(self): self.stopped = False def set_url(self, url: str): self.server_path = url def set_access_token(self, token: str): self.access_token = token def get_active_worker_count(self) -> int: count = 0 for worker in self.active_jobs: if worker is not None: count += 1 return count def get_events_in_batch(self) -> Tuple[List[Event] | None, bool]: if not (self.event_queue): return None, False first_event = self.event_queue.popleft() events = [first_event] event_fn_index = first_event.fn_index batch = self.blocks_dependencies[event_fn_index]["batch"] if batch: batch_size = self.blocks_dependencies[event_fn_index]["max_batch_size"] rest_of_batch = [ event for event in self.event_queue if event.fn_index == event_fn_index ][: batch_size - 1] events.extend(rest_of_batch) [self.event_queue.remove(event) for event in rest_of_batch] return events, batch async def start_processing(self) -> None: while not self.stopped: if not self.event_queue: await asyncio.sleep(self.sleep_when_free) continue if not (None in self.active_jobs): await asyncio.sleep(self.sleep_when_free) continue # Using mutex to avoid editing a list in use async with self.delete_lock: events, batch = self.get_events_in_batch() if events: self.active_jobs[self.active_jobs.index(None)] = events task = run_coro_in_background(self.process_events, events, batch) run_coro_in_background(self.broadcast_live_estimations) set_task_name(task, events[0].session_hash, events[0].fn_index, batch) async def start_progress_tracking(self) -> None: while not self.stopped: if not any(self.active_jobs): await asyncio.sleep(self.progress_update_sleep_when_free) continue for job in self.active_jobs: if job is None: continue for event in job: if event.progress_pending and event.progress: event.progress_pending = False client_awake = await self.send_message( event, event.progress.dict() ) if not client_awake: await self.clean_event(event) await asyncio.sleep(self.progress_update_sleep_when_free) def set_progress( self, event_id: str, iterables: List[TrackedIterable] | None, ): if iterables is None: return for job in self.active_jobs: if job is None: continue for evt in job: if evt._id == event_id: progress_data: List[ProgressUnit] = [] for iterable in iterables: progress_unit = ProgressUnit( index=iterable.index, length=iterable.length, unit=iterable.unit, progress=iterable.progress, desc=iterable.desc, ) progress_data.append(progress_unit) evt.progress = Progress(progress_data=progress_data) evt.progress_pending = True def push(self, event: Event) -> int | None: """ Add event to queue, or return None if Queue is full Parameters: event: Event to add to Queue Returns: rank of submitted Event """ queue_len = len(self.event_queue) if self.max_size is not None and queue_len >= self.max_size: return None self.event_queue.append(event) return queue_len async def clean_event(self, event: Event) -> None: if event in self.event_queue: async with self.delete_lock: self.event_queue.remove(event) async def broadcast_live_estimations(self) -> None: """ Runs 2 functions sequentially instead of concurrently. Otherwise dced clients are tried to get deleted twice. """ if self.live_updates: await self.broadcast_estimations() async def gather_event_data(self, event: Event) -> bool: """ Gather data for the event Parameters: event: """ if not event.data: client_awake = await self.send_message(event, {"msg": "send_data"}) if not client_awake: return False event.data = await self.get_message(event) return True async def notify_clients(self) -> None: """ Notify clients about events statuses in the queue periodically. """ while not self.stopped: await asyncio.sleep(self.update_intervals) if self.event_queue: await self.broadcast_estimations() async def broadcast_estimations(self) -> None: estimation = self.get_estimation() # Send all messages concurrently await asyncio.gather( *[ self.send_estimation(event, estimation, rank) for rank, event in enumerate(self.event_queue) ] ) async def send_estimation( self, event: Event, estimation: Estimation, rank: int ) -> Estimation: """ Send estimation about ETA to the client. Parameters: event: estimation: rank: """ estimation.rank = rank if self.avg_concurrent_process_time is not None: estimation.rank_eta = ( estimation.rank * self.avg_concurrent_process_time + self.avg_process_time ) if None not in self.active_jobs: # Add estimated amount of time for a thread to get empty estimation.rank_eta += self.avg_concurrent_process_time client_awake = await self.send_message(event, estimation.dict()) if not client_awake: await self.clean_event(event) return estimation def update_estimation(self, duration: float) -> None: """ Update estimation by last x element's average duration. Parameters: duration: """ self.duration_history_total += duration self.duration_history_count += 1 self.avg_process_time = ( self.duration_history_total / self.duration_history_count ) self.avg_concurrent_process_time = self.avg_process_time / min( self.max_thread_count, self.duration_history_count ) self.queue_duration = self.avg_concurrent_process_time * len(self.event_queue) def get_estimation(self) -> Estimation: return Estimation( queue_size=len(self.event_queue), avg_event_process_time=self.avg_process_time, avg_event_concurrent_process_time=self.avg_concurrent_process_time, queue_eta=self.queue_duration, ) def get_request_params(self, websocket: fastapi.WebSocket) -> Dict[str, Any]: return { "url": str(websocket.url), "headers": dict(websocket.headers), "query_params": dict(websocket.query_params), "path_params": dict(websocket.path_params), "client": dict(host=websocket.client.host, port=websocket.client.port), # type: ignore } async def call_prediction(self, events: List[Event], batch: bool): data = events[0].data assert data is not None, "No event data" token = events[0].token data.event_id = events[0]._id if not batch else None try: data.request = self.get_request_params(events[0].websocket) except ValueError: pass if batch: data.data = list(zip(*[event.data.data for event in events if event.data])) data.request = [ self.get_request_params(event.websocket) for event in events if event.data ] data.batched = True response = await AsyncRequest( method=AsyncRequest.Method.POST, url=f"{self.server_path}api/predict", json=dict(data), headers={"Authorization": f"Bearer {self.access_token}"}, cookies={"access-token": token} if token is not None else None, ) return response async def process_events(self, events: List[Event], batch: bool) -> None: awake_events: List[Event] = [] try: for event in events: client_awake = await self.gather_event_data(event) if client_awake: client_awake = await self.send_message( event, {"msg": "process_starts"} ) if client_awake: awake_events.append(event) if not awake_events: return begin_time = time.time() response = await self.call_prediction(awake_events, batch) if response.has_exception: for event in awake_events: await self.send_message( event, { "msg": "process_completed", "output": {"error": str(response.exception)}, "success": False, }, ) elif response.json.get("is_generating", False): old_response = response while response.json.get("is_generating", False): # Python 3.7 doesn't have named tasks. # In order to determine if a task was cancelled, we # ping the websocket to see if it was closed mid-iteration. if sys.version_info < (3, 8): is_alive = await self.send_message(event, {"msg": "alive?"}) if not is_alive: return old_response = response open_ws = [] for event in awake_events: open = await self.send_message( event, { "msg": "process_generating", "output": old_response.json, "success": old_response.status == 200, }, ) open_ws.append(open) awake_events = [ e for e, is_open in zip(awake_events, open_ws) if is_open ] if not awake_events: return response = await self.call_prediction(awake_events, batch) for event in awake_events: if response.status != 200: relevant_response = response else: relevant_response = old_response await self.send_message( event, { "msg": "process_completed", "output": relevant_response.json, "success": relevant_response.status == 200, }, ) else: output = copy.deepcopy(response.json) for e, event in enumerate(awake_events): if batch and "data" in output: output["data"] = list(zip(*response.json.get("data")))[e] await self.send_message( event, { "msg": "process_completed", "output": output, "success": response.status == 200, }, ) end_time = time.time() if response.status == 200: self.update_estimation(end_time - begin_time) finally: for event in awake_events: try: await event.disconnect() except Exception: pass self.active_jobs[self.active_jobs.index(events)] = None for event in awake_events: await self.clean_event(event) # Always reset the state of the iterator # If the job finished successfully, this has no effect # If the job is cancelled, this will enable future runs # to start "from scratch" await self.reset_iterators(event.session_hash, event.fn_index) async def send_message(self, event, data: Dict) -> bool: try: await event.websocket.send_json(data=data) return True except: await self.clean_event(event) return False async def get_message(self, event) -> PredictBody | None: try: data = await event.websocket.receive_json() return PredictBody(**data) except: await self.clean_event(event) return None async def reset_iterators(self, session_hash: str, fn_index: int): await AsyncRequest( method=AsyncRequest.Method.POST, url=f"{self.server_path}reset", json={ "session_hash": session_hash, "fn_index": fn_index, }, )