Spaces:
Sleeping
Sleeping
bhaveshgoel07
commited on
Commit
•
790c8e6
1
Parent(s):
504f5bf
Fixed errors
Browse files
app.py
CHANGED
@@ -39,18 +39,26 @@ transform = transforms.Compose([
|
|
39 |
# Prediction function
|
40 |
|
41 |
|
42 |
-
|
43 |
-
def predict(image):
|
44 |
try:
|
45 |
-
#
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
51 |
|
52 |
# Convert to tensor and add batch and channel dimensions
|
53 |
-
tensor_image = torch.tensor(
|
54 |
|
55 |
# Normalize the image
|
56 |
tensor_image = tensor_image / 255.0 # Scale to [0, 1]
|
|
|
39 |
# Prediction function
|
40 |
|
41 |
|
42 |
+
def predict(sketch):
|
|
|
43 |
try:
|
44 |
+
# Extract the image data from the dictionary
|
45 |
+
image = sketch['image']
|
46 |
+
|
47 |
+
# Convert the image to a numpy array
|
48 |
+
image_array = np.array(image)
|
49 |
+
|
50 |
+
# Ensure the image is 2D (grayscale)
|
51 |
+
if len(image_array.shape) == 3:
|
52 |
+
image_array = image_array[:,:,0] # Take the first channel if it's RGB
|
53 |
|
54 |
+
# Resize to 28x28 if necessary
|
55 |
+
if image_array.shape != (28, 28):
|
56 |
+
pil_image = Image.fromarray(image_array.astype('uint8'))
|
57 |
+
pil_image = pil_image.resize((28, 28))
|
58 |
+
image_array = np.array(pil_image)
|
59 |
|
60 |
# Convert to tensor and add batch and channel dimensions
|
61 |
+
tensor_image = torch.tensor(image_array, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
62 |
|
63 |
# Normalize the image
|
64 |
tensor_image = tensor_image / 255.0 # Scale to [0, 1]
|