Korbd's picture
Create app.py
4638390 verified
Raw
History Blame Contribute Delete
2.77 kB
import gradio as gr
from gradio_client import Client, handle_file
from PIL import Image
import tempfile
# --------------------------------------------------------
# API Clients
# --------------------------------------------------------
det_client = Client("Korbd/object_detection_model")
cls_client = Client("Korbd/image_classification_model")
# --------------------------------------------------------
# Main Prediction Function
# --------------------------------------------------------
def run_prediction(image, task, threshold):
if image is None:
return None, None
# Save to temp file (same as your Streamlit version)
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
Image.fromarray(image).save(tmp.name)
tmp_path = tmp.name
classification_result = None
detection_result = None
# ---- Classification ----
if task in ["Classification", "Both"]:
try:
classification_result = cls_client.predict(
image=handle_file(tmp_path),
api_name="/predict"
)
except Exception as e:
classification_result = {"error": str(e)}
# ---- Object Detection ----
if task in ["Object Detection", "Both"]:
try:
detection_result = det_client.predict(
image=handle_file(tmp_path),
api_name="/detect_objects"
)
except Exception as e:
detection_result = {"error": str(e)}
return classification_result, detection_result
# --------------------------------------------------------
# UI
# --------------------------------------------------------
with gr.Blocks(title="Image Predictor") as demo:
gr.Markdown(
"""
# ๐Ÿ“ธ Image Predictor โ€” Classification & Detection
---
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="numpy",
label="Upload an Image"
)
task = gr.Radio(
["Classification", "Object Detection", "Both"],
value="Both",
label="Task",
)
threshold = gr.Slider(
0.0, 1.0, 0.5,
step=0.01,
label="Detection Score Threshold"
)
run_button = gr.Button("Run Prediction", variant="primary")
with gr.Column(scale=1):
class_output = gr.JSON(label="๐Ÿ”ต Classification Result")
detect_output = gr.JSON(label="๐ŸŸ  Object Detection Result")
run_button.click(
fn=run_prediction,
inputs=[image_input, task, threshold],
outputs=[class_output, detect_output]
)
demo.launch()