import traceback from typing import Callable import os from gradio_client.client import Job os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' from gradio_client import Client class GradioClient(Client): """ Parent class of gradio client To handle automatically refreshing client if detect gradio server changed """ def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs super().__init__(*args, **kwargs) self.server_hash = self.get_server_hash() def get_server_hash(self): """ Get server hash using super without any refresh action triggered Returns: git hash of gradio server """ return super().submit(api_name='/system_hash').result() def refresh_client_if_should(self): # get current hash in order to update api_name -> fn_index map in case gradio server changed # FIXME: Could add cli api as hash server_hash = self.get_server_hash() if self.server_hash != server_hash: self.refresh_client() self.server_hash = server_hash else: self.reset_session() def refresh_client(self): """ Ensure every client call is independent Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code) Returns: """ # need session hash to be new every time, to avoid "generator already executing" self.reset_session() client = Client(*self.args, **self.kwargs) for k, v in client.__dict__.items(): setattr(self, k, v) def submit( self, *args, api_name: str | None = None, fn_index: int | None = None, result_callbacks: Callable | list[Callable] | None = None, ) -> Job: # Note predict calls submit try: self.refresh_client_if_should() job = super().submit(*args, api_name=api_name, fn_index=fn_index) except Exception as e: print("Hit e=%s" % str(e), flush=True) # force reconfig in case only that self.refresh_client() job = super().submit(*args, api_name=api_name, fn_index=fn_index) # see if immediately failed e = job.future._exception if e is not None: print("GR job failed: %s %s" % (str(e), ''.join(traceback.format_tb(e.__traceback__))), flush=True) # force reconfig in case only that self.refresh_client() job = super().submit(*args, api_name=api_name, fn_index=fn_index) e2 = job.future._exception if e2 is not None: print("GR job failed again: %s\n%s" % (str(e2), ''.join(traceback.format_tb(e2.__traceback__))), flush=True) return job