Spaces:
Runtime error
Runtime error
Bastien Dechamps
commited on
Commit
•
ed8157d
1
Parent(s):
06ee965
[ADD] Random Embedder
Browse files
app.py
CHANGED
@@ -6,8 +6,7 @@ import plotly.graph_objects as go
|
|
6 |
|
7 |
from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, NearestNeighborEmbedderGuessr, \
|
8 |
AverageNeighborsEmbedderGuessr
|
9 |
-
from geoguessr_bot.retriever import DinoV2Embedder, Retriever
|
10 |
-
|
11 |
|
12 |
ALL_GUESSR_CLASS = {
|
13 |
"random": RandomGuessr,
|
@@ -21,6 +20,7 @@ ALL_GUESSR_ARGS = {
|
|
21 |
"embedder": DinoV2Embedder(
|
22 |
device="cpu"
|
23 |
),
|
|
|
24 |
"retriever": Retriever(
|
25 |
embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
26 |
"resources/embeddings.npy"),
|
@@ -32,13 +32,14 @@ ALL_GUESSR_ARGS = {
|
|
32 |
"embedder": DinoV2Embedder(
|
33 |
device="cpu"
|
34 |
),
|
|
|
35 |
"retriever": Retriever(
|
36 |
embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
37 |
"resources/embeddings.npy"),
|
38 |
),
|
39 |
"metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
40 |
"resources/metadatav3.csv"),
|
41 |
-
"n_neighbors":
|
42 |
"dbscan_eps": 0.5
|
43 |
}
|
44 |
}
|
|
|
6 |
|
7 |
from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr, NearestNeighborEmbedderGuessr, \
|
8 |
AverageNeighborsEmbedderGuessr
|
9 |
+
from geoguessr_bot.retriever import DinoV2Embedder, Retriever, RandomEmbedder
|
|
|
10 |
|
11 |
ALL_GUESSR_CLASS = {
|
12 |
"random": RandomGuessr,
|
|
|
20 |
"embedder": DinoV2Embedder(
|
21 |
device="cpu"
|
22 |
),
|
23 |
+
# "embedder": RandomEmbedder(n_dim=384),
|
24 |
"retriever": Retriever(
|
25 |
embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
26 |
"resources/embeddings.npy"),
|
|
|
32 |
"embedder": DinoV2Embedder(
|
33 |
device="cpu"
|
34 |
),
|
35 |
+
# "embedder": RandomEmbedder(n_dim=384),
|
36 |
"retriever": Retriever(
|
37 |
embeddings_path=os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
38 |
"resources/embeddings.npy"),
|
39 |
),
|
40 |
"metadata_path": os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
41 |
"resources/metadatav3.csv"),
|
42 |
+
"n_neighbors": 100,
|
43 |
"dbscan_eps": 0.5
|
44 |
}
|
45 |
}
|
geoguessr_bot/guessr/average_neighbor_embedder_guessr.py
CHANGED
@@ -13,6 +13,11 @@ from geoguessr_bot.retriever import AbstractImageEmbedder
|
|
13 |
from geoguessr_bot.retriever import Retriever
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
@dataclass
|
17 |
class AverageNeighborsEmbedderGuessr(AbstractGuessr):
|
18 |
"""Guesses a coordinate using an Embedder and a retriever followed by NN.
|
@@ -20,7 +25,7 @@ class AverageNeighborsEmbedderGuessr(AbstractGuessr):
|
|
20 |
embedder: AbstractImageEmbedder
|
21 |
retriever: Retriever
|
22 |
metadata_path: str
|
23 |
-
n_neighbors: int =
|
24 |
dbscan_eps: float = 0.05
|
25 |
|
26 |
def __post_init__(self):
|
@@ -32,7 +37,7 @@ class AverageNeighborsEmbedderGuessr(AbstractGuessr):
|
|
32 |
for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"])
|
33 |
}
|
34 |
# DBSCAN will be used to take the centroid of the biggest cluster among the N neighbors, using Haversine
|
35 |
-
self.dbscan = DBSCAN(eps=self.dbscan_eps, metric=
|
36 |
|
37 |
def guess(self, image: Image) -> Coordinate:
|
38 |
"""Guess a coordinate from an image
|
|
|
13 |
from geoguessr_bot.retriever import Retriever
|
14 |
|
15 |
|
16 |
+
def haversine_distance(x, y) -> float:
|
17 |
+
"""Compute the haversine distance between two coordinates
|
18 |
+
"""
|
19 |
+
return haversine_distances(np.array(x).reshape(1, -1), np.array(y).reshape(1, -1))[0][0]
|
20 |
+
|
21 |
@dataclass
|
22 |
class AverageNeighborsEmbedderGuessr(AbstractGuessr):
|
23 |
"""Guesses a coordinate using an Embedder and a retriever followed by NN.
|
|
|
25 |
embedder: AbstractImageEmbedder
|
26 |
retriever: Retriever
|
27 |
metadata_path: str
|
28 |
+
n_neighbors: int = 50
|
29 |
dbscan_eps: float = 0.05
|
30 |
|
31 |
def __post_init__(self):
|
|
|
37 |
for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"])
|
38 |
}
|
39 |
# DBSCAN will be used to take the centroid of the biggest cluster among the N neighbors, using Haversine
|
40 |
+
self.dbscan = DBSCAN(eps=self.dbscan_eps, metric=haversine_distance)
|
41 |
|
42 |
def guess(self, image: Image) -> Coordinate:
|
43 |
"""Guess a coordinate from an image
|
geoguessr_bot/retriever/__init__.py
CHANGED
@@ -2,3 +2,4 @@ from .retriever import Retriever
|
|
2 |
from .abstract_embedder import AbstractImageEmbedder
|
3 |
from .dino_embedder import DinoEmbedder
|
4 |
from .dino_v2_embedder import DinoV2Embedder
|
|
|
|
2 |
from .abstract_embedder import AbstractImageEmbedder
|
3 |
from .dino_embedder import DinoEmbedder
|
4 |
from .dino_v2_embedder import DinoV2Embedder
|
5 |
+
from .random_embedder import RandomEmbedder
|
geoguessr_bot/retriever/abstract_embedder.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import os
|
|
|
|
|
2 |
from PIL import Image
|
3 |
import numpy as np
|
4 |
from tqdm import tqdm
|
@@ -8,6 +10,7 @@ class AbstractImageEmbedder:
|
|
8 |
def __init__(self, device: str = "cpu"):
|
9 |
self.device = device
|
10 |
|
|
|
11 |
def embed(self, image: Image) -> np.ndarray:
|
12 |
"""Embed an image
|
13 |
"""
|
|
|
1 |
import os
|
2 |
+
from abc import abstractmethod
|
3 |
+
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
from tqdm import tqdm
|
|
|
10 |
def __init__(self, device: str = "cpu"):
|
11 |
self.device = device
|
12 |
|
13 |
+
@abstractmethod
|
14 |
def embed(self, image: Image) -> np.ndarray:
|
15 |
"""Embed an image
|
16 |
"""
|
geoguessr_bot/retriever/random_embedder.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
from geoguessr_bot.retriever import AbstractImageEmbedder
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class RandomEmbedder(AbstractImageEmbedder):
|
11 |
+
n_dim: int = 8
|
12 |
+
|
13 |
+
def embed(self, image: Image) -> np.ndarray:
|
14 |
+
"""Embed an image
|
15 |
+
"""
|
16 |
+
return np.random.rand(self.n_dim)
|