GeoGuessrRobot / app.py
robocan's picture
Update app.py
14085cb verified
raw
history blame
4.27 kB
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import joblib
import gradio as gr
import plotly.graph_objects as go
from io import BytesIO
from PIL import Image
from torchvision import transforms, models
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from gradio import Interface, Image, Label, HTML
from huggingface_hub import snapshot_download
import s2sphere
import folium
# Retrieve the token from the environment variables
token = os.environ.get("token")
# Download the repository snapshot
local_dir = snapshot_download(
repo_id="robocan/GeoG-GCP",
repo_type="model",
local_dir="SVD",
token=token
)
device = 'cpu'
le = LabelEncoder()
le = joblib.load("SVD/le.gz")
len_classes = len(le.classes_) + 1
class ModelPre(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Sequential(
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
torch.nn.Flatten(),
torch.nn.Linear(in_features=768, out_features=512),
torch.nn.ReLU(),
torch.nn.Linear(in_features=512, out_features=len_classes),
)
# Freeze all layers
def forward(self, data):
return self.embedding(data)
# Load the pretrained model
model = ModelPre()
model_w = torch.load("SVD/GeoG.pth", map_location=torch.device(device))
model.load_state_dict(model_w['model'])
cmp = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=(224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict(input_img):
with torch.inference_mode():
img = cmp(input_img).unsqueeze(0)
res = model(img.to(device))
probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
top_10_indices = np.argsort(probabilities)[-10:][::-1]
top_10_probabilities = probabilities[top_10_indices]
top_10_predictions = le.inverse_transform(top_10_indices)
results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
return results, top_10_predictions
# Function to get S2 cell polygon
def get_s2_cell_polygon(cell_id):
cell = s2sphere.Cell(s2sphere.CellId(cell_id))
vertices = []
for i in range(4):
vertex = s2sphere.LatLng.from_point(cell.get_vertex(i))
vertices.append((vertex.lat().degrees, vertex.lng().degrees))
vertices.append(vertices[0]) # Close the polygon
return vertices
# Function to generate Plotly map figure
def create_map_figure(predictions, cell_ids):
fig = go.Figure()
for cell_id in cell_ids:
cell_id = int(cell_id)
polygon = get_s2_cell_polygon(cell_id)
lats, lons = zip(*polygon)
fig.add_trace(go.Scattermapbox(
lat=lats,
lon=lons,
mode='lines',
fill='toself',
fillcolor='rgba(0, 0, 255, 0.2)',
line=dict(color='blue'),
name=f'Cell ID: {cell_id}'
))
fig.update_layout(
mapbox_style="open-street-map",
hovermode='closest',
mapbox=dict(
bearing=0,
center=go.layout.mapbox.Center(
lat=np.mean(lats),
lon=np.mean(lons)
),
pitch=0,
zoom=3
),
)
return fig
# Create label output function
def create_label_output(predictions):
results, cell_ids = predictions
fig = create_map_figure(results, cell_ids)
return fig
# Predict and plot function
def predict_and_plot(input_img):
predictions = predict(input_img)
return create_label_output(predictions)
# Gradio app definition
with gr.Blocks() as gradio_app:
with gr.Column():
input_image = gr.Image(label="Upload an Image", type="pil")
output_map = gr.Plot(label="Predicted Location on Map")
btn_predict = gr.Button("Predict")
btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map)
examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"]
gr.Examples(examples=examples, inputs=input_image)
gradio_app.launch()