Improve image transformation

#1
Files changed (1) hide show
  1. image_transformation.py +12 -36
image_transformation.py CHANGED
@@ -15,7 +15,7 @@ if is_vision_available():
15
  from PIL import Image
16
 
17
  if is_diffusers_available():
18
- from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler
19
 
20
  if is_opencv_available():
21
  import cv2
@@ -29,8 +29,7 @@ IMAGE_TRANSFORMATION_DESCRIPTION = (
29
 
30
 
31
  class ImageTransformationTool(Tool):
32
- default_stable_diffusion_checkpoint = "runwayml/stable-diffusion-v1-5"
33
- default_controlnet_checkpoint = "lllyasviel/sd-controlnet-canny"
34
  description = IMAGE_TRANSFORMATION_DESCRIPTION
35
  inputs = ['image', 'text']
36
  outputs = ['image']
@@ -47,13 +46,7 @@ class ImageTransformationTool(Tool):
47
 
48
  super().__init__()
49
 
50
- if controlnet is None:
51
- controlnet = self.default_controlnet_checkpoint
52
- self.controlnet_checkpoint = controlnet
53
-
54
- if stable_diffusion is None:
55
- stable_diffusion = self.default_stable_diffusion_checkpoint
56
- self.stable_diffusion_checkpoint = stable_diffusion
57
 
58
  self.device = device
59
  self.hub_kwargs = hub_kwargs
@@ -62,37 +55,20 @@ class ImageTransformationTool(Tool):
62
  if self.device is None:
63
  self.device = get_default_device()
64
 
65
- self.controlnet = ControlNetModel.from_pretrained(self.controlnet_checkpoint)
66
- self.pipeline = StableDiffusionControlNetPipeline.from_pretrained(
67
- self.stable_diffusion_checkpoint, controlnet=self.controlnet
68
- )
69
- self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
70
- self.pipeline.enable_model_cpu_offload()
71
 
72
  self.is_initialized = True
73
 
74
- def __call__(self, image, prompt):
75
  if not self.is_initialized:
76
  self.setup()
77
 
78
- initial_prompt = "super-hero character, best quality, extremely detailed"
79
- prompt = initial_prompt + prompt
80
-
81
- low_threshold = 100
82
- high_threshold = 200
83
-
84
- image = np.array(image)
85
- image = cv2.Canny(image, low_threshold, high_threshold)
86
- image = image[:, :, None]
87
- image = np.concatenate([image, image, image], axis=2)
88
- canny_image = Image.fromarray(image)
89
-
90
- generator = torch.Generator(device="cpu").manual_seed(2)
91
-
92
  return self.pipeline(
93
- prompt,
94
- canny_image,
95
- negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
96
- num_inference_steps=20,
97
- generator=generator,
98
  ).images[0]
 
15
  from PIL import Image
16
 
17
  if is_diffusers_available():
18
+ from diffusers import DiffusionPipeline
19
 
20
  if is_opencv_available():
21
  import cv2
 
29
 
30
 
31
  class ImageTransformationTool(Tool):
32
+ default_stable_diffusion_checkpoint = "timbrooks/instruct-pix2pix"
 
33
  description = IMAGE_TRANSFORMATION_DESCRIPTION
34
  inputs = ['image', 'text']
35
  outputs = ['image']
 
46
 
47
  super().__init__()
48
 
49
+ self.stable_diffusion = self.default_stable_diffusion_checkpoint
 
 
 
 
 
 
50
 
51
  self.device = device
52
  self.hub_kwargs = hub_kwargs
 
55
  if self.device is None:
56
  self.device = get_default_device()
57
 
58
+ self.pipeline = DiffusionPipeline.from_pretrained(self.stable_diffusion)
59
+
60
+ self.pipeline.to(self.device)
61
+ if self.device.type == "cuda":
62
+ self.pipeline.to(torch_dtype=torch.float16)
 
63
 
64
  self.is_initialized = True
65
 
66
+ def __call__(self, image, prompt, negative_prompt="low quality, bad quality, deformed, low resolution", added_prompt=" , highest quality, highly realistic, very high resolution"):
67
  if not self.is_initialized:
68
  self.setup()
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  return self.pipeline(
71
+ prompt + added_prompt,
72
+ image,
73
+ negative_prompt=negative_prompt,
 
 
74
  ).images[0]