Spaces:
Runtime error
Runtime error
Bastien Dechamps
commited on
Commit
•
d78053e
1
Parent(s):
43bfbd0
[ADD] First demo with 2500 images
Browse files
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
resources/embeddings.npy filter=lfs diff=lfs merge=lfs -text
|
geoguessr_bot/guessr/global_embedder_guessr.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
|
3 |
from PIL import Image
|
|
|
4 |
|
5 |
from geoguessr_bot.guessr import AbstractGuessr
|
6 |
from geoguessr_bot.interfaces import Coordinate
|
@@ -15,16 +16,28 @@ class GlobalEmbedderGuessr(AbstractGuessr):
|
|
15 |
|
16 |
embedder: AbstractImageEmbedder
|
17 |
retriever: Retriever
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def guess(self, image: Image) -> Coordinate:
|
20 |
"""Guess a coordinate from an image
|
21 |
"""
|
22 |
# Embed image
|
23 |
-
image_embedding = self.embedder.embed(image)
|
24 |
-
|
25 |
# Retrieve nearest neighbors
|
26 |
nearest_neighbors = self.retriever.retrieve(image_embedding)
|
|
|
27 |
|
28 |
# Guess coordinate
|
29 |
-
guess_coordinate = self.
|
30 |
return guess_coordinate
|
|
|
1 |
from dataclasses import dataclass
|
2 |
|
3 |
from PIL import Image
|
4 |
+
import pandas as pd
|
5 |
|
6 |
from geoguessr_bot.guessr import AbstractGuessr
|
7 |
from geoguessr_bot.interfaces import Coordinate
|
|
|
16 |
|
17 |
embedder: AbstractImageEmbedder
|
18 |
retriever: Retriever
|
19 |
+
metadata_path: str
|
20 |
+
|
21 |
+
def __post_init__(self):
|
22 |
+
"""Load metadata
|
23 |
+
"""
|
24 |
+
metadata = pd.read_csv(self.metadata_path)
|
25 |
+
self.image_to_coordinate = {
|
26 |
+
image.split("/")[-1]: Coordinate(latitude=latitude, longitude=longitude)
|
27 |
+
for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"])
|
28 |
+
}
|
29 |
+
|
30 |
|
31 |
def guess(self, image: Image) -> Coordinate:
|
32 |
"""Guess a coordinate from an image
|
33 |
"""
|
34 |
# Embed image
|
35 |
+
image_embedding = self.embedder.embed(image)[None, :]
|
36 |
+
|
37 |
# Retrieve nearest neighbors
|
38 |
nearest_neighbors = self.retriever.retrieve(image_embedding)
|
39 |
+
nearest_neighbor = nearest_neighbors[0][0]
|
40 |
|
41 |
# Guess coordinate
|
42 |
+
guess_coordinate = self.image_to_coordinate[nearest_neighbor]
|
43 |
return guess_coordinate
|
geoguessr_bot/retriever/retriever.py
CHANGED
@@ -5,8 +5,9 @@ import faiss
|
|
5 |
|
6 |
|
7 |
class Retriever:
|
8 |
-
def __init__(self, embeddings_path: str):
|
9 |
self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path)
|
|
|
10 |
|
11 |
# Keep track of image names
|
12 |
self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())}
|
@@ -24,9 +25,8 @@ class Retriever:
|
|
24 |
"""
|
25 |
return np.load(embeddings_path, allow_pickle=True).item()
|
26 |
|
27 |
-
def retrieve(self, queries: np.ndarray
|
28 |
"""Retrieve nearest neighbors indexes from queries
|
29 |
"""
|
30 |
-
|
31 |
return [[self.index_to_image[i] for i in index] for index in indexes]
|
32 |
-
|
|
|
5 |
|
6 |
|
7 |
class Retriever:
|
8 |
+
def __init__(self, embeddings_path: str, n_neighbors: int = 5):
|
9 |
self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path)
|
10 |
+
self.n_neighbors = n_neighbors
|
11 |
|
12 |
# Keep track of image names
|
13 |
self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())}
|
|
|
25 |
"""
|
26 |
return np.load(embeddings_path, allow_pickle=True).item()
|
27 |
|
28 |
+
def retrieve(self, queries: np.ndarray) -> List[List[str]]:
|
29 |
"""Retrieve nearest neighbors indexes from queries
|
30 |
"""
|
31 |
+
_, indexes = self.index.search(queries, self.n_neighbors)
|
32 |
return [[self.index_to_image[i] for i in index] for index in indexes]
|
|
resources/metadatav3.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|