import os import torch from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np import joblib from PIL import Image from torchvision import transforms,models from sklearn.preprocessing import LabelEncoder from gradio import Interface, Image, Label from huggingface_hub import snapshot_download # Retrieve the token from the environment variables token = os.environ.get("token") # Download the repository snapshot local_dir = snapshot_download( repo_id="robocan/GeoG_City", repo_type="model", local_dir="SVD", token=token ) device = 'cpu' le = LabelEncoder() le = joblib.load("SVD/le.gz") len_classes = len(le.classes_) + 1 class ModelPre(torch.nn.Module): def __init__(self): super().__init__() self.embedding = torch.nn.Sequential( *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1], torch.nn.Flatten(), torch.nn.Linear(in_features=768,out_features=512), torch.nn.ReLU(), torch.nn.Linear(in_features=512,out_features=len_classes), ) def forward(self, data): return self.embedding(data) model = torch.load("SVD/GeoG.pth", map_location=torch.device(device)) modelm = ModelPre() modelm.load_state_dict(model['model']) cmp = transforms.Compose([ transforms.ToTensor(), transforms.Resize(size=(224, 224), antialias=True), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def predict(input_img): with torch.inference_mode(): img = cmp(input_img).unsqueeze(0) res = modelm(img.to(device)) probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten() top_10_indices = np.argsort(probabilities)[-10:][::-1] top_10_probabilities = probabilities[top_10_indices] top_10_predictions = le.inverse_transform(top_10_indices) results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)} return results def create_label_output(predictions): return predictions def predict_and_plot(input_img): predictions = predict(input_img) return create_label_output(predictions) gradio_app = Interface( fn=predict_and_plot, inputs=Image(label="Upload an Image", type="pil"), outputs=Label(num_top_classes=10), title="Predict the Location of this Image" ) if __name__ == "__main__": gradio_app.launch()