yonishafir commited on
Commit
6a1229b
·
verified ·
1 Parent(s): 15935a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -25
app.py CHANGED
@@ -6,6 +6,20 @@ import os
6
  from PIL import Image
7
  hf_token = os.environ.get("HF_TOKEN")
8
  from diffusers import StableDiffusionXLInpaintPipeline, DDIMScheduler, UNet2DConditionModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  ratios_map = {
11
  0.5:{"width":704,"height":1408},
@@ -28,6 +42,30 @@ ratios_map = {
28
  }
29
  ratios = np.array(list(ratios_map.keys()))
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def get_size(init_image):
32
  w,h=init_image.size
33
  curr_ratio = w/h
@@ -40,26 +78,33 @@ def get_size(init_image):
40
 
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
 
43
- unet = UNet2DConditionModel.from_pretrained(
44
- "briaai/BRIA-2.2-Inpainting",
45
- subfolder="unet",
46
- torch_dtype=torch.float16,
47
- )
48
 
49
- scheduler = DDIMScheduler.from_pretrained("briaai/BRIA-2.3", subfolder="scheduler",clip_sample=False)
 
 
 
 
 
50
 
51
- pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
52
- "briaai/BRIA-2.3",
53
- unet=unet,
54
- scheduler=scheduler,
55
- torch_dtype=torch.float16,
56
- force_zeros_for_empty_prompt=False
57
- )
 
 
 
58
 
59
- pipe = pipe.to(device)
60
- pipe.force_zeros_for_empty_prompt = False
61
 
62
- default_negative_prompt= "" #"Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
 
 
 
 
 
63
 
64
 
65
  def read_content(file_path: str) -> str:
@@ -70,26 +115,74 @@ def read_content(file_path: str) -> str:
70
 
71
  return content
72
 
73
- def predict(dict, prompt="", negative_prompt="", guidance_scale=5, steps=30, strength=1.0):
74
  if negative_prompt == "":
75
  negative_prompt = None
76
 
77
 
78
  init_image = dict["image"].convert("RGB")#.resize((1024, 1024))
79
- mask = dict["mask"].convert("RGB")#.resize((1024, 1024))
80
 
81
- w,h = get_size(init_image)
82
 
83
- init_image = init_image.resize((w, h))
84
- mask = mask.resize((w, h))
85
 
86
  # Resize to nearest ratio ?
87
 
