djairbee5 commited on
Commit
c152ea8
1 Parent(s): f32f1b1
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -5,6 +5,7 @@ import torch.nn as nn
5
  from torchvision import transforms
6
  import torchvision.models as models
7
  from PIL import Image
 
8
 
9
  # Function to load class names from a file
10
  def load_class_names(file_path):
@@ -38,7 +39,10 @@ val_transform = transforms.Compose([
38
 
39
  # Define the prediction function
40
  def predict_image(image):
41
- image = Image.open(image).convert('RGB')
 
 
 
42
  image = val_transform(image)
43
  image = image.unsqueeze(0) # Add batch dimension
44
 
 
5
  from torchvision import transforms
6
  import torchvision.models as models
7
  from PIL import Image
8
+ import numpy as np
9
 
10
  # Function to load class names from a file
11
  def load_class_names(file_path):
 
39
 
40
  # Define the prediction function
41
  def predict_image(image):
42
+ # Convert the NumPy array to a PIL Image
43
+ if isinstance(image, np.ndarray):
44
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
45
+
46
  image = val_transform(image)
47
  image = image.unsqueeze(0) # Add batch dimension
48