Pear-playground / app.py
Kaori1707's picture
Upload 10 files
69ef5c2 verified
raw
history blame
2.72 kB
import cv2
import gradio as gr
from detection import PearDetectionModel
from classification import predict
# make streaming interface that reads from camera and displays the output with bounding boxes
config = {"model_path": "./weights/best.pt", "classes": ['burn_bbox', 'defected_pear', 'defected_pear_bbox', 'normal_pear', 'normal_pear_bbox']}
model = PearDetectionModel(config)
def classify(image):
"""
Gradio์—์„œ PIL ์ด๋ฏธ์ง€๋ฅผ ์ž…๋ ฅ๋ฐ›์•„ ์ถ”๋ก  ๊ฒฐ๊ณผ๋ฅผ ๋ฐ˜ํ™˜.
Args:
image (PIL.Image): ์—…๋กœ๋“œ๋œ ์ด๋ฏธ์ง€.
Returns:
str: ๋ชจ๋ธ ์˜ˆ์ธก ๊ฒฐ๊ณผ.
"""
# ์ž„์‹œ ํŒŒ์ผ ์ €์žฅ ํ›„ ์ฒ˜๋ฆฌ
image_path = "temp_image.jpg"
image.save(image_path)
return predict(image_path)
def detect(img):
cls, xyxy, conf = model.inference(img)
for box, conf in zip(xyxy, conf):
cv2.rectangle(
img,
(int(box[0]), int(box[1])),
(int(box[2]), int(box[3])),
(0, 255, 0),
2,
)
cv2.putText(
img,
f"{conf:.2f}",
(int(box[0]), int(box[1])),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 0),
2,
)
cv2.putText(
img,
"Class: Normal Pear" if cls == 0 else "Class: Abnormal Pear",
(0, 50),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 0),
2,
)
return img
css = """.my-group {max-width: 500px !important; max-height: 500px !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks(css=css) as demo:
demo.title = "Pear Playground"
# add markdown
gr.Markdown("## This is a demo for Pear Playground by AISeed.")
with gr.Tab(label="Classification"):
gr.Interface(
fn=classify,
inputs=gr.Image(type="pil", label="Upload an image"),
outputs=gr.Label(num_top_classes=9),
examples=["examples/1.jpg", "examples/2.jpg"],
title="๋น„์ •์ƒ ๊ณผ์ˆ˜ ๋ถ„๋ฅ˜๊ธฐ",
description="๊ฒฝ๋Ÿ‰ ๋ชจ๋ธ ResNet101e ์„ ํ™œ์šฉํ•˜์—ฌ ๋น„์ •์ƒ๋ฐฐ ๋ถ„๋ฅ˜"
)
with gr.Tab(label="Detection"):
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
input_img = gr.Image(sources=["webcam"], type="numpy", streaming=True)
input_img.stream(
detect,
[input_img],
[input_img],
time_limit=30,
stream_every=0.1,
)
if __name__ == "__main__":
demo.launch()