AltDiffusion / app.py
Yemin Shi
update concurrency
6fd44fd
raw
history blame
15 kB
import io
import re
import imp
import time
import json
import base64
import requests
import gradio as gr
import ui_functions as uifn
from css_and_js import js, call_JS
from PIL import Image, PngImagePlugin, ImageChops
url_host = "http://flagstudio.baai.ac.cn"
token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiZjAxOGMxMzJiYTUyNDBjMzk5NTMzYTI5YjBmMzZiODMiLCJhcHBfbmFtZSI6IndlYiIsImlkZW50aXR5X3R5cGUiOiIyIiwidXNlcl9yb2xlIjoiMiIsImp0aSI6IjVjMmQzMjdiLWI5Y2MtNDhiZS1hZWQ4LTllMjQ4MDk4NzMxYyIsIm5iZiI6MTY2OTAwNjE5NywiZXhwIjoxOTg0MzY2MTk3LCJpYXQiOjE2NjkwMDYxOTd9.9B3MDk8wA6iWH5puXjcD19tJJ4Ox7mdpRyWZs5Kwt70"
def read_content(file_path: str) -> str:
"""read the content of target file
"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
def filter_content(raw_style: str):
if "(" in raw_style:
i = raw_style.index("(")
else :
i = -1
if i == -1:
return raw_style
else :
return raw_style[:i]
def upload_image(img):
url = url_host + "/api/v1/image/get-upload-link"
headers = {"token": token}
r = requests.post(url, json={}, headers=headers)
if r.status_code != 200:
raise gr.Error(r.reason)
head_res = r.json()
if head_res["code"] != 0:
raise gr.Error("Unknown error")
image_id = head_res["data"]["image_id"]
image_url = head_res["data"]["url"]
image_headers = head_res["data"]["headers"]
imgBytes = io.BytesIO()
img.save(imgBytes, "PNG")
imgBytes = imgBytes.getvalue()
r = requests.put(image_url, data=imgBytes, headers=image_headers)
if r.status_code != 200:
raise gr.Error(r.reason)
return image_id, image_url
def post_reqest(seed, prompt, width, height, image_num, img=None, mask=None):
data = {
"type": "gen-image",
"parameters": {
"width": width, # output height width
"height": height, # output image height
"prompts": [prompt],
}
}
data["parameters"]["seed"] = int(seed)
if img is not None:
# Upload image
image_id, image_url = upload_image(img)
data["parameters"]["init_image"] = {
"image_id": image_id,
"url": image_url,
"width": img.width,
"height": img.height,
}
if mask is not None:
# Upload mask
extrama = mask.convert("L").getextrema()
if extrama[1] > 0:
mask_id, mask_url = upload_image(mask)
data["parameters"]["mask_image"] = {
"image_id": mask_id,
"url": mask_url,
"width": mask.width,
"height": mask.height,
}
headers = {"token": token}
# Send create task request
all_task_data = []
url = url_host+"/api/v1/task/create"
for _ in range(image_num):
r = requests.post(url, json=data, headers=headers)
if r.status_code != 200:
raise gr.Error(r.reason)
create_res = r.json()
if create_res['code'] == 3002:
raise gr.Error("Inappropriate prompt detected.")
elif create_res['code'] != 0:
raise gr.Error("Unknown error")
all_task_data.append(create_res["data"])
# Get result
url = url_host+"/api/v1/task/status"
images = []
while True:
if len(all_task_data) <= 0:
return images
for i in range(len(all_task_data)-1, -1, -1):
data = all_task_data[i]
r = requests.post(url, json=data, headers=headers)
if r.status_code != 200:
raise gr.Error(r.reason)
res = r.json()
if res["code"] == 6002:
# Running
continue
if res["code"] == 6005:
raise gr.Error("NSFW image detected.")
elif res["code"] == 0:
# Finished
for img_info in res["data"]["images"]:
img_res = requests.get(img_info["url"])
images.append(Image.open(io.BytesIO(img_res.content)).convert("RGB"))
del all_task_data[i]
else:
raise gr.Error(f"Error code: {res['code']}")
time.sleep(1)
def request_images(raw_text, class_draw, style_draw, batch_size, w, h, seed):
if filter_content(class_draw) != "国画":
if filter_content(class_draw) != "通用":
raw_text = raw_text + f",{filter_content(class_draw)}"
for sty in style_draw:
raw_text = raw_text + f",{filter_content(sty)}"
elif filter_content(class_draw) == "国画":
raw_text = raw_text + ",国画,水墨画,大作,黑白,高清,传统"
print(f"raw text is {raw_text}")
images = post_reqest(seed, raw_text, w, h, int(batch_size))
return images
def img2img(prompt, image_and_mask):
if image_and_mask["image"].width <= image_and_mask["image"].height:
width = 512
height = int((width/image_and_mask["image"].width)*image_and_mask["image"].height)
else:
height = 512
width = int((height/image_and_mask["image"].height)*image_and_mask["image"].width)
return post_reqest(0, prompt, width, height, 1, image_and_mask["image"], image_and_mask["mask"])
examples = [
'水墨蝴蝶和牡丹花,国画',
'苍劲有力的墨竹,国画',
'暴风雨中的灯塔',
'机械小松鼠,科学幻想',
'中国水墨山水画,国画',
"Lighthouse in the storm",
"A dog",
"Landscape by 张大千",
"A tiger 长了兔子耳朵",
"A baby bird 铅笔素描",
]
if __name__ == "__main__":
block = gr.Blocks(css=read_content('style.css'))
with block:
gr.HTML(read_content("header.html"))
with gr.Tabs(elem_id='tabss') as tabs:
with gr.TabItem("文生图(Text-to-img)", id='txt2img_tab'):
with gr.Group():
with gr.Box():
with gr.Row().style(mobile_collapse=False, equal_height=True):
text = gr.Textbox(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Input text(输入文字)",
interactive=True,
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Generate image").style(
margin=False,
rounded=(True, True, True, True),
)
with gr.Row().style(mobile_collapse=False, equal_height=True):
class_draw = gr.Radio(choices=["通用(general)","国画(traditional Chinese painting)",], value="通用(general)", show_label=True, label='生成类型(type)')
# class_draw = gr.Dropdown(["通用(general)", "国画(traditional Chinese painting)",
# "照片,摄影(picture photography)", "油画(oil painting)",
# "铅笔素描(pencil sketch)", "CG",
# "水彩画(watercolor painting)", "水墨画(ink and wash)",
# "插画(illustrations)", "3D", "图生图(img2img)"],
# label="生成类型(type)",
# show_label=True,
# value="通用(general)")
with gr.Row().style(mobile_collapse=False, equal_height=True):
style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)",
"概念艺术(concept art)", "Warming lighting",
"Dramatic lighting", "Natural lighting",
"虚幻引擎(unreal engine)", "4k", "8k",
"充满细节(full details)"],
label="画面风格(style)",
show_label=True,
)
with gr.Row().style(mobile_collapse=False, equal_height=True):
# sample_size = gr.Slider(minimum=1,
# maximum=4,
# step=1,
# label="生成数量(number)",
# show_label=True,
# interactive=True,
# )
sample_size = gr.Radio(choices=["1","2","3","4"], value="1", show_label=True, label='生成数量(number)')
seed = gr.Number(0, label='seed', interactive=True)
with gr.Row().style(mobile_collapse=False, equal_height=True):
w = gr.Slider(512,1024,value=512, step=64, label="width")
h = gr.Slider(512,1024,value=512, step=64, label="height")
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2,2])
gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100)
with gr.Row().style(mobile_collapse=False, equal_height=True):
img_choices = gr.Dropdown(["图片1(img1)"],label='请选择一张图片发送到图生图',show_label=True,value="图片1(img1)")
with gr.Row().style(mobile_collapse=False, equal_height=True):
output_txt2img_copy_to_input_btn = gr.Button("发送图片到图生图(Sent the image to img2img)").style(
margin=False,
rounded=(True, True, True, True),
)
with gr.Row():
prompt = gr.Markdown("提示(Prompt):", visible=False)
with gr.Row():
move_prompt_zh = gr.Markdown("请移至图生图部分进行编辑(拉到顶部)", visible=False)
with gr.Row():
move_prompt_en = gr.Markdown("Please move to the img2img section for editing(Pull to the top)", visible=False)
text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery)
btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery)
sample_size.change(
fn=uifn.change_img_choices,
inputs=[sample_size],
outputs=[img_choices]
)
with gr.TabItem("图生图(Img-to-Img)", id="img2img_tab"):
with gr.Row(elem_id="prompt_row"):
img2img_prompt = gr.Textbox(label="Prompt",
elem_id='img2img_prompt_input',
placeholder="神奇的森林,流淌的河流.",
lines=1,
max_lines=1,
value="",
show_label=False).style()
img2img_btn_mask = gr.Button("Generate", variant="primary", visible=False,
elem_id="img2img_mask_btn")
img2img_btn_editor = gr.Button("Generate", variant="primary", elem_id="img2img_edit_btn")
gr.Markdown('#### 输入图像')
with gr.Row().style(equal_height=False):
#with gr.Column():
img2img_image_mask = gr.Image(
value=None,
source="upload",
interactive=True,
tool="sketch",
type='pil',
elem_id="img2img_mask",
image_mode="RGBA"
)
gr.Markdown('#### 编辑后的图片')
with gr.Row():
output_img2img_gallery = gr.Gallery(label="Images", elem_id="img2img_gallery_output").style(
grid=[4,4,4] )
with gr.Row():
gr.Markdown('提示(prompt):')
with gr.Row():
gr.Markdown('请选择一张图像掩盖掉一部分区域,并输入文本描述')
with gr.Row():
gr.Markdown('Please select an image to cover up a part of the area and enter a text description.')
gr.Markdown('# 编辑设置',visible=False)
output_txt2img_copy_to_input_btn.click(
uifn.copy_img_to_input,
[gallery, img_choices],
[tabs, img2img_image_mask, move_prompt_zh, move_prompt_en, prompt]
)
img2img_func = img2img
img2img_inputs = [img2img_prompt, img2img_image_mask]
img2img_outputs = [output_img2img_gallery]
img2img_btn_mask.click(
img2img_func,
img2img_inputs,
img2img_outputs
)
def img2img_submit_params():
return (img2img_func,
img2img_inputs,
img2img_outputs)
img2img_btn_editor.click(*img2img_submit_params())
# GENERATE ON ENTER
img2img_prompt.submit(None, None, None,
_js=call_JS("clickFirstVisibleButton",
rowId="prompt_row"))
gr.HTML(read_content("footer.html"))
# gr.Image('./contributors.png')
block.queue(max_size=512, concurrency_count=256).launch()