import os import torch from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np import joblib import gradio as gr import plotly.graph_objects as go from io import BytesIO from PIL import Image from torchvision import transforms, models from sklearn.preprocessing import LabelEncoder, MinMaxScaler from gradio import Interface, Image, Label, HTML from huggingface_hub import snapshot_download import s2sphere import folium token = os.environ.get("token") local_dir = snapshot_download( repo_id="robocan/GeoG-GCP", repo_type="model", local_dir="SVD", token=token ) device = 'cpu' le = LabelEncoder() le = joblib.load("SVD/le.gz") len_classes = len(le.classes_) + 1 class ModelPre(torch.nn.Module): def __init__(self): super().__init__() self.embedding = torch.nn.Sequential( *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1], torch.nn.Flatten(), torch.nn.Linear(in_features=768, out_features=512), torch.nn.ReLU(), torch.nn.Linear(in_features=512, out_features=len_classes), ) def forward(self, data): return self.embedding(data) # Load the pretrained model model = ModelPre() model_w = torch.load("SVD/GeoG.pth", map_location=torch.device(device)) model.load_state_dict(model_w['model']) cmp = transforms.Compose([ transforms.ToTensor(), transforms.Resize(size=(224, 224), antialias=True), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def predict(input_img): with torch.inference_mode(): img = cmp(input_img).unsqueeze(0) res = model(img.to(device)) probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten() top_10_indices = np.argsort(probabilities)[-10:][::-1] top_10_probabilities = probabilities[top_10_indices] top_10_predictions = le.inverse_transform(top_10_indices) results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)} return results, top_10_predictions # Function to get S2 cell polygon def get_s2_cell_polygon(cell_id): cell = s2sphere.Cell(s2sphere.CellId(cell_id)) vertices = [] for i in range(4): vertex = s2sphere.LatLng.from_point(cell.get_vertex(i)) vertices.append((vertex.lat().degrees, vertex.lng().degrees)) vertices.append(vertices[0]) # Close the polygon return vertices def create_map_figure(predictions, cell_ids): fig = go.Figure() # Assign colors based on rank colors = ['rgba(0, 255, 0, 0.2)'] * 3 + ['rgba(255, 255, 0, 0.2)'] * 7 for rank, cell_id in enumerate(cell_ids): cell_id = int(cell_id) polygon = get_s2_cell_polygon(cell_id) lats, lons = zip(*polygon) color = colors[rank] fig.add_trace(go.Scattermapbox( lat=lats, lon=lons, mode='lines', fill='toself', fillcolor=color, line=dict(color='blue'), name=f'Cell ID: {cell_id}' )) fig.update_layout( mapbox_style="open-street-map", hovermode='closest', mapbox=dict( bearing=0, center=go.layout.mapbox.Center( lat=np.mean(lats), lon=np.mean(lons) ), pitch=0, zoom=1 ), ) return fig # Create label output function def create_label_output(predictions): results, cell_ids = predictions fig = create_map_figure(results, cell_ids) return fig # Predict and plot function def predict_and_plot(input_img): predictions = predict(input_img) return create_label_output(predictions) # Gradio app definition with gr.Blocks() as gradio_app: with gr.Column(): input_image = gr.Image(label="Upload an Image", type="pil") output_map = gr.Plot(label="Predicted Location on Map") btn_predict = gr.Button("Predict") btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map) examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"] gr.Examples(examples=examples, inputs=input_image) gradio_app.launch()