| | import io |
| | import os |
| | import time |
| | from pathlib import Path |
| |
|
| | import requests |
| | from PIL import Image |
| |
|
| | API_ENDPOINT = "https://api.bfl.ml" |
| |
|
| |
|
| | class ApiException(Exception): |
| | def __init__(self, status_code: int, detail: str | list[dict] | None = None): |
| | super().__init__() |
| | self.detail = detail |
| | self.status_code = status_code |
| |
|
| | def __str__(self) -> str: |
| | return self.__repr__() |
| |
|
| | def __repr__(self) -> str: |
| | if self.detail is None: |
| | message = None |
| | elif isinstance(self.detail, str): |
| | message = self.detail |
| | else: |
| | message = "[" + ",".join(d["msg"] for d in self.detail) + "]" |
| | return f"ApiException({self.status_code=}, {message=}, detail={self.detail})" |
| |
|
| |
|
| | class ImageRequest: |
| | def __init__( |
| | self, |
| | prompt: str, |
| | width: int = 1024, |
| | height: int = 1024, |
| | name: str = "flux.1-pro", |
| | num_steps: int = 50, |
| | prompt_upsampling: bool = False, |
| | seed: int | None = None, |
| | validate: bool = True, |
| | launch: bool = True, |
| | api_key: str | None = None, |
| | ): |
| | """ |
| | Manages an image generation request to the API. |
| | |
| | Args: |
| | prompt: Prompt to sample |
| | width: Width of the image in pixel |
| | height: Height of the image in pixel |
| | name: Name of the model |
| | num_steps: Number of network evaluations |
| | prompt_upsampling: Use prompt upsampling |
| | seed: Fix the generation seed |
| | validate: Run input validation |
| | launch: Directly launches request |
| | api_key: Your API key if not provided by the environment |
| | |
| | Raises: |
| | ValueError: For invalid input |
| | ApiException: For errors raised from the API |
| | """ |
| | if validate: |
| | if name not in ["flux.1-pro"]: |
| | raise ValueError(f"Invalid model {name}") |
| | elif width % 32 != 0: |
| | raise ValueError(f"width must be divisible by 32, got {width}") |
| | elif not (256 <= width <= 1440): |
| | raise ValueError(f"width must be between 256 and 1440, got {width}") |
| | elif height % 32 != 0: |
| | raise ValueError(f"height must be divisible by 32, got {height}") |
| | elif not (256 <= height <= 1440): |
| | raise ValueError(f"height must be between 256 and 1440, got {height}") |
| | elif not (1 <= num_steps <= 50): |
| | raise ValueError(f"steps must be between 1 and 50, got {num_steps}") |
| |
|
| | self.request_json = { |
| | "prompt": prompt, |
| | "width": width, |
| | "height": height, |
| | "variant": name, |
| | "steps": num_steps, |
| | "prompt_upsampling": prompt_upsampling, |
| | } |
| | if seed is not None: |
| | self.request_json["seed"] = seed |
| |
|
| | self.request_id: str | None = None |
| | self.result: dict | None = None |
| | self._image_bytes: bytes | None = None |
| | self._url: str | None = None |
| | if api_key is None: |
| | self.api_key = os.environ.get("BFL_API_KEY") |
| | else: |
| | self.api_key = api_key |
| |
|
| | if launch: |
| | self.request() |
| |
|
| | def request(self): |
| | """ |
| | Request to generate the image. |
| | """ |
| | if self.request_id is not None: |
| | return |
| | response = requests.post( |
| | f"{API_ENDPOINT}/v1/image", |
| | headers={ |
| | "accept": "application/json", |
| | "x-key": self.api_key, |
| | "Content-Type": "application/json", |
| | }, |
| | json=self.request_json, |
| | ) |
| | result = response.json() |
| | if response.status_code != 200: |
| | raise ApiException(status_code=response.status_code, detail=result.get("detail")) |
| | self.request_id = response.json()["id"] |
| |
|
| | def retrieve(self) -> dict: |
| | """ |
| | Wait for the generation to finish and retrieve response. |
| | """ |
| | if self.request_id is None: |
| | self.request() |
| | while self.result is None: |
| | response = requests.get( |
| | f"{API_ENDPOINT}/v1/get_result", |
| | headers={ |
| | "accept": "application/json", |
| | "x-key": self.api_key, |
| | }, |
| | params={ |
| | "id": self.request_id, |
| | }, |
| | ) |
| | result = response.json() |
| | if "status" not in result: |
| | raise ApiException(status_code=response.status_code, detail=result.get("detail")) |
| | elif result["status"] == "Ready": |
| | self.result = result["result"] |
| | elif result["status"] == "Pending": |
| | time.sleep(0.5) |
| | else: |
| | raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") |
| | return self.result |
| |
|
| | @property |
| | def bytes(self) -> bytes: |
| | """ |
| | Generated image as bytes. |
| | """ |
| | if self._image_bytes is None: |
| | response = requests.get(self.url) |
| | if response.status_code == 200: |
| | self._image_bytes = response.content |
| | else: |
| | raise ApiException(status_code=response.status_code) |
| | return self._image_bytes |
| |
|
| | @property |
| | def url(self) -> str: |
| | """ |
| | Public url to retrieve the image from |
| | """ |
| | if self._url is None: |
| | result = self.retrieve() |
| | self._url = result["sample"] |
| | return self._url |
| |
|
| | @property |
| | def image(self) -> Image.Image: |
| | """ |
| | Load the image as a PIL Image |
| | """ |
| | return Image.open(io.BytesIO(self.bytes)) |
| |
|
| | def save(self, path: str): |
| | """ |
| | Save the generated image to a local path |
| | """ |
| | suffix = Path(self.url).suffix |
| | if not path.endswith(suffix): |
| | path = path + suffix |
| | Path(path).resolve().parent.mkdir(parents=True, exist_ok=True) |
| | with open(path, "wb") as file: |
| | file.write(self.bytes) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | from fire import Fire |
| |
|
| | Fire(ImageRequest) |
| |
|