Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from gradio_client import Client, handle_file | |
| from PIL import Image | |
| import tempfile | |
| # -------------------------------------------------------- | |
| # API Clients | |
| # -------------------------------------------------------- | |
| det_client = Client("Korbd/object_detection_model") | |
| cls_client = Client("Korbd/image_classification_model") | |
| # -------------------------------------------------------- | |
| # Main Prediction Function | |
| # -------------------------------------------------------- | |
| def run_prediction(image, task, threshold): | |
| if image is None: | |
| return None, None | |
| # Save to temp file (same as your Streamlit version) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
| Image.fromarray(image).save(tmp.name) | |
| tmp_path = tmp.name | |
| classification_result = None | |
| detection_result = None | |
| # ---- Classification ---- | |
| if task in ["Classification", "Both"]: | |
| try: | |
| classification_result = cls_client.predict( | |
| image=handle_file(tmp_path), | |
| api_name="/predict" | |
| ) | |
| except Exception as e: | |
| classification_result = {"error": str(e)} | |
| # ---- Object Detection ---- | |
| if task in ["Object Detection", "Both"]: | |
| try: | |
| detection_result = det_client.predict( | |
| image=handle_file(tmp_path), | |
| api_name="/detect_objects" | |
| ) | |
| except Exception as e: | |
| detection_result = {"error": str(e)} | |
| return classification_result, detection_result | |
| # -------------------------------------------------------- | |
| # UI | |
| # -------------------------------------------------------- | |
| with gr.Blocks(title="Image Predictor") as demo: | |
| gr.Markdown( | |
| """ | |
| # ๐ธ Image Predictor โ Classification & Detection | |
| --- | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| type="numpy", | |
| label="Upload an Image" | |
| ) | |
| task = gr.Radio( | |
| ["Classification", "Object Detection", "Both"], | |
| value="Both", | |
| label="Task", | |
| ) | |
| threshold = gr.Slider( | |
| 0.0, 1.0, 0.5, | |
| step=0.01, | |
| label="Detection Score Threshold" | |
| ) | |
| run_button = gr.Button("Run Prediction", variant="primary") | |
| with gr.Column(scale=1): | |
| class_output = gr.JSON(label="๐ต Classification Result") | |
| detect_output = gr.JSON(label="๐ Object Detection Result") | |
| run_button.click( | |
| fn=run_prediction, | |
| inputs=[image_input, task, threshold], | |
| outputs=[class_output, detect_output] | |
| ) | |
| demo.launch() | |