|
import importlib |
|
import gradio as gr |
|
import io |
|
from PIL import Image, PngImagePlugin |
|
import base64 |
|
import requests |
|
import json |
|
import ui_functions as uifn |
|
from css_and_js import js, call_JS |
|
import re |
|
from share_btn import community_icon_html, loading_icon_html, share_js |
|
|
|
|
|
txt2img_defaults = { |
|
'prompt': '', |
|
'ddim_steps': 50, |
|
'toggles': [1, 2, 3], |
|
'sampler_name': 'k_lms', |
|
'ddim_eta': 0.0, |
|
'n_iter': 1, |
|
'batch_size': 1, |
|
'cfg_scale': 7.5, |
|
'seed': '', |
|
'height': 512, |
|
'width': 512, |
|
'fp': None, |
|
'variant_amount': 0.0, |
|
'variant_seed': '', |
|
'submit_on_enter': 'Yes', |
|
} |
|
|
|
img2img_defaults = { |
|
'prompt': '', |
|
'ddim_steps': 50, |
|
'toggles': [1, 4, 5], |
|
'sampler_name': 'k_lms', |
|
'ddim_eta': 0.0, |
|
'n_iter': 1, |
|
'batch_size': 1, |
|
'cfg_scale': 5.0, |
|
'denoising_strength': 0.75, |
|
'mask_mode': 1, |
|
'resize_mode': 0, |
|
'seed': '', |
|
'height': 512, |
|
'width': 512, |
|
'fp': None, |
|
} |
|
sample_img2img = None |
|
job_manager = None |
|
RealESRGAN = True |
|
show_embeddings = False |
|
|
|
img2img_resize_modes = [ |
|
"Just resize", |
|
"Crop and resize", |
|
"Resize and fill", |
|
] |
|
|
|
img2img_toggles = [ |
|
'Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', |
|
'Normalize Prompt Weights (ensure sum of weights add up to 1.0)', |
|
'Loopback (use images from previous batch when creating next batch)', |
|
'Random loopback seed', |
|
'Save individual images', |
|
'Save grid', |
|
'Sort samples by prompt', |
|
'Write sample info files', |
|
'Write sample info to one file', |
|
'jpg samples', |
|
] |
|
|
|
img2img_toggle_defaults = [img2img_toggles[i] for i in img2img_defaults['toggles']] |
|
|
|
def read_content(file_path: str) -> str: |
|
"""read the content of target file |
|
""" |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
|
|
return content |
|
|
|
def base2picture(resbase64): |
|
res=resbase64.split(',')[1] |
|
img_b64decode = base64.b64decode(res) |
|
image = io.BytesIO(img_b64decode) |
|
img = Image.open(image) |
|
return img |
|
|
|
def filter_content(raw_style: str): |
|
if "(" in raw_style: |
|
i = raw_style.index("(") |
|
else : |
|
i = -1 |
|
|
|
if i == -1: |
|
return raw_style |
|
else : |
|
return raw_style[:i] |
|
|
|
def request_images(raw_text, class_draw, style_draw, batch_size, sr_option, w, h, seed): |
|
if filter_content(class_draw) != "国画": |
|
if filter_content(class_draw) != "通用": |
|
raw_text = raw_text + f",{filter_content(class_draw)}" |
|
|
|
for sty in style_draw: |
|
raw_text = raw_text + f",{filter_content(sty)}" |
|
print(f"raw text is {raw_text}") |
|
url = "http://flagart.baai.ac.cn/api/general/" |
|
elif filter_content(class_draw) == "国画": |
|
if raw_text.endswith("国画"): |
|
pass |
|
else : |
|
raw_text = raw_text + ",国画" |
|
url = "http://flagart.baai.ac.cn/api/guohua/" |
|
|
|
d = {"data":[raw_text, int(batch_size), sr_option, w, h, seed]} |
|
|
|
|
|
r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"}) |
|
result_text = r.text |
|
content = json.loads(result_text)["data"][0] |
|
images = [] |
|
for i in range(int(batch_size)): |
|
|
|
images.append(base2picture(content[i])) |
|
|
|
return images |
|
|
|
def sr_request_images(img_str, idx, w, h, seed): |
|
idx_map = { |
|
"图片1(img1)":0, |
|
"图片2(img2)":1, |
|
"图片3(img3)":2, |
|
"图片4(img4)":3, |
|
} |
|
idx = idx_map[idx] |
|
|
|
image_data = img_str[idx] |
|
|
|
d = {"data":[image_data, 0, False, w, h, seed]} |
|
|
|
url = "http://flagart.baai.ac.cn/api/general/" |
|
r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"}) |
|
result_text = r.text |
|
|
|
content = json.loads(result_text)["data"][0] |
|
|
|
images = [base2picture(content[0])] |
|
|
|
return images |
|
|
|
|
|
def encode_pil_to_base64(pil_image): |
|
with io.BytesIO() as output_bytes: |
|
|
|
|
|
use_metadata = False |
|
metadata = PngImagePlugin.PngInfo() |
|
for key, value in pil_image.info.items(): |
|
if isinstance(key, str) and isinstance(value, str): |
|
metadata.add_text(key, value) |
|
use_metadata = True |
|
|
|
pil_image.save( |
|
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None) |
|
) |
|
bytes_data = output_bytes.getvalue() |
|
base64_str = str(base64.b64encode(bytes_data), "utf-8") |
|
return "data:image/png;base64," + base64_str |
|
|
|
def img2img(*args): |
|
|
|
|
|
for i, item in enumerate(args): |
|
|
|
if type(item) == dict: |
|
args[i]['image'] = encode_pil_to_base64(item['image']) |
|
args[i]['mask'] = encode_pil_to_base64(item['mask']) |
|
|
|
|
|
|
|
|
|
batch_size = args[8] |
|
|
|
url = "http://flagart.baai.ac.cn/api/img2img/" |
|
d = {"data":args} |
|
r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"}) |
|
|
|
result_text = r.text |
|
content = json.loads(result_text)["data"][0] |
|
images = [] |
|
for i in range(batch_size): |
|
|
|
images.append(base2picture(content[i])) |
|
|
|
|
|
|
|
return images |
|
|
|
|
|
examples = [ |
|
'水墨蝴蝶和牡丹花,国画', |
|
'苍劲有力的墨竹,国画', |
|
'暴风雨中的灯塔', |
|
'机械小松鼠,科学幻想', |
|
'中国水墨山水画,国画', |
|
"Lighthouse in the storm", |
|
"A dog", |
|
"Landscape by 张大千", |
|
"A tiger 长了兔子耳朵", |
|
"A baby bird 铅笔素描", |
|
] |
|
|
|
if __name__ == "__main__": |
|
block = gr.Blocks(css=read_content('style.css')) |
|
|
|
with block: |
|
|
|
|
|
|
|
|
|
|
|
with gr.Tabs(elem_id='tabss') as tabs: |
|
|
|
with gr.TabItem("文生图(Text-to-img)", id='txt2img_tab'): |
|
|
|
with gr.Group(): |
|
with gr.Box(): |
|
with gr.Row().style(mobile_collapse=False, equal_height=True): |
|
text = gr.Textbox( |
|
label="Prompt", |
|
show_label=False, |
|
max_lines=1, |
|
placeholder="Input text(输入文字)", |
|
interactive=True, |
|
).style( |
|
border=(True, False, True, True), |
|
rounded=(True, False, False, True), |
|
container=False, |
|
) |
|
|
|
btn = gr.Button("Generate image").style( |
|
margin=False, |
|
rounded=(True, True, True, True), |
|
) |
|
with gr.Row().style(mobile_collapse=False, equal_height=True): |
|
class_draw = gr.Radio(choices=["通用(general)","国画(traditional Chinese painting)",], value="通用(general)", show_label=True, label='生成类型(type)') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row().style(mobile_collapse=False, equal_height=True): |
|
style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)", |
|
"概念艺术(concept art)", "Warming lighting", |
|
"Dramatic lighting", "Natural lighting", |
|
"虚幻引擎(unreal engine)", "4k", "8k", |
|
"充满细节(full details)"], |
|
label="画面风格(style)", |
|
show_label=True, |
|
) |
|
with gr.Row().style(mobile_collapse=False, equal_height=True): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample_size = gr.Radio(choices=["1","2","3","4"], value="1", show_label=True, label='生成数量(number)') |
|
seed = gr.Number(-1, label='seed', interactive=True) |
|
with gr.Row().style(mobile_collapse=False, equal_height=True): |
|
w = gr.Slider(512,1024,value=512, step=64, label="width") |
|
h = gr.Slider(512,1024,value=512, step=64, label="height") |
|
with gr.Row(visible=False).style(mobile_collapse=False, equal_height=True): |
|
sr_option = gr.Checkbox(value=False, label="是否使用超分(Whether to use super-resolution)") |
|
|
|
gallery = gr.Gallery( |
|
label="Generated images", show_label=False, elem_id="gallery" |
|
).style(grid=[2,2]) |
|
gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100) |
|
with gr.Row().style(mobile_collapse=False, equal_height=True): |
|
img_choices = gr.Dropdown(["图片1(img1)"],label='请选择一张图片发送到图生图或者超分',show_label=True,value="图片1(img1)") |
|
with gr.Row().style(mobile_collapse=False, equal_height=True): |
|
output_txt2img_copy_to_input_btn = gr.Button("发送图片到图生图(Sent the image to img2img)").style( |
|
margin=False, |
|
rounded=(True, True, True, True), |
|
) |
|
output_txt2img_sr_btn = gr.Button("将选择的图片进行超分(super-resolution)").style( |
|
margin=False, |
|
rounded=(True, True, True, True), |
|
) |
|
|
|
sr_gallery = gr.Gallery( |
|
label="SR images", show_label=True, elem_id="sr_gallery" |
|
).style(grid=[1,1]) |
|
|
|
with gr.Row(): |
|
prompt = gr.Markdown("提示(Prompt):", visible=False) |
|
with gr.Row(): |
|
move_prompt_zh = gr.Markdown("请移至图生图部分进行编辑(拉到顶部)", visible=False) |
|
with gr.Row(): |
|
move_prompt_en = gr.Markdown("Please move to the img2img section for editing(Pull to the top)", visible=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size, sr_option, w, h, seed], outputs=gallery) |
|
btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size, sr_option, w, h, seed], outputs=gallery) |
|
|
|
output_txt2img_sr_btn.click(sr_request_images, inputs=[gallery, img_choices, w, h, seed], outputs=[sr_gallery]) |
|
|
|
sample_size.change( |
|
fn=uifn.change_img_choices, |
|
inputs=[sample_size], |
|
outputs=[img_choices] |
|
) |
|
|
|
with gr.TabItem("图生图(Img-to-Img)", id="img2img_tab"): |
|
with gr.Row(elem_id="prompt_row"): |
|
img2img_prompt = gr.Textbox(label="Prompt", |
|
elem_id='img2img_prompt_input', |
|
placeholder="神奇的森林,流淌的河流.", |
|
lines=1, |
|
max_lines=1 if txt2img_defaults['submit_on_enter'] == 'Yes' else 25, |
|
value=img2img_defaults['prompt'], |
|
show_label=False).style() |
|
|
|
img2img_btn_mask = gr.Button("Generate", variant="primary", visible=False, |
|
elem_id="img2img_mask_btn") |
|
img2img_btn_editor = gr.Button("Generate", variant="primary", elem_id="img2img_edit_btn") |
|
gr.Markdown('#### 输入图像') |
|
with gr.Row().style(equal_height=False): |
|
|
|
img2img_image_mask = gr.Image( |
|
value=sample_img2img, |
|
source="upload", |
|
interactive=True, |
|
tool="sketch", |
|
type='pil', |
|
elem_id="img2img_mask", |
|
image_mode="RGBA" |
|
) |
|
img2img_image_editor = gr.Image( |
|
value=sample_img2img, |
|
source="upload", |
|
interactive=True, |
|
tool="select", |
|
type='pil', |
|
visible=False, |
|
image_mode="RGBA", |
|
elem_id="img2img_editor" |
|
) |
|
with gr.Tabs(visible=False): |
|
with gr.TabItem("编辑设置"): |
|
with gr.Row(): |
|
|
|
choices=["Mask", "Crop", "Uncrop"] |
|
img2img_image_editor_mode = gr.Radio(choices=["Mask"], |
|
label="编辑模式", |
|
value="Mask", elem_id='edit_mode_select', |
|
visible=True) |
|
img2img_mask = gr.Radio(choices=["保留mask区域", "生成mask区域"], |
|
label="Mask 方式", |
|
|
|
value = "生成mask区域", |
|
visible=True) |
|
|
|
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1, |
|
label="How much blurry should the mask be? (to avoid hard edges)", |
|
value=3, visible=False) |
|
|
|
img2img_resize = gr.Radio(label="Resize mode", |
|
choices=["Just resize", "Crop and resize", |
|
"Resize and fill"], |
|
value=img2img_resize_modes[ |
|
img2img_defaults['resize_mode']], visible=False) |
|
|
|
img2img_painterro_btn = gr.Button("Advanced Editor",visible=False) |
|
|
|
|
|
gr.Markdown('#### 编辑后的图片') |
|
with gr.Row(): |
|
output_img2img_gallery = gr.Gallery(label="Images", elem_id="img2img_gallery_output").style( |
|
grid=[4,4,4] ) |
|
img2img_job_ui = job_manager.draw_gradio_ui() if job_manager else None |
|
with gr.Column(visible=False): |
|
with gr.Tabs(visible=False): |
|
with gr.TabItem("", id="img2img_actions_tab",visible=False): |
|
gr.Markdown("Select an image, then press one of the buttons below") |
|
with gr.Row(): |
|
output_img2img_copy_to_clipboard_btn = gr.Button("Copy to clipboard") |
|
output_img2img_copy_to_input_btn = gr.Button("Push to img2img input") |
|
output_img2img_copy_to_mask_btn = gr.Button("Push to img2img input mask") |
|
|
|
gr.Markdown("Warning: This will clear your current image and mask settings!") |
|
with gr.TabItem("", id="img2img_output_info_tab",visible=False): |
|
output_img2img_params = gr.Textbox(label="Generation parameters") |
|
with gr.Row(): |
|
output_img2img_copy_params = gr.Button("Copy full parameters").click( |
|
inputs=output_img2img_params, outputs=[], |
|
_js='(x) => {navigator.clipboard.writeText(x.replace(": ",":"))}', fn=None, |
|
show_progress=False) |
|
output_img2img_seed = gr.Number(label='Seed', interactive=False, visible=False) |
|
output_img2img_copy_seed = gr.Button("Copy only seed").click( |
|
inputs=output_img2img_seed, outputs=[], |
|
_js=call_JS("gradioInputToClipboard"), fn=None, show_progress=False) |
|
output_img2img_stats = gr.HTML(label='Stats') |
|
with gr.Row(): |
|
gr.Markdown('提示(prompt):') |
|
with gr.Row(): |
|
gr.Markdown('请选择一张图像掩盖掉一部分区域,并输入文本描述') |
|
with gr.Row(): |
|
gr.Markdown('Please select an image to cover up a part of the area and enter a text description.') |
|
gr.Markdown('# 编辑设置',visible=False) |
|
|
|
with gr.Row(visible=False): |
|
with gr.Column(): |
|
img2img_width = gr.Slider(minimum=64, maximum=2048, step=64, label="图片宽度", |
|
value=img2img_defaults["width"]) |
|
img2img_height = gr.Slider(minimum=64, maximum=2048, step=64, label="图片高度", |
|
value=img2img_defaults["height"]) |
|
img2img_cfg = gr.Slider(minimum=-40.0, maximum=30.0, step=0.5, |
|
label='文本引导强度', |
|
value=img2img_defaults['cfg_scale'], elem_id='cfg_slider') |
|
img2img_seed = gr.Textbox(label="随机种子", lines=1, max_lines=1, |
|
value=img2img_defaults["seed"]) |
|
img2img_batch_count = gr.Slider(minimum=1, maximum=50, step=1, |
|
label='生成数量', |
|
value=img2img_defaults['n_iter']) |
|
img2img_dimensions_info_text_box = gr.Textbox( |
|
label="长宽比设置") |
|
with gr.Column(): |
|
img2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="采样步数", |
|
value=img2img_defaults['ddim_steps']) |
|
|
|
img2img_sampling = gr.Dropdown(label='采样方式', |
|
choices=["DDIM", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', |
|
'k_heun', 'k_lms'], |
|
value=img2img_defaults['sampler_name']) |
|
|
|
img2img_denoising = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', |
|
value=img2img_defaults['denoising_strength'],visible=False) |
|
|
|
img2img_toggles = gr.CheckboxGroup(label='', choices=img2img_toggles, |
|
value=img2img_toggle_defaults,visible=False) |
|
|
|
img2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', |
|
choices=['RealESRGAN_x4plus', |
|
'RealESRGAN_x4plus_anime_6B'], |
|
value='RealESRGAN_x4plus', |
|
visible=RealESRGAN is not None) |
|
|
|
img2img_embeddings = gr.File(label="Embeddings file for textual inversion", |
|
visible=show_embeddings) |
|
|
|
img2img_image_editor_mode.change( |
|
uifn.change_image_editor_mode, |
|
[img2img_image_editor_mode, |
|
img2img_image_editor, |
|
img2img_image_mask, |
|
img2img_resize, |
|
img2img_width, |
|
img2img_height |
|
], |
|
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask, |
|
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength] |
|
) |
|
|
|
|
|
img2img_image_editor_mode.change( |
|
uifn.update_image_mask, |
|
[img2img_image_editor, img2img_resize, img2img_width, img2img_height], |
|
img2img_image_mask |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_txt2img_copy_to_input_btn.click( |
|
uifn.copy_img_to_input, |
|
[gallery, img_choices], |
|
[tabs, img2img_image_editor, img2img_image_mask, move_prompt_zh, move_prompt_en, prompt] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_img2img_copy_to_input_btn.click( |
|
uifn.copy_img_to_edit, |
|
[output_img2img_gallery], |
|
[img2img_image_editor, tabs, img2img_image_editor_mode], |
|
_js=call_JS("moveImageFromGallery", |
|
fromId="gallery", |
|
toId="img2img_editor") |
|
) |
|
output_img2img_copy_to_mask_btn.click( |
|
uifn.copy_img_to_mask, |
|
[output_img2img_gallery], |
|
[img2img_image_mask, tabs, img2img_image_editor_mode], |
|
_js=call_JS("moveImageFromGallery", |
|
fromId="img2img_gallery_output", |
|
toId="img2img_editor") |
|
) |
|
|
|
output_img2img_copy_to_clipboard_btn.click(fn=None, inputs=output_img2img_gallery, outputs=[], |
|
_js=call_JS("copyImageFromGalleryToClipboard", |
|
fromId="img2img_gallery_output") |
|
) |
|
|
|
img2img_func = img2img |
|
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, |
|
img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles, |
|
img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg, |
|
img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, |
|
img2img_image_mask] |
|
|
|
|
|
|
|
img2img_outputs = [output_img2img_gallery] |
|
|
|
|
|
if img2img_job_ui: |
|
img2img_func, img2img_inputs, img2img_outputs = img2img_job_ui.wrap_func( |
|
func=img2img_func, |
|
inputs=img2img_inputs, |
|
outputs=img2img_outputs, |
|
) |
|
|
|
img2img_btn_mask.click( |
|
img2img_func, |
|
img2img_inputs, |
|
img2img_outputs |
|
) |
|
|
|
def img2img_submit_params(): |
|
|
|
|
|
|
|
|
|
|
|
return (img2img_func, |
|
img2img_inputs, |
|
img2img_outputs) |
|
|
|
img2img_btn_editor.click(*img2img_submit_params()) |
|
|
|
|
|
img2img_prompt.submit(None, None, None, |
|
_js=call_JS("clickFirstVisibleButton", |
|
rowId="prompt_row")) |
|
|
|
img2img_painterro_btn.click(None, |
|
[img2img_image_editor, img2img_image_mask, img2img_image_editor_mode], |
|
[img2img_image_editor, img2img_image_mask], |
|
_js=call_JS("Painterro.init", toId="img2img_editor") |
|
) |
|
|
|
img2img_width.change(fn=uifn.update_dimensions_info, inputs=[img2img_width, img2img_height], |
|
outputs=img2img_dimensions_info_text_box) |
|
img2img_height.change(fn=uifn.update_dimensions_info, inputs=[img2img_width, img2img_height], |
|
outputs=img2img_dimensions_info_text_box) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block.queue(max_size=50, concurrency_count=20).launch() |