test / client /h2ogpt_client /_gradio_client.py
iblfe's picture
Upload folder using huggingface_hub
b585c7f verified
raw
history blame
No virus
1.78 kB
import asyncio
import time
from typing import Any, AsyncGenerator, Generator, List, Optional
import gradio_client # type: ignore
class GradioClientWrapper:
def __init__(
self,
src: str,
h2ogpt_key: Optional[str] = None,
huggingface_token: Optional[str] = None,
):
self._client = gradio_client.Client(
src=src, hf_token=huggingface_token, serialize=False, verbose=False
)
self.h2ogpt_key = h2ogpt_key
def predict(self, *args, api_name: str) -> Any:
return self._client.predict(*args, api_name=api_name)
def predict_and_stream(self, *args, api_name: str) -> Generator[str, None, None]:
job = self._client.submit(*args, api_name=api_name)
while not job.done():
outputs: List[str] = job.outputs()
if not len(outputs):
time.sleep(0.1)
continue
newest_response = outputs[-1]
yield newest_response
e = job.exception()
if e and isinstance(e, BaseException):
raise RuntimeError from e
async def submit(self, *args, api_name: str) -> Any:
return await asyncio.wrap_future(self._client.submit(*args, api_name=api_name))
async def submit_and_stream(
self, *args, api_name: str
) -> AsyncGenerator[Any, None]:
job = self._client.submit(*args, api_name=api_name)
while not job.done():
outputs: List[str] = job.outputs()
if not len(outputs):
await asyncio.sleep(0.1)
continue
newest_response = outputs[-1]
yield newest_response
e = job.exception()
if e and isinstance(e, BaseException):
raise RuntimeError from e