Bastien Dechamps commited on
Commit
d78053e
1 Parent(s): 43bfbd0

[ADD] First demo with 2500 images

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ resources/embeddings.npy filter=lfs diff=lfs merge=lfs -text
geoguessr_bot/guessr/global_embedder_guessr.py CHANGED
@@ -1,6 +1,7 @@
1
  from dataclasses import dataclass
2
 
3
  from PIL import Image
 
4
 
5
  from geoguessr_bot.guessr import AbstractGuessr
6
  from geoguessr_bot.interfaces import Coordinate
@@ -15,16 +16,28 @@ class GlobalEmbedderGuessr(AbstractGuessr):
15
 
16
  embedder: AbstractImageEmbedder
17
  retriever: Retriever
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def guess(self, image: Image) -> Coordinate:
20
  """Guess a coordinate from an image
21
  """
22
  # Embed image
23
- image_embedding = self.embedder.embed(image)
24
-
25
  # Retrieve nearest neighbors
26
  nearest_neighbors = self.retriever.retrieve(image_embedding)
 
27
 
28
  # Guess coordinate
29
- guess_coordinate = self.retriever.image_to_coordinate[nearest_neighbors[0][0]]
30
  return guess_coordinate
 
1
  from dataclasses import dataclass
2
 
3
  from PIL import Image
4
+ import pandas as pd
5
 
6
  from geoguessr_bot.guessr import AbstractGuessr
7
  from geoguessr_bot.interfaces import Coordinate
 
16
 
17
  embedder: AbstractImageEmbedder
18
  retriever: Retriever
19
+ metadata_path: str
20
+
21
+ def __post_init__(self):
22
+ """Load metadata
23
+ """
24
+ metadata = pd.read_csv(self.metadata_path)
25
+ self.image_to_coordinate = {
26
+ image.split("/")[-1]: Coordinate(latitude=latitude, longitude=longitude)
27
+ for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"])
28
+ }
29
+
30
 
31
  def guess(self, image: Image) -> Coordinate:
32
  """Guess a coordinate from an image
33
  """
34
  # Embed image
35
+ image_embedding = self.embedder.embed(image)[None, :]
36
+
37
  # Retrieve nearest neighbors
38
  nearest_neighbors = self.retriever.retrieve(image_embedding)
39
+ nearest_neighbor = nearest_neighbors[0][0]
40
 
41
  # Guess coordinate
42
+ guess_coordinate = self.image_to_coordinate[nearest_neighbor]
43
  return guess_coordinate
geoguessr_bot/retriever/retriever.py CHANGED
@@ -5,8 +5,9 @@ import faiss
5
 
6
 
7
  class Retriever:
8
- def __init__(self, embeddings_path: str):
9
  self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path)
 
10
 
11
  # Keep track of image names
12
  self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())}
@@ -24,9 +25,8 @@ class Retriever:
24
  """
25
  return np.load(embeddings_path, allow_pickle=True).item()
26
 
27
- def retrieve(self, queries: np.ndarray, n_neighbors: int = 5) -> List[List[str]]:
28
  """Retrieve nearest neighbors indexes from queries
29
  """
30
- dist, indexes = self.index.search(queries, n_neighbors)
31
  return [[self.index_to_image[i] for i in index] for index in indexes]
32
-
 
5
 
6
 
7
  class Retriever:
8
+ def __init__(self, embeddings_path: str, n_neighbors: int = 5):
9
  self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path)
10
+ self.n_neighbors = n_neighbors
11
 
12
  # Keep track of image names
13
  self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())}
 
25
  """
26
  return np.load(embeddings_path, allow_pickle=True).item()
27
 
28
+ def retrieve(self, queries: np.ndarray) -> List[List[str]]:
29
  """Retrieve nearest neighbors indexes from queries
30
  """
31
+ _, indexes = self.index.search(queries, self.n_neighbors)
32
  return [[self.index_to_image[i] for i in index] for index in indexes]
 
resources/metadatav3.csv ADDED
The diff for this file is too large to render. See raw diff