Spaces:
Running
Running
File size: 3,155 Bytes
9b889da db673be 9678900 8a4658f 2355c91 9b889da db673be 963d228 db673be 9b889da 955fc23 6d49cf1 955fc23 2a50088 4184b6d 955fc23 be60ccb 955fc23 be60ccb dccd8f9 955fc23 dccd8f9 955fc23 2a50088 955fc23 6d49cf1 955fc23 9678900 6d49cf1 8a4658f 35a87f4 c48e224 46bc924 c48e224 8a4658f 2355c91 8a4658f 6d49cf1 8a4658f 955fc23 8a4658f 955fc23 6d49cf1 9678900 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
from huggingface_hub import Repository
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Retrieve the token from the environment variables
token = os.environ.get("token")
repo = Repository(
local_dir="SVD",
repo_type="model",
clone_from="robocan/GeoG_City",
token=token
)
repo.git_pull()
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import io
import joblib
import requests
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torchvision import models
import gradio as gr
device = 'cpu'
le = LabelEncoder()
le = joblib.load("SVD/le.gz")
len_classes = len(le.classes_) + 1
class ModelPre(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Sequential(
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
torch.nn.Flatten(),
torch.nn.Linear(in_features=768,out_features=512),
torch.nn.ReLU(),
torch.nn.Linear(in_features=512,out_features=len_classes),
)
# Freeze all layers
def forward(self, data):
return self.embedding(data)
model = torch.load("SVD/GeoG.pth", map_location=torch.device(device))
modelm = ModelPre()
modelm.load_state_dict(model['model'])
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning, module="multiprocessing.popen_fork")
cmp = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=(224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict(input_img):
with torch.inference_mode():
img = cmp(input_img).unsqueeze(0)
res = modelm(img.to(device))
probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
top_10_indices = np.argsort(probabilities)[-10:][::-1]
top_10_probabilities = probabilities[top_10_indices]
top_10_predictions = le.inverse_transform(top_10_indices)
results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
return results
def create_bar_plot(predictions):
data = pd.DataFrame(list(predictions.items()), columns=["Location", "Probability"])
max_prob = data["Probability"].max()
return gr.BarPlot(
data,
x="Location",
y="Probability",
title="Top 10 Predictions with Probabilities",
tooltip=["Location", "Probability"],
y_lim=[0, max_prob],
width=800, # Set the width of the plot
height=600 # Set the height of the plot
)
def predict_and_plot(input_img):
predictions = predict(input_img)
return create_bar_plot(predictions)
gradio_app = gr.Interface(
fn=predict_and_plot,
inputs=gr.Image(label="Upload an Image", type="pil"),
outputs=gr.BarPlot(),
title="Predict the Location of this Image"
)
if __name__ == "__main__":
gradio_app.launch() |