Bastien Dechamps commited on
Commit
4388025
1 Parent(s): bf6fd1c

wip embedder

Browse files
app.py CHANGED
@@ -5,7 +5,7 @@ import plotly.graph_objects as go
5
  from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr
6
 
7
  ALL_GUESSR_CLASS = {
8
- "random": RandomGuessr
9
  }
10
 
11
  ALL_GUESSR_ARGS = {
 
5
  from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr
6
 
7
  ALL_GUESSR_CLASS = {
8
+ "random": RandomGuessr,
9
  }
10
 
11
  ALL_GUESSR_ARGS = {
geoguessr_bot/guessr/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
  from .abstract_guessr import AbstractGuessr
2
  from .random_guessr import RandomGuessr
 
 
1
  from .abstract_guessr import AbstractGuessr
2
  from .random_guessr import RandomGuessr
3
+ from .global_embedder_guessr import GlobalEmbedderGuessr
geoguessr_bot/guessr/global_embedder_guessr.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from geoguessr_bot.guessr import AbstractGuessr
2
+ from geoguessr_bot.retriever import AbstractImageEmbedder
3
+ from geoguessr_bot.retriever import Retriever
4
+
5
+
6
+ @dataclass
7
+ class GlobalEmbedderGuessr(AbstractGuessr):
8
+ """Guesses a coordinate using an Embedder and a retriever
9
+ """
10
+
11
+ embedder: AbstractImageEmbedder
12
+ retriever: Retriever
13
+
14
+ def guess(self, image: Image) -> Coordinate:
15
+ """Guess a coordinate from an image
16
+ """
17
+ # Embed image
18
+ image_embedding = self.embedder.embed(image)
19
+
20
+ # Retrieve nearest neighbors
21
+ nearest_neighbors = self.retriever.retrieve(image_embedding)
22
+
23
+ # Guess coordinate
24
+ guess_coordinate = self.retriever.image_to_coordinate[nearest_neighbors[0][0]]
25
+ return guess_coordinate
geoguessr_bot/retriever/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .retriever import AbstractRetriever
2
+ from .abstract_embedder import AbstractImageEmbedder
3
+ from .dino_embedder import DinoEmbedder
geoguessr_bot/retriever/abstract_embedder.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class AbstractImageEmbedder:
2
+ def __init__(self, device: str = "cpu"):
3
+ self.device = device
4
+
5
+ def embed(self, image: Image) -> np.ndarray:
6
+ """Embed an image
7
+ """
8
+ raise NotImplementedError
9
+
10
+ def embed_folder(self, folder_path: str):
11
+ """Embed all images in a folder and save them in a .npy file
12
+ """
13
+ embeddings = {}
14
+ for image in os.listdir(folder_path):
15
+ image_path = os.path.join(folder_path, image)
16
+ image = Image.open(image_path)
17
+ embedding = self.embed(image)
18
+ embeddings[image] = embedding
19
+ np.save(folder_path + ".npy", embeddings)
geoguessr_bot/retriever/dino_embedder.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTFeatureExtractor, ViTModel
2
+
3
+ from .abstract_embedder import AbstractImageEmbedder
4
+
5
+
6
+ class DinoEmbedder(AbstractImageEmbedder):
7
+ def __init__(self, device: str = "cpu", model_name: str = "facebook/dino-vitb8"):
8
+ super().__init__(device)
9
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
10
+ self.model = ViTModel.from_pretrained(model_name).to(self.device)
11
+
12
+ def embed(self, image: Image) -> np.ndarray:
13
+ inputs = feature_extractor(images=image, return_tensors="pt")
14
+ for key in inputs:
15
+ inputs[key] = inputs[key].to(self.device)
16
+ outputs = model(**inputs)
17
+ last_hidden_states = outputs.last_hidden_state.to("cpu").numpy()
18
+ return last_hidden_states
geoguessr_bot/retriever/retriever.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Retriever:
2
+ def __init__(self, embeddings_path: str):
3
+ self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path)
4
+
5
+ # Keep track of image names
6
+ self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())}
7
+ self.index_to_image = {i: image_name for i, image_name in enumerate(self.embeddings.keys())}
8
+
9
+ # Build Faiss index
10
+ self.embeddings = np.array(list(self.embeddings.values()))
11
+ self.dim = self.embeddings.shape[1]
12
+ self.index = faiss.IndexFlatL2(self.dim)
13
+ self.index.add(self.embeddings)
14
+
15
+ @staticmethod
16
+ def load_embeddings(embeddings_path: str) -> Dict[str, np.ndarray]:
17
+ """Load embeddings from a file
18
+ """
19
+ raise NotImplementedError
20
+
21
+ def retrieve(self, queries: np.ndarray, n_neighbors: int = 5) -> List[str]:
22
+ """Retrieve nearest neighbors indexes from queries
23
+ """
24
+ dist, indexes = self.index.search(queries, n_neighbors)
25
+ return [[self.index_to_image[i] for i in index] for index in indexes]
26
+
requirements.txt CHANGED
@@ -4,3 +4,6 @@ plotly
4
  gradio
5
  pydantic
6
  Pillow
 
 
 
 
4
  gradio
5
  pydantic
6
  Pillow
7
+ transformers
8
+ requests
9
+ faiss-cpu