animrods commited on
Commit
711fa9a
1 Parent(s): e692d95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -72
app.py CHANGED
@@ -5,58 +5,12 @@ import diffusers
5
  import os
6
  from PIL import Image
7
  hf_token = os.environ.get("HF_TOKEN")
8
- from diffusers import StableDiffusionXLInpaintPipeline, DDIMScheduler, UNet2DConditionModel
9
-
10
- ratios_map = {
11
- 0.5:{"width":704,"height":1408},
12
- 0.57:{"width":768,"height":1344},
13
- 0.68:{"width":832,"height":1216},
14
- 0.72:{"width":832,"height":1152},
15
- 0.78:{"width":896,"height":1152},
16
- 0.82:{"width":896,"height":1088},
17
- 0.88:{"width":960,"height":1088},
18
- 0.94:{"width":960,"height":1024},
19
- 1.00:{"width":1024,"height":1024},
20
- 1.13:{"width":1088,"height":960},
21
- 1.21:{"width":1088,"height":896},
22
- 1.29:{"width":1152,"height":896},
23
- 1.38:{"width":1152,"height":832},
24
- 1.46:{"width":1216,"height":832},
25
- 1.67:{"width":1280,"height":768},
26
- 1.75:{"width":1344,"height":768},
27
- 2.00:{"width":1408,"height":704}
28
- }
29
- ratios = np.array(list(ratios_map.keys()))
30
-
31
- def get_size(init_image):
32
- w,h=init_image.size
33
- curr_ratio = w/h
34
- ind = np.argmin(np.abs(curr_ratio-ratios))
35
- ratio = ratios[ind]
36
- chosen_ratio = ratios_map[ratio]
37
- w,h = chosen_ratio['width'], chosen_ratio['height']
38
 
39
- return w,h
40
 
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
- unet = UNet2DConditionModel.from_pretrained(
43
- "briaai/BRIA-2.3-Inpainting",
44
- subfolder="unet",
45
- torch_dtype=torch.float16,
46
- )
47
-
48
- scheduler = DDIMScheduler.from_pretrained("briaai/BRIA-2.3", subfolder="scheduler",
49
- rescale_betas_zero_snr=True,prediction_type='v_prediction',timestep_spacing="trailing",clip_sample=False)
50
-
51
- pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
52
- "briaai/BRIA-2.3",
53
- unet=unet,
54
- scheduler=scheduler,
55
- torch_dtype=torch.float16,
56
- force_zeros_for_empty_prompt=False
57
- )
58
- pipe = pipe.to(device)
59
- pipe.force_zeros_for_empty_prompt = False
60
 
61
  default_negative_prompt= "" #"Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
62
 
@@ -69,27 +23,28 @@ def read_content(file_path: str) -> str:
69
 
70
  return content
71
 
72
- def predict(dict, prompt="", negative_prompt="", guidance_scale=5, steps=30, strength=1.0):
73
- if negative_prompt == "":
74
- negative_prompt = None
75
 
76
-
77
- init_image = dict["image"].convert("RGB")#.resize((1024, 1024))
78
- mask = dict["mask"].convert("RGB")#.resize((1024, 1024))
79
-
80
- w,h = get_size(init_image)
81
 
82
- init_image = init_image.resize((w, h))
83
- mask = mask.resize((w, h))
84
-
85
- # Resize to nearest ratio ?
86
-
87
- mask = np.array(mask)
88
- mask[mask>0]=255
89
- mask = Image.fromarray(mask)
90
 
91
- output = pipe(prompt = prompt,width=w,height=h, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  return output.images[0] #, gr.update(visible=True)
94
 
95
 
@@ -146,14 +101,20 @@ with image_blocks as demo:
146
  image = gr.Image(sources=['upload'], tool='sketch', elem_id="image_upload", type="pil", label="Upload", height=400)
147
  with gr.Row(elem_id="prompt-container", equal_height=True):
148
  with gr.Row():
149
- prompt = gr.Textbox(placeholder="Your prompt (what you want in place of what is erased)", show_label=False, elem_id="prompt")
150
- btn = gr.Button("Inpaint!", elem_id="run_button")
 
 
 
 
 
 
151
 
152
  with gr.Accordion(label="Advanced Settings", open=False):
153
  with gr.Row(equal_height=True):
154
  guidance_scale = gr.Number(value=5, minimum=1.0, maximum=10.0, step=0.5, label="guidance_scale")
155
- steps = gr.Number(value=30, minimum=20, maximum=50, step=1, label="steps")
156
- strength = gr.Number(value=1, minimum=0.01, maximum=1.0, step=0.01, label="strength")
157
  negative_prompt = gr.Textbox(label="negative_prompt", value=default_negative_prompt, placeholder=default_negative_prompt, info="what you don't want to see in the image")
158
 
159
 
@@ -162,8 +123,8 @@ with image_blocks as demo:
162
 
163
 
164
 
165
- btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength], outputs=[image_out], api_name='run')
166
- prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength], outputs=[image_out])
167
 
