GeoGuessrRobot / app.py
robocan's picture
Update app.py
eea9362 verified
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import joblib
import gradio as gr
import plotly.graph_objects as go
from io import BytesIO
from PIL import Image
from torchvision import transforms, models
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from gradio import Interface, Image, Label, HTML
from huggingface_hub import snapshot_download
import torch_xla.utils.serialization as xser
import s2sphere
import folium
local_dir = snapshot_download(
repo_id="robocan/GeoG_23k",
repo_type="model",
local_dir="SVD",
)
device = 'cpu'
le = LabelEncoder()
le = joblib.load("SVD/le.gz")
len_classes = len(le.classes_) + 1
class ModelPre(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Sequential(
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
torch.nn.Flatten(),
torch.nn.Linear(in_features=768, out_features=1024),
torch.nn.ReLU(),
torch.nn.Linear(in_features=1024, out_features=1024),
torch.nn.ReLU(),
torch.nn.Linear(in_features=1024, out_features=len_classes),
)
def forward(self, data):
return self.embedding(data)
# Load the pretrained model
model = ModelPre()
model_w = xser.load("SVD/GeoG.pth")
model.load_state_dict(model_w['model'])
cmp = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=(224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict(input_img):
with torch.inference_mode():
img = cmp(input_img).unsqueeze(0)
res = model(img.to(device))
probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
top_10_indices = np.argsort(probabilities)[-10:][::-1]
top_10_probabilities = probabilities[top_10_indices]
top_10_predictions = le.inverse_transform(top_10_indices)
results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
return results, top_10_predictions
# Function to get S2 cell polygon
def get_s2_cell_polygon(cell_id):
cell = s2sphere.Cell(s2sphere.CellId(cell_id))
vertices = []
for i in range(4):
vertex = s2sphere.LatLng.from_point(cell.get_vertex(i))
vertices.append((vertex.lat().degrees, vertex.lng().degrees))
vertices.append(vertices[0]) # Close the polygon
return vertices
def create_map_figure(predictions, cell_ids, selected_index=None):
fig = go.Figure()
# Assign colors based on rank
colors = ['rgba(0, 255, 0, 0.2)'] * 3 + ['rgba(255, 255, 0, 0.2)'] * 7
zoom_level = 1 # Default zoom level
center_lat = None
center_lon = None
for rank, cell_id in enumerate(cell_ids):
cell_id = int(float(cell_id))
polygon = get_s2_cell_polygon(cell_id)
lats, lons = zip(*polygon)
color = colors[rank]
# Draw S2 cell polygon
fig.add_trace(go.Scattermapbox(
lat=lats,
lon=lons,
mode='lines',
fill='toself',
fillcolor=color,
line=dict(color='blue'),
name=f'Prediction {rank + 1}',
))
# Adjust zoom level if selected prediction is found
if selected_index is not None and rank == selected_index:
zoom_level = 10 # Adjust the zoom level to your liking
center_lat = np.mean(lats)
center_lon = np.mean(lons)
# Update map layout
fig.update_layout(
mapbox_style="open-street-map",
hovermode='closest',
mapbox=dict(
bearing=0,
center=go.layout.mapbox.Center(
lat=center_lat if center_lat else np.mean(lats),
lon=center_lon if center_lon else np.mean(lons)
),
pitch=0,
zoom=zoom_level # Zoom in based on selection
),
)
return fig
# Create label output function
def create_label_output(predictions):
results, cell_ids = predictions
fig = create_map_figure(results, cell_ids)
return fig
def predict_and_plot(input_img, selected_prediction):
predictions = predict(input_img)
# Convert dropdown selection into an index (Prediction 1 corresponds to index 0, etc.)
if selected_prediction is not None:
selected_index = int(selected_prediction.split()[-1]) - 1 # Extract index from "Prediction X"
else:
selected_index = None # No selection, default view
return create_map_figure(predictions, predictions[1], selected_index=selected_index)
# Gradio app definition
with gr.Blocks() as gradio_app:
with gr.Column():
input_image = gr.Image(label="Upload an Image", type="pil")
selected_prediction = gr.Dropdown(
choices=[f"Prediction {i+1}" for i in range(10)],
label="Select Prediction to Zoom",
value="Prediction 1" # Set default to "Prediction 1"
)
output_map = gr.Plot(label="Predicted Location on Map")
btn_predict = gr.Button("Predict")
# Update click function to include selected prediction
btn_predict.click(predict_and_plot, inputs=[input_image, selected_prediction], outputs=output_map)
examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
gr.Examples(examples=examples, inputs=input_image)
gradio_app.launch()