sigyllly commited on
Commit
fa8c892
1 Parent(s): 945cca9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -25,6 +25,9 @@ def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
25
 
26
  pred = torch.sigmoid(preds)
27
  mat = pred.cpu().numpy()
 
 
 
28
  mask = Image.fromarray(np.uint8(mat * 255), "L")
29
  mask = mask.convert("RGB")
30
  mask = mask.resize(image.size)
@@ -37,7 +40,6 @@ def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
37
 
38
  # threshold the mask
39
  bmask = mask > threshold
40
- # zero out values below the threshold
41
  mask[mask < threshold] = 0
42
 
43
  fig, ax = plt.subplots()
@@ -74,6 +76,7 @@ def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
74
 
75
  return fig, result_mask, result_output
76
 
 
77
  # Existing process_image function, copy it here
78
  # ...
79
 
 
25
 
26
  pred = torch.sigmoid(preds)
27
  mat = pred.cpu().numpy()
28
+
29
+ # Ensure we are working with a single-channel 2D mask
30
+ mat = np.squeeze(mat, axis=0) # Remove batch dimension if it exists
31
  mask = Image.fromarray(np.uint8(mat * 255), "L")
32
  mask = mask.convert("RGB")
33
  mask = mask.resize(image.size)
 
40
 
41
  # threshold the mask
42
  bmask = mask > threshold
 
43
  mask[mask < threshold] = 0
44
 
45
  fig, ax = plt.subplots()
 
76
 
77
  return fig, result_mask, result_output
78
 
79
+
80
  # Existing process_image function, copy it here
81
  # ...
82