bhaveshgoel07 commited on
Commit
1210b12
1 Parent(s): 43be314

Fixed errors

Browse files
Files changed (2) hide show
  1. app.py +14 -5
  2. requirements.txt +3 -1
app.py CHANGED
@@ -2,7 +2,8 @@ import torch
2
  import torch.nn as nn
3
  import torchvision.transforms as transforms
4
  import gradio as gr
5
-
 
6
  # Define the CNN
7
  class SimpleCNN(nn.Module):
8
  def __init__(self):
@@ -36,14 +37,22 @@ transform = transforms.Compose([
36
  ])
37
 
38
  # Prediction function
39
- def predict(image):
 
 
40
  try:
 
 
 
 
41
  image = transform(image).unsqueeze(0) # Add batch dimension
 
42
  with torch.no_grad():
43
  output = model(image)
44
  probabilities = nn.Softmax(dim=1)(output)
45
- predicted_class = torch.argmax(probabilities, dim=1)
46
- return {str(i): probabilities[0][i].item() for i in range(10)}
 
47
  except Exception as e:
48
  print(f"Error in predict function: {e}")
49
  return {"error": str(e)}
@@ -52,7 +61,7 @@ def predict(image):
52
  interface = gr.Interface(
53
  fn=predict,
54
  inputs=gr.Sketchpad(),
55
- outputs=gr.Label()
56
  )
57
 
58
  # Launch the interface
 
2
  import torch.nn as nn
3
  import torchvision.transforms as transforms
4
  import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
  # Define the CNN
8
  class SimpleCNN(nn.Module):
9
  def __init__(self):
 
37
  ])
38
 
39
  # Prediction function
40
+
41
+
42
+ def predict(image_dict):
43
  try:
44
+ # Convert the dictionary to a PIL Image
45
+ image = Image.fromarray(np.uint8(image_dict["image"])).convert('RGB')
46
+
47
+ # Apply the transformation
48
  image = transform(image).unsqueeze(0) # Add batch dimension
49
+
50
  with torch.no_grad():
51
  output = model(image)
52
  probabilities = nn.Softmax(dim=1)(output)
53
+
54
+ # Return the probabilities as a dictionary
55
+ return {str(i): float(probabilities[0][i]) for i in range(10)}
56
  except Exception as e:
57
  print(f"Error in predict function: {e}")
58
  return {"error": str(e)}
 
61
  interface = gr.Interface(
62
  fn=predict,
63
  inputs=gr.Sketchpad(),
64
+ outputs=gr.Label(num_top_classes=10)
65
  )
66
 
67
  # Launch the interface
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  torch
2
  torchvision
3
- gradio
 
 
 
1
  torch
2
  torchvision
3
+ gradio
4
+ numpy
5
+ PIL