import io import re import imp import time import json import base64 import requests import gradio as gr import ui_functions as uifn from css_and_js import js, call_JS from PIL import Image, PngImagePlugin, ImageChops url_host = "https://flagstudio.baai.ac.cn" token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiZjAxOGMxMzJiYTUyNDBjMzk5NTMzYTI5YjBmMzZiODMiLCJhcHBfbmFtZSI6IndlYiIsImlkZW50aXR5X3R5cGUiOiIyIiwidXNlcl9yb2xlIjoiMiIsImp0aSI6IjVjMmQzMjdiLWI5Y2MtNDhiZS1hZWQ4LTllMjQ4MDk4NzMxYyIsIm5iZiI6MTY2OTAwNjE5NywiZXhwIjoxOTg0MzY2MTk3LCJpYXQiOjE2NjkwMDYxOTd9.9B3MDk8wA6iWH5puXjcD19tJJ4Ox7mdpRyWZs5Kwt70" def read_content(file_path: str) -> str: """read the content of target file """ with open(file_path, 'r', encoding='utf-8') as f: content = f.read() return content def filter_content(raw_style: str): if "(" in raw_style: i = raw_style.index("(") else : i = -1 if i == -1: return raw_style else : return raw_style[:i] def upload_image(img): url = url_host + "/api/v1/image/get-upload-link" headers = {"token": token} r = requests.post(url, json={}, headers=headers) if r.status_code != 200: raise gr.Error(r.reason) head_res = r.json() if head_res["code"] != 0: raise gr.Error("Unknown error") image_id = head_res["data"]["image_id"] image_url = head_res["data"]["url"] image_headers = head_res["data"]["headers"] imgBytes = io.BytesIO() img.save(imgBytes, "PNG") imgBytes = imgBytes.getvalue() r = requests.put(image_url, data=imgBytes, headers=image_headers) if r.status_code != 200: raise gr.Error(r.reason) return image_id, image_url def post_reqest(seed, prompt, width, height, image_num, img=None, mask=None): data = { "type": "gen-image", "parameters": { "width": width, # output height width "height": height, # output image height "prompts": [prompt], } } data["parameters"]["seed"] = int(seed) if img is not None: # Upload image image_id, image_url = upload_image(img) data["parameters"]["init_image"] = { "image_id": image_id, "url": image_url, "width": img.width, "height": img.height, } if mask is not None: # Upload mask extrama = mask.convert("L").getextrema() if extrama[1] > 0: mask_id, mask_url = upload_image(mask) data["parameters"]["mask_image"] = { "image_id": mask_id, "url": mask_url, "width": mask.width, "height": mask.height, } headers = {"token": token} # Send create task request all_task_data = [] url = url_host+"/api/v1/task/create" for _ in range(image_num): r = requests.post(url, json=data, headers=headers) if r.status_code != 200: raise gr.Error(r.reason) create_res = r.json() if create_res['code'] == 3002: raise gr.Error("Inappropriate prompt detected.") elif create_res['code'] != 0: raise gr.Error("Unknown error") all_task_data.append(create_res["data"]) # Get result url = url_host+"/api/v1/task/status" images = [] while True: if len(all_task_data) <= 0: return images for i in range(len(all_task_data)-1, -1, -1): data = all_task_data[i] r = requests.post(url, json=data, headers=headers) if r.status_code != 200: raise gr.Error(r.reason) res = r.json() if res["code"] == 6002: # Running continue if res["code"] == 6005: raise gr.Error("NSFW image detected.") elif res["code"] == 0: # Finished for img_info in res["data"]["images"]: img_res = requests.get(img_info["url"]) images.append(Image.open(io.BytesIO(img_res.content)).convert("RGB")) del all_task_data[i] else: raise gr.Error(f"Error code: {res['code']}") time.sleep(1) def request_images(raw_text, class_draw, style_draw, batch_size, w, h, seed): if filter_content(class_draw) != "国画": if filter_content(class_draw) != "通用": raw_text = raw_text + f",{filter_content(class_draw)}" for sty in style_draw: raw_text = raw_text + f",{filter_content(sty)}" elif filter_content(class_draw) == "国画": raw_text = raw_text + ",国画,水墨画,大作,黑白,高清,传统" print(f"raw text is {raw_text}") images = post_reqest(seed, raw_text, w, h, int(batch_size)) return images def img2img(prompt, image_and_mask): if image_and_mask["image"].width <= image_and_mask["image"].height: width = 512 height = int((width/image_and_mask["image"].width)*image_and_mask["image"].height) else: height = 512 width = int((height/image_and_mask["image"].height)*image_and_mask["image"].width) return post_reqest(0, prompt, width, height, 1, image_and_mask["image"], image_and_mask["mask"]) examples = [ '水墨蝴蝶和牡丹花,国画', '苍劲有力的墨竹,国画', '暴风雨中的灯塔', '机械小松鼠,科学幻想', '中国水墨山水画,国画', "Lighthouse in the storm", "A dog", "Landscape by 张大千", "A tiger 长了兔子耳朵", "A baby bird 铅笔素描", ] if __name__ == "__main__": block = gr.Blocks(css=read_content('style.css')) with block: gr.HTML(read_content("header.html")) with gr.Tabs(elem_id='tabss') as tabs: with gr.TabItem("文生图(Text-to-img)", id='txt2img_tab'): with gr.Group(): with gr.Box(): with gr.Row().style(mobile_collapse=False, equal_height=True): text = gr.Textbox( label="Prompt", show_label=False, max_lines=1, placeholder="Input text(输入文字)", interactive=True, ).style( border=(True, False, True, True), rounded=(True, False, False, True), container=False, ) btn = gr.Button("Generate image").style( margin=False, rounded=(True, True, True, True), ) with gr.Row().style(mobile_collapse=False, equal_height=True): class_draw = gr.Radio(choices=["通用(general)","国画(traditional Chinese painting)",], value="通用(general)", show_label=True, label='生成类型(type)') # class_draw = gr.Dropdown(["通用(general)", "国画(traditional Chinese painting)", # "照片,摄影(picture photography)", "油画(oil painting)", # "铅笔素描(pencil sketch)", "CG", # "水彩画(watercolor painting)", "水墨画(ink and wash)", # "插画(illustrations)", "3D", "图生图(img2img)"], # label="生成类型(type)", # show_label=True, # value="通用(general)") with gr.Row().style(mobile_collapse=False, equal_height=True): style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)", "概念艺术(concept art)", "Warming lighting", "Dramatic lighting", "Natural lighting", "虚幻引擎(unreal engine)", "4k", "8k", "充满细节(full details)"], label="画面风格(style)", show_label=True, ) with gr.Row().style(mobile_collapse=False, equal_height=True): # sample_size = gr.Slider(minimum=1, # maximum=4, # step=1, # label="生成数量(number)", # show_label=True, # interactive=True, # ) sample_size = gr.Radio(choices=["1","2","3","4"], value="1", show_label=True, label='生成数量(number)') seed = gr.Number(0, label='seed', interactive=True) with gr.Row().style(mobile_collapse=False, equal_height=True): w = gr.Slider(512,1024,value=512, step=64, label="width") h = gr.Slider(512,1024,value=512, step=64, label="height") gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(grid=[2,2]) gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100) with gr.Row().style(mobile_collapse=False, equal_height=True): img_choices = gr.Dropdown(["图片1(img1)"],label='请选择一张图片发送到图生图',show_label=True,value="图片1(img1)") with gr.Row().style(mobile_collapse=False, equal_height=True): output_txt2img_copy_to_input_btn = gr.Button("发送图片到图生图(Sent the image to img2img)").style( margin=False, rounded=(True, True, True, True), ) with gr.Row(): prompt = gr.Markdown("提示(Prompt):", visible=False) with gr.Row(): move_prompt_zh = gr.Markdown("请移至图生图部分进行编辑(拉到顶部)", visible=False) with gr.Row(): move_prompt_en = gr.Markdown("Please move to the img2img section for editing(Pull to the top)", visible=False) text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery) btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size, w, h, seed], outputs=gallery) sample_size.change( fn=uifn.change_img_choices, inputs=[sample_size], outputs=[img_choices] ) with gr.TabItem("图生图(Img-to-Img)", id="img2img_tab"): with gr.Row(elem_id="prompt_row"): img2img_prompt = gr.Textbox(label="Prompt", elem_id='img2img_prompt_input', placeholder="神奇的森林,流淌的河流.", lines=1, max_lines=1, value="", show_label=False).style() img2img_btn_mask = gr.Button("Generate", variant="primary", visible=False, elem_id="img2img_mask_btn") img2img_btn_editor = gr.Button("Generate", variant="primary", elem_id="img2img_edit_btn") gr.Markdown('#### 输入图像') with gr.Row().style(equal_height=False): #with gr.Column(): img2img_image_mask = gr.Image( value=None, source="upload", interactive=True, tool="sketch", type='pil', elem_id="img2img_mask", image_mode="RGBA" ) gr.Markdown('#### 编辑后的图片') with gr.Row(): output_img2img_gallery = gr.Gallery(label="Images", elem_id="img2img_gallery_output").style( grid=[4,4,4] ) with gr.Row(): gr.Markdown('提示(prompt):') with gr.Row(): gr.Markdown('请选择一张图像掩盖掉一部分区域,并输入文本描述') with gr.Row(): gr.Markdown('Please select an image to cover up a part of the area and enter a text description.') gr.Markdown('# 编辑设置',visible=False) output_txt2img_copy_to_input_btn.click( uifn.copy_img_to_input, [gallery, img_choices], [tabs, img2img_image_mask, move_prompt_zh, move_prompt_en, prompt] ) img2img_func = img2img img2img_inputs = [img2img_prompt, img2img_image_mask] img2img_outputs = [output_img2img_gallery] img2img_btn_mask.click( img2img_func, img2img_inputs, img2img_outputs ) def img2img_submit_params(): return (img2img_func, img2img_inputs, img2img_outputs) img2img_btn_editor.click(*img2img_submit_params()) # GENERATE ON ENTER img2img_prompt.submit(None, None, None, _js=call_JS("clickFirstVisibleButton", rowId="prompt_row")) gr.HTML(read_content("footer.html")) # gr.Image('./contributors.png') block.queue(max_size=512, concurrency_count=256).launch()