StableDiffusion / app.py
coconutkim's picture
Update app.py
fbcef95 verified
from openai import OpenAI
# from google.colab import userdata
import os
import gradio as gr
# openai_key = userdata.get('OPENAI_API_KEY')
# SD_API_KEY = userdata.get('SD_API_KEY')
# API ํ‚ค๋ฅผ ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ ์ฝ์–ด์˜ค๊ธฐ
openai_key = os.environ.get('OPENAI_API_KEY')
SD_API_KEY = os.environ.get('SD_API_KEY')
os.environ["OPENAI_API_KEY"] = openai_key
# ์ž๋™ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
def novel_keyword(nobel_input):
system_prompt = "๋‹น์‹ ์€ ์ฃผ์–ด์ง€๋Š” ๋‚ด์šฉ์— ์–ด์šธ๋ฆฌ๋Š” ์ด๋ฏธ์ง€๋ฅผ ์ถ”์ฒœํ•˜๋Š” ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€๋ฅผ ๋ฌ˜์‚ฌํ•˜๋Š” ํ‚ค์›Œ๋“œ๋ฅผ ์˜์–ด๋กœ, ์ฝค๋งˆ๋กœ ์•Œ๋ ค์ฃผ์„ธ์š”"
client = OpenAI()
completion = client.chat.completions.create(
model="gpt-3.5-turbo",
temperature=0,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": nobel_input}
])
return completion.choices[0].message.content
# ์Šคํ…Œ์ด๋ธ” ๋””ํ“จ์ „ ๊ธฐ๋ณธ ์ด๋ฏธ์ง€ API ์š”์ฒญ
import requests
import json
def sd_call(prompt, width, height):
url = "https://stablediffusionapi.com/api/v3/text2img"
payload = json.dumps({
"key": SD_API_KEY,
"prompt": prompt,
"negative_prompt": "ng_deepnegative_v1_75t, (worst quality:1.4), (low quality:1.4), (normal quality:1.4), lowres, (nsfw:1.4)",
"width": width,
"height": height,
"samples": "1",
"num_inference_steps": "20",
"seed": None,
"guidance_scale": 7.5,
"safety_checker": "yes",
"multi_lingual": "no",
"panorama": "no",
"self_attention": "no",
"upscale": "no",
"embeddings_model": None,
"webhook": None,
"track_id": None
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
response_data = json.loads(response.text)
if response_data["status"] == "success":
image_link = response_data["output"][0]
fetch_result = "์ด๋ฏธ์ง€ ์ƒ์„ฑ์ด ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค"
elif response_data["status"] == "processing":
image_link = None
fetch_result = response_data["fetch_result"]
else:
image_link = None
fetch_result = "์‹คํŒจ์ž…๋‹ˆ๋‹ค ๋‹ค์‹œ ์‹คํ–‰ํ•ด ์ฃผ์„ธ์š”"
return image_link, fetch_result
def sd_recall(fetch_result):
url = fetch_result
payload = json.dumps({
"key": SD_API_KEY
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
response_data = json.loads(response.text)
image_link = response_data["output"]
if not image_link:
image_link = None
else:
image_link = image_link[0]
return image_link
def edit_load_img(img_url):
return img_url
import base64
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read())
encoded_string = encoded_image.decode('utf-8')
return encoded_string
from io import BytesIO
from PIL import Image
def save_base64_image_from_url(url):
file_path = "edit_img.png"
response = requests.get(url)
base64_string = response.text
image_data = base64.b64decode(base64_string)
image = Image.open(BytesIO(image_data))
image.save(file_path)
return file_path
def edit_img_generator(input_img, prompt):
# ์ด๋ฏธ์ง€ ์ €์žฅ
background_img = input_img['background']
background_img = Image.fromarray(background_img)
background_img_path = "background_img.png"
background_img.save(background_img_path)
mask_img = input_img['layers'][0]
mask_img_arr = Image.fromarray(mask_img)
mask_img_arr_path = "masked_img.png"
mask_img_arr.save(mask_img_arr_path)
mask_img_arr = Image.fromarray(mask_img)
mask_img_arr_path = "masked_img.png"
mask_img_arr.save(mask_img_arr_path)
background_img_base64 = image_to_base64(background_img_path)
mask_img_base64 = image_to_base64(mask_img_arr_path)
# inpaint ์š”์ฒญ ํŒŒํŠธ
url = "https://stablediffusionapi.com/api/v3/inpaint"
payload = json.dumps({
"key": SD_API_KEY,
"prompt": prompt,
"negative_prompt": None,
"init_image": background_img_base64,
"mask_image": mask_img_base64,
"width": "512",
"height": "512",
"samples": "1",
"num_inference_steps": "30",
"safety_checker": "no",
"enhance_prompt": "yes",
"guidance_scale": 7.5,
"strength": 0.7,
"base64": "yes",
"seed": None,
"webhook": None,
"track_id": None
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
response_data = json.loads(response.text)
if response_data["status"] == "success":
image_link = save_base64_image_from_url(response_data["output"][0])
fetch_result = ""
elif response_data["status"] == "processing":
image_link = None
fetch_result = response_data["fetch_result"]
else:
image_link = None
fetch_result = "์‹คํŒจ์ž…๋‹ˆ๋‹ค ๋‹ค์‹œ ์‹คํ–‰ํ•ด ์ฃผ์„ธ์š”"
print("๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค", response_data)
return image_link, fetch_result
# edit ์ด๋ฏธ์ง€ ๋‹ค์‹œ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
def sd_edit_recall(fetch_result):
edit_url = sd_recall(fetch_result)
image_link = save_base64_image_from_url(edit_url)
return image_link
with gr.Blocks(theme=gr.themes.Default()) as app:
with gr.Tab("์‚ฝํ™” ์ƒ์„ฑ"):
with gr.Row():
# 1
gr.Markdown(
value="""
# ์‚ฝํ™” ์ƒ์„ฑ
์งง์€ ์†Œ์„ค์— ์–ด์šธ๋ฆฌ๋Š” ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
"""
)
with gr.Row():
# column1
with gr.Column(scale=5):
# 2
pos_prompt = gr.Textbox(
label="์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž‘์„ฑํ•ด ์ฃผ์„ธ์š”",
value="ultra realistic close up portrait ((beautiful pale cyberpunk female with heavy black eyeliner))",
lines=8,
interactive=True,
)
with gr.Row():
# 3
auto_prompt_generator = gr.Textbox(
label="์ž๋™ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ",
lines=6,
placeholder="์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ์œ„ํ•œ ์†Œ์„ค ๋‚ด์šฉ์„ ์ž‘์„ฑํ•ด ์ฃผ์„ธ์š”.\n์ž๋™์œผ๋กœ ํ”„๋กฌํ”„ํŠธ๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.",
scale=7,
)
# 4
prompt_generator_btn = gr.Button(scale=1, value="์ž๋™\n์ƒ์„ฑ")
with gr.Group():
with gr.Row():
# 5
img_width = gr.Slider(
label="=Width", maximum=1024, value=512, interactive=True
)
# 6
img_height = gr.Slider(
label="=Height", maximum=1024, value=512, interactive=True
)
# column2
with gr.Column(scale=3):
# 7
output_status = gr.Textbox(
show_label=False, lines=1, placeholder="์ด๋ฏธ์ง€ ์ƒํƒœ๊ฐ€ ์ถœ๋ ฅ๋ฉ๋‹ˆ๋‹ค."
)
# 8
generator_img = gr.Image(
value="https://pub-3626123a908346a7a8be8d9295f44e26.r2.dev/generations/8c595b57-563c-4417-9bfb-96aaebbb30b3-0.png",
label="์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.",
)
# 9
generator_img_btn = gr.Button(value="์ด๋ฏธ์ง€ ์ƒ์„ฑ")
# 10
refresh_img_btn = gr.Button(value="์ด๋ฏธ์ง€ ์ƒˆ๋กœ๊ณ ์นจ")
# ์ž๋™ ์ƒ์„ฑ ๋ฒ„ํŠผ ํด๋ฆญ
prompt_generator_btn.click(
fn=novel_keyword,
inputs=[auto_prompt_generator],
outputs=[pos_prompt]
) # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ๋ฒ„ํŠผ ํด๋ฆญ
generator_img_btn.click(
fn=sd_call,
inputs=[pos_prompt, img_width, img_height],
outputs=[generator_img, output_status]
) # ์ด๋ฏธ์ง€ ์ƒˆ๋กœ๊ณ ์นจ ๋ฒ„ํŠผ ํด๋ฆญ
refresh_img_btn.click(
fn=sd_recall,
inputs=[output_status],
outputs=[generator_img]
)
with gr.Tab("์ด๋ฏธ์ง€ ํŽธ์ง‘") as edit_tab:
with gr.Row():
#1
gr.Markdown(
value="""
# ์ด๋ฏธ์ง€ ํŽธ์ง‘
์ƒ์„ฑํ•œ ์ด๋ฏธ์ง€๋ฅผ ํŽธ์ง‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
""")
with gr.Row():
#2
edit_prompt = gr.Textbox(
label="์ด๋ฏธ์ง€ ์ˆ˜์ • ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž‘์„ฑํ•ด ์ฃผ์„ธ์š”",
value="black hair",
lines=5,
interactive=True,
scale=7
)
#3
edit_btn = gr.Button(
value="์ด๋ฏธ์ง€ ํŽธ์ง‘",
scale=1
)
with gr.Row():
#4
edit_status = gr.Textbox(
show_label=False,
lines=1,
placeholder="ํŽธ์ง‘ ์ด๋ฏธ์ง€ ์ƒํƒœ๊ฐ€ ์ถœ๋ ฅ๋ฉ๋‹ˆ๋‹ค.",
scale=7
)
#5
refresh_edit_btn = gr.Button(
value="์ด๋ฏธ์ง€ ์ƒˆ๋กœ๊ณ ์นจ",
scale=1
)
with gr.Row():
#6
org_img = gr.ImageMask(
label="์ˆ˜์ • ์ „ ์ด๋ฏธ์ง€",
image_mode='RGB',
brush = gr.Brush(
default_size=20,
colors=["#FFFFFF"],
color_mode="fixed",
),
show_label=True
)
#7
edit_img = gr.Image(
label="์ˆ˜์ • ํ›„ ์ด๋ฏธ์ง€",
)
edit_tab.select(
fn=edit_load_img,
inputs=[generator_img],
outputs=[org_img]
)
#์ด๋ฏธ์ง€ ํŽธ์ง‘ ๋ฒ„ํŠผ ํด๋ฆญ
edit_btn.click(
fn=edit_img_generator,
inputs=[org_img, edit_prompt],
outputs=[edit_img, edit_status]
)
#์ด๋ฏธ์ง€
refresh_edit_btn.click(
fn=sd_edit_recall,
inputs=[edit_status],
outputs=[edit_img]
)
app.launch(debug=True)