Commit
•
714d90d
1
Parent(s):
fefd52e
Update app.py
Browse files
app.py
CHANGED
@@ -17,8 +17,14 @@ with open(config_path, "r", encoding="utf-8") as f:
|
|
17 |
# Initialize the pipeline
|
18 |
pipe = pipeline(task="image-classification", model=model_repo)
|
19 |
|
|
|
|
|
|
|
20 |
# Define a custom prediction function
|
21 |
def predict(image):
|
|
|
|
|
|
|
22 |
# Get the predictions from the pipeline
|
23 |
predictions = pipe(image)
|
24 |
# Get the predicted label index
|
@@ -29,6 +35,7 @@ def predict(image):
|
|
29 |
confidence_score = predictions[0]['score']
|
30 |
return f"{label_name} ({confidence_score:.2f})"
|
31 |
|
|
|
32 |
# Create Gradio interface
|
33 |
iface = gr.Interface(fn=predict,
|
34 |
inputs=gr.Image(type="numpy"),
|
|
|
17 |
# Initialize the pipeline
|
18 |
pipe = pipeline(task="image-classification", model=model_repo)
|
19 |
|
20 |
+
# Define a custom prediction function
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
# Define a custom prediction function
|
24 |
def predict(image):
|
25 |
+
# Convert the image to numpy array if it's not already in that format
|
26 |
+
if not isinstance(image, np.ndarray):
|
27 |
+
raise ValueError("Input image must be a numpy array.")
|
28 |
# Get the predictions from the pipeline
|
29 |
predictions = pipe(image)
|
30 |
# Get the predicted label index
|
|
|
35 |
confidence_score = predictions[0]['score']
|
36 |
return f"{label_name} ({confidence_score:.2f})"
|
37 |
|
38 |
+
|
39 |
# Create Gradio interface
|
40 |
iface = gr.Interface(fn=predict,
|
41 |
inputs=gr.Image(type="numpy"),
|