zhiweili commited on
Commit
c4a0a85
1 Parent(s): 312679f

add app_haircolor_pix2pix

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. app_haircolor_pix2pix.py +117 -0
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
 
3
  # from app_base import create_demo as create_demo_face
4
- from app_haircolor import create_demo as create_demo_haircolor
5
 
6
  with gr.Blocks(css="style.css") as demo:
7
  with gr.Tabs():
 
1
  import gradio as gr
2
 
3
  # from app_base import create_demo as create_demo_face
4
+ from app_haircolor_pix2pix import create_demo as create_demo_haircolor
5
 
6
  with gr.Blocks(css="style.css") as demo:
7
  with gr.Tabs():
app_haircolor_pix2pix.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import time
4
+ import torch
5
+
6
+ from PIL import Image
7
+ from segment_utils import(
8
+ segment_image,
9
+ restore_result,
10
+ )
11
+ from diffusers import (
12
+ StableDiffusionInstructPix2PixPipeline,
13
+ EulerAncestralDiscreteScheduler,
14
+ )
15
+
16
+ BASE_MODEL = "timbrooks/instruct-pix2pix"
17
+
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ DEFAULT_EDIT_PROMPT = "change hair to blue"
21
+
22
+ DEFAULT_CATEGORY = "hair"
23
+
24
+ basepipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
25
+ BASE_MODEL,
26
+ torch_dtype=torch.float16,
27
+ use_safetensors=True,
28
+ )
29
+
30
+ basepipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(basepipeline.scheduler.config)
31
+
32
+ basepipeline = basepipeline.to(DEVICE)
33
+
34
+ basepipeline.enable_model_cpu_offload()
35
+
36
+ @spaces.GPU(duration=15)
37
+ def image_to_image(
38
+ input_image: Image,
39
+ edit_prompt: str,
40
+ seed: int,
41
+ num_steps: int,
42
+ guidance_scale: float,
43
+ image_guidance_scale: float,
44
+ ):
45
+ run_task_time = 0
46
+ time_cost_str = ''
47
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
48
+
49
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
50
+ generated_image = basepipeline(
51
+ generator=generator,
52
+ prompt=edit_prompt,
53
+ image=input_image,
54
+ guidance_scale=guidance_scale,
55
+ image_guidance_scale=image_guidance_scale,
56
+ num_inference_steps=num_steps,
57
+ ).images[0]
58
+
59
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
60
+
61
+ return generated_image, time_cost_str
62
+
63
+ def get_time_cost(run_task_time, time_cost_str):
64
+ now_time = int(time.time()*1000)
65
+ if run_task_time == 0:
66
+ time_cost_str = 'start'
67
+ else:
68
+ if time_cost_str != '':
69
+ time_cost_str += f'-->'
70
+ time_cost_str += f'{now_time - run_task_time}'
71
+ run_task_time = now_time
72
+ return run_task_time, time_cost_str
73
+
74
+ def create_demo() -> gr.Blocks:
75
+ with gr.Blocks() as demo:
76
+ croper = gr.State()
77
+ with gr.Row():
78
+ with gr.Column():
79
+ edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
80
+ generate_size = gr.Number(label="Generate Size", value=512)
81
+ with gr.Column():
82
+ num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
83
+ guidance_scale = gr.Slider(minimum=0, maximum=30, value=5, step=0.5, label="Guidance Scale")
84
+ with gr.Column():
85
+ image_guidance_scale = gr.Slider(minimum=0, maximum=30, value=1.5, step=0.1, label="Image Guidance Scale")
86
+ with gr.Accordion("Advanced Options", open=False):
87
+ mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
88
+ mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
89
+ seed = gr.Number(label="Seed", value=8)
90
+ category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
91
+ g_btn = gr.Button("Edit Image")
92
+
93
+ with gr.Row():
94
+ with gr.Column():
95
+ input_image = gr.Image(label="Input Image", type="pil")
96
+ with gr.Column():
97
+ restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
98
+ with gr.Column():
99
+ origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
100
+ generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
101
+ generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
102
+
103
+ g_btn.click(
104
+ fn=segment_image,
105
+ inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
106
+ outputs=[origin_area_image, croper],
107
+ ).success(
108
+ fn=image_to_image,
109
+ inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale, image_guidance_scale],
110
+ outputs=[generated_image, generated_cost],
111
+ ).success(
112
+ fn=restore_result,
113
+ inputs=[croper, category, generated_image],
114
+ outputs=[restored_image],
115
+ )
116
+
117
+ return demo