File size: 1,776 Bytes
b585c7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import asyncio
import time
from typing import Any, AsyncGenerator, Generator, List, Optional

import gradio_client  # type: ignore


class GradioClientWrapper:
    def __init__(
        self,
        src: str,
        h2ogpt_key: Optional[str] = None,
        huggingface_token: Optional[str] = None,
    ):
        self._client = gradio_client.Client(
            src=src, hf_token=huggingface_token, serialize=False, verbose=False
        )
        self.h2ogpt_key = h2ogpt_key

    def predict(self, *args, api_name: str) -> Any:
        return self._client.predict(*args, api_name=api_name)

    def predict_and_stream(self, *args, api_name: str) -> Generator[str, None, None]:
        job = self._client.submit(*args, api_name=api_name)
        while not job.done():
            outputs: List[str] = job.outputs()
            if not len(outputs):
                time.sleep(0.1)
                continue
            newest_response = outputs[-1]
            yield newest_response

        e = job.exception()
        if e and isinstance(e, BaseException):
            raise RuntimeError from e

    async def submit(self, *args, api_name: str) -> Any:
        return await asyncio.wrap_future(self._client.submit(*args, api_name=api_name))

    async def submit_and_stream(
        self, *args, api_name: str
    ) -> AsyncGenerator[Any, None]:
        job = self._client.submit(*args, api_name=api_name)
        while not job.done():
            outputs: List[str] = job.outputs()
            if not len(outputs):
                await asyncio.sleep(0.1)
                continue
            newest_response = outputs[-1]
            yield newest_response

        e = job.exception()
        if e and isinstance(e, BaseException):
            raise RuntimeError from e