SivaResearch's picture
Added file for initial run
0bed063
import gradio as gr
from ultralyticsplus import YOLO, render_result, postprocess_classify_output
def classification(image, threshold):
model = YOLO('yolov8n-cls.pt')
model.overrides['conf'] = threshold
# result = model('bus.jpg')
result = model.predict(image)
render = postprocess_classify_output(model=model, result=result[0])
return render
def detection(image, threshold):
model = YOLO('yolov8n.pt')
model.overrides['conf'] = threshold
results = model.predict(image)
render = render_result(model=model, image=image, result=results[0])
return render
def segmentation(image, threshold):
model = YOLO('yolov8n-seg.pt')
model.overrides['conf'] = threshold
results = model.predict(image)
render = render_result(model=model, image=image, result=results[0])
return render
with gr.Blocks() as demo:
with gr.Tab("Detection"):
with gr.Row():
with gr.Column():
detect_input = gr.Image()
detect_threshold = gr.Slider(
maximum=1,
step=0.01,
value=0.25,
label="Threshold:",
interactive=True)
detect_button = gr.Button("Detect!")
with gr.Column():
detect_output = gr.Image(
label="Predictions:", interactive=False)
with gr.Tab("Segmentation"):
with gr.Row():
with gr.Column():
segment_input = gr.Image()
segment_threshold = gr.Slider(
maximum=1,
step=0.01,
value=0.25,
label="Threshold:",
interactive=True)
segment_button = gr.Button("Segment!")
with gr.Column():
segment_output = gr.Image(
label="Predictions:", interactive=False)
with gr.Tab("Classification"):
with gr.Row():
with gr.Column():
classify_input = gr.Image()
classify_threshold = gr.Slider(
maximum=1,
step=0.01,
value=0.25,
label="Threshold:",
interactive=True)
classify_button = gr.Button("Classify!")
with gr.Column():
classify_output = gr.Label(
label="Predictions:", show_label=True, num_top_classes=5)
detect_button.click(
detection,
inputs=[
detect_input,
detect_threshold],
outputs=detect_output,
api_name="Detect")
segment_button.click(
segmentation,
inputs=[
segment_input,
segment_threshold],
outputs=segment_output,
api_name="Segmentation")
classify_button.click(
classification,
inputs=[
classify_input,
classify_threshold],
outputs=classify_output,
api_name="classify")
demo.launch(debug=True, enable_queue=True)