import os import io import base64 import json from typing import Callable, Any from PIL import Image import gradio as gr from common import utils as posex project_dir = os.path.dirname(os.path.abspath(__file__)) print(project_dir) if '__file__' in globals(): posex.set_save_dir(os.path.join(os.path.dirname(__file__), '', 'saved_poses')) else: # cf. https://stackoverflow.com/a/53293924 import inspect posex.set_save_dir(os.path.join(os.path.dirname(inspect.getfile(lambda: None)), '', 'saved_poses')) def js2py( name: str, id: Callable[[str], str], js: Callable[[str], str], sink: gr.components.IOComponent, ) -> gr.Textbox: v_set = gr.Button(elem_id=id(f'{name}_set'),visible=False) v = gr.Textbox(elem_id=id(name),visible=False) v_sink = gr.Textbox(visible=False) v_set.click(fn=None, _js=js(name), outputs=[v, v_sink]) v_sink.change(fn=None, _js=js(f'{name}_after'), outputs=[sink]) return v def py2js( name: str, fn: Callable[[], str], id: Callable[[str], str], js: Callable[[str], str], sink: gr.components.IOComponent, ) -> None: v_fire = gr.Button(elem_id=id(f'{name}_get'),visible=False) v_sink = gr.Textbox(visible=False) v_sink2 = gr.Textbox(visible=False) v_fire.click(fn=wrap_api(fn), outputs=[v_sink, v_sink2]) v_sink2.change(fn=None, _js=js(name), inputs=[v_sink], outputs=[sink]) def jscall( name: str, fn: Callable[[str], str], id: Callable[[str], str], js: Callable[[str], str], sink: gr.components.IOComponent, ) -> None: v_args_set = gr.Button(elem_id=id(f'{name}_args_set'), visible=False) v_args = gr.JSON(elem_id=id(f'{name}_args'), visible=False) v_args_sink = gr.JSON(visible=False) v_args_set.click(fn=None, _js=js(f'{name}_args'), outputs=[v_args, v_args_sink]) v_args_sink.change(fn=None, _js=js(f'{name}_args_after'), outputs=[sink]) v_fire = gr.Button(elem_id=id(f'{name}_get'),visible=False) v_sink = gr.Textbox(visible=False) v_sink2 = gr.Textbox(visible=False) v_fire.click(fn=wrap_api(fn), inputs=[v_args], outputs=[v_sink, v_sink2]) v_sink2.change(fn=None, _js=js(name), inputs=[v_sink], outputs=[sink]) def generatecall( name: str, fn: Callable[[str], str], id: Callable[[str], str], js: Callable[[str], str], sink: gr.components.IOComponent, prompt, prompt_n, output_img, ) -> None: v_args_set = gr.Button(elem_id=id(f'{name}_args_set'), visible=False) v_args = gr.JSON(elem_id=id(f'{name}_args'), visible=False) v_args_sink = gr.JSON(visible=False) v_args_set.click(fn=None, _js=js(f'{name}_args'), outputs=[v_args, v_args_sink]) v_args_sink.change(fn=None, _js=js(f'{name}_args_after'), outputs=[sink]) v_fire = gr.Button(elem_id=id(f'{name}_get'),visible=False) v_sink = gr.Textbox(visible=False) v_sink2 = gr.Textbox(visible=False) v_fire.click(fn=fn, inputs=[v_args,prompt,prompt_n], outputs=[output_img]) v_sink2.change(fn=None, _js=js(name), inputs=[v_sink], outputs=[sink]) def get_self_extension(): if '__file__' in globals(): filepath = __file__ else: import inspect filepath = inspect.getfile(lambda: None) # APIs def wrap_api(fn): _r = 0 def f(*args, **kwargs): nonlocal _r _r += 1 v = fn(*args, **kwargs) return v, str(_r) return f def all_pose(): return json.dumps(list(posex.all_poses())) def delete_pose(args): posex.delete_pose(json.loads(args)[0]) return '' def save_pose(args): posex.save_pose(json.loads(args)[0]) return '' def load_pose(args): return json.dumps(posex.load_pose(json.loads(args)[0])) # def get_imgs(args): # return posex.get_img(args) def generate_imgs(data, image_prompt, image_n_prompt): return posex.generate_img(data, image_prompt, image_n_prompt) def get_image_sketch(image_prompt, image_n_prompt, image): return posex.get_image_sketch(image, image_prompt, image_n_prompt) def javascript_html(): script_js = f'script.js?{os.path.getmtime(os.path.join(project_dir,"script.js"))}' path7 = f'javascript/posex-webui.js?{os.path.getmtime(os.path.join(project_dir,"javascript/posex-webui.js"))}' head = f'\n' head += f'\n' return head def css_html(): head = f'' return head def reload_javascript(): js = javascript_html() css = css_html() def template_response(*args, **kwargs): res = GradioTemplateResponseOriginal(*args, **kwargs) res.body = res.body.replace(b'', f'{js}'.encode("utf8")) res.body = res.body.replace(b'', f'{css}'.encode("utf8")) res.init_headers() return res gr.routes.templates.TemplateResponse = template_response GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse if __name__ == '__main__': reload_javascript() app = gr.Blocks() with app: is_img2img = [] # id = lambda s: f'posex-{["t2i", "i2i"]}-{s}' id = lambda s: f'posex-t2i-{s}' js = lambda s: f'globalThis["{id(s)}"]' ext = get_self_extension() xpath = os.path.join(project_dir,"javascript/lazyload/posex-webui.js") js_ = [project_dir, f'{xpath}?{os.path.getmtime(xpath)}'] print(js_) with gr.Blocks() as demo: with gr.Tab("sketch"): with gr.Row(): image_prompt_sketch = gr.Textbox(label="image prompt", value='', elem_id=id('image_prompt')) image_n_prompt_sketch = gr.Textbox(label="negative prompt", value='', elem_id=id('image_n_prompt')) with gr.Row(): image_sketch = gr.Image(interactive=True,source='upload', type="numpy", tool="sketch").style(height=512, width=512) result_img = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') with gr.Row(): generate_result_button = gr.Button(value='generate') generate_result_button.click(fn=get_image_sketch,inputs=[image_prompt_sketch, image_n_prompt_sketch, image_sketch], outputs=[result_img]) with gr.Tab("openpose", elem_id=id('tab')): # with gr.Row(): # enabled = gr.Checkbox(value=False, label='Send this image to ControlNet.', elem_id=id('enabled')) # cn_num = gr.Number(value=0, precision=0, label='Target ControlNet number', visible=True) with gr.Row(): # segment_prompt = gr.Textbox(label="segment prompt", value='', # elem_id=id('segment_prompt')) image_prompt = gr.Textbox(label="image prompt", value='', elem_id=id('image_prompt')) image_n_prompt = gr.Textbox(label="negative prompt", value='', elem_id=id('image_n_prompt')) gr.HTML(value='\n'.join(js_), elem_id=id('js'), visible=False) gr.HTML(value='', elem_id=id('html')) generate_button = gr.Button(elem_id=id(f'generate'), value='generate') with gr.Column(): result_img = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') with gr.Group(visible=False): sink = gr.HTML(value='', visible=False) # to suppress error in javascript base64 = js2py('base64', id, js, sink) py2js('allposes', all_pose, id, js, sink) jscall('delpose', delete_pose, id, js, sink) jscall('savepose', save_pose, id, js, sink) jscall('loadpose', load_pose, id, js, sink) # jscall('getimgs', get_imgs, id, js, sink) generatecall('getimgs', generate_imgs, id, js, sink, image_prompt, image_n_prompt, result_img) app.launch()