File size: 1,838 Bytes
fff6fcf
 
481fb54
fff6fcf
ddd296c
 
 
 
481fb54
120e140
481fb54
fff6fcf
 
 
 
 
 
 
 
ddd296c
 
 
 
 
 
 
 
 
 
481fb54
fff6fcf
 
481fb54
ddd296c
 
 
 
 
 
 
120e140
 
 
 
 
 
ddd296c
 
 
 
 
 
120e140
 
 
 
 
161157e
120e140
ddd296c
 
 
 
 
 
 
120e140
 
 
161157e
ddd296c
 
58361e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import io

import gradio as gr
import requests
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image

from constants import MAKES_MODELS, PRICE_BIN_LABELS, YEARS

print("downloading checkpoint...")
data = requests.get(
    "https://data.aqnichol.com/car-data/models/mobilenetv2_432000_calib_torchscript.pt",
    stream=True,
).content

print("creating model...")
model = torch.jit.load(io.BytesIO(data))
model.eval()
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            (0.48145466, 0.4578275, 0.40821073),
            (0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)

print("done.")


def classify(img: Image.Image):
    in_tensor = transform(img)[None]
    outputs = model(in_tensor)

    price_bins = dict(
        zip(PRICE_BIN_LABELS, F.softmax(outputs["price_bin"], dim=-1)[0].tolist())
    )
    years = dict(
        zip(
            [str(year) for year in YEARS] + ["Unknown"],
            F.softmax(outputs["year"], dim=-1)[0].tolist(),
        )
    )
    make_models = dict(
        zip(
            ([f"{make} {model}" for make, model in MAKES_MODELS] + ["Unknown"]),
            F.softmax(outputs["make_model"], dim=-1)[0].tolist(),
        )
    )
    return (
        f"${int(round(outputs['price_median'].item()))}",
        price_bins,
        years,
        make_models,
        img,
    )


iface = gr.Interface(
    fn=classify,
    inputs=gr.Image(shape=(224, 224), type="pil"),
    outputs=[
        gr.Text(label="Price Prediction"),
        gr.Label(label="Price Bin", num_top_classes=5),
        gr.Label(label="Year", num_top_classes=5),
        gr.Label(label="Make/Model", num_top_classes=10),
        gr.Image(label="Cropped Input"),
    ],
)
iface.queue(concurrency_count=2).launch()