Spaces:
Sleeping
Sleeping
File size: 2,458 Bytes
f060249 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import base64
import logging
import os
import time
from io import BytesIO
from PIL import Image
from octoai.client import Client as OctoAiClient
from fair import FairClient
logger = logging.getLogger()
import requests
SERVER_ADDRESS = "https://faircompute.com:8000"
ENDPOINT_ADDRESS = "http://dikobraz.mooo.com:5000"
TARGET_NODE = "119eccba-2388-43c1-bdb9-02133049604c"
# SERVER_ADDRESS = "http://localhost:8000"
# ENDPOINT_ADDRESS = "http://localhost:5000"
# TARGET_NODE = "ef09913249aa40ecba7d0097f7622855"
DOCKER_IMAGE = "faircompute/diffusion-octo:v1"
class EndpointClient:
def infer(self, prompt):
client = OctoAiClient()
inputs = {"prompt": {"text": prompt}}
response = client.infer(endpoint_url=f"{ENDPOINT_ADDRESS}/infer", inputs=inputs)
image_b64 = response["output"]["image_b64"]
image_data = base64.b64decode(image_b64)
image_data = BytesIO(image_data)
image = Image.open(image_data)
return image
class ServerNotReadyException(Exception):
pass
def wait_for_server(retries, timeout, delay=1.0):
for i in range(retries):
try:
r = requests.get(ENDPOINT_ADDRESS, timeout=timeout)
r.raise_for_status()
return
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.Timeout) as e:
if i == retries - 1:
raise ServerNotReadyException("Failed to start the server") from e
else:
logger.info("Server is not ready yet")
time.sleep(delay)
def start_server():
# default credentials will work only for local server built in debug mode
client = FairClient(server_address=SERVER_ADDRESS,
user_email=os.getenv('FAIRCOMPUTE_EMAIL', "debug-usr"),
user_password=os.environ.get('FAIRCOMPUTE_PASSWORD', "debug-pwd"))
client.run(node=TARGET_NODE,
image=DOCKER_IMAGE,
ports=[(5000, 8080)],
detach=True)
# wait until the server is ready
wait_for_server(retries=10, timeout=1.0)
def text_to_image(text):
try:
wait_for_server(retries=1, timeout=1.0, delay=0.0)
except ServerNotReadyException:
start_server()
client = EndpointClient()
return client.infer(text)
if __name__ == "__main__":
image = text_to_image(text="Robot dinosaur\n")
image.save("result.png")
|