import base64 import logging import os import hashlib import requests import time from io import BytesIO from PIL import Image from fair import FairClient logger = logging.getLogger() SERVER_ADDRESS = "https://faircompute.com:8000" INFERENCE_NODE = "magnus" TUNNEL_NODE = "gcs-e2-micro" # SERVER_ADDRESS = "http://localhost:8000" # INFERENCE_NODE = "ef09913249aa40ecba7d0097f7622855" # TUNNEL_NODE = "c312e6c4788b00c73c287ab0445d3655" INFERENCE_DOCKER_IMAGE = "faircompute/diffusers-api-dreamshaper-8" TUNNEL_DOCKER_IMAGE = "rapiz1/rathole" endpoint_client = None fair_client = None class EndpointClient: def __init__(self, server_address, timeout): self.endpoint_address = f'http://{server_address}:5000' response = requests.get(os.path.join(self.endpoint_address, 'healthcheck'), timeout=timeout).json() if response['state'] != 'healthy': raise Exception("Server is not healthy") def infer(self, prompt): inputs = { "modelInputs": { "prompt": prompt, "num_inference_steps": 25, "width": 512, "height": 512, }, "callInputs": { "MODEL_ID": "lykon/dreamshaper-8", "PIPELINE": "AutoPipelineForText2Image", "SCHEDULER": "DEISMultistepScheduler", "PRECISION": "fp16", "REVISION": "fp16", }, } response = requests.post(self.endpoint_address, json=inputs).json() image_data = BytesIO(base64.b64decode(response["image_base64"])) image = Image.open(image_data) return image class ServerNotReadyException(Exception): pass def create_fair_client(): return FairClient(server_address=SERVER_ADDRESS, user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"), user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd")) def create_endpoint_client(fc, retries, timeout=1.0, delay=2.0): nodes = fc.cluster().nodes.list() server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE) for i in range(retries): try: return EndpointClient(server_address, timeout=timeout) except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e: logging.exception(e) time.sleep(delay) raise ServerNotReadyException("Failed to start the server") def start_tunnel(fc: FairClient): # generate fixed random authentication token based off some secret token = hashlib.sha256(os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd").encode()).hexdigest() # start tunnel node server_config = f""" [server] bind_addr = "0.0.0.0:2333" # port that rathole listens for clients [server.services.inference_server] token = "{token}" # token that is used to authenticate the client for the service bind_addr = "0.0.0.0:5000" # port that exposes service to the Internet """ with open('server.toml', 'w') as file: file.write(server_config) fc.run(node_name=TUNNEL_NODE, image=TUNNEL_DOCKER_IMAGE, command=["--server", "/app/config.toml"], volumes=[("./server.toml", "/app/config.toml")], network="host", detach=True) nodes = fc.cluster().nodes.list() server_address = next(info['host_address'] for info in nodes if info['name'] == TUNNEL_NODE) client_config = f""" [client] remote_addr = "{server_address}:2333" # address of the rathole server [client.services.inference_server] token = "{token}" # token that is used to authenticate the client for the service local_addr = "127.0.0.1:5001" # address of the service that needs to be forwarded """ with open('client.toml', 'w') as file: file.write(client_config) fc.run(node_name=INFERENCE_NODE, image=TUNNEL_DOCKER_IMAGE, command=["--client", "/app/config.toml"], volumes=[("./client.toml", "/app/config.toml")], network="host", detach=True) def start_inference_server(fc: FairClient): fc.run(node_name=INFERENCE_NODE, image=INFERENCE_DOCKER_IMAGE, runtime="nvidia", ports=[(5001, 8000)], detach=True) def text_to_image(text): global endpoint_client global fair_client if fair_client is None: fair_client = create_fair_client() try: # client is configured, try to do inference right away if endpoint_client is not None: return endpoint_client.infer(text) # client is not configured, try connecting to the inference server, maybe it is running else: endpoint_client = create_endpoint_client(fair_client, 1) except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout, ServerNotReadyException): # inference server is not ready, start inference server and open the tunnel start_inference_server(fair_client) start_tunnel(fair_client) endpoint_client = create_endpoint_client(fair_client, retries=10) # run inference return endpoint_client.infer(text) if __name__ == "__main__": image = text_to_image(text="Robot dinosaur\n") image.save("result.png")