sab commited on
Commit
9c91605
·
1 Parent(s): bf90ab3
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -27,11 +27,20 @@ def fn(image):
27
  def process(image):
28
  def numpy_to_pil(image):
29
  """Convert a numpy array to a PIL Image."""
30
- if image.dtype == np.uint8: # Most common case
 
 
 
 
 
 
31
  mode = "RGB"
 
 
32
  else:
33
- mode = "F" # Floating point
34
- return Image.fromarray(image.astype('uint8'), mode)
 
35
 
36
  image = numpy_to_pil(image) # Convert numpy array to PIL Image
37
  buffered = BytesIO()
 
27
  def process(image):
28
  def numpy_to_pil(image):
29
  """Convert a numpy array to a PIL Image."""
30
+ if not isinstance(image, np.ndarray):
31
+ raise TypeError("Input must be a numpy array")
32
+
33
+ # Determine the mode based on the shape and dtype of the image
34
+ if image.ndim == 2: # Grayscale image
35
+ mode = "L"
36
+ elif image.ndim == 3 and image.shape[2] == 3: # RGB image
37
  mode = "RGB"
38
+ elif image.ndim == 3 and image.shape[2] == 4: # RGBA image
39
+ mode = "RGBA"
40
  else:
41
+ raise ValueError("Unsupported image shape: {}".format(image.shape))
42
+
43
+ return Image.fromarray(image, mode)
44
 
45
  image = numpy_to_pil(image) # Convert numpy array to PIL Image
46
  buffered = BytesIO()