|
""" Image Generation Module for AutoGPT.""" |
|
import io |
|
import os.path |
|
import uuid |
|
from base64 import b64decode |
|
|
|
import openai |
|
import requests |
|
from PIL import Image |
|
|
|
from autogpt.config import Config |
|
from autogpt.workspace import path_in_workspace |
|
|
|
CFG = Config() |
|
|
|
|
|
def generate_image(prompt: str, size: int = 256) -> str: |
|
"""Generate an image from a prompt. |
|
|
|
Args: |
|
prompt (str): The prompt to use |
|
size (int, optional): The size of the image. Defaults to 256. (Not supported by HuggingFace) |
|
|
|
Returns: |
|
str: The filename of the image |
|
""" |
|
filename = f"{str(uuid.uuid4())}.jpg" |
|
|
|
|
|
if CFG.image_provider == "dalle": |
|
return generate_image_with_dalle(prompt, filename, size) |
|
|
|
elif CFG.image_provider == "huggingface": |
|
return generate_image_with_hf(prompt, filename) |
|
|
|
elif CFG.image_provider == "sdwebui": |
|
return generate_image_with_sd_webui(prompt, filename, size) |
|
return "No Image Provider Set" |
|
|
|
|
|
def generate_image_with_hf(prompt: str, filename: str) -> str: |
|
"""Generate an image with HuggingFace's API. |
|
|
|
Args: |
|
prompt (str): The prompt to use |
|
filename (str): The filename to save the image to |
|
|
|
Returns: |
|
str: The filename of the image |
|
""" |
|
API_URL = ( |
|
f"https://api-inference.huggingface.co/models/{CFG.huggingface_image_model}" |
|
) |
|
if CFG.huggingface_api_token is None: |
|
raise ValueError( |
|
"You need to set your Hugging Face API token in the config file." |
|
) |
|
headers = { |
|
"Authorization": f"Bearer {CFG.huggingface_api_token}", |
|
"X-Use-Cache": "false", |
|
} |
|
|
|
response = requests.post( |
|
API_URL, |
|
headers=headers, |
|
json={ |
|
"inputs": prompt, |
|
}, |
|
) |
|
|
|
image = Image.open(io.BytesIO(response.content)) |
|
print(f"Image Generated for prompt:{prompt}") |
|
|
|
image.save(path_in_workspace(filename)) |
|
|
|
return f"Saved to disk:{filename}" |
|
|
|
|
|
def generate_image_with_dalle(prompt: str, filename: str) -> str: |
|
"""Generate an image with DALL-E. |
|
|
|
Args: |
|
prompt (str): The prompt to use |
|
filename (str): The filename to save the image to |
|
|
|
Returns: |
|
str: The filename of the image |
|
""" |
|
openai.api_key = CFG.openai_api_key |
|
|
|
|
|
if size not in [256, 512, 1024]: |
|
closest = min([256, 512, 1024], key=lambda x: abs(x - size)) |
|
print( |
|
f"DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. Setting to {closest}, was {size}." |
|
) |
|
size = closest |
|
|
|
response = openai.Image.create( |
|
prompt=prompt, |
|
n=1, |
|
size=f"{size}x{size}", |
|
response_format="b64_json", |
|
) |
|
|
|
print(f"Image Generated for prompt:{prompt}") |
|
|
|
image_data = b64decode(response["data"][0]["b64_json"]) |
|
|
|
with open(path_in_workspace(filename), mode="wb") as png: |
|
png.write(image_data) |
|
|
|
return f"Saved to disk:{filename}" |
|
|
|
|
|
def generate_image_with_sd_webui( |
|
prompt: str, |
|
filename: str, |
|
size: int = 512, |
|
negative_prompt: str = "", |
|
extra: dict = {}, |
|
) -> str: |
|
"""Generate an image with Stable Diffusion webui. |
|
Args: |
|
prompt (str): The prompt to use |
|
filename (str): The filename to save the image to |
|
size (int, optional): The size of the image. Defaults to 256. |
|
negative_prompt (str, optional): The negative prompt to use. Defaults to "". |
|
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}. |
|
Returns: |
|
str: The filename of the image |
|
""" |
|
|
|
s = requests.Session() |
|
if CFG.sd_webui_auth: |
|
username, password = CFG.sd_webui_auth.split(":") |
|
s.auth = (username, password or "") |
|
|
|
|
|
response = requests.post( |
|
f"{CFG.sd_webui_url}/sdapi/v1/txt2img", |
|
json={ |
|
"prompt": prompt, |
|
"negative_prompt": negative_prompt, |
|
"sampler_index": "DDIM", |
|
"steps": 20, |
|
"cfg_scale": 7.0, |
|
"width": size, |
|
"height": size, |
|
"n_iter": 1, |
|
**extra, |
|
}, |
|
) |
|
|
|
print(f"Image Generated for prompt:{prompt}") |
|
|
|
|
|
response = response.json() |
|
b64 = b64decode(response["images"][0].split(",", 1)[0]) |
|
image = Image.open(io.BytesIO(b64)) |
|
image.save(path_in_workspace(filename)) |
|
|
|
return f"Saved to disk:{filename}" |
|
|