radames commited on
Commit
5f1aa51
1 Parent(s): 578d7dc

add ScharrOperator

Browse files
Files changed (1) hide show
  1. server/pipelines/utils/canny_gpu.py +58 -0
server/pipelines/utils/canny_gpu.py CHANGED
@@ -57,3 +57,61 @@ class SobelOperator(nn.Module):
57
  return edge
58
  elif output_type == "pil,tensor":
59
  return ToPILImage()(edge.squeeze(0).cpu()), edge
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  return edge
58
  elif output_type == "pil,tensor":
59
  return ToPILImage()(edge.squeeze(0).cpu()), edge
60
+
61
+
62
+ class ScharrOperator(nn.Module):
63
+ SCHARR_KERNEL_X = torch.tensor(
64
+ [[-3.0, 0.0, 3.0], [-10.0, 0.0, 10.0], [-3.0, 0.0, 3.0]]
65
+ )
66
+ SCHARR_KERNEL_Y = torch.tensor(
67
+ [[-3.0, -10.0, -3.0], [0.0, 0.0, 0.0], [3.0, 10.0, 3.0]]
68
+ )
69
+
70
+ def __init__(self, device="cuda"):
71
+ super(ScharrOperator, self).__init__()
72
+ self.device = device
73
+ self.edge_conv_x = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
74
+ self.device
75
+ )
76
+ self.edge_conv_y = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
77
+ self.device
78
+ )
79
+ self.edge_conv_x.weight = nn.Parameter(
80
+ self.SCHARR_KERNEL_X.view((1, 1, 3, 3)).to(self.device)
81
+ )
82
+ self.edge_conv_y.weight = nn.Parameter(
83
+ self.SCHARR_KERNEL_Y.view((1, 1, 3, 3)).to(self.device)
84
+ )
85
+
86
+ @torch.no_grad()
87
+ def forward(
88
+ self,
89
+ image: Image.Image,
90
+ low_threshold: float,
91
+ high_threshold: float,
92
+ output_type="pil",
93
+ invert: bool = False,
94
+ ) -> Image.Image | torch.Tensor | tuple[Image.Image, torch.Tensor]:
95
+ # Convert PIL image to PyTorch tensor
96
+ image_gray = image.convert("L")
97
+ image_tensor = ToTensor()(image_gray).unsqueeze(0).to(self.device)
98
+
99
+ # Compute gradients
100
+ edge_x = self.edge_conv_x(image_tensor)
101
+ edge_y = self.edge_conv_y(image_tensor)
102
+ edge = torch.abs(edge_x) + torch.abs(edge_y)
103
+
104
+ # Apply thresholding
105
+ edge.div_(edge.max()) # Normalize to 0-1 (in-place operation)
106
+ edge[edge >= high_threshold] = 1.0
107
+ edge[edge <= low_threshold] = 0.0
108
+ if invert:
109
+ edge = 1 - edge
110
+
111
+ # Convert the result back to a PIL image
112
+ if output_type == "pil":
113
+ return ToPILImage()(edge.squeeze(0).cpu())
114
+ elif output_type == "tensor":
115
+ return edge
116
+ elif output_type == "pil,tensor":
117
+ return ToPILImage()(edge.squeeze(0).cpu()), edge