cycool29 commited on
Commit
3549285
·
1 Parent(s): 6afd133

Upload predict.py

Browse files
Files changed (1) hide show
  1. handetect/predict.py +2 -6
handetect/predict.py CHANGED
@@ -6,17 +6,13 @@ from PIL import Image
6
  from handetect.models import *
7
  from torchmetrics import ConfusionMatrix
8
  import matplotlib.pyplot as plt
9
- import pathlib
10
- import sys
11
 
12
  # Define the path to your model checkpoint
13
  model_checkpoint_path = "model.pth"
14
 
15
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
 
17
- NUM_CLASSES = len(
18
- os.listdir(r"C:\Users\User\Documents\PISTEK\HANDETECT\data\train\Task 1")
19
- ) # Update with the correct number of classes
20
 
21
  # Define transformation for preprocessing the input image
22
  preprocess = transforms.Compose(
@@ -41,7 +37,7 @@ torch.set_grad_enabled(False)
41
 
42
  def predict_image(image_path, model=model, transform=preprocess):
43
  # Define images variable to recursively list all the data file in the image_path
44
- classes = os.listdir(r"C:\Users\User\Documents\PISTEK\HANDETECT\data\train\Task 1")
45
 
46
  print("---------------------------")
47
  print("Image path:", image_path)
 
6
  from handetect.models import *
7
  from torchmetrics import ConfusionMatrix
8
  import matplotlib.pyplot as plt
 
 
9
 
10
  # Define the path to your model checkpoint
11
  model_checkpoint_path = "model.pth"
12
 
13
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
 
15
+ NUM_CLASSES = 6
 
 
16
 
17
  # Define transformation for preprocessing the input image
18
  preprocess = transforms.Compose(
 
37
 
38
  def predict_image(image_path, model=model, transform=preprocess):
39
  # Define images variable to recursively list all the data file in the image_path
40
+ classes = ['Cerebral Palsy', 'Dystonia', 'Essential Tremor', 'Healthy', 'Huntington Disease', 'Parkinson Disease']
41
 
42
  print("---------------------------")
43
  print("Image path:", image_path)