File size: 1,843 Bytes
f1c6cd4
 
 
 
 
 
 
b3836ea
f1c6cd4
 
ec9093c
f1c6cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3836ea
ec9093c
 
b3836ea
 
 
 
 
 
 
 
 
f1c6cd4
 
b3836ea
f1c6cd4
b3836ea
 
 
 
 
f1c6cd4
 
b3836ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os

import torch
from PIL import Image
import torchvision.transforms as T

import gradio as gr
from gradio import themes

# Define the model path
model_path = os.path.join("ml", "models", "model.pt")

# Determine the device to use (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
model = torch.jit.load(model_path)
model.eval()  # Set to evaluation mode

# Depending on the device, load the model
model = model.to(device)


# Define the transformation
transform = T.Compose(
    [
        T.Resize(224),
        T.CenterCrop(224),
        T.ToTensor(),  # Converts to [C, H, W] with values in [0, 1]
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet mean  # ImageNet std
    ]
)


def cls_helper(label):
    if label == 0:
        return "Clear sky"
    elif label == 1:
        return "Cloudy"
    elif label == 2:
        return "Haze"
    else:
        return "Unknown"


def predict(image: Image.Image):
    img = image.convert("RGB")
    tensor = transform(img).unsqueeze(0)  # [1, 3, 224, 224]

    with torch.no_grad():
        output = model(tensor)
        pred_idx = torch.argmax(output, dim=1).item()
        pred_class = cls_helper(pred_idx)

    return pred_class


examples = [
    "ml/data/train_11890.jpg",
    "ml/data/train_11716.jpg",
]

theme = gr.Theme(
    primary_hue="blue",
    secondary_hue="blue",
    font="Arial",
    font_mono="Courier New",
)

interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", height=350),
    outputs=["text"],
    examples=examples,
    title="Weather Condition Classifier",
    description="Upload an image to classify the weather condition as Clear sky, Cloudy, or Haze.",
    preload_example=0,
    theme=themes.Base(),
)

interface.launch(debug=True)