88
- mask = np.array(mask)
89
- mask[mask>0]=255
90
- mask = Image.fromarray(mask)
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- output = pipe(prompt = prompt,width=w,height=h, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  return output.images[0] #, gr.update(visible=True)
95
 
 
6
  from PIL import Image
7
  hf_token = os.environ.get("HF_TOKEN")
8
  from diffusers import StableDiffusionXLInpaintPipeline, DDIMScheduler, UNet2DConditionModel
9
+ from diffusers import (
10
+ AutoencoderKL,
11
+ LCMScheduler,
12
+ )
13
+ from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
14
+ from controlnet import ControlNetModel, ControlNetConditioningEmbedding
15
+ import torch
16
+ import numpy as np
17
+ from PIL import Image
18
+ import requests
19
+ import PIL
20
+ from io import BytesIO
21
+ from torchvision import transforms
22
+
23
 
24
  ratios_map = {
25
  0.5:{"width":704,"height":1408},
 
42
  }
43
  ratios = np.array(list(ratios_map.keys()))
44
 
45
+ image_transforms = transforms.Compose(
46
+ [
47
+ transforms.ToTensor(),
48
+ ]
49
+ )
50
+
51
+ default_negative_prompt = "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
52
+
53
+
54
+ def get_masked_image(image, image_mask, width, height):
55
+ image_mask = image_mask # inpaint area is white
56
+ image_mask = image_mask.resize((width, height)) # object to remove is white (1)
57
+ image_mask_pil = image_mask
58
+ image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
59
+ image_mask = np.array(image_mask_pil.convert("L")).astype(np.float32) / 255.0
60
+ assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
61
+ masked_image_to_present = image.copy()
62
+ masked_image_to_present[image_mask > 0.5] = (0.5,0.5,0.5) # set as masked pixel
63
+ image[image_mask > 0.5] = 0.5 # set as masked pixel - s.t. will be grey
64
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
65
+ masked_image_to_present = Image.fromarray((masked_image_to_present * 255.0).astype(np.uint8))
66
+ return image, image_mask_pil, masked_image_to_present
67
+
68
+
69
  def get_size(init_image):
70
  w,h=init_image.size
71
  curr_ratio = w/h
 
78
 
79
  device = "cuda" if torch.cuda.is_available() else "cpu"
80
 
 
 
 
 
 
81
 
82
+ # Load, init model
83
+ controlnet = ControlNetModel().from_config('briaai/DEV-ControlNetInpaintingFast', torch_dtype=torch.float16)
84
+ controlnet.controlnet_cond_embedding = ControlNetConditioningEmbedding(
85
+ conditioning_embedding_channels=320,
86
+ conditioning_channels = 5
87
+ )
88
 
89
+ controlnet = ControlNetModel().from_pretrained("briaai/DEV-ControlNetInpaintingFast", torch_dtype=torch.float16)
90
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
91
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained("briaai/BRIA-2.3", controlnet=controlnet.to(dtype=torch.float16), torch_dtype=torch.float16, vae=vae) #force_zeros_for_empty_prompt=False, # vae=vae)
92
+
93
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
94
+ pipe.load_lora_weights("briaai/BRIA-2.3-FAST-LORA")
95
+ pipe.fuse_lora()
96
+
97
+ pipe = pipe.to('cuda:0')
98
+ pipe.enable_xformers_memory_efficient_attention()
99
 
100
+ generator = torch.Generator(device='cuda:0').manual_seed(123456)
 
101
 
102
+ vae = pipe.vae
103
+
104
+
105
+ # pipe.force_zeros_for_empty_prompt = False
106
+
107
+ # default_negative_prompt= "" #"Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
108
 
109
 
110
  def read_content(file_path: str) -> str:
 
115
 
116
  return content
117
 
118
+ def predict(dict, prompt="", negative_prompt = default_negative_prompt, guidance_scale=1.2, steps=12, strength=1.0):
119
  if negative_prompt == "":
120
  negative_prompt = None
121
 
122
 
123
  init_image = dict["image"].convert("RGB")#.resize((1024, 1024))
124
+ mask = dict["mask"].convert("L")#.resize((1024, 1024))
125
 
126
+ width, height = get_size(init_image)
127
 
128
+ init_image = init_image.resize((width, height))
129
+ mask = mask.resize((width, height))
130
 
131
  # Resize to nearest ratio ?
132
 
133
+ # mask = np.array(mask)
134
+ # mask[mask>0]=255
135
+ # mask = Image.fromarray(mask)
136
+
137
+
138
+ masked_image, image_mask, masked_image_to_present = get_masked_image(init_image, mask, width, height)
139
+ masked_image_tensor = image_transforms(masked_image)
140
+ masked_image_tensor = (masked_image_tensor - 0.5) / 0.5
141
+
142
+ masked_image_tensor = masked_image_tensor.unsqueeze(0).to(device="cuda")
143
+
144
+ control_latents = vae.encode(
145
+ masked_image_tensor[:, :3, :, :].to(vae.dtype)
146
+ ).latent_dist.sample()
147
 
148
+ control_latents = control_latents * vae.config.scaling_factor
149
+
150
+ image_mask = np.array(image_mask)[:,:]
151
+ mask_tensor = torch.tensor(image_mask, dtype=torch.float32)[None, ...]
152
+ # binarize the mask
153
+ mask_tensor = torch.where(mask_tensor > 128.0, 255.0, 0)
154
+
155
+ mask_tensor = mask_tensor / 255.0
156
+
157
+ mask_tensor = mask_tensor.to(device="cuda")
158
+ mask_resized = torch.nn.functional.interpolate(mask_tensor[None, ...], size=(control_latents.shape[2], control_latents.shape[3]), mode='nearest')
159
+ # mask_resized = mask_resized.to(torch.float16)
160
+ masked_image = torch.cat([control_latents, mask_resized], dim=1)
161
+
162
+
163
+ output = pipe(prompt = prompt,
164
+ width=width,
165
+ height=height,
166
+ negative_prompt=negative_prompt,
167
+ image = masked_image, # control image V
168
+ init_image = init_image,
169
+ mask_image=mask_tensor,
170
+ guidance_scale=guidance_scale,
171
+ num_inference_steps=int(steps),
172
+ strength=strength,
173
+ generator=generator,
174
+ controlnet_conditioning_sale=1.0, )
175
+
176
+ # gen_img = pipe(negative_prompt=default_negative_prompt, prompt=prompt,
177
+ # controlnet_conditioning_sale=1.0,
178
+ # num_inference_steps=12,
179
+ # height=height, width=width,
180
+ # image = masked_image, # control image
181
+ # init_image = init_image,
182
+ # mask_image = mask_tensor,
183
+ # guidance_scale = 1.2,
184
+ # generator=generator).images[0]
185
+
186
 
187
  return output.images[0] #, gr.update(visible=True)
188