import gradio as gr from gradio_client import Client, handle_file from PIL import Image import tempfile # -------------------------------------------------------- # API Clients # -------------------------------------------------------- det_client = Client("Korbd/object_detection_model") cls_client = Client("Korbd/image_classification_model") # -------------------------------------------------------- # Main Prediction Function # -------------------------------------------------------- def run_prediction(image, task, threshold): if image is None: return None, None # Save to temp file (same as your Streamlit version) with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: Image.fromarray(image).save(tmp.name) tmp_path = tmp.name classification_result = None detection_result = None # ---- Classification ---- if task in ["Classification", "Both"]: try: classification_result = cls_client.predict( image=handle_file(tmp_path), api_name="/predict" ) except Exception as e: classification_result = {"error": str(e)} # ---- Object Detection ---- if task in ["Object Detection", "Both"]: try: detection_result = det_client.predict( image=handle_file(tmp_path), api_name="/detect_objects" ) except Exception as e: detection_result = {"error": str(e)} return classification_result, detection_result # -------------------------------------------------------- # UI # -------------------------------------------------------- with gr.Blocks(title="Image Predictor") as demo: gr.Markdown( """ # 📸 Image Predictor — Classification & Detection --- """ ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( type="numpy", label="Upload an Image" ) task = gr.Radio( ["Classification", "Object Detection", "Both"], value="Both", label="Task", ) threshold = gr.Slider( 0.0, 1.0, 0.5, step=0.01, label="Detection Score Threshold" ) run_button = gr.Button("Run Prediction", variant="primary") with gr.Column(scale=1): class_output = gr.JSON(label="🔵 Classification Result") detect_output = gr.JSON(label="🟠 Object Detection Result") run_button.click( fn=run_prediction, inputs=[image_input, task, threshold], outputs=[class_output, detect_output] ) demo.launch()