LanHarmony commited on
Commit
d911096
1 Parent(s): 569cb36

introduce control net from diffusers

Browse files
Files changed (3) hide show
  1. app.py +18 -18
  2. requirements.txt +1 -0
  3. visual_foundation_models.py +193 -44
app.py CHANGED
@@ -118,24 +118,24 @@ class ConversationBot:
118
  self.edit = ImageEditing(device="cuda:0")
119
  self.i2t = ImageCaptioning(device="cuda:0")
120
  self.t2i = T2I(device="cuda:0")
121
- # self.image2canny = image2canny()
122
- # self.canny2image = canny2image(device="cuda:1")
123
- # self.image2line = image2line()
124
- # self.line2image = line2image(device="cuda:1")
125
- # self.image2hed = image2hed()
126
- # self.hed2image = hed2image(device="cuda:2")
127
- # self.image2scribble = image2scribble()
128
- # self.scribble2image = scribble2image(device="cuda:3")
129
- # self.image2pose = image2pose()
130
- # self.pose2image = pose2image(device="cuda:3")
131
- # self.BLIPVQA = BLIPVQA(device="cuda:4")
132
- # self.image2seg = image2seg()
133
- # self.seg2image = seg2image(device="cuda:7")
134
- # self.image2depth = image2depth()
135
- # self.depth2image = depth2image(device="cuda:7")
136
- # self.image2normal = image2normal()
137
- # self.normal2image = normal2image(device="cuda:5")
138
- # self.pix2pix = Pix2Pix(device="cuda:0")
139
  self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
