Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from pipeline import CustomPipeline, setup_scheduler | |
| from diffusers import StableDiffusionPipeline | |
| from PIL import Image | |
| # from easydict import EasyDict as edict | |
| original_pipe = None | |
| original_config = None | |
| device = None | |
| # def run_dpm_demo(id, prompt, beta, num_inference_steps, guidance_scale, seed, enable_token_merging): | |
| def run_dpm_demo(prompt, beta, num_inference_steps, guidance_scale, seed): | |
| global original_pipe, original_config | |
| pipe = CustomPipeline(**original_pipe.components) | |
| seed = int(seed) | |
| num_inference_steps = int(num_inference_steps) | |
| scheduler = "DPM-Solver++" | |
| params = { | |
| "prompt": prompt, | |
| "num_inference_steps": num_inference_steps, | |
| "guidance_scale": guidance_scale, | |
| "method": "dpm" | |
| } | |
| # without momentum (equivalent to DPM-Solver++) | |
| pipe = setup_scheduler(pipe, scheduler, beta=1.0, original_config=original_config) | |
| params["generator"] = torch.Generator(device=device).manual_seed(seed) | |
| ori_image = pipe(**params).images[0] | |
| # with momentum | |
| pipe = setup_scheduler(pipe, scheduler, beta=beta, original_config=original_config) | |
| params["generator"] = torch.Generator(device=device).manual_seed(seed) | |
| image = pipe(**params).images[0] | |
| ori_image.save("temp1.png") | |
| image.save("temp2.png") | |
| return [ori_image, image] | |
| # def run_plms_demo(id, prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed, enable_token_merging): | |
| def run_plms_demo(prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed): | |
| global original_pipe, original_config | |
| pipe = CustomPipeline(**original_pipe.components) | |
| seed = int(seed) | |
| num_inference_steps = int(num_inference_steps) | |
| scheduler = "PLMS" | |
| method = "hb" if momentum_type == "Polyak's heavy ball" else "nt" | |
| params = { | |
| "prompt": prompt, | |
| "num_inference_steps": num_inference_steps, | |
| "guidance_scale": guidance_scale, | |
| "method": method | |
| } | |
| # without momentum (equivalent to PLMS) | |
| pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=1.0, original_config=original_config) | |
| params["generator"] = torch.Generator(device=device).manual_seed(seed) | |
| ori_image = pipe(**params).images[0] | |
| # with momentum | |
| pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=beta, original_config=original_config) | |
| params["generator"] = torch.Generator(device=device).manual_seed(seed) | |
| image = pipe(**params).images[0] | |
| return [ori_image, image] | |
| # def run_ghvb_demo(id, prompt, order, beta, num_inference_steps, guidance_scale, seed, enable_token_merging): | |
| def run_ghvb_demo(prompt, order, beta, num_inference_steps, guidance_scale, seed): | |
| global original_pipe, original_config | |
| pipe = CustomPipeline(**original_pipe.components) | |
| seed = int(seed) | |
| num_inference_steps = int(num_inference_steps) | |
| scheduler = "GHVB" | |
| params = { | |
| "prompt": prompt, | |
| "num_inference_steps": num_inference_steps, | |
| "guidance_scale": guidance_scale, | |
| "method": "ghvb" | |
| } | |
| # without momentum (equivalent to PLMS) | |
| pipe = setup_scheduler(pipe, scheduler, order=order, beta=1.0, original_config=original_config) | |
| params["generator"] = torch.Generator(device=device).manual_seed(seed) | |
| ori_image = pipe(**params).images[0] | |
| # with momentum | |
| pipe = setup_scheduler(pipe, scheduler, order=order, beta=beta, original_config=original_config) | |
| params["generator"] = torch.Generator(device=device).manual_seed(seed) | |
| image = pipe(**params).images[0] | |
| return [ori_image, image] | |
| if __name__ == "__main__": | |
| demo = gr.Blocks() | |
| inputs = {} | |
| outputs = {} | |
| buttons = {} | |
| list_models = [ | |
| "Linaqruf/anything-v3.0", | |
| "runwayml/stable-diffusion-v1-5", | |
| "dreamlike-art/dreamlike-photoreal-2.0", | |
| ] | |
| for model_id in list_models: | |
| pipeline = StableDiffusionPipeline.from_pretrained(model_id) | |
| del pipeline | |
| print(f"Downloaded {model_id}") | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Momentum-Diffusion Demo | |
| A novel sampling method for diffusion models based on momentum to reduce artifacts | |
| """ | |
| ) | |
| id = gr.Dropdown(list_models, label="Model ID", value="Linaqruf/anything-v3.0", allow_custom_value=True) | |
| enable_token_merging = gr.Checkbox(label="Enable Token Merging", value=False) | |
| # output = gr.Textbox() | |
| buttons["select_model"] = gr.Button("Select") | |
| with gr.Tab("GHVB", visible=False) as tab3: | |
| prompt3 = gr.Textbox(label="Prompt", value="a cozy cafe", visible=False) | |
| with gr.Row(visible=False) as row31: | |
| order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order") | |
| beta = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.05, label="beta") | |
| num_inference_steps = gr.Number(label="Number of steps", value=12) | |
| guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10) | |
| seed = gr.Number(label="Seed", value=42) | |
| with gr.Row(visible=False) as row32: | |
| out1 = gr.Image(label="PLMS", interactive=False) | |
| out2 = gr.Image(label="GHVB", interactive=False) | |
| inputs["GHVB"] = [prompt3, order, beta, num_inference_steps, guidance_scale, seed] | |
| outputs["GHVB"] = [out1, out2] | |
| buttons["GHVB"] = gr.Button("Sample", visible=False) | |
| with gr.Tab("PLMS", visible=False) as tab2: | |
| prompt2 = gr.Textbox(label="Prompt", value="1girl", visible=False) | |
| with gr.Row(visible=False) as row21: | |
| order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order") | |
| beta = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.05, label="beta") | |
| momentum_type = gr.Dropdown(["Polyak's heavy ball", "Nesterov"], label="Momentum Type", value="Polyak's heavy ball") | |
| num_inference_steps = gr.Number(label="Number of steps", value=10) | |
| guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10) | |
| seed = gr.Number(label="Seed", value=42) | |
| with gr.Row(visible=False) as row22: | |
| out1 = gr.Image(label="Without momentum", interactive=False) | |
| out2 = gr.Image(label="With momentum", interactive=False) | |
| inputs["PLMS"] = [prompt2, order, beta, momentum_type, num_inference_steps, guidance_scale, seed] | |
| outputs["PLMS"] = [out1, out2] | |
| buttons["PLMS"] = gr.Button("Sample", visible=False) | |
| with gr.Tab("DPM-Solver++", visible=False) as tab1: | |
| prompt1 = gr.Textbox(label="Prompt", value="1girl", visible=False) | |
| with gr.Row(visible=False) as row11: | |
| beta = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.05, label="beta") | |
| num_inference_steps = gr.Number(label="Number of steps", value=15) | |
| guidance_scale = gr.Number(label="Guidance scale (cfg)", value=20) | |
| seed = gr.Number(label="Seed", value=0) | |
| with gr.Row(visible=False) as row12: | |
| out1 = gr.Image(label="Without momentum", interactive=False) | |
| out2 = gr.Image(label="With momentum", interactive=False) | |
| inputs["DPM-Solver++"] = [prompt1, beta, num_inference_steps, guidance_scale, seed] | |
| outputs["DPM-Solver++"] = [out1, out2] | |
| buttons["DPM-Solver++"] = gr.Button("Sample", visible=False) | |
| def prepare_model(id, enable_token_merging): | |
| global original_pipe, original_config, device | |
| if original_pipe is not None: | |
| del original_pipe | |
| original_pipe = CustomPipeline.from_pretrained(id) | |
| device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |
| original_pipe = original_pipe.to(device) | |
| if enable_token_merging: | |
| import tomesd | |
| tomesd.apply_patch(original_pipe, ratio=0.5) | |
| print("Enabled Token merging.") | |
| original_config = original_pipe.scheduler.config | |
| print(type(original_pipe)) | |
| print(original_config) | |
| return { | |
| row11: gr.update(visible=True), | |
| row12: gr.update(visible=True), | |
| row21: gr.update(visible=True), | |
| row22: gr.update(visible=True), | |
| row31: gr.update(visible=True), | |
| row32: gr.update(visible=True), | |
| prompt1: gr.update(visible=True), | |
| prompt2: gr.update(visible=True), | |
| prompt3: gr.update(visible=True), | |
| buttons["DPM-Solver++"]: gr.update(visible=True), | |
| buttons["PLMS"]: gr.update(visible=True), | |
| buttons["GHVB"]: gr.update(visible=True), | |
| } | |
| all_outputs = [row11, row12, row21, row22, row31, row32, prompt1, prompt2, prompt3, buttons["DPM-Solver++"], buttons["PLMS"], buttons["GHVB"]] | |
| buttons["select_model"].click(prepare_model, inputs=[id, enable_token_merging], outputs=all_outputs) | |
| buttons["DPM-Solver++"].click(run_dpm_demo, inputs=inputs["DPM-Solver++"], outputs=outputs["DPM-Solver++"]) | |
| buttons["PLMS"].click(run_plms_demo, inputs=inputs["PLMS"], outputs=outputs["PLMS"]) | |
| buttons["GHVB"].click(run_ghvb_demo, inputs=inputs["GHVB"], outputs=outputs["GHVB"]) | |
| demo.launch() |