ryanwang058 commited on
Commit
8cff122
·
1 Parent(s): eddc864

Fix model input

Browse files
Files changed (1) hide show
  1. plant_disease_classifier.py +10 -6
plant_disease_classifier.py CHANGED
@@ -94,12 +94,16 @@ class PlantDiseaseClassifier:
94
  accuracy = (correct / total) * 100 if total > 0 else 0.0
95
  return accuracy
96
 
97
- def predict_image(self, image_path):
98
- # Load and transform the image
99
- image = Image.open(image_path).convert('RGB')
100
- transformed_image = self.data_transforms(image).unsqueeze(0).to(self.device)
101
-
102
- # Predict
 
 
 
 
103
  with torch.no_grad():
104
  outputs = self.model(transformed_image)
105
  logits = outputs.logits if self.model_type in ["levit"] else outputs
 
94
  accuracy = (correct / total) * 100 if total > 0 else 0.0
95
  return accuracy
96
 
97
+ def predict(self, image):
98
+ # Ensure the image is in RGB format if not already
99
+ if image.mode != "RGB":
100
+ image = image.convert("RGB")
101
+
102
+ # Transform the image to match the model's input requirements
103
+ transformed_image = self.data_transforms(image).unsqueeze(0)
104
+ transformed_image = transformed_image.to(self.device)
105
+
106
+ # Make prediction
107
  with torch.no_grad():
108
  outputs = self.model(transformed_image)
109
  logits = outputs.logits if self.model_type in ["levit"] else outputs