File size: 5,854 Bytes
ab5eb13 944cd16 ab5eb13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import spaces
import gradio as gr
import time
import torch
from PIL import Image
from segment_utils import(
segment_image_withmask,
restore_result,
)
from enhance_utils import enhance_image
from inversion_run_adapter_blur import run as adapter_run
DEFAULT_SRC_PROMPT = "RAW photo, Fujifilm XT3, sharp hair, high resolution hair, hair tones, natural hair, magazine hair"
DEFAULT_EDIT_PROMPT = "RAW photo, Fujifilm XT3, sharp hair, high resolution hair, hair tones, natural hair, magazine hair, white color hair"
DEFAULT_CATEGORY = "hair"
device = "cuda" if torch.cuda.is_available() else "cpu"
@spaces.GPU(duration=30)
def image_to_image(
input_image: Image,
mask_image: Image,
input_image_prompt: str,
edit_prompt: str,
seed: int,
w1: float,
num_steps: int,
start_step: int,
guidance_scale: float,
generate_size: int,
strength: float,
blur_radius: int,
lineart_scale: float,
canny_scale: float,
lineart_detect: float,
canny_detect: float,
):
w2 = 1.0
run_task_time = 0
time_cost_str = ''
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
run_model = adapter_run
generated_image = run_model(
input_image,
mask_image,
input_image_prompt,
edit_prompt,
generate_size,
seed,
w1,
w2,
num_steps,
start_step,
guidance_scale,
strength,
lineart_scale,
canny_scale,
lineart_detect,
canny_detect,
blur_radius,
)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
enhanced_image = enhance_image(generated_image, False)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
return enhanced_image, generated_image, time_cost_str
def get_time_cost(run_task_time, time_cost_str):
now_time = int(time.time()*1000)
if run_task_time == 0:
time_cost_str = 'start'
else:
if time_cost_str != '':
time_cost_str += f'-->'
time_cost_str += f'{now_time - run_task_time}'
run_task_time = now_time
return run_task_time, time_cost_str
def create_demo() -> gr.Blocks:
with gr.Blocks() as demo:
croper = gr.State()
with gr.Row():
with gr.Column():
input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
with gr.Column():
num_steps = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Num Steps")
start_step = gr.Slider(minimum=1, maximum=100, value=4, step=1, label="Start Step")
strength = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Strength", visible=True)
with gr.Accordion("Advanced Options", open=False):
blur_radius = gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Blur Radius", visible=True)
guidance_scale = gr.Slider(minimum=0, maximum=20, value=1, step=0.5, label="Guidance Scale", visible=True)
generate_size = gr.Number(label="Generate Size", value=768)
mask_expansion = gr.Number(label="Mask Expansion", value=10, visible=True)
mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
lineart_scale = gr.Slider(minimum=0, maximum=5, value=0.8, step=0.1, label="Lineart Weights", visible=True)
canny_scale = gr.Slider(minimum=0, maximum=5, value=0.1, step=0.1, label="Canny Weights", visible=True)
lineart_detect = gr.Number(label="Lineart Detect", value=0.375, visible=True)
canny_detect = gr.Number(label="Canny Detect", value=0.375, visible=True)
with gr.Column():
seed = gr.Number(label="Seed", value=8)
w1 = gr.Number(label="W1", value=5)
g_btn = gr.Button("Edit Image")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
with gr.Column():
restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
download_path = gr.File(label="Download the output image", interactive=False)
with gr.Column():
origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
enhanced_image = gr.Image(label="Enhanced Image", type="pil", interactive=False)
generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
mask_image = gr.Image(label="Mask Image", type="pil", interactive=False)
g_btn.click(
fn=segment_image_withmask,
inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
outputs=[origin_area_image, mask_image, croper],
).success(
fn=image_to_image,
inputs=[origin_area_image, mask_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, generate_size, strength, blur_radius, lineart_scale, canny_scale, lineart_detect, canny_detect],
outputs=[enhanced_image, generated_image, generated_cost],
).success(
fn=restore_result,
inputs=[croper, category, enhanced_image],
outputs=[restored_image, download_path],
)
return demo |