flx-pulid / flux /api.py
邬彦泽
1
aa8012e
raw
history blame
6.1 kB
import io
import os
import time
from pathlib import Path
import requests
from PIL import Image
API_ENDPOINT = "https://api.bfl.ml"
class ApiException(Exception):
def __init__(self, status_code: int, detail: str = None):
super().__init__()
self.detail = detail
self.status_code = status_code
def __str__(self) -> str:
return self.__repr__()
def __repr__(self) -> str:
if self.detail is None:
message = None
elif isinstance(self.detail, str):
message = self.detail
else:
message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
class ImageRequest:
def __init__(
self,
prompt: str,
width: int = 1024,
height: int = 1024,
name: str = "flux.1-pro",
num_steps: int = 50,
prompt_upsampling: bool = False,
seed: int = None,
validate: bool = True,
launch: bool = True,
api_key: str = None,
):
"""
Manages an image generation request to the API.
Args:
prompt: Prompt to sample
width: Width of the image in pixel
height: Height of the image in pixel
name: Name of the model
num_steps: Number of network evaluations
prompt_upsampling: Use prompt upsampling
seed: Fix the generation seed
validate: Run input validation
launch: Directly launches request
api_key: Your API key if not provided by the environment
Raises:
ValueError: For invalid input
ApiException: For errors raised from the API
"""
if validate:
if name not in ["flux.1-pro"]:
raise ValueError(f"Invalid model {name}")
elif width % 32 != 0:
raise ValueError(f"width must be divisible by 32, got {width}")
elif not (256 <= width <= 1440):
raise ValueError(f"width must be between 256 and 1440, got {width}")
elif height % 32 != 0:
raise ValueError(f"height must be divisible by 32, got {height}")
elif not (256 <= height <= 1440):
raise ValueError(f"height must be between 256 and 1440, got {height}")
elif not (1 <= num_steps <= 50):
raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
self.request_json = {
"prompt": prompt,
"width": width,
"height": height,
"variant": name,
"steps": num_steps,
"prompt_upsampling": prompt_upsampling,
}
if seed is not None:
self.request_json["seed"] = seed
self.request_id: str = None
self.result: dict = None
self._image_bytes: bytes = None
self._url: str = None
if api_key is None:
self.api_key = os.environ.get("BFL_API_KEY")
else:
self.api_key = api_key
if launch:
self.request()
def request(self):
"""
Request to generate the image.
"""
if self.request_id is not None:
return
response = requests.post(
f"{API_ENDPOINT}/v1/image",
headers={
"accept": "application/json",
"x-key": self.api_key,
"Content-Type": "application/json",
},
json=self.request_json,
)
result = response.json()
if response.status_code != 200:
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
self.request_id = response.json()["id"]
def retrieve(self) -> dict:
"""
Wait for the generation to finish and retrieve response.
"""
if self.request_id is None:
self.request()
while self.result is None:
response = requests.get(
f"{API_ENDPOINT}/v1/get_result",
headers={
"accept": "application/json",
"x-key": self.api_key,
},
params={
"id": self.request_id,
},
)
result = response.json()
if "status" not in result:
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
elif result["status"] == "Ready":
self.result = result["result"]
elif result["status"] == "Pending":
time.sleep(0.5)
else:
raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
return self.result
@property
def bytes(self) -> bytes:
"""
Generated image as bytes.
"""
if self._image_bytes is None:
response = requests.get(self.url)
if response.status_code == 200:
self._image_bytes = response.content
else:
raise ApiException(status_code=response.status_code)
return self._image_bytes
@property
def url(self) -> str:
"""
Public url to retrieve the image from
"""
if self._url is None:
result = self.retrieve()
self._url = result["sample"]
return self._url
@property
def image(self) -> Image.Image:
"""
Load the image as a PIL Image
"""
return Image.open(io.BytesIO(self.bytes))
def save(self, path: str):
"""
Save the generated image to a local path
"""
suffix = Path(self.url).suffix
if not path.endswith(suffix):
path = path + suffix
Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as file:
file.write(self.bytes)
if __name__ == "__main__":
from fire import Fire
Fire(ImageRequest)