File size: 2,301 Bytes
89535c7
 
 
 
 
49e4de0
89535c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gradio as gr
from model import Pipeline

pipeline = Pipeline()
labels = {"plastic": "ํ”Œ๋ผ์Šคํ‹ฑ", "glass": "์œ ๋ฆฌ", "metal": "๊ธˆ์†", "paper": "์ข…์ด", "cardboard": "๊ณจํŒ์ง€", "trash": "์“ฐ๋ ˆ๊ธฐ"}
models = {"ResNet152 (Pretrained, latest)": "resnet152_0902", "ResNet152 (Pretrained)": "resnet152_0813", "ResNet50 (Pretrained)": "resnet50_0810"}
is_webcam = False

def predict(model, file_input, webcam_input):
    image = webcam_input if is_webcam else file_input
    if image == None:
        return [None, None, None]
    pred = pipeline.predict_image(models[model], image)
    return [{labels[key]: value for key, value in pred.data.items()}, f"{pred.duration}ms ({pred.duration / 1000}s)", pred.heatmap]

def set_input_type(value):
    global is_webcam
    if value == "ํŒŒ์ผ":
        is_webcam = False
        return [gr.update(visible=True), gr.update(visible=False)]
    else:
        is_webcam = True
        return [gr.update(visible=False), gr.update(visible=True)]

with gr.Blocks(title="๐ŸŒฟ Trash AI") as demo:
    gr.Markdown('<h1 align="center">๐ŸŒฟ Trash AI</h1>')
    gr.Markdown('<p align="center">๋”ฅ๋Ÿฌ๋‹ ๊ธฐ๋ฐ˜ ์“ฐ๋ ˆ๊ธฐ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ๋ฐ๋ชจ</p>')

    with gr.Row():
        with gr.Column():
            model_select = gr.Dropdown(label="๋ชจ๋ธ ์„ ํƒ", choices=list(models.keys()), value=list(models.keys())[0])
            input_select = gr.Dropdown(label="์ž…๋ ฅ ์œ ํ˜•", choices=["ํŒŒ์ผ", "์›น์บ "], value="ํŒŒ์ผ")
            file_input = gr.Image(label="์ž…๋ ฅ ์ด๋ฏธ์ง€", type="pil", source="upload")
            webcam_input = gr.Image(label="์ž…๋ ฅ ์ด๋ฏธ์ง€", type="pil", source="webcam", visible=False)
            with gr.Row():
                classify = gr.Button("๋ถ„๋ฅ˜")
            gr.Examples(["images/cocacola.jpg", "images/samdasoo.jpg", "images/sprite.jpg", "images/box.jpg", "images/tissue.jpg"], file_input)
        with gr.Column():
            output = gr.Label(label="๊ฒฐ๊ณผ")
            duration = gr.Label(label="์†Œ์š” ์‹œ๊ฐ„")
            heatmap = gr.Image(label="ํžˆํŠธ๋งต")

    input_select.change(set_input_type, [input_select], [file_input, webcam_input])
    classify.click(predict, [model_select, file_input, webcam_input], [output, duration, heatmap])

if __name__ == "__main__":
    demo.launch(share=True)