radames commited on
Commit
b212cf7
1 Parent(s): a080665
Files changed (2) hide show
  1. app.py +3 -7
  2. canny_gpu.py +117 -0
app.py CHANGED
@@ -8,12 +8,11 @@ from diffusers import (
8
  StableDiffusionXLControlNetImg2ImgPipeline,
9
  DDIMScheduler,
10
  )
11
-
12
  from compel import Compel, ReturnedEmbeddingsType
13
  from PIL import Image
14
  import os
15
  import time
16
- import cv2
17
  import numpy as np
18
 
19
  IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
@@ -64,6 +63,7 @@ if not IS_SPACES_ZERO:
64
  # pipe.enable_xformers_memory_efficient_attention()
65
  pipe.enable_model_cpu_offload()
66
  pipe.enable_vae_tiling()
 
67
 
68
 
69
  def pad_image(image):
@@ -106,11 +106,7 @@ def predict(
106
  conditioning, pooled = compel([prompt, negative_prompt])
107
  generator = torch.manual_seed(seed)
108
  last_time = time.time()
109
- canny_image = np.array(padded_image)
110
- canny_image = cv2.Canny(canny_image, 100, 200)
111
- canny_image = canny_image[:, :, None]
112
- canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
113
- canny_image = Image.fromarray(canny_image)
114
  images = pipe(
115
  image=padded_image,
116
  control_image=canny_image,
 
8
  StableDiffusionXLControlNetImg2ImgPipeline,
9
  DDIMScheduler,
10
  )
11
+ from canny_gpu import SobelOperator
12
  from compel import Compel, ReturnedEmbeddingsType
13
  from PIL import Image
14
  import os
15
  import time
 
16
  import numpy as np
17
 
18
  IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
 
63
  # pipe.enable_xformers_memory_efficient_attention()
64
  pipe.enable_model_cpu_offload()
65
  pipe.enable_vae_tiling()
66
+ canny_torch = SobelOperator(device=device)
67
 
68
 
69
  def pad_image(image):
 
106
  conditioning, pooled = compel([prompt, negative_prompt])
107
  generator = torch.manual_seed(seed)
108
  last_time = time.time()
109
+ canny_image = canny_torch(padded_image, 0.01, 0.2)
 
 
 
 
110
  images = pipe(
111
  image=padded_image,
112
  control_image=canny_image,
canny_gpu.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.transforms import ToTensor, ToPILImage
4
+ from PIL import Image
5
+
6
+
7
+ class SobelOperator(nn.Module):
8
+ SOBEL_KERNEL_X = torch.tensor(
9
+ [[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]
10
+ )
11
+ SOBEL_KERNEL_Y = torch.tensor(
12
+ [[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]
13
+ )
14
+
15
+ def __init__(self, device="cuda"):
16
+ super(SobelOperator, self).__init__()
17
+ self.device = device
18
+ self.edge_conv_x = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
19
+ self.device
20
+ )
21
+ self.edge_conv_y = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
22
+ self.device
23
+ )
24
+ self.edge_conv_x.weight = nn.Parameter(
25
+ self.SOBEL_KERNEL_X.view((1, 1, 3, 3)).to(self.device)
26
+ )
27
+ self.edge_conv_y.weight = nn.Parameter(
28
+ self.SOBEL_KERNEL_Y.view((1, 1, 3, 3)).to(self.device)
29
+ )
30
+
31
+ @torch.no_grad()
32
+ def forward(
33
+ self,
34
+ image: Image.Image,
35
+ low_threshold: float,
36
+ high_threshold: float,
37
+ output_type="pil",
38
+ ) -> Image.Image | torch.Tensor | tuple[Image.Image, torch.Tensor]:
39
+ # Convert PIL image to PyTorch tensor
40
+ image_gray = image.convert("L")
41
+ image_tensor = ToTensor()(image_gray).unsqueeze(0).to(self.device)
42
+
43
+ # Compute gradients
44
+ edge_x = self.edge_conv_x(image_tensor)
45
+ edge_y = self.edge_conv_y(image_tensor)
46
+ edge = torch.sqrt(torch.square(edge_x) + torch.square(edge_y))
47
+
48
+ # Apply thresholding
49
+ edge.div_(edge.max()) # Normalize to 0-1 (in-place operation)
50
+ edge[edge >= high_threshold] = 1.0
51
+ edge[edge <= low_threshold] = 0.0
52
+
53
+ # Convert the result back to a PIL image
54
+ if output_type == "pil":
55
+ return ToPILImage()(edge.squeeze(0).cpu())
56
+ elif output_type == "tensor":
57
+ return edge
58
+ elif output_type == "pil,tensor":
59
+ return ToPILImage()(edge.squeeze(0).cpu()), edge
60
+
61
+
62
+ class ScharrOperator(nn.Module):
63
+ SCHARR_KERNEL_X = torch.tensor(
64
+ [[-3.0, 0.0, 3.0], [-10.0, 0.0, 10.0], [-3.0, 0.0, 3.0]]
65
+ )
66
+ SCHARR_KERNEL_Y = torch.tensor(
67
+ [[-3.0, -10.0, -3.0], [0.0, 0.0, 0.0], [3.0, 10.0, 3.0]]
68
+ )
69
+
70
+ def __init__(self, device="cuda"):
71
+ super(ScharrOperator, self).__init__()
72
+ self.device = device
73
+ self.edge_conv_x = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
74
+ self.device
75
+ )
76
+ self.edge_conv_y = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
77
+ self.device
78
+ )
79
+ self.edge_conv_x.weight = nn.Parameter(
80
+ self.SCHARR_KERNEL_X.view((1, 1, 3, 3)).to(self.device)
81
+ )
82
+ self.edge_conv_y.weight = nn.Parameter(
83
+ self.SCHARR_KERNEL_Y.view((1, 1, 3, 3)).to(self.device)
84
+ )
85
+
86
+ @torch.no_grad()
87
+ def forward(
88
+ self,
89
+ image: Image.Image,
90
+ low_threshold: float,
91
+ high_threshold: float,
92
+ output_type="pil",
93
+ invert: bool = False,
94
+ ) -> Image.Image | torch.Tensor | tuple[Image.Image, torch.Tensor]:
95
+ # Convert PIL image to PyTorch tensor
96
+ image_gray = image.convert("L")
97
+ image_tensor = ToTensor()(image_gray).unsqueeze(0).to(self.device)
98
+
99
+ # Compute gradients
100
+ edge_x = self.edge_conv_x(image_tensor)
101
+ edge_y = self.edge_conv_y(image_tensor)
102
+ edge = torch.abs(edge_x) + torch.abs(edge_y)
103
+
104
+ # Apply thresholding
105
+ edge.div_(edge.max()) # Normalize to 0-1 (in-place operation)
106
+ edge[edge >= high_threshold] = 1.0
107
+ edge[edge <= low_threshold] = 0.0
108
+ if invert:
109
+ edge = 1 - edge
110
+
111
+ # Convert the result back to a PIL image
112
+ if output_type == "pil":
113
+ return ToPILImage()(edge.squeeze(0).cpu())
114
+ elif output_type == "tensor":
115
+ return edge
116
+ elif output_type == "pil,tensor":
117
+ return ToPILImage()(edge.squeeze(0).cpu()), edge