keremberke's picture
Update app.py
734f9b3
raw
history blame contribute delete
No virus
9.2 kB
import os
from pathlib import Path
import gradio as gr
from datasets import load_dataset
from ultralyticsplus import YOLO, render_result, postprocess_classify_output
from utils import load_models_from_txt_files, get_dataset_id_from_model_id, get_task_from_readme
EXAMPLE_IMAGE_DIR = 'example_images'
DEFAULT_DET_MODEL_ID = 'keremberke/yolov8m-valorant-detection'
DEFAULT_DET_DATASET_ID = 'keremberke/valorant-object-detection'
DEFAULT_SEG_MODEL_ID = 'keremberke/yolov8s-building-segmentation'
DEFAULT_SEG_DATASET_ID = 'keremberke/satellite-building-segmentation'
DEFAULT_CLS_MODEL_ID = 'keremberke/yolov8m-chest-xray-classification'
DEFAULT_CLS_DATASET_ID = 'keremberke/chest-xray-classification'
# load model ids and default models
det_model_ids, seg_model_ids, cls_model_ids = load_models_from_txt_files()
task_to_model_ids = {'detect': det_model_ids, 'segment': seg_model_ids, 'classify': cls_model_ids}
det_model = YOLO(DEFAULT_DET_MODEL_ID)
det_model_id = DEFAULT_DET_MODEL_ID
seg_model = YOLO(DEFAULT_SEG_MODEL_ID)
seg_model_id = DEFAULT_SEG_MODEL_ID
cls_model = YOLO(DEFAULT_CLS_MODEL_ID)
cls_model_id = DEFAULT_CLS_MODEL_ID
def get_examples(task):
examples = []
Path(EXAMPLE_IMAGE_DIR).mkdir(parents=True, exist_ok=True)
image_ind = 0
for model_id in task_to_model_ids[task]:
dataset_id = get_dataset_id_from_model_id(model_id)
ds = load_dataset(dataset_id, name="mini")["validation"]
for ind in range(min(2, len(ds))):
jpeg_image_file = ds[ind]["image"]
image_file_path = str(Path(EXAMPLE_IMAGE_DIR) / f"{task}_example_{image_ind}.jpg")
jpeg_image_file.save(image_file_path, format='JPEG', quality=100)
image_path = os.path.abspath(image_file_path)
examples.append([image_path, model_id, 0.25])
image_ind += 1
return examples
# load default examples using default datasets
det_examples = get_examples('detect')
seg_examples = get_examples('segment')
cls_examples = get_examples('classify')
def predict(image, model_id, threshold):
"""Perform inference on image."""
# set task
if model_id in det_model_ids:
task = 'detect'
elif model_id in seg_model_ids:
task = 'segment'
elif model_id in cls_model_ids:
task = 'classify'
else:
raise ValueError(f"Invalid model_id: {model_id}")
# set model
if task == 'detect':
global det_model
global det_model_id
if model_id != det_model_id:
det_model = YOLO(model_id)
det_model_id = model_id
model = det_model
elif task == 'segment':
global seg_model
global seg_model_id
if model_id != seg_model_id:
seg_model = YOLO(model_id)
seg_model_id = model_id
model = seg_model
elif task == 'classify':
global cls_model
global cls_model_id
if model_id != cls_model_id:
cls_model = YOLO(model_id)
cls_model_id = model_id
model = cls_model
else:
raise ValueError(f"Invalid task: {task}")
# set model parameters
model.overrides['conf'] = threshold
# perform inference
results = model.predict(image)
print(model_id)
print(task)
if task in ['detect', 'segment']:
# draw predictions
output = render_result(model=model, image=image, result=results[0])
elif task == 'classify':
# postprocess classification output
output = postprocess_classify_output(model, result=results[0])
else:
raise ValueError(f"Invalid task: {task}")
return output
with gr.Blocks() as demo:
gr.Markdown("""# <p align='center'><a href="https://github.com/keremberke/awesome-yolov8-models" target='_blank'><img width='500px' src='https://user-images.githubusercontent.com/34196005/215836968-fb54e066-a524-4caf-b469-92bbaa96f921.gif' /></a></p>
<p style='text-align: center'>
<br> <a href='https://yolov8.xyz' target='_blank'>project website</a> | <a href='https://github.com/keremberke/awesome-yolov8-models' target='_blank'>project github</a>
</p>
<p style='text-align: center'>
Follow me for more!
<br> <a href='https://twitter.com/_keremberke' target='_blank'>twitter</a> | <a href='https://github.com/keremberke' target='_blank'>github</a> | <a href='https://www.linkedin.com/in/kerem-berke-bba6a5204/' target='_blank'>linkedin</a>
</p>
""")
with gr.Tab("Detection"):
with gr.Row():
with gr.Column():
detect_input = gr.Image()
detect_model_id = gr.Dropdown(choices=det_model_ids, label="Model:", value=DEFAULT_DET_MODEL_ID, interactive=True)
detect_threshold = gr.Slider(maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True)
detect_button = gr.Button("Detect!")
with gr.Column():
detect_output = gr.Image(label="Predictions:", interactive=False)
with gr.Row():
half_ind = int(len(det_examples) / 2)
with gr.Column():
gr.Examples(
det_examples[half_ind:],
inputs=[detect_input, detect_model_id, detect_threshold],
outputs=detect_output,
fn=predict,
cache_examples=False,
run_on_click=False,
)
with gr.Column():
gr.Examples(
det_examples[:half_ind],
inputs=[detect_input, detect_model_id, detect_threshold],
outputs=detect_output,
fn=predict,
cache_examples=False,
run_on_click=False,
)
with gr.Tab("Segmentation"):
with gr.Row():
with gr.Column():
segment_input = gr.Image()
segment_model_id = gr.Dropdown(choices=seg_model_ids, label="Model:", value=DEFAULT_SEG_MODEL_ID, interactive=True)
segment_threshold = gr.Slider(maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True)
segment_button = gr.Button("Segment!")
with gr.Column():
segment_output = gr.Image(label="Predictions:", interactive=False)
with gr.Row():
half_ind = int(len(seg_examples) / 2)
with gr.Column():
gr.Examples(
seg_examples[half_ind:],
inputs=[segment_input, segment_model_id, segment_threshold],
outputs=segment_output,
fn=predict,
cache_examples=False,
run_on_click=False,
)
with gr.Column():
gr.Examples(
seg_examples[:half_ind],
inputs=[segment_input, segment_model_id, segment_threshold],
outputs=segment_output,
fn=predict,
cache_examples=False,
run_on_click=False,
)
with gr.Tab("Classification"):
with gr.Row():
with gr.Column():
classify_input = gr.Image()
classify_model_id = gr.Dropdown(choices=cls_model_ids, label="Model:", value=DEFAULT_CLS_MODEL_ID, interactive=True)
classify_threshold = gr.Slider(maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True)
classify_button = gr.Button("Classify!")
with gr.Column():
classify_output = gr.Label(
label="Predictions:", show_label=True, num_top_classes=5
)
with gr.Row():
half_ind = int(len(cls_examples) / 2)
with gr.Column():
gr.Examples(
cls_examples[half_ind:],
inputs=[classify_input, classify_model_id, classify_threshold],
outputs=classify_output,
fn=predict,
cache_examples=False,
run_on_click=False,
)
with gr.Column():
gr.Examples(
cls_examples[:half_ind],
inputs=[classify_input, classify_model_id, classify_threshold],
outputs=classify_output,
fn=predict,
cache_examples=False,
run_on_click=False,
)
detect_button.click(
predict, inputs=[detect_input, detect_model_id, detect_threshold], outputs=detect_output, api_name="detect"
)
segment_button.click(
predict, inputs=[segment_input, segment_model_id, segment_threshold], outputs=segment_output, api_name="segment"
)
classify_button.click(
predict, inputs=[classify_input, classify_model_id, classify_threshold], outputs=classify_output, api_name="classify"
)
demo.launch(enable_queue=True)