KhadijaAsehnoune12 commited on
Commit
4755ddc
1 Parent(s): 714d90d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -5
app.py CHANGED
@@ -20,13 +20,28 @@ pipe = pipeline(task="image-classification", model=model_repo)
20
  # Define a custom prediction function
21
  import numpy as np
22
 
23
- # Define a custom prediction function
 
 
 
 
24
  def predict(image):
25
- # Convert the image to numpy array if it's not already in that format
26
- if not isinstance(image, np.ndarray):
27
- raise ValueError("Input image must be a numpy array.")
 
 
 
 
 
 
 
 
 
 
 
28
  # Get the predictions from the pipeline
29
- predictions = pipe(image)
30
  # Get the predicted label index
31
  predicted_index = predictions[0]['label']
32
  # Map the index to the corresponding disease name using id2label
@@ -36,6 +51,7 @@ def predict(image):
36
  return f"{label_name} ({confidence_score:.2f})"
37
 
38
 
 
39
  # Create Gradio interface
40
  iface = gr.Interface(fn=predict,
41
  inputs=gr.Image(type="numpy"),
 
20
  # Define a custom prediction function
21
  import numpy as np
22
 
23
+ from PIL import Image
24
+ from io import BytesIO
25
+ import numpy as np
26
+ import base64
27
+
28
  def predict(image):
29
+ # Check if the input image is a base64 encoded string
30
+ if isinstance(image, str):
31
+ # Decode the base64 encoded image string and convert it to a PIL image object
32
+ image_data = BytesIO(base64.b64decode(image))
33
+ pil_image = Image.open(image_data)
34
+ # Convert the PIL image to a numpy array
35
+ image_np = np.array(pil_image)
36
+ elif isinstance(image, np.ndarray):
37
+ # If the input image is already a numpy array, use it directly
38
+ image_np = image
39
+ else:
40
+ # If the input is neither a base64 encoded string nor a numpy array, raise an error
41
+ raise ValueError("Input image must be either a base64 encoded string or a numpy array.")
42
+
43
  # Get the predictions from the pipeline
44
+ predictions = pipe(image_np)
45
  # Get the predicted label index
46
  predicted_index = predictions[0]['label']
47
  # Map the index to the corresponding disease name using id2label
 
51
  return f"{label_name} ({confidence_score:.2f})"
52
 
53
 
54
+
55
  # Create Gradio interface
56
  iface = gr.Interface(fn=predict,
57
  inputs=gr.Image(type="numpy"),