| import gradio as gr | |
| from model import Pipeline | |
| pipeline = Pipeline() | |
| labels = {"plastic": "ํ๋ผ์คํฑ", "glass": "์ ๋ฆฌ", "metal": "๊ธ์", "paper": "์ข ์ด", "cardboard": "๊ณจํ์ง", "trash": "์ฐ๋ ๊ธฐ"} | |
| models = {"ResNet152 (Pretrained, latest)": "resnet152_0902", "ResNet152 (Pretrained)": "resnet152_0813", "ResNet50 (Pretrained)": "resnet50_0810"} | |
| is_webcam = False | |
| def predict(model, file_input, webcam_input): | |
| image = webcam_input if is_webcam else file_input | |
| if image == None: | |
| return [None, None, None] | |
| pred = pipeline.predict_image(models[model], image) | |
| return [{labels[key]: value for key, value in pred.data.items()}, f"{pred.duration}ms ({pred.duration / 1000}s)", pred.heatmap] | |
| def set_input_type(value): | |
| global is_webcam | |
| if value == "ํ์ผ": | |
| is_webcam = False | |
| return [gr.update(visible=True), gr.update(visible=False)] | |
| else: | |
| is_webcam = True | |
| return [gr.update(visible=False), gr.update(visible=True)] | |
| with gr.Blocks(title="๐ฟ Trash AI") as demo: | |
| gr.Markdown('<h1 align="center">๐ฟ Trash AI</h1>') | |
| gr.Markdown('<p align="center">๋ฅ๋ฌ๋ ๊ธฐ๋ฐ ์ฐ๋ ๊ธฐ ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ชจ๋ธ ๋ฐ๋ชจ</p>') | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_select = gr.Dropdown(label="๋ชจ๋ธ ์ ํ", choices=list(models.keys()), value=list(models.keys())[0]) | |
| input_select = gr.Dropdown(label="์ ๋ ฅ ์ ํ", choices=["ํ์ผ", "์น์บ "], value="ํ์ผ") | |
| file_input = gr.Image(label="์ ๋ ฅ ์ด๋ฏธ์ง", type="pil", source="upload") | |
| webcam_input = gr.Image(label="์ ๋ ฅ ์ด๋ฏธ์ง", type="pil", source="webcam", visible=False) | |
| with gr.Row(): | |
| classify = gr.Button("๋ถ๋ฅ") | |
| gr.Examples(["images/cocacola.jpg", "images/samdasoo.jpg", "images/sprite.jpg", "images/box.jpg", "images/tissue.jpg"], file_input) | |
| with gr.Column(): | |
| output = gr.Label(label="๊ฒฐ๊ณผ") | |
| duration = gr.Label(label="์์ ์๊ฐ") | |
| heatmap = gr.Image(label="ํํธ๋งต") | |
| input_select.change(set_input_type, [input_select], [file_input, webcam_input]) | |
| classify.click(predict, [model_select, file_input, webcam_input], [output, duration, heatmap]) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |