File size: 2,459 Bytes
9b889da
afe2deb
 
8a4658f
afe2deb
 
 
bde2d24
afe2deb
 
dbf177b
9b889da
 
 
 
dbf177b
 
 
db673be
dbf177b
db673be
 
9b889da
955fc23
 
2a50088
4184b6d
955fc23
 
 
 
 
be60ccb
955fc23
be60ccb
 
dccd8f9
955fc23
 
 
 
 
2a50088
955fc23
 
 
 
 
 
 
 
 
6d49cf1
 
955fc23
 
 
9678900
 
 
 
 
 
 
6d49cf1
afe2deb
 
8a4658f
 
 
afe2deb
 
 
8a4658f
afe2deb
 
955fc23
6d49cf1
 
 
9678900
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import joblib
from PIL import Image
from torchvision import transforms,models
from sklearn.preprocessing import LabelEncoder
from gradio import Interface, Image, Label
from huggingface_hub import snapshot_download

# Retrieve the token from the environment variables
token = os.environ.get("token")

# Download the repository snapshot
local_dir = snapshot_download(
    repo_id="robocan/GeoG_City",
    repo_type="model",
    local_dir="SVD",
    token=token
)

device = 'cpu'
le = LabelEncoder()
le = joblib.load("SVD/le.gz")
len_classes = len(le.classes_) + 1

class ModelPre(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Sequential(
            *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
            torch.nn.Flatten(),
            torch.nn.Linear(in_features=768,out_features=512),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=512,out_features=len_classes),
        )

    def forward(self, data):
        return self.embedding(data)

model = torch.load("SVD/GeoG.pth", map_location=torch.device(device))

modelm = ModelPre()
modelm.load_state_dict(model['model'])

cmp = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=(224, 224), antialias=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def predict(input_img):
    with torch.inference_mode():
        img = cmp(input_img).unsqueeze(0)
        res = modelm(img.to(device))
        probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
        top_10_indices = np.argsort(probabilities)[-10:][::-1]
        top_10_probabilities = probabilities[top_10_indices]
        top_10_predictions = le.inverse_transform(top_10_indices)
        
        results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
        return results

def create_label_output(predictions):
    return predictions

def predict_and_plot(input_img):
    predictions = predict(input_img)
    return create_label_output(predictions)

gradio_app = Interface(
    fn=predict_and_plot,
    inputs=Image(label="Upload an Image", type="pil"),
    outputs=Label(num_top_classes=10),
    title="Predict the Location of this Image"
)

if __name__ == "__main__":
    gradio_app.launch()