bhaveshgoel07 commited on
Commit
790c8e6
1 Parent(s): 504f5bf

Fixed errors

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -39,18 +39,26 @@ transform = transforms.Compose([
39
  # Prediction function
40
 
41
 
42
-
43
- def predict(image):
44
  try:
45
- # The input image is already a 2D numpy array (grayscale)
46
- # Ensure it's the right size and normalize it
 
 
 
 
 
 
 
47
 
48
- pil_image = Image.fromarray(image.squeeze().astype('uint8'))
49
- pil_image = pil_image.resize((28, 28))
50
- image = np.array(pil_image)
 
 
51
 
52
  # Convert to tensor and add batch and channel dimensions
53
- tensor_image = torch.tensor(image, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
54
 
55
  # Normalize the image
56
  tensor_image = tensor_image / 255.0 # Scale to [0, 1]
 
39
  # Prediction function
40
 
41
 
42
+ def predict(sketch):
 
43
  try:
44
+ # Extract the image data from the dictionary
45
+ image = sketch['image']
46
+
47
+ # Convert the image to a numpy array
48
+ image_array = np.array(image)
49
+
50
+ # Ensure the image is 2D (grayscale)
51
+ if len(image_array.shape) == 3:
52
+ image_array = image_array[:,:,0] # Take the first channel if it's RGB
53
 
54
+ # Resize to 28x28 if necessary
55
+ if image_array.shape != (28, 28):
56
+ pil_image = Image.fromarray(image_array.astype('uint8'))
57
+ pil_image = pil_image.resize((28, 28))
58
+ image_array = np.array(pil_image)
59
 
60
  # Convert to tensor and add batch and channel dimensions
61
+ tensor_image = torch.tensor(image_array, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
62
 
63
  # Normalize the image
64
  tensor_image = tensor_image / 255.0 # Scale to [0, 1]