Spaces:
Sleeping
Sleeping
import os.path | |
import numpy as np | |
import gradio as gr | |
import plotly.graph_objects as go | |
from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, NearestNeighborEmbedderGuessr, \ | |
AverageNeighborsEmbedderGuessr | |
from geoguessr_bot.retriever import DinoV2Embedder, Retriever, RandomEmbedder | |
ALL_GUESSR_CLASS = { | |
"random": RandomGuessr, | |
"nearestNeighborEmbedder": NearestNeighborEmbedderGuessr, | |
"averageNeighborsEmbedder": AverageNeighborsEmbedderGuessr, | |
} | |
ALL_GUESSR_ARGS = { | |
"random": {}, | |
"nearestNeighborEmbedder": { | |
"embedder": DinoV2Embedder( | |
device="cpu" | |
), | |
"retriever": Retriever( | |
embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), | |
"resources/embeddings.npy"), | |
), | |
"metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), | |
"resources/metadatav3.csv"), | |
}, | |
"averageNeighborsEmbedder": { | |
"embedder": DinoV2Embedder( | |
device="cpu" | |
), | |
"retriever": Retriever( | |
embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), | |
"resources/embeddings.npy"), | |
), | |
"metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), | |
"resources/metadatav3.csv"), | |
"n_neighbors": 100, | |
"dbscan_eps": 0.5 | |
} | |
} | |
# For instantiating guessrs only when needed | |
ALL_GUESSR = {} | |
def create_map(guessr: str) -> go.Figure: | |
"""Create an interactive map | |
""" | |
# Instantiate guessr if not already done | |
if guessr not in ALL_GUESSR: | |
ALL_GUESSR[guessr] = ALL_GUESSR_CLASS[guessr](**ALL_GUESSR_ARGS[guessr]) | |
return AbstractGuessr.create_map() | |
def guess(guessr: str, uploaded_image) -> go.Figure: | |
"""Guess a coordinate from an image uploaded in the Gradio interface | |
""" | |
# Instantiate guessr if not already done | |
if guessr not in ALL_GUESSR: | |
ALL_GUESSR[guessr] = ALL_GUESSR_CLASS[guessr](**ALL_GUESSR_ARGS[guessr]) | |
# Convert image to numpy array | |
uploaded_image = np.array(uploaded_image) | |
# Guess coordinate | |
guess_coordinate = ALL_GUESSR[guessr].guess(uploaded_image) | |
# Create map | |
fig = ALL_GUESSR[guessr].create_map(guess_coordinate) | |
return fig | |
if __name__ == "__main__": | |
# Create & launch Gradio interface | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
guessr_dropdown = gr.Dropdown( | |
list(ALL_GUESSR_CLASS.keys()), | |
value="nearestNeighborEmbedder", | |
label="Guessr type", | |
info="More Guessr types will be added soon!" | |
) | |
image = gr.Image() # Removed shape argument | |
button = gr.Button("Guess") # Changed 'text' to 'label' | |
interactive_map = gr.Plot() | |
demo.load(create_map, [guessr_dropdown], interactive_map) | |
button.click(guess, [guessr_dropdown, image], interactive_map) | |
# Launch demo 🚀 | |
demo.launch() | |