168
  # gr.Examples(
169
  # examples=[
 
5
  import os
6
  from PIL import Image
7
  hf_token = os.environ.get("HF_TOKEN")
8
+ from diffusers import AutoPipelineForText2Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ pipe = AutoPipelineForText2Image.from_pretrained("briaai/BRIA-2.3", torch_dtype=torch.float16, force_zeros_for_empty_prompt=False).to(device)
13
+ pipe.load_ip_adapter("briaai/DEV-Image-Prompt", subfolder='models', weight_name="ip_adapter_bria.bin")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  default_negative_prompt= "" #"Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
16
 
 
23
 
24
  return content
25
 
26
+ def predict(dict, prompt="high quality, best quality", negative_prompt="", guidance_scale=5, steps=30, ip_adapter_scale = 1.0, width=1024, height=1024, seed=0):
 
 
27
 
28
+ pipeline.set_ip_adapter_scale(ip_adapter_scale)
 
 
 
 
29
 
30
+ if negative_prompt == "":
31
+ negative_prompt = None
 
 
 
 
 
 
32
 
33
+ init_image = dict["image"].convert("RGB")
34
+ init_image = init_image.resize((224, 224))
35
 
36
+ generator = torch.Generator(device="cpu").manual_seed(seed)
37
+
38
+ output = pipe(
39
+ prompt=prompt,
40
+ negative_prompt=negative_prompt,
41
+ ip_adapter_image=init_image,
42
+ num_inference_steps=steps,
43
+ generator=generator,
44
+ height=height, width=width,
45
+ guidance_scale=guidance_scale
46
+ ).images
47
+
48
  return output.images[0] #, gr.update(visible=True)
49
 
50
 
 
101
  image = gr.Image(sources=['upload'], tool='sketch', elem_id="image_upload", type="pil", label="Upload", height=400)
102
  with gr.Row(elem_id="prompt-container", equal_height=True):
103
  with gr.Row():
104
+ prompt = gr.Textbox(placeholder="Your prompt (you can leave it empty if you only want the image prompt as input)", show_label=False, elem_id="prompt")
105
+ btn = gr.Button("Generate!", elem_id="run_button")
106
+
107
+ with gr.Accordion(label="Settings", open=True):
108
+ with gr.Row(equal_height=True):
109
+ ip_adapter_scale = gr.Number(value=1.0, minimum=0.01, maximum=1.0, step=0.01, label="ip_adapter_scale")
110
+ width = gr.Number(value=1024, minimum=0.01, maximum=1.0, step=0.01, label="width")
111
+ height = gr.Number(value=1024, minimum=0.01, maximum=1.0, step=0.01, label="height")
112
 
113
  with gr.Accordion(label="Advanced Settings", open=False):
114
  with gr.Row(equal_height=True):
115
  guidance_scale = gr.Number(value=5, minimum=1.0, maximum=10.0, step=0.5, label="guidance_scale")
116
+ steps = gr.Number(value=30, minimum=10, maximum=100, step=1, label="steps")
117
+ seed = gr.Number(value=0, minimum=0, maximum=100000, step=1, label="seed")
118
  negative_prompt = gr.Textbox(label="negative_prompt", value=default_negative_prompt, placeholder=default_negative_prompt, info="what you don't want to see in the image")
119
 
120
 
 
123
 
124
 
125
 
126
+ btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, ip_adapter_scale, width, height, seed], outputs=[image_out], api_name='run')
127
+ prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, ip_adapter_scale, width, height, seed], outputs=[image_out])
128
 
129
  # gr.Examples(
130
  # examples=[