File size: 1,776 Bytes
b585c7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import asyncio
import time
from typing import Any, AsyncGenerator, Generator, List, Optional
import gradio_client # type: ignore
class GradioClientWrapper:
def __init__(
self,
src: str,
h2ogpt_key: Optional[str] = None,
huggingface_token: Optional[str] = None,
):
self._client = gradio_client.Client(
src=src, hf_token=huggingface_token, serialize=False, verbose=False
)
self.h2ogpt_key = h2ogpt_key
def predict(self, *args, api_name: str) -> Any:
return self._client.predict(*args, api_name=api_name)
def predict_and_stream(self, *args, api_name: str) -> Generator[str, None, None]:
job = self._client.submit(*args, api_name=api_name)
while not job.done():
outputs: List[str] = job.outputs()
if not len(outputs):
time.sleep(0.1)
continue
newest_response = outputs[-1]
yield newest_response
e = job.exception()
if e and isinstance(e, BaseException):
raise RuntimeError from e
async def submit(self, *args, api_name: str) -> Any:
return await asyncio.wrap_future(self._client.submit(*args, api_name=api_name))
async def submit_and_stream(
self, *args, api_name: str
) -> AsyncGenerator[Any, None]:
job = self._client.submit(*args, api_name=api_name)
while not job.done():
outputs: List[str] = job.outputs()
if not len(outputs):
await asyncio.sleep(0.1)
continue
newest_response = outputs[-1]
yield newest_response
e = job.exception()
if e and isinstance(e, BaseException):
raise RuntimeError from e
|