File size: 3,155 Bytes
9b889da
db673be
9678900
8a4658f
2355c91
9b889da
 
 
 
db673be
 
 
963d228
db673be
 
 
9b889da
955fc23
 
 
 
 
 
 
 
 
 
 
 
 
6d49cf1
 
955fc23
 
2a50088
4184b6d
955fc23
 
 
 
 
be60ccb
955fc23
be60ccb
 
dccd8f9
955fc23
dccd8f9
955fc23
 
 
 
2a50088
955fc23
 
 
 
 
 
 
 
 
 
 
 
6d49cf1
 
955fc23
 
 
9678900
 
 
 
 
 
 
6d49cf1
8a4658f
 
35a87f4
c48e224
 
 
 
 
 
 
46bc924
c48e224
 
8a4658f
 
 
 
2355c91
8a4658f
6d49cf1
8a4658f
955fc23
8a4658f
955fc23
6d49cf1
 
 
9678900
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
from huggingface_hub import Repository
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Retrieve the token from the environment variables
token = os.environ.get("token")

repo = Repository(
    local_dir="SVD",
    repo_type="model",
    clone_from="robocan/GeoG_City",
    token=token
)
repo.git_pull()

import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import io
import joblib
import requests
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torchvision import models
import gradio as gr

device = 'cpu'
le = LabelEncoder()
le = joblib.load("SVD/le.gz")
len_classes = len(le.classes_) + 1

class ModelPre(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Sequential(
            *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
            torch.nn.Flatten(),
            torch.nn.Linear(in_features=768,out_features=512),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=512,out_features=len_classes),
        )
        # Freeze all layers

    def forward(self, data):
        return self.embedding(data)

model = torch.load("SVD/GeoG.pth", map_location=torch.device(device))

modelm = ModelPre()
modelm.load_state_dict(model['model'])

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning, module="multiprocessing.popen_fork")

cmp = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=(224, 224), antialias=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def predict(input_img):
    with torch.inference_mode():
        img = cmp(input_img).unsqueeze(0)
        res = modelm(img.to(device))
        probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
        top_10_indices = np.argsort(probabilities)[-10:][::-1]
        top_10_probabilities = probabilities[top_10_indices]
        top_10_predictions = le.inverse_transform(top_10_indices)
        
        results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
        return results

def create_bar_plot(predictions):
    data = pd.DataFrame(list(predictions.items()), columns=["Location", "Probability"])
    max_prob = data["Probability"].max()
    return gr.BarPlot(
        data,
        x="Location",
        y="Probability",
        title="Top 10 Predictions with Probabilities",
        tooltip=["Location", "Probability"],
        y_lim=[0, max_prob],
        width=800,  # Set the width of the plot
        height=600  # Set the height of the plot
    )

def predict_and_plot(input_img):
    predictions = predict(input_img)
    return create_bar_plot(predictions)
    
    
gradio_app = gr.Interface(
    fn=predict_and_plot,
    inputs=gr.Image(label="Upload an Image", type="pil"),
    outputs=gr.BarPlot(),
    title="Predict the Location of this Image"
)

if __name__ == "__main__":
    gradio_app.launch()