File size: 2,301 Bytes
89535c7 49e4de0 89535c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import gradio as gr
from model import Pipeline
pipeline = Pipeline()
labels = {"plastic": "ํ๋ผ์คํฑ", "glass": "์ ๋ฆฌ", "metal": "๊ธ์", "paper": "์ข
์ด", "cardboard": "๊ณจํ์ง", "trash": "์ฐ๋ ๊ธฐ"}
models = {"ResNet152 (Pretrained, latest)": "resnet152_0902", "ResNet152 (Pretrained)": "resnet152_0813", "ResNet50 (Pretrained)": "resnet50_0810"}
is_webcam = False
def predict(model, file_input, webcam_input):
image = webcam_input if is_webcam else file_input
if image == None:
return [None, None, None]
pred = pipeline.predict_image(models[model], image)
return [{labels[key]: value for key, value in pred.data.items()}, f"{pred.duration}ms ({pred.duration / 1000}s)", pred.heatmap]
def set_input_type(value):
global is_webcam
if value == "ํ์ผ":
is_webcam = False
return [gr.update(visible=True), gr.update(visible=False)]
else:
is_webcam = True
return [gr.update(visible=False), gr.update(visible=True)]
with gr.Blocks(title="๐ฟ Trash AI") as demo:
gr.Markdown('<h1 align="center">๐ฟ Trash AI</h1>')
gr.Markdown('<p align="center">๋ฅ๋ฌ๋ ๊ธฐ๋ฐ ์ฐ๋ ๊ธฐ ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ชจ๋ธ ๋ฐ๋ชจ</p>')
with gr.Row():
with gr.Column():
model_select = gr.Dropdown(label="๋ชจ๋ธ ์ ํ", choices=list(models.keys()), value=list(models.keys())[0])
input_select = gr.Dropdown(label="์
๋ ฅ ์ ํ", choices=["ํ์ผ", "์น์บ "], value="ํ์ผ")
file_input = gr.Image(label="์
๋ ฅ ์ด๋ฏธ์ง", type="pil", source="upload")
webcam_input = gr.Image(label="์
๋ ฅ ์ด๋ฏธ์ง", type="pil", source="webcam", visible=False)
with gr.Row():
classify = gr.Button("๋ถ๋ฅ")
gr.Examples(["images/cocacola.jpg", "images/samdasoo.jpg", "images/sprite.jpg", "images/box.jpg", "images/tissue.jpg"], file_input)
with gr.Column():
output = gr.Label(label="๊ฒฐ๊ณผ")
duration = gr.Label(label="์์ ์๊ฐ")
heatmap = gr.Image(label="ํํธ๋งต")
input_select.change(set_input_type, [input_select], [file_input, webcam_input])
classify.click(predict, [model_select, file_input, webcam_input], [output, duration, heatmap])
if __name__ == "__main__":
demo.launch(share=True) |