bhaveshgoel07 commited on
Commit
7355da1
1 Parent(s): 790c8e6

Fixed errors

Browse files
Files changed (1) hide show
  1. app.py +9 -37
app.py CHANGED
@@ -39,47 +39,19 @@ transform = transforms.Compose([
39
  # Prediction function
40
 
41
 
42
- def predict(sketch):
43
- try:
44
- # Extract the image data from the dictionary
45
- image = sketch['image']
46
-
47
- # Convert the image to a numpy array
48
- image_array = np.array(image)
49
-
50
- # Ensure the image is 2D (grayscale)
51
- if len(image_array.shape) == 3:
52
- image_array = image_array[:,:,0] # Take the first channel if it's RGB
53
-
54
- # Resize to 28x28 if necessary
55
- if image_array.shape != (28, 28):
56
- pil_image = Image.fromarray(image_array.astype('uint8'))
57
- pil_image = pil_image.resize((28, 28))
58
- image_array = np.array(pil_image)
59
-
60
- # Convert to tensor and add batch and channel dimensions
61
- tensor_image = torch.tensor(image_array, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
62
-
63
- # Normalize the image
64
- tensor_image = tensor_image / 255.0 # Scale to [0, 1]
65
- tensor_image = (tensor_image - 0.5) / 0.5 # Normalize to [-1, 1]
66
-
67
- with torch.no_grad():
68
- output = model(tensor_image)
69
- probabilities = nn.Softmax(dim=1)(output)
70
-
71
- # Return the probabilities as a dictionary
72
- return {str(i): float(probabilities[0][i]) for i in range(10)}
73
- except Exception as e:
74
- print(f"Error in predict function: {e}")
75
- import traceback
76
- traceback.print_exc()
77
- return {"error": str(e)}
78
 
79
  # Create the Gradio interface
80
  interface = gr.Interface(
81
  fn=predict,
82
- inputs=gr.Sketchpad(),
83
  outputs=gr.Label(num_top_classes=10)
84
  )
85
 
 
39
  # Prediction function
40
 
41
 
42
+ # Prediction function
43
+ def predict(image):
44
+ image = transform(image).unsqueeze(0) # Add batch dimension
45
+ with torch.no_grad():
46
+ output = model(image)
47
+ probabilities = nn.Softmax(dim=1)(output)
48
+ predicted_class = torch.argmax(probabilities, dim=1)
49
+ return {str(i): probabilities[0][i].item() for i in range(10)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Create the Gradio interface
52
  interface = gr.Interface(
53
  fn=predict,
54
+ inputs=gr.Sketchpad(type='pil'),
55
  outputs=gr.Label(num_top_classes=10)
56
  )
57