File size: 1,804 Bytes
d7aea57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os

import gradio as gr
import pandas as pd
from PIL import ImageDraw
from PIL.Image import Image
from sacred import Experiment

from engie_pipeline.pipeline import pipeline_experiment, set_up_pipeline, pipeline
from engie_pipeline.utils import draw_boxes

ex = Experiment("app", ingredients=[pipeline_experiment])


def process_image(models, image: Image, comformity_threshold: float):
    image = image.convert("RGB")
    siglip_probs, boxes, labels, scores, comformity = pipeline(**models, image=image, conformity_threshold=comformity_threshold, force_detr=True)

    image = draw_boxes(
        image=image, boxes=boxes[labels == 0], probs=scores[labels == 0], color="gray"
    )

    image = draw_boxes(
        image=image, boxes=boxes[labels == 1], probs=scores[labels == 1], color="orange"
    )
    image = draw_boxes(
        image=image, boxes=boxes[labels == 2], probs=scores[labels == 2], color="purple"
    )

    siglip_probs = pd.DataFrame(
        siglip_probs.detach().numpy(),
        columns=["Admissible", "Hors-sujet", "Non-admissible"],
    )
    return siglip_probs, image, comformity


@ex.automain
def app():
    models = set_up_pipeline()

    image_input = gr.Image(label="Input image", type="pil")
    comformity_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.8)
    prob_output = gr.DataFrame(label="Probabilities (%)")
    image_output = gr.Image(label="Output image")
    label = gr.Label(label="Is the image conform?")

    demo = gr.Interface(
        fn=lambda im, thresh: process_image(models=models, image=im, comformity_threshold=thresh),
        inputs=[image_input, comformity_threshold],
        outputs=[prob_output, image_output, label],
    )

    demo.launch(auth=(os.environ.get("GRADIO_USERNAME"), os.environ.get("GRADIO_PASSWORD")))