import asyncio import time from typing import Any, AsyncGenerator, Generator, List, Optional import gradio_client # type: ignore class GradioClientWrapper: def __init__( self, src: str, h2ogpt_key: Optional[str] = None, huggingface_token: Optional[str] = None, ): self._client = gradio_client.Client( src=src, hf_token=huggingface_token, serialize=False, verbose=False ) self.h2ogpt_key = h2ogpt_key def predict(self, *args, api_name: str) -> Any: return self._client.predict(*args, api_name=api_name) def predict_and_stream(self, *args, api_name: str) -> Generator[str, None, None]: job = self._client.submit(*args, api_name=api_name) while not job.done(): outputs: List[str] = job.outputs() if not len(outputs): time.sleep(0.1) continue newest_response = outputs[-1] yield newest_response e = job.exception() if e and isinstance(e, BaseException): raise RuntimeError from e async def submit(self, *args, api_name: str) -> Any: return await asyncio.wrap_future(self._client.submit(*args, api_name=api_name)) async def submit_and_stream( self, *args, api_name: str ) -> AsyncGenerator[Any, None]: job = self._client.submit(*args, api_name=api_name) while not job.done(): outputs: List[str] = job.outputs() if not len(outputs): await asyncio.sleep(0.1) continue newest_response = outputs[-1] yield newest_response e = job.exception() if e and isinstance(e, BaseException): raise RuntimeError from e