Spaces:
Runtime error
Runtime error
import os | |
from pathlib import Path | |
import gradio as gr | |
from datasets import load_dataset | |
from ultralyticsplus import YOLO, render_result, postprocess_classify_output | |
from utils import load_models_from_txt_files, get_dataset_id_from_model_id, get_task_from_readme | |
EXAMPLE_IMAGE_DIR = 'example_images' | |
DEFAULT_DET_MODEL_ID = 'keremberke/yolov8m-valorant-detection' | |
DEFAULT_DET_DATASET_ID = 'keremberke/valorant-object-detection' | |
DEFAULT_SEG_MODEL_ID = 'keremberke/yolov8s-building-segmentation' | |
DEFAULT_SEG_DATASET_ID = 'keremberke/satellite-building-segmentation' | |
DEFAULT_CLS_MODEL_ID = 'keremberke/yolov8m-chest-xray-classification' | |
DEFAULT_CLS_DATASET_ID = 'keremberke/chest-xray-classification' | |
# load model ids and default models | |
det_model_ids, seg_model_ids, cls_model_ids = load_models_from_txt_files() | |
task_to_model_ids = {'detect': det_model_ids, 'segment': seg_model_ids, 'classify': cls_model_ids} | |
det_model = YOLO(DEFAULT_DET_MODEL_ID) | |
det_model_id = DEFAULT_DET_MODEL_ID | |
seg_model = YOLO(DEFAULT_SEG_MODEL_ID) | |
seg_model_id = DEFAULT_SEG_MODEL_ID | |
cls_model = YOLO(DEFAULT_CLS_MODEL_ID) | |
cls_model_id = DEFAULT_CLS_MODEL_ID | |
def get_examples(task): | |
examples = [] | |
Path(EXAMPLE_IMAGE_DIR).mkdir(parents=True, exist_ok=True) | |
image_ind = 0 | |
for model_id in task_to_model_ids[task]: | |
dataset_id = get_dataset_id_from_model_id(model_id) | |
ds = load_dataset(dataset_id, name="mini")["validation"] | |
for ind in range(min(2, len(ds))): | |
jpeg_image_file = ds[ind]["image"] | |
image_file_path = str(Path(EXAMPLE_IMAGE_DIR) / f"{task}_example_{image_ind}.jpg") | |
jpeg_image_file.save(image_file_path, format='JPEG', quality=100) | |
image_path = os.path.abspath(image_file_path) | |
examples.append([image_path, model_id, 0.25]) | |
image_ind += 1 | |
return examples | |
# load default examples using default datasets | |
det_examples = get_examples('detect') | |
seg_examples = get_examples('segment') | |
cls_examples = get_examples('classify') | |
def predict(image, model_id, threshold): | |
"""Perform inference on image.""" | |
# set task | |
if model_id in det_model_ids: | |
task = 'detect' | |
elif model_id in seg_model_ids: | |
task = 'segment' | |
elif model_id in cls_model_ids: | |
task = 'classify' | |
else: | |
raise ValueError(f"Invalid model_id: {model_id}") | |
# set model | |
if task == 'detect': | |
global det_model | |
global det_model_id | |
if model_id != det_model_id: | |
det_model = YOLO(model_id) | |
det_model_id = model_id | |
model = det_model | |
elif task == 'segment': | |
global seg_model | |
global seg_model_id | |
if model_id != seg_model_id: | |
seg_model = YOLO(model_id) | |
seg_model_id = model_id | |
model = seg_model | |
elif task == 'classify': | |
global cls_model | |
global cls_model_id | |
if model_id != cls_model_id: | |
cls_model = YOLO(model_id) | |
cls_model_id = model_id | |
model = cls_model | |
else: | |
raise ValueError(f"Invalid task: {task}") | |
# set model parameters | |
model.overrides['conf'] = threshold | |
# perform inference | |
results = model.predict(image) | |
print(model_id) | |
print(task) | |
if task in ['detect', 'segment']: | |
# draw predictions | |
output = render_result(model=model, image=image, result=results[0]) | |
elif task == 'classify': | |
# postprocess classification output | |
output = postprocess_classify_output(model, result=results[0]) | |
else: | |
raise ValueError(f"Invalid task: {task}") | |
return output | |
with gr.Blocks() as demo: | |
gr.Markdown("""# <p align='center'><img width='500px' src='https://user-images.githubusercontent.com/34196005/215836968-fb54e066-a524-4caf-b469-92bbaa96f921.gif' /></p> | |
<p style='text-align: center'> | |
<br> <a href='https://yolov8.xyz' target='_blank'>project website</a> | <a href='https://github.com/keremberke/awesome-yolov8-models' target='_blank'>project github</a> | |
</p> | |
<p style='text-align: center'> | |
Follow me for more! | |
<br> <a href='https://twitter.com/_keremberke' target='_blank'>twitter</a> | <a href='https://github.com/keremberke' target='_blank'>github</a> | <a href='https://www.linkedin.com/in/kerem-berke-bba6a5204/' target='_blank'>linkedin</a> | |
</p> | |
""") | |
with gr.Tab("Detection"): | |
with gr.Row(): | |
with gr.Column(): | |
detect_input = gr.Image() | |
detect_model_id = gr.Dropdown(choices=det_model_ids, label="Model:", value=DEFAULT_DET_MODEL_ID, interactive=True) | |
detect_threshold = gr.Slider(maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True) | |
detect_button = gr.Button("Detect!") | |
with gr.Column(): | |
detect_output = gr.Image(label="Predictions:", interactive=False) | |
with gr.Row(): | |
half_ind = int(len(det_examples) / 2) | |
with gr.Column(): | |
detect_examples = gr.Examples( | |
det_examples[:half_ind], | |
inputs=[detect_input, detect_model_id, detect_threshold], | |
outputs=detect_output, | |
fn=predict, | |
cache_examples=False, | |
) | |
with gr.Column(): | |
detect_examples = gr.Examples( | |
det_examples[:half_ind], | |
inputs=[detect_input, detect_model_id, detect_threshold], | |
outputs=detect_output, | |
fn=predict, | |
cache_examples=False, | |
) | |
with gr.Tab("Segmentation"): | |
with gr.Row(): | |
with gr.Column(): | |
segment_input = gr.Image() | |
segment_model_id = gr.Dropdown(choices=seg_model_ids, label="Model:", value=DEFAULT_SEG_MODEL_ID, interactive=True) | |
segment_threshold = gr.Slider(maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True) | |
segment_button = gr.Button("Segment!") | |
with gr.Column(): | |
segment_output = gr.Image(label="Predictions:", interactive=False) | |
with gr.Row(): | |
half_ind = int(len(det_examples) / 2) | |
with gr.Column(): | |
segment_examples = gr.Examples( | |
seg_examples[:half_ind], | |
inputs=[segment_input, segment_model_id, segment_threshold], | |
outputs=segment_output, | |
fn=predict, | |
cache_examples=False, | |
) | |
with gr.Column(): | |
segment_examples = gr.Examples( | |
seg_examples[:half_ind], | |
inputs=[segment_input, segment_model_id, segment_threshold], | |
outputs=segment_output, | |
fn=predict, | |
cache_examples=False, | |
) | |
with gr.Tab("Classification"): | |
with gr.Row(): | |
with gr.Column(): | |
classify_input = gr.Image() | |
classify_model_id = gr.Dropdown(choices=cls_model_ids, label="Model:", value=DEFAULT_CLS_MODEL_ID, interactive=True) | |
classify_threshold = gr.Slider(maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True) | |
classify_button = gr.Button("Classify!") | |
with gr.Column(): | |
classify_output = gr.Label( | |
label="Predictions:", show_label=True, num_top_classes=5 | |
) | |
with gr.Row(): | |
half_ind = int(len(det_examples) / 2) | |
with gr.Column(): | |
classify_examples = gr.Examples( | |
cls_examples[half_ind:], | |
inputs=[classify_input, classify_model_id, classify_threshold], | |
outputs=classify_output, | |
fn=predict, | |
cache_examples=False, | |
) | |
with gr.Column(): | |
classify_examples = gr.Examples( | |
cls_examples[:half_ind], | |
inputs=[classify_input, classify_model_id, classify_threshold], | |
outputs=classify_output, | |
fn=predict, | |
cache_examples=False, | |
) | |
detect_button.click( | |
predict, inputs=[detect_input, detect_model_id, detect_threshold], outputs=detect_output | |
) | |
segment_button.click( | |
predict, inputs=[segment_input, segment_model_id, segment_threshold], outputs=segment_output | |
) | |
classify_button.click( | |
predict, inputs=[classify_input, classify_model_id, classify_threshold], outputs=classify_output | |
) | |
demo.launch(server_port=8080) |