File size: 1,743 Bytes
fd2ef9c
 
 
 
 
 
 
 
 
 
 
 
5ed6ee0
fd2ef9c
 
5ed6ee0
fd2ef9c
 
 
 
 
 
5ed6ee0
fd2ef9c
 
 
 
5ed6ee0
 
 
 
fd2ef9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import pathlib
import gradio as gr
from loguru import logger
from transformers import AutoFeatureExtractor, AutoModelForImageClassification

logger.info("starting gradio app")

CURRENT_DIR = pathlib.Path(__file__).resolve().parent
APP_NAME = "Mona Lisa Detection"

logger.debug("loading processor and model.")
processor = AutoFeatureExtractor.from_pretrained(
    "drift-ai/autotrain-mona-lisa-detection-38345101350", use_auth_token=True
)
model = AutoModelForImageClassification.from_pretrained(
    "drift-ai/autotrain-mona-lisa-detection-38345101350", use_auth_token=True
)
logger.debug("loading processor and model succeeded.")


def process_image(image, model=model, processor=processor):
    logger.info("Making a prediction ...")

    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()

    label = {1: "Not Mona Lisa", 0: "Mona Lisa"}
    predictions = logits.softmax(dim=-1).tolist()
    result = {label[predicted_class_idx]: predictions[0][predicted_class_idx]}
    print("Predicted class:", result)
    logger.info("Prediction finished.")
    return result


examples = [
    "mona-lisa-1.jpg",
    "mona-lisa-2.jpg",
    "mona-lisa-3.jpg",
    "not-mona-lisa-1.jpg",
    "not-mona-lisa-2.jpg",
    "not-mona-lisa-3.jpg",
]

if __name__ == "__main__":
    title = """
        Mona Lisa Detection.
"""
    app = gr.Interface(
        fn=process_image,
        inputs=[
            gr.inputs.Image(type="pil", label="Image"),
        ],
        outputs=gr.Label(label="Predictions:", show_label=True),
        examples=examples,
        examples_per_page=32,
        title=title,
        enable_queue=True,
    ).launch()