Spaces:
Running
Running
import os | |
import json | |
import base64 | |
import requests | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# MAX_LEN = 40 | |
# STEP = 2 | |
# x = np.arange(0, MAX_LEN, STEP) | |
# token_counts = [0] * (MAX_LEN//STEP) | |
# with open("prompts.json", 'r') as f: | |
# prompts = json.load(f) | |
# for prompt in prompts: | |
# tokens = len(prompt.strip().split(' ')) | |
# token_counts[min(tokens//STEP, MAX_LEN//STEP-1)] += 1 | |
# plt.xticks(x, x+1) | |
# plt.xlabel("token counts") | |
# plt.bar(x, token_counts, width=1.3) | |
# # plt.show() | |
# plt.savefig("token_counts.png") | |
## Generate image prompts | |
with open("prompts.json") as f: | |
text_prompts = json.load(f) | |
engine_id = "stable-diffusion-v1-6" | |
api_host = os.getenv('API_HOST', 'https://api.stability.ai') | |
api_key = os.getenv("STABILITY_API_KEY", "sk-ZvoFiXEbln6yh0hvSlm1K60WYcWFY5rmyW8a9FgoVBrKKP9N") | |
if api_key is None: | |
raise Exception("Missing Stability API key.") | |
for idx, text in enumerate(text_prompts): | |
if idx<=20: continue | |
print(f"Start generate prompt[{idx}]: {text}") | |
response = requests.post( | |
f"{api_host}/v1/generation/{engine_id}/text-to-image", | |
headers={ | |
"Content-Type": "application/json", | |
"Accept": "application/json", | |
"Authorization": f"Bearer {api_key}" | |
}, | |
json={ | |
"text_prompts": [ | |
{ | |
"text": text.strip() | |
} | |
], | |
"cfg_scale": 7, | |
"height": 1024, | |
"width": 1024, | |
"samples": 3, | |
"steps": 30, | |
}, | |
) | |
if response.status_code != 200: | |
# raise Exception("Non-200 response: " + str(response.text)) | |
print(f"{idx} Failed!!! {str(response.text)}") | |
continue | |
print("Finished!") | |
data = response.json() | |
for i, image in enumerate(data["artifacts"]): | |
img_path = f"./images/{idx}/v1_txt2img_{i}.png" | |
os.makedirs(os.path.dirname(img_path), exist_ok=True) | |
with open(img_path, "wb") as f: | |
f.write(base64.b64decode(image["base64"])) |