File size: 4,679 Bytes
d7aea57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import json
import os.path
import uuid

from engie_pipeline.models import (
    model_ingredient,
    get_detr,
    get_detr_feature_extractor,
    get_siglip,
    get_siglip_preprocessor,
)

from PIL import Image
from sacred import Experiment
from tqdm import tqdm

from engie_pipeline.utils import draw_boxes

pipeline_experiment = Experiment("pipeline", ingredients=[model_ingredient])


@pipeline_experiment.config
def config():
    path = "data/"
    output_path = "data/output"

    conformity_threshold = .8


@pipeline_experiment.capture
def set_up_pipeline(output_path):
    for labels in ["Conforme", "Non-conforme", "Hors-sujet", "Non-admissible"]:
        if not os.path.isdir(os.path.join(output_path, labels)):
            os.makedirs(os.path.join(output_path, labels), exist_ok=True)

    detr = get_detr()
    detr.eval()
    detr_preprocessor = get_detr_feature_extractor()

    siglip = get_siglip()
    siglip.eval()
    siglip_preprocessor = get_siglip_preprocessor()

    return {
        "detr": detr,
        "detr_preprocessor": detr_preprocessor,
        "siglip": siglip,
        "siglip_preprocessor": siglip_preprocessor,
    }


@pipeline_experiment.capture
def pipeline(
    detr,
    detr_preprocessor,
    siglip,
    siglip_preprocessor,
    image: Image.Image,
    output_path: str,
    conformity_threshold: float,
    force_detr: bool = False,
):
    filename = (
        ".".join(os.path.basename(image.filename).split(".")[:-1])
        if hasattr(image, "filename")
        else str(uuid.uuid4())
    )

    if os.path.isfile(os.path.join(output_path, "Hors-sujet", filename + ".jpg")) or os.path.isfile(os.path.join(output_path, "Non-admissible", filename + ".jpg")) or os.path.isfile(os.path.join(output_path, "Conforme", filename + ".jpg")) or os.path.isfile(os.path.join(output_path, "Non-conforme", filename + ".jpg")):
        return



    siglip_image_input = siglip_preprocessor(image.copy())
    siglip_probs = siglip(siglip_image_input.unsqueeze(0)).softmax(-1)

    conformity = None

    if siglip_probs.argmax() == 1:
        conformity = "Hors-sujet"
        image.save(os.path.join(output_path, "Hors-sujet", filename + ".jpg"))
        if not force_detr:
            with open(os.path.join(output_path, "Hors-sujet", filename + ".json"), "w") as file:
                json.dump(
                    {"classification_probs": siglip_probs.tolist()}, file, indent=4
                )
            return

    if siglip_probs.argmax() == 2:
        conformity = "Non-admissible"
        image.save(os.path.join(output_path, "Non-admissible", filename + ".jpg"))
        if not force_detr:
            with open(os.path.join(output_path, "Non-admissible", filename + ".json"), "w") as file:
                json.dump(
                    {"classification_probs": siglip_probs.tolist()}, file, indent=4
                )
            return

    detr_image_input = detr_preprocessor(image.copy(), return_tensors="pt")
    detr_output = detr(detr_image_input["pixel_values"])
    boxes, labels, scores = detr.process_output(detr_output, image.size)

    if conformity is None:
        conformity = "Conforme" if 2 in labels and scores[labels == 2].max() > conformity_threshold else "Non-conforme"

    image = draw_boxes(
        image=image, boxes=boxes[labels == 0], probs=scores[labels == 0], color="gray"
    )

    image = draw_boxes(
        image=image, boxes=boxes[labels == 1], probs=scores[labels == 1], color="orange"
    )
    image = draw_boxes(
        image=image, boxes=boxes[labels == 2], probs=scores[labels == 2], color="purple"
    )
    image.save(os.path.join(output_path, conformity, filename + ".jpg"))
    with open(os.path.join(output_path, conformity, filename + ".json"), "w") as file:
        json.dump(
            {
                "classification_probs": siglip_probs.tolist(),
                "detection_statistics": {
                    label: {"scores": scores[labels == i].tolist(), "boxes": boxes[labels == i].tolist()}
                    for i, label in enumerate(
                        ["tableau", "disjoncteur", "bouton de test"]
                    )
                },
            },
            file,
            indent=4,
        )

    return siglip_probs, boxes, labels, scores, conformity


@pipeline_experiment.automain
def run(path: str):
    models = set_up_pipeline()
    if os.path.isfile(path):
        pipeline(**models, image=Image.open(path))

    for file in tqdm(os.listdir(path)):
        if not file.lower().endswith(("jpg", "jpeg", "png")):
            continue

        pipeline(**models, image=Image.open(os.path.join(path, file)).convert('RGB'), force_detr=False)