pipeline / engie_pipeline /pipeline.py
nathbotbol's picture
Upload folder using huggingface_hub
d7aea57 verified
raw
history blame contribute delete
No virus
4.68 kB
import json
import os.path
import uuid
from engie_pipeline.models import (
model_ingredient,
get_detr,
get_detr_feature_extractor,
get_siglip,
get_siglip_preprocessor,
)
from PIL import Image
from sacred import Experiment
from tqdm import tqdm
from engie_pipeline.utils import draw_boxes
pipeline_experiment = Experiment("pipeline", ingredients=[model_ingredient])
@pipeline_experiment.config
def config():
path = "data/"
output_path = "data/output"
conformity_threshold = .8
@pipeline_experiment.capture
def set_up_pipeline(output_path):
for labels in ["Conforme", "Non-conforme", "Hors-sujet", "Non-admissible"]:
if not os.path.isdir(os.path.join(output_path, labels)):
os.makedirs(os.path.join(output_path, labels), exist_ok=True)
detr = get_detr()
detr.eval()
detr_preprocessor = get_detr_feature_extractor()
siglip = get_siglip()
siglip.eval()
siglip_preprocessor = get_siglip_preprocessor()
return {
"detr": detr,
"detr_preprocessor": detr_preprocessor,
"siglip": siglip,
"siglip_preprocessor": siglip_preprocessor,
}
@pipeline_experiment.capture
def pipeline(
detr,
detr_preprocessor,
siglip,
siglip_preprocessor,
image: Image.Image,
output_path: str,
conformity_threshold: float,
force_detr: bool = False,
):
filename = (
".".join(os.path.basename(image.filename).split(".")[:-1])
if hasattr(image, "filename")
else str(uuid.uuid4())
)
if os.path.isfile(os.path.join(output_path, "Hors-sujet", filename + ".jpg")) or os.path.isfile(os.path.join(output_path, "Non-admissible", filename + ".jpg")) or os.path.isfile(os.path.join(output_path, "Conforme", filename + ".jpg")) or os.path.isfile(os.path.join(output_path, "Non-conforme", filename + ".jpg")):
return
siglip_image_input = siglip_preprocessor(image.copy())
siglip_probs = siglip(siglip_image_input.unsqueeze(0)).softmax(-1)
conformity = None
if siglip_probs.argmax() == 1:
conformity = "Hors-sujet"
image.save(os.path.join(output_path, "Hors-sujet", filename + ".jpg"))
if not force_detr:
with open(os.path.join(output_path, "Hors-sujet", filename + ".json"), "w") as file:
json.dump(
{"classification_probs": siglip_probs.tolist()}, file, indent=4
)
return
if siglip_probs.argmax() == 2:
conformity = "Non-admissible"
image.save(os.path.join(output_path, "Non-admissible", filename + ".jpg"))
if not force_detr:
with open(os.path.join(output_path, "Non-admissible", filename + ".json"), "w") as file:
json.dump(
{"classification_probs": siglip_probs.tolist()}, file, indent=4
)
return
detr_image_input = detr_preprocessor(image.copy(), return_tensors="pt")
detr_output = detr(detr_image_input["pixel_values"])
boxes, labels, scores = detr.process_output(detr_output, image.size)
if conformity is None:
conformity = "Conforme" if 2 in labels and scores[labels == 2].max() > conformity_threshold else "Non-conforme"
image = draw_boxes(
image=image, boxes=boxes[labels == 0], probs=scores[labels == 0], color="gray"
)
image = draw_boxes(
image=image, boxes=boxes[labels == 1], probs=scores[labels == 1], color="orange"
)
image = draw_boxes(
image=image, boxes=boxes[labels == 2], probs=scores[labels == 2], color="purple"
)
image.save(os.path.join(output_path, conformity, filename + ".jpg"))
with open(os.path.join(output_path, conformity, filename + ".json"), "w") as file:
json.dump(
{
"classification_probs": siglip_probs.tolist(),
"detection_statistics": {
label: {"scores": scores[labels == i].tolist(), "boxes": boxes[labels == i].tolist()}
for i, label in enumerate(
["tableau", "disjoncteur", "bouton de test"]
)
},
},
file,
indent=4,
)
return siglip_probs, boxes, labels, scores, conformity
@pipeline_experiment.automain
def run(path: str):
models = set_up_pipeline()
if os.path.isfile(path):
pipeline(**models, image=Image.open(path))
for file in tqdm(os.listdir(path)):
if not file.lower().endswith(("jpg", "jpeg", "png")):
continue
pipeline(**models, image=Image.open(os.path.join(path, file)).convert('RGB'), force_detr=False)