Spaces:
Sleeping
Sleeping
bhaveshgoel07
commited on
Commit
•
7355da1
1
Parent(s):
790c8e6
Fixed errors
Browse files
app.py
CHANGED
@@ -39,47 +39,19 @@ transform = transforms.Compose([
|
|
39 |
# Prediction function
|
40 |
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
# Ensure the image is 2D (grayscale)
|
51 |
-
if len(image_array.shape) == 3:
|
52 |
-
image_array = image_array[:,:,0] # Take the first channel if it's RGB
|
53 |
-
|
54 |
-
# Resize to 28x28 if necessary
|
55 |
-
if image_array.shape != (28, 28):
|
56 |
-
pil_image = Image.fromarray(image_array.astype('uint8'))
|
57 |
-
pil_image = pil_image.resize((28, 28))
|
58 |
-
image_array = np.array(pil_image)
|
59 |
-
|
60 |
-
# Convert to tensor and add batch and channel dimensions
|
61 |
-
tensor_image = torch.tensor(image_array, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
62 |
-
|
63 |
-
# Normalize the image
|
64 |
-
tensor_image = tensor_image / 255.0 # Scale to [0, 1]
|
65 |
-
tensor_image = (tensor_image - 0.5) / 0.5 # Normalize to [-1, 1]
|
66 |
-
|
67 |
-
with torch.no_grad():
|
68 |
-
output = model(tensor_image)
|
69 |
-
probabilities = nn.Softmax(dim=1)(output)
|
70 |
-
|
71 |
-
# Return the probabilities as a dictionary
|
72 |
-
return {str(i): float(probabilities[0][i]) for i in range(10)}
|
73 |
-
except Exception as e:
|
74 |
-
print(f"Error in predict function: {e}")
|
75 |
-
import traceback
|
76 |
-
traceback.print_exc()
|
77 |
-
return {"error": str(e)}
|
78 |
|
79 |
# Create the Gradio interface
|
80 |
interface = gr.Interface(
|
81 |
fn=predict,
|
82 |
-
inputs=gr.Sketchpad(),
|
83 |
outputs=gr.Label(num_top_classes=10)
|
84 |
)
|
85 |
|
|
|
39 |
# Prediction function
|
40 |
|
41 |
|
42 |
+
# Prediction function
|
43 |
+
def predict(image):
|
44 |
+
image = transform(image).unsqueeze(0) # Add batch dimension
|
45 |
+
with torch.no_grad():
|
46 |
+
output = model(image)
|
47 |
+
probabilities = nn.Softmax(dim=1)(output)
|
48 |
+
predicted_class = torch.argmax(probabilities, dim=1)
|
49 |
+
return {str(i): probabilities[0][i].item() for i in range(10)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
# Create the Gradio interface
|
52 |
interface = gr.Interface(
|
53 |
fn=predict,
|
54 |
+
inputs=gr.Sketchpad(type='pil'),
|
55 |
outputs=gr.Label(num_top_classes=10)
|
56 |
)
|
57 |
|