im2 commited on
Commit
1937011
1 Parent(s): a8024f1

image still not correct

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -32,9 +32,12 @@ model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
32
  model.eval()
33
 
34
  # Gradio preprocessing and prediction pipeline
35
- def predict_digit(image):
36
- # Convert the numpy array (from gr.Sketchpad) to a PIL Image
37
- image = Image.fromarray(image).convert('L') # Convert to grayscale
 
 
 
38
 
39
  # Preprocess: resize to 28x28 and normalize
40
  transform = transforms.Compose([
@@ -64,4 +67,3 @@ interface = gr.Interface(
64
  # Launch the app
65
  if __name__ == "__main__":
66
  interface.launch()
67
-
 
32
  model.eval()
33
 
34
  # Gradio preprocessing and prediction pipeline
35
+ def predict_digit(image_dict):
36
+ # Extract the image array from the 'mask' key (gr.Sketchpad output)
37
+ image = image_dict["mask"] # Get the image from the dict
38
+
39
+ # Convert the image to a numpy array, then to a PIL image, and preprocess
40
+ image = Image.fromarray(np.array(image)).convert('L') # Convert to grayscale
41
 
42
  # Preprocess: resize to 28x28 and normalize
43
  transform = transforms.Compose([
 
67
  # Launch the app
68
  if __name__ == "__main__":
69
  interface.launch()