File size: 2,234 Bytes
64b550a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74d38e8
64b550a
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import pandas as pd
import torch
import torch.nn.functional as F
from detect import detect
from huggingface_hub import hf_hub_download
from torchvision.transforms import Compose, Normalize, Resize, ToTensor
from transformers.models.auto.modeling_auto import \
    AutoModelForImageClassification


def run(image, auto_crop):
    if auto_crop:
        image = detect(image)

    # Preprocess image
    transforms = Compose(
        [
            Resize((224, 224)),
            ToTensor(),
            Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ]
    )
    image = transforms(image).unsqueeze(0)

    # Pass through model
    prediction = F.softmax(model(pixel_values=image).logits[0], dim=0)
    confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))}

    # Denormalize image
    image.clamp_(min=float(image.min()), max=float(image.max()))
    image.add_(-float(image.min())).div_(float(image.max()) - float(image.min()) + 1e-5)
    image = image.squeeze(0).permute(1, 2, 0).numpy()

    return confidences, image


# Load model
ckpt_path = hf_hub_download(
    "bwconrad/beit-base-patch16-224-pt22k-ft22k-dafre",
    "beit-base-patch16-224-pt22k-ft22k-dafre.ckpt",
    use_auth_token=True,
)
ckpt = torch.load(ckpt_path)["state_dict"]

model = AutoModelForImageClassification.from_pretrained(
    "microsoft/beit-base-patch16-224-pt22k-ft22k",
    num_labels=3263,
    ignore_mismatched_sizes=True,
    image_size=224,
)

# Remove prefix from key names
new_state_dict = {}
for k, v in ckpt.items():
    if k.startswith("net"):
        k = k.replace("net" + ".", "")
        new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=True)

# Load label names
labels = pd.read_csv("classid_classname.csv", names=["id", "name"])["name"].tolist()
labels = [l.replace("_", " ").title() for l in labels]  # Remove _ and capitalize

# Run app
description = """ """

app = gr.Interface(
    title="Classification Model",
    description=description,
    fn=run,
    inputs=[gr.Image(type="pil", tool="select"), gr.Checkbox(label="auto_crop")],
    outputs=[gr.Label(num_top_classes=5), gr.Image().style(height=224, width=224)],
    allow_flagging="never",
)
app.launch()