import os from pathlib import Path import gradio as gr from datasets import load_dataset from ultralyticsplus import YOLO, render_result, postprocess_classify_output from utils import load_models_from_txt_files, get_dataset_id_from_model_id, get_task_from_readme EXAMPLE_IMAGE_DIR = 'example_images' DEFAULT_DET_MODEL_ID = 'keremberke/yolov8m-valorant-detection' DEFAULT_DET_DATASET_ID = 'keremberke/valorant-object-detection' DEFAULT_SEG_MODEL_ID = 'keremberke/yolov8s-building-segmentation' DEFAULT_SEG_DATASET_ID = 'keremberke/satellite-building-segmentation' DEFAULT_CLS_MODEL_ID = 'keremberke/yolov8m-chest-xray-classification' DEFAULT_CLS_DATASET_ID = 'keremberke/chest-xray-classification' # load model ids and default models det_model_ids, seg_model_ids, cls_model_ids = load_models_from_txt_files() task_to_model_ids = {'detect': det_model_ids, 'segment': seg_model_ids, 'classify': cls_model_ids} det_model = YOLO(DEFAULT_DET_MODEL_ID) det_model_id = DEFAULT_DET_MODEL_ID seg_model = YOLO(DEFAULT_SEG_MODEL_ID) seg_model_id = DEFAULT_SEG_MODEL_ID cls_model = YOLO(DEFAULT_CLS_MODEL_ID) cls_model_id = DEFAULT_CLS_MODEL_ID def get_examples(task): examples = [] Path(EXAMPLE_IMAGE_DIR).mkdir(parents=True, exist_ok=True) image_ind = 0 for model_id in task_to_model_ids[task]: dataset_id = get_dataset_id_from_model_id(model_id) ds = load_dataset(dataset_id, name="mini")["validation"] for ind in range(min(2, len(ds))): jpeg_image_file = ds[ind]["image"] image_file_path = str(Path(EXAMPLE_IMAGE_DIR) / f"{task}_example_{image_ind}.jpg") jpeg_image_file.save(image_file_path, format='JPEG', quality=100) image_path = os.path.abspath(image_file_path) examples.append([image_path, model_id, 0.25]) image_ind += 1 return examples # load default examples using default datasets det_examples = get_examples('detect') seg_examples = get_examples('segment') cls_examples = get_examples('classify') def predict(image, model_id, threshold): """Perform inference on image.""" # set task if model_id in det_model_ids: task = 'detect' elif model_id in seg_model_ids: task = 'segment' elif model_id in cls_model_ids: task = 'classify' else: raise ValueError(f"Invalid model_id: {model_id}") # set model if task == 'detect': global det_model global det_model_id if model_id != det_model_id: det_model = YOLO(model_id) det_model_id = model_id model = det_model elif task == 'segment': global seg_model global seg_model_id if model_id != seg_model_id: seg_model = YOLO(model_id) seg_model_id = model_id model = seg_model elif task == 'classify': global cls_model global cls_model_id if model_id != cls_model_id: cls_model = YOLO(model_id) cls_model_id = model_id model = cls_model else: raise ValueError(f"Invalid task: {task}") # set model parameters model.overrides['conf'] = threshold # perform inference results = model.predict(image) print(model_id) print(task) if task in ['detect', 'segment']: # draw predictions output = render_result(model=model, image=image, result=results[0]) elif task == 'classify': # postprocess classification output output = postprocess_classify_output(model, result=results[0]) else: raise ValueError(f"Invalid task: {task}") return output with gr.Blocks() as demo: gr.Markdown("""#
""") with gr.Tab("Detection"): with gr.Row(): with gr.Column(): detect_input = gr.Image() detect_model_id = gr.Dropdown(choices=det_model_ids, label="Model:", value=DEFAULT_DET_MODEL_ID, interactive=True) detect_threshold = gr.Slider(maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True) detect_button = gr.Button("Detect!") with gr.Column(): detect_output = gr.Image(label="Predictions:", interactive=False) with gr.Row(): half_ind = int(len(det_examples) / 2) with gr.Column(): detect_examples = gr.Examples( det_examples[:half_ind], inputs=[detect_input, detect_model_id, detect_threshold], outputs=detect_output, fn=predict, cache_examples=False, ) with gr.Column(): detect_examples = gr.Examples( det_examples[:half_ind], inputs=[detect_input, detect_model_id, detect_threshold], outputs=detect_output, fn=predict, cache_examples=False, ) with gr.Tab("Segmentation"): with gr.Row(): with gr.Column(): segment_input = gr.Image() segment_model_id = gr.Dropdown(choices=seg_model_ids, label="Model:", value=DEFAULT_SEG_MODEL_ID, interactive=True) segment_threshold = gr.Slider(maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True) segment_button = gr.Button("Segment!") with gr.Column(): segment_output = gr.Image(label="Predictions:", interactive=False) with gr.Row(): half_ind = int(len(det_examples) / 2) with gr.Column(): segment_examples = gr.Examples( seg_examples[:half_ind], inputs=[segment_input, segment_model_id, segment_threshold], outputs=segment_output, fn=predict, cache_examples=False, ) with gr.Column(): segment_examples = gr.Examples( seg_examples[:half_ind], inputs=[segment_input, segment_model_id, segment_threshold], outputs=segment_output, fn=predict, cache_examples=False, ) with gr.Tab("Classification"): with gr.Row(): with gr.Column(): classify_input = gr.Image() classify_model_id = gr.Dropdown(choices=cls_model_ids, label="Model:", value=DEFAULT_CLS_MODEL_ID, interactive=True) classify_threshold = gr.Slider(maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True) classify_button = gr.Button("Classify!") with gr.Column(): classify_output = gr.Label( label="Predictions:", show_label=True, num_top_classes=5 ) with gr.Row(): half_ind = int(len(det_examples) / 2) with gr.Column(): classify_examples = gr.Examples( cls_examples[half_ind:], inputs=[classify_input, classify_model_id, classify_threshold], outputs=classify_output, fn=predict, cache_examples=False, ) with gr.Column(): classify_examples = gr.Examples( cls_examples[:half_ind], inputs=[classify_input, classify_model_id, classify_threshold], outputs=classify_output, fn=predict, cache_examples=False, ) detect_button.click( predict, inputs=[detect_input, detect_model_id, detect_threshold], outputs=detect_output ) segment_button.click( predict, inputs=[segment_input, segment_model_id, segment_threshold], outputs=segment_output ) classify_button.click( predict, inputs=[classify_input, classify_model_id, classify_threshold], outputs=classify_output ) demo.launch(server_port=8080)