Spaces:
Runtime error
Runtime error
Upload predict.py
Browse files- handetect/predict.py +2 -6
handetect/predict.py
CHANGED
@@ -6,17 +6,13 @@ from PIL import Image
|
|
6 |
from handetect.models import *
|
7 |
from torchmetrics import ConfusionMatrix
|
8 |
import matplotlib.pyplot as plt
|
9 |
-
import pathlib
|
10 |
-
import sys
|
11 |
|
12 |
# Define the path to your model checkpoint
|
13 |
model_checkpoint_path = "model.pth"
|
14 |
|
15 |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
16 |
|
17 |
-
NUM_CLASSES =
|
18 |
-
os.listdir(r"C:\Users\User\Documents\PISTEK\HANDETECT\data\train\Task 1")
|
19 |
-
) # Update with the correct number of classes
|
20 |
|
21 |
# Define transformation for preprocessing the input image
|
22 |
preprocess = transforms.Compose(
|
@@ -41,7 +37,7 @@ torch.set_grad_enabled(False)
|
|
41 |
|
42 |
def predict_image(image_path, model=model, transform=preprocess):
|
43 |
# Define images variable to recursively list all the data file in the image_path
|
44 |
-
classes =
|
45 |
|
46 |
print("---------------------------")
|
47 |
print("Image path:", image_path)
|
|
|
6 |
from handetect.models import *
|
7 |
from torchmetrics import ConfusionMatrix
|
8 |
import matplotlib.pyplot as plt
|
|
|
|
|
9 |
|
10 |
# Define the path to your model checkpoint
|
11 |
model_checkpoint_path = "model.pth"
|
12 |
|
13 |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
14 |
|
15 |
+
NUM_CLASSES = 6
|
|
|
|
|
16 |
|
17 |
# Define transformation for preprocessing the input image
|
18 |
preprocess = transforms.Compose(
|
|
|
37 |
|
38 |
def predict_image(image_path, model=model, transform=preprocess):
|
39 |
# Define images variable to recursively list all the data file in the image_path
|
40 |
+
classes = ['Cerebral Palsy', 'Dystonia', 'Essential Tremor', 'Healthy', 'Huntington Disease', 'Parkinson Disease']
|
41 |
|
42 |
print("---------------------------")
|
43 |
print("Image path:", image_path)
|