Spaces:
Sleeping
Sleeping
import io | |
import os | |
import time | |
from pathlib import Path | |
import requests | |
from PIL import Image | |
API_ENDPOINT = "https://api.bfl.ml" | |
class ApiException(Exception): | |
def __init__(self, status_code: int, detail: str = None): | |
super().__init__() | |
self.detail = detail | |
self.status_code = status_code | |
def __str__(self) -> str: | |
return self.__repr__() | |
def __repr__(self) -> str: | |
if self.detail is None: | |
message = None | |
elif isinstance(self.detail, str): | |
message = self.detail | |
else: | |
message = "[" + ",".join(d["msg"] for d in self.detail) + "]" | |
return f"ApiException({self.status_code=}, {message=}, detail={self.detail})" | |
class ImageRequest: | |
def __init__( | |
self, | |
prompt: str, | |
width: int = 1024, | |
height: int = 1024, | |
name: str = "flux.1-pro", | |
num_steps: int = 50, | |
prompt_upsampling: bool = False, | |
seed: int = None, | |
validate: bool = True, | |
launch: bool = True, | |
api_key: str = None, | |
): | |
""" | |
Manages an image generation request to the API. | |
Args: | |
prompt: Prompt to sample | |
width: Width of the image in pixel | |
height: Height of the image in pixel | |
name: Name of the model | |
num_steps: Number of network evaluations | |
prompt_upsampling: Use prompt upsampling | |
seed: Fix the generation seed | |
validate: Run input validation | |
launch: Directly launches request | |
api_key: Your API key if not provided by the environment | |
Raises: | |
ValueError: For invalid input | |
ApiException: For errors raised from the API | |
""" | |
if validate: | |
if name not in ["flux.1-pro"]: | |
raise ValueError(f"Invalid model {name}") | |
elif width % 32 != 0: | |
raise ValueError(f"width must be divisible by 32, got {width}") | |
elif not (256 <= width <= 1440): | |
raise ValueError(f"width must be between 256 and 1440, got {width}") | |
elif height % 32 != 0: | |
raise ValueError(f"height must be divisible by 32, got {height}") | |
elif not (256 <= height <= 1440): | |
raise ValueError(f"height must be between 256 and 1440, got {height}") | |
elif not (1 <= num_steps <= 50): | |
raise ValueError(f"steps must be between 1 and 50, got {num_steps}") | |
self.request_json = { | |
"prompt": prompt, | |
"width": width, | |
"height": height, | |
"variant": name, | |
"steps": num_steps, | |
"prompt_upsampling": prompt_upsampling, | |
} | |
if seed is not None: | |
self.request_json["seed"] = seed | |
self.request_id: str = None | |
self.result: dict = None | |
self._image_bytes: bytes = None | |
self._url: str = None | |
if api_key is None: | |
self.api_key = os.environ.get("BFL_API_KEY") | |
else: | |
self.api_key = api_key | |
if launch: | |
self.request() | |
def request(self): | |
""" | |
Request to generate the image. | |
""" | |
if self.request_id is not None: | |
return | |
response = requests.post( | |
f"{API_ENDPOINT}/v1/image", | |
headers={ | |
"accept": "application/json", | |
"x-key": self.api_key, | |
"Content-Type": "application/json", | |
}, | |
json=self.request_json, | |
) | |
result = response.json() | |
if response.status_code != 200: | |
raise ApiException(status_code=response.status_code, detail=result.get("detail")) | |
self.request_id = response.json()["id"] | |
def retrieve(self) -> dict: | |
""" | |
Wait for the generation to finish and retrieve response. | |
""" | |
if self.request_id is None: | |
self.request() | |
while self.result is None: | |
response = requests.get( | |
f"{API_ENDPOINT}/v1/get_result", | |
headers={ | |
"accept": "application/json", | |
"x-key": self.api_key, | |
}, | |
params={ | |
"id": self.request_id, | |
}, | |
) | |
result = response.json() | |
if "status" not in result: | |
raise ApiException(status_code=response.status_code, detail=result.get("detail")) | |
elif result["status"] == "Ready": | |
self.result = result["result"] | |
elif result["status"] == "Pending": | |
time.sleep(0.5) | |
else: | |
raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") | |
return self.result | |
def bytes(self) -> bytes: | |
""" | |
Generated image as bytes. | |
""" | |
if self._image_bytes is None: | |
response = requests.get(self.url) | |
if response.status_code == 200: | |
self._image_bytes = response.content | |
else: | |
raise ApiException(status_code=response.status_code) | |
return self._image_bytes | |
def url(self) -> str: | |
""" | |
Public url to retrieve the image from | |
""" | |
if self._url is None: | |
result = self.retrieve() | |
self._url = result["sample"] | |
return self._url | |
def image(self) -> Image.Image: | |
""" | |
Load the image as a PIL Image | |
""" | |
return Image.open(io.BytesIO(self.bytes)) | |
def save(self, path: str): | |
""" | |
Save the generated image to a local path | |
""" | |
suffix = Path(self.url).suffix | |
if not path.endswith(suffix): | |
path = path + suffix | |
Path(path).resolve().parent.mkdir(parents=True, exist_ok=True) | |
with open(path, "wb") as file: | |
file.write(self.bytes) | |
if __name__ == "__main__": | |
from fire import Fire | |
Fire(ImageRequest) | |