140
  self.tools = [
141
  Tool(name="Get Photo Description", func=self.i2t.inference,
 
118
  self.edit = ImageEditing(device="cuda:0")
119
  self.i2t = ImageCaptioning(device="cuda:0")
120
  self.t2i = T2I(device="cuda:0")
121
+ self.image2canny = image2canny()
122
+ self.canny2image = canny2image(device="cuda:1")
123
+ self.image2line = image2line()
124
+ self.line2image = line2image(device="cuda:1")
125
+ self.image2hed = image2hed()
126
+ self.hed2image = hed2image(device="cuda:2")
127
+ self.image2scribble = image2scribble()
128
+ self.scribble2image = scribble2image(device="cuda:3")
129
+ self.image2pose = image2pose()
130
+ self.pose2image = pose2image(device="cuda:3")
131
+ self.BLIPVQA = BLIPVQA(device="cuda:4")
132
+ self.image2seg = image2seg()
133
+ self.seg2image = seg2image(device="cuda:7")
134
+ self.image2depth = image2depth()
135
+ self.depth2image = depth2image(device="cuda:7")
136
+ self.image2normal = image2normal()
137
+ self.normal2image = normal2image(device="cuda:5")
138
+ self.pix2pix = Pix2Pix(device="cuda:0")
139
  self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
140
  self.tools = [
141
  Tool(name="Get Photo Description", func=self.i2t.inference,
requirements.txt CHANGED
@@ -28,3 +28,4 @@ diffusers==0.14.0
28
  gradio
29
  openai
30
  accelerate
 
 
28
  gradio
29
  openai
30
  accelerate
31
+ controlnet-aux==0.0.1
visual_foundation_models.py CHANGED
@@ -1,19 +1,22 @@
1
  import os
 
 
2
  from diffusers import StableDiffusionPipeline
3
  from diffusers import StableDiffusionInpaintPipeline
4
  from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
 
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
6
  from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
7
  from ldm.util import instantiate_from_config
8
  from ControlNet.cldm.model import create_model, load_state_dict
9
  from ControlNet.cldm.ddim_hacked import DDIMSampler
10
- from ControlNet.annotator.canny import CannyDetector
11
- from ControlNet.annotator.mlsd import MLSDdetector
12
- from ControlNet.annotator.util import HWC3, resize_image
13
- from ControlNet.annotator.hed import HEDdetector, nms
14
- from ControlNet.annotator.openpose import OpenposeDetector
15
- from ControlNet.annotator.uniformer import UniformerDetector
16
- from ControlNet.annotator.midas import MidasDetector
17
  from PIL import Image
18
  import torch
19
  import numpy as np
@@ -23,6 +26,36 @@ from pytorch_lightning import seed_everything
23
  import cv2
24
  import random
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def get_new_image_name(org_img_name, func_name="update"):
27
  head_tail = os.path.split(org_img_name)
28
  head = head_tail[0]
@@ -139,40 +172,41 @@ class ImageCaptioning:
139
  captions = self.processor.decode(out[0], skip_special_tokens=True)
140
  return captions
141
 
142
- class image2canny:
143
  def __init__(self):
144
  print("Direct detect canny.")
145
- self.detector = CannyDetector()
146
- self.low_thresh = 100
147
- self.high_thresh = 200
148
 
149
  def inference(self, inputs):
150
  print("===>Starting image2canny Inference")
151
  image = Image.open(inputs)
152
  image = np.array(image)
153
- canny = self.detector(image, self.low_thresh, self.high_thresh)
 
 
154
  canny = 255 - canny
155
- image = Image.fromarray(canny)
156
  updated_image_path = get_new_image_name(inputs, func_name="edge")
157
- image.save(updated_image_path)
158
  return updated_image_path
159
 
160
- class canny2image:
161
  def __init__(self, device):
162
- print("Initialize the canny2image model.")
163
- model = create_model('ControlNet/models/cldm_v15.yaml', device=device).to(device)
164
- model.load_state_dict(load_state_dict('ControlNet/models/control_sd15_canny.pth', location='cpu'))
165
- self.model = model.to(device)
166
- self.device = device
167
- self.ddim_sampler = DDIMSampler(self.model)
168
- self.ddim_steps = 20
 
 
 
169
  self.image_resolution = 512
170
- self.num_samples = 1
171
- self.save_memory = False
172
- self.strength = 1.0
173
- self.guess_mode = False
174
- self.scale = 9.0
175
  self.seed = -1
 
176
  self.a_prompt = 'best quality, extremely detailed'
177
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
178
 
@@ -184,28 +218,143 @@ class canny2image:
184
  image = 255 - image
185
  prompt = instruct_text
186
  img = resize_image(HWC3(image), self.image_resolution)
187
- H, W, C = img.shape
188
- control = torch.from_numpy(img.copy()).float().to(device=self.device) / 255.0
189
- control = torch.stack([control for _ in range(self.num_samples)], dim=0)
190
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
191
  self.seed = random.randint(0, 65535)
192
  seed_everything(self.seed)
193
- if self.save_memory:
194
- self.model.low_vram_shift(is_diffusing=False)
195
- cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
196
- un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
197
- shape = (4, H // 8, W // 8)
198
- self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
199
- samples, intermediates = self.ddim_sampler.sample(self.ddim_steps, self.num_samples, shape, cond, verbose=False, eta=0., unconditional_guidance_scale=self.scale, unconditional_conditioning=un_cond)
200
- if self.save_memory:
201
- self.model.low_vram_shift(is_diffusing=False)
202
- x_samples = self.model.decode_first_stage(samples)
203
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
204
  updated_image_path = get_new_image_name(image_path, func_name="canny2image")
205
- real_image = Image.fromarray(x_samples[0]) # get default the index0 image
206
- real_image.save(updated_image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  return updated_image_path
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  class image2line:
210
  def __init__(self):
211
  print("Direct detect straight line...")
 
1
  import os
2
+
3
+ import diffusers.utils
4
  from diffusers import StableDiffusionPipeline
5
  from diffusers import StableDiffusionInpaintPipeline
6
  from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
7
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
8
+ from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
10
  from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
11
  from ldm.util import instantiate_from_config
12
  from ControlNet.cldm.model import create_model, load_state_dict
13
  from ControlNet.cldm.ddim_hacked import DDIMSampler
14
+ # from ControlNet.annotator.canny import CannyDetector
15
+ # from ControlNet.annotator.mlsd import MLSDdetector
16
+ # from ControlNet.annotator.hed import HEDdetector, nms
17
+ # from ControlNet.annotator.openpose import OpenposeDetector
18
+ # from ControlNet.annotator.uniformer import UniformerDetector
19
+ # from ControlNet.annotator.midas import MidasDetector
 
20
  from PIL import Image
21
  import torch
22
  import numpy as np
 
26
  import cv2
27
  import random
28
 
29
+ def HWC3(x):
30
+ assert x.dtype == np.uint8
31
+ if x.ndim == 2:
32
+ x = x[:, :, None]
33
+ assert x.ndim == 3
34
+ H, W, C = x.shape
35
+ assert C == 1 or C == 3 or C == 4
36
+ if C == 3:
37
+ return x
38
+ if C == 1:
39
+ return np.concatenate([x, x, x], axis=2)
40
+ if C == 4:
41
+ color = x[:, :, 0:3].astype(np.float32)
42
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
43
+ y = color * alpha + 255.0 * (1.0 - alpha)
44
+ y = y.clip(0, 255).astype(np.uint8)
45
+ return y
46
+
47
+ def resize_image(input_image, resolution):
48
+ H, W, C = input_image.shape
49
+ H = float(H)
50
+ W = float(W)
51
+ k = float(resolution) / min(H, W)
52
+ H *= k
53
+ W *= k
54
+ H = int(np.round(H / 64.0)) * 64
55
+ W = int(np.round(W / 64.0)) * 64
56
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
57
+ return img
58
+
59
  def get_new_image_name(org_img_name, func_name="update"):
60
  head_tail = os.path.split(org_img_name)
61
  head = head_tail[0]
 
172
  captions = self.processor.decode(out[0], skip_special_tokens=True)
173
  return captions
174
 
175
+ class image2canny_new:
176
  def __init__(self):
177
  print("Direct detect canny.")
178
+ self.low_threshold = 100
179
+ self.high_threshold = 200
 
180
 
181
  def inference(self, inputs):
182
  print("===>Starting image2canny Inference")
183
  image = Image.open(inputs)
184
  image = np.array(image)
185
+ canny = cv2.Canny(image, self.low_threshold, self.high_threshold)
186
+ canny = canny[:, :, None]
187
+ canny = np.concatenate([canny, canny, canny], axis=2)
188
  canny = 255 - canny
189
+ canny = Image.fromarray(canny)
190
  updated_image_path = get_new_image_name(inputs, func_name="edge")
191
+ canny.save(updated_image_path)
192
  return updated_image_path
193
 
194
+ class canny2image_new:
195
  def __init__(self, device):
196
+ self.controlnet = ControlNetModel.from_pretrained(
197
+ "fusing/stable-diffusion-v1-5-controlnet-canny"
198
+ )
199
+
200
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
201
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None
202
+ )
203
+
204
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
205
+ self.pipe.to(device)
206
  self.image_resolution = 512
207
+ self.num_inference_steps = 20
 
 
 
 
208
  self.seed = -1
209
+ self.unconditional_guidance_scale = 9.0
210
  self.a_prompt = 'best quality, extremely detailed'
211
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
212
 
 
218
  image = 255 - image
219
  prompt = instruct_text
220
  img = resize_image(HWC3(image), self.image_resolution)
221
+ img = Image.fromarray(img)
222
+
 
 
223
  self.seed = random.randint(0, 65535)
224
  seed_everything(self.seed)
225
+ prompt = prompt + ', ' + self.a_prompt
226
+ image = self.pipe(prompt, img, num_inference_steps=self.num_inference_steps, eta=0.0, negative_prompt=self.n_prompt, guidance_scale=self.unconditional_guidance_scale).images[0]
 
 
 
 
 
 
 
 
 
227
  updated_image_path = get_new_image_name(image_path, func_name="canny2image")
228
+ image.save(updated_image_path)
229
+ return updated_image_path
230
+
231
+
232
+ # class image2canny:
233
+ # def __init__(self):
234
+ # print("Direct detect canny.")
235
+ # self.detector = CannyDetector()
236
+ # self.low_thresh = 100
237
+ # self.high_thresh = 200
238
+ #
239
+ # def inference(self, inputs):
240
+ # print("===>Starting image2canny Inference")
241
+ # image = Image.open(inputs)
242
+ # image = np.array(image)
243
+ # canny = self.detector(image, self.low_thresh, self.high_thresh)
244
+ # canny = 255 - canny
245
+ # image = Image.fromarray(canny)
246
+ # updated_image_path = get_new_image_name(inputs, func_name="edge")
247
+ # image.save(updated_image_path)
248
+ # return updated_image_path
249
+ #
250
+ # class canny2image:
251
+ # def __init__(self, device):
252
+ # print("Initialize the canny2image model.")
253
+ # model = create_model('ControlNet/models/cldm_v15.yaml', device=device).to(device)
254
+ # model.load_state_dict(load_state_dict('ControlNet/models/control_sd15_canny.pth', location='cpu'))
255
+ # self.model = model.to(device)
256
+ # self.device = device
257
+ # self.ddim_sampler = DDIMSampler(self.model)
258
+ # self.ddim_steps = 20
259
+ # self.image_resolution = 512
260
+ # self.num_samples = 1
261
+ # self.save_memory = False
262
+ # self.strength = 1.0
263
+ # self.guess_mode = False
264
+ # self.scale = 9.0
265
+ # self.seed = -1
266
+ # self.a_prompt = 'best quality, extremely detailed'
267
+ # self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
268
+ #
269
+ # def inference(self, inputs):
270
+ # print("===>Starting canny2image Inference")
271
+ # image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
272
+ # image = Image.open(image_path)
273
+ # image = np.array(image)
274
+ # image = 255 - image
275
+ # prompt = instruct_text
276
+ # img = resize_image(HWC3(image), self.image_resolution)
277
+ # H, W, C = img.shape
278
+ # control = torch.from_numpy(img.copy()).float().to(device=self.device) / 255.0
279
+ # control = torch.stack([control for _ in range(self.num_samples)], dim=0)
280
+ # control = einops.rearrange(control, 'b h w c -> b c h w').clone()
281
+ # self.seed = random.randint(0, 65535)
282
+ # seed_everything(self.seed)
283
+ # if self.save_memory:
284
+ # self.model.low_vram_shift(is_diffusing=False)
285
+ # cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + self.a_prompt] * self.num_samples)]}
286
+ # un_cond = {"c_concat": None if self.guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([self.n_prompt] * self.num_samples)]}
287
+ # shape = (4, H // 8, W // 8)
288
+ # self.model.control_scales = [self.strength * (0.825 ** float(12 - i)) for i in range(13)] if self.guess_mode else ([self.strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
289
+ # samples, intermediates = self.ddim_sampler.sample(self.ddim_steps, self.num_samples, shape, cond, verbose=False, eta=0., unconditional_guidance_scale=self.scale, unconditional_conditioning=un_cond)
290
+ # if self.save_memory:
291
+ # self.model.low_vram_shift(is_diffusing=False)
292
+ # x_samples = self.model.decode_first_stage(samples)
293
+ # x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
294
+ # updated_image_path = get_new_image_name(image_path, func_name="canny2image")
295
+ # real_image = Image.fromarray(x_samples[0]) # get default the index0 image
296
+ # real_image.save(updated_image_path)
297
+ # return updated_image_path
298
+ class image2line_new:
299
+ def __init__(self):
300
+ self.detector = MLSDdetector.from_pretrained('lllyasviel/ControlNet')
301
+ self.value_thresh = 0.1
302
+ self.dis_thresh = 0.1
303
+ self.resolution = 512
304
+
305
+ def inference(self, inputs):
306
+ print("===>Starting image2line Inference")
307
+ image = Image.open(inputs)
308
+ image = np.array(image)
309
+ image = HWC3(image)
310
+ mlsd = self.detector(resize_image(image, self.resolution), thr_v=self.value_thresh, thr_d=self.dis_thresh)
311
+ mlsd = np.array(mlsd)
312
+ mlsd = 255 - mlsd
313
+ mlsd = Image.fromarray(mlsd)
314
+ updated_image_path = get_new_image_name(inputs, func_name="line-of")
315
+ mlsd.save(updated_image_path)
316
  return updated_image_path
317
 
318
+ class line2image_new:
319
+ def __init__(self, device):
320
+ print("Initialize the line2image model...")
321
+ self.controlnet = ControlNetModel.from_pretrained(
322
+ "fusing/stable-diffusion-v1-5-controlnet-mlsd"
323
+ )
324
+
325
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
326
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None
327
+ )
328
+
329
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
330
+ self.pipe.to(device)
331
+ self.image_resolution = 512
332
+ self.num_inference_steps = 20
333
+ self.seed = -1
334
+ self.unconditional_guidance_scale = 9.0
335
+ self.a_prompt = 'best quality, extremely detailed'
336
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
337
+
338
+ def inference(self, inputs):
339
+ print("===>Starting line2image Inference")
340
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
341
+ image = Image.open(image_path)
342
+ image = np.array(image)
343
+ image = 255 - image
344
+ prompt = instruct_text
345
+ img = resize_image(HWC3(image), self.image_resolution)
346
+ img = Image.fromarray(img)
347
+
348
+ self.seed = random.randint(0, 65535)
349
+ seed_everything(self.seed)
350
+
351
+ prompt = prompt + ', ' + self.a_prompt
352
+ image = self.pipe(prompt, img, num_inference_steps=self.num_inference_steps, eta=0.0, negative_prompt=self.n_prompt, guidance_scale=self.unconditional_guidance_scale).images[0]
353
+ updated_image_path = get_new_image_name(image_path, func_name="line2image")
354
+ image.save(updated_image_path)
355
+ return updated_image_path
356
+
357
+
358
  class image2line:
359
  def __init__(self):
360
  print("Direct detect straight line...")