import os import torch from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np import joblib import gradio as gr import plotly.graph_objects as go from io import BytesIO from PIL import Image from torchvision import transforms, models from sklearn.preprocessing import LabelEncoder, MinMaxScaler from gradio import Interface, Image, Label, HTML from huggingface_hub import snapshot_download import torch_xla.utils.serialization as xser import s2sphere import folium token = os.environ.get("token") local_dir = snapshot_download( repo_id="robocan/GeoG_23k", repo_type="model", local_dir="SVD", token=token ) device = 'cpu' le = LabelEncoder() le = joblib.load("SVD/le.gz") len_classes = len(le.classes_) + 1 # Global variable to store predictions for dynamic zoom global_predictions = None class ModelPre(torch.nn.Module): def __init__(self): super().__init__() self.embedding = torch.nn.Sequential( *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1], torch.nn.Flatten(), torch.nn.Linear(in_features=768, out_features=1024), torch.nn.ReLU(), torch.nn.Linear(in_features=1024, out_features=1024), torch.nn.ReLU(), torch.nn.Linear(in_features=1024, out_features=len_classes), ) def forward(self, data): return self.embedding(data) # Load the pretrained model model = ModelPre() model_w = xser.load("SVD/GeoG.pth") model.load_state_dict(model_w['model']) cmp = transforms.Compose([ transforms.ToTensor(), transforms.Resize(size=(224, 224), antialias=True), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def predict(input_img): global global_predictions results, cell_ids = predict(input_img) global_predictions = (results, cell_ids) # Store predictions globally for zoom functionality return create_map_figure(global_predictions, global_predictions[1]) def zoom_on_prediction(selected_prediction): global global_predictions if global_predictions is None: return None # No prediction made yet # Convert dropdown selection into an index (Prediction 1 corresponds to index 0, etc.) if selected_prediction is not None: selected_index = int(selected_prediction.split()[-1]) - 1 # Extract index from "Prediction X" else: selected_index = None # No selection, default view # Return the updated map with zoom return create_map_figure(global_predictions, global_predictions[1], selected_index=selected_index) # Function to get S2 cell polygon def get_s2_cell_polygon(cell_id): cell = s2sphere.Cell(s2sphere.CellId(cell_id)) vertices = [] for i in range(4): vertex = s2sphere.LatLng.from_point(cell.get_vertex(i)) vertices.append((vertex.lat().degrees, vertex.lng().degrees)) vertices.append(vertices[0]) # Close the polygon return vertices def create_map_figure(predictions, cell_ids, selected_index=None): fig = go.Figure() # Assign colors based on rank colors = ['rgba(0, 255, 0, 0.2)'] * 3 + ['rgba(255, 255, 0, 0.2)'] * 7 zoom_level = 1 # Default zoom level center_lat = None center_lon = None for rank, cell_id in enumerate(cell_ids): cell_id = int(float(cell_id)) polygon = get_s2_cell_polygon(cell_id) lats, lons = zip(*polygon) color = colors[rank] # Draw S2 cell polygon fig.add_trace(go.Scattermapbox( lat=lats, lon=lons, mode='lines', fill='toself', fillcolor=color, line=dict(color='blue'), name=f'Prediction {rank + 1}', )) # Adjust zoom level if selected prediction is found if selected_index is not None and rank == selected_index: zoom_level = 10 # Adjust the zoom level to your liking center_lat = np.mean(lats) center_lon = np.mean(lons) # Update map layout fig.update_layout( mapbox_style="open-street-map", hovermode='closest', mapbox=dict( bearing=0, center=go.layout.mapbox.Center( lat=center_lat if center_lat else np.mean(lats), lon=center_lon if center_lon else np.mean(lons) ), pitch=0, zoom=zoom_level # Zoom in based on selection ), ) return fig # Create label output function def create_label_output(predictions): results, cell_ids = predictions fig = create_map_figure(results, cell_ids) return fig def predict_and_plot(input_img, selected_prediction): predictions = predict(input_img) # Convert dropdown selection into an index (Prediction 1 corresponds to index 0, etc.) if selected_prediction is not None: selected_index = int(selected_prediction.split()[-1]) - 1 # Extract index from "Prediction X" else: selected_index = None # No selection, default view return create_map_figure(predictions, predictions[1], selected_index=selected_index) # Gradio app definition with gr.Blocks() as gradio_app: with gr.Column(): input_image = gr.Image(label="Upload an Image", type="pil") output_map = gr.Plot(label="Predicted Location on Map") btn_predict = gr.Button("Predict") selected_prediction = gr.Dropdown(choices=[f"Prediction {i+1}" for i in range(10)], label="Select Prediction to Zoom", value=None) # Perform the prediction and plot the initial map btn_predict.click(predict, inputs=input_image, outputs=output_map) # Allow the user to zoom in on a selected prediction after the prediction is made selected_prediction.change(zoom_on_prediction, inputs=selected_prediction, outputs=output_map) examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"] gr.Examples(examples=examples, inputs=input_image) gradio_app.launch()