File size: 2,685 Bytes
49a7401
 
 
 
1a423d1
 
49a7401
bddc3d6
49a7401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30a5c0f
49a7401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a423d1
bddc3d6
1a423d1
 
 
 
bddc3d6
1a423d1
49a7401
 
 
 
 
 
 
 
 
 
 
bda76e2
30a5c0f
49a7401
bda76e2
49a7401
3f3c078
49a7401
3f3c078
30a5c0f
 
3f3c078
94d6d22
7d51cfd
 
 
49a7401
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from torch import nn
from torchvision import models
from torchvision.transforms import v2
import os
import requests


labels = ['Pastel',
            'Yellow Belly',
            'Enchi',
            'Clown',
            'Leopard',
            'Piebald',
            'Orange Dream',
            'Fire',
            'Mojave',
            'Pinstripe',
            'Banana',
            'Normal',
            'Black Pastel',
            'Lesser',
            'Spotnose',
            'Cinnamon',
            'GHI',
            'Hypo',
            'Spider',
            'Super Pastel']
num_labels = len(labels)

def predict(img, confidence):

    new_layers = nn.Sequential(
    nn.Linear(1920, 1000),  # Reduce dimension from 1024 to 500
    nn.BatchNorm1d(1000),   # Normalize the activations from the previous layer
    nn.ReLU(),             # Non-linear activation function
    nn.Dropout(0.5),       # Dropout for regularization (50% probability)
    nn.Linear(1000, num_labels)  # Final layer for class predictions
    )

    IMAGE_SIZE = 512
    transform = v2.Compose([
    v2.ToImage(),
    v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


    densenet = models.densenet201(weights='DenseNet201_Weights.DEFAULT')
    densenet.classifier = new_layers

    # If using GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Download model from GCS
    model_path = os.getenv('model_path')
    response = requests.get(model_path)

    with open('model.pt', 'wb') as f:
        f.write(response.content)
    
    checkpoint = torch.load('model.pt', map_location=device)
    densenet.load_state_dict(checkpoint['model_state_dict'])

    densenet.eval()

    input_img = transform(img)
    input_img = input_img.unsqueeze(0)


    with torch.no_grad():
        output = densenet(input_img)

    predicted_probs = torch.sigmoid(output).to('cpu').flatten().tolist()
    prediction_dict = {labels[i]: predicted_probs[i] for i in range(len(labels)) if predicted_probs[i] > confidence}

    return prediction_dict

import gradio as gr

gr.Interface(fn=predict,
             inputs=[gr.Image(type="pil"),
                     gr.Slider(0, 1, value=0.5, label="Confidence", info="Show predictions that are above this confidence level")],
             outputs=gr.Label(),
             examples=[["pastel_yb.png", 0.5], ["piebald.png", 0.5], ["leopard_fire.png", 0.5]],
             title='Ball Python Morph Identifier',
             description="Upload or paste an image of your ball python to identify its morphs!"
             ).launch()