Spaces:
Runtime error
Runtime error
Bastien Dechamps
commited on
Commit
•
4388025
1
Parent(s):
bf6fd1c
wip embedder
Browse files- app.py +1 -1
- geoguessr_bot/guessr/__init__.py +1 -0
- geoguessr_bot/guessr/global_embedder_guessr.py +25 -0
- geoguessr_bot/retriever/__init__.py +3 -0
- geoguessr_bot/retriever/abstract_embedder.py +19 -0
- geoguessr_bot/retriever/dino_embedder.py +18 -0
- geoguessr_bot/retriever/retriever.py +26 -0
- requirements.txt +3 -0
app.py
CHANGED
@@ -5,7 +5,7 @@ import plotly.graph_objects as go
|
|
5 |
from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr
|
6 |
|
7 |
ALL_GUESSR_CLASS = {
|
8 |
-
"random": RandomGuessr
|
9 |
}
|
10 |
|
11 |
ALL_GUESSR_ARGS = {
|
|
|
5 |
from geoguessr_bot.guessr import RandomGuessr, AbstractGuessr
|
6 |
|
7 |
ALL_GUESSR_CLASS = {
|
8 |
+
"random": RandomGuessr,
|
9 |
}
|
10 |
|
11 |
ALL_GUESSR_ARGS = {
|
geoguessr_bot/guessr/__init__.py
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
from .abstract_guessr import AbstractGuessr
|
2 |
from .random_guessr import RandomGuessr
|
|
|
|
1 |
from .abstract_guessr import AbstractGuessr
|
2 |
from .random_guessr import RandomGuessr
|
3 |
+
from .global_embedder_guessr import GlobalEmbedderGuessr
|
geoguessr_bot/guessr/global_embedder_guessr.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from geoguessr_bot.guessr import AbstractGuessr
|
2 |
+
from geoguessr_bot.retriever import AbstractImageEmbedder
|
3 |
+
from geoguessr_bot.retriever import Retriever
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class GlobalEmbedderGuessr(AbstractGuessr):
|
8 |
+
"""Guesses a coordinate using an Embedder and a retriever
|
9 |
+
"""
|
10 |
+
|
11 |
+
embedder: AbstractImageEmbedder
|
12 |
+
retriever: Retriever
|
13 |
+
|
14 |
+
def guess(self, image: Image) -> Coordinate:
|
15 |
+
"""Guess a coordinate from an image
|
16 |
+
"""
|
17 |
+
# Embed image
|
18 |
+
image_embedding = self.embedder.embed(image)
|
19 |
+
|
20 |
+
# Retrieve nearest neighbors
|
21 |
+
nearest_neighbors = self.retriever.retrieve(image_embedding)
|
22 |
+
|
23 |
+
# Guess coordinate
|
24 |
+
guess_coordinate = self.retriever.image_to_coordinate[nearest_neighbors[0][0]]
|
25 |
+
return guess_coordinate
|
geoguessr_bot/retriever/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .retriever import AbstractRetriever
|
2 |
+
from .abstract_embedder import AbstractImageEmbedder
|
3 |
+
from .dino_embedder import DinoEmbedder
|
geoguessr_bot/retriever/abstract_embedder.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class AbstractImageEmbedder:
|
2 |
+
def __init__(self, device: str = "cpu"):
|
3 |
+
self.device = device
|
4 |
+
|
5 |
+
def embed(self, image: Image) -> np.ndarray:
|
6 |
+
"""Embed an image
|
7 |
+
"""
|
8 |
+
raise NotImplementedError
|
9 |
+
|
10 |
+
def embed_folder(self, folder_path: str):
|
11 |
+
"""Embed all images in a folder and save them in a .npy file
|
12 |
+
"""
|
13 |
+
embeddings = {}
|
14 |
+
for image in os.listdir(folder_path):
|
15 |
+
image_path = os.path.join(folder_path, image)
|
16 |
+
image = Image.open(image_path)
|
17 |
+
embedding = self.embed(image)
|
18 |
+
embeddings[image] = embedding
|
19 |
+
np.save(folder_path + ".npy", embeddings)
|
geoguessr_bot/retriever/dino_embedder.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import ViTFeatureExtractor, ViTModel
|
2 |
+
|
3 |
+
from .abstract_embedder import AbstractImageEmbedder
|
4 |
+
|
5 |
+
|
6 |
+
class DinoEmbedder(AbstractImageEmbedder):
|
7 |
+
def __init__(self, device: str = "cpu", model_name: str = "facebook/dino-vitb8"):
|
8 |
+
super().__init__(device)
|
9 |
+
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
10 |
+
self.model = ViTModel.from_pretrained(model_name).to(self.device)
|
11 |
+
|
12 |
+
def embed(self, image: Image) -> np.ndarray:
|
13 |
+
inputs = feature_extractor(images=image, return_tensors="pt")
|
14 |
+
for key in inputs:
|
15 |
+
inputs[key] = inputs[key].to(self.device)
|
16 |
+
outputs = model(**inputs)
|
17 |
+
last_hidden_states = outputs.last_hidden_state.to("cpu").numpy()
|
18 |
+
return last_hidden_states
|
geoguessr_bot/retriever/retriever.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Retriever:
|
2 |
+
def __init__(self, embeddings_path: str):
|
3 |
+
self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path)
|
4 |
+
|
5 |
+
# Keep track of image names
|
6 |
+
self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())}
|
7 |
+
self.index_to_image = {i: image_name for i, image_name in enumerate(self.embeddings.keys())}
|
8 |
+
|
9 |
+
# Build Faiss index
|
10 |
+
self.embeddings = np.array(list(self.embeddings.values()))
|
11 |
+
self.dim = self.embeddings.shape[1]
|
12 |
+
self.index = faiss.IndexFlatL2(self.dim)
|
13 |
+
self.index.add(self.embeddings)
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
def load_embeddings(embeddings_path: str) -> Dict[str, np.ndarray]:
|
17 |
+
"""Load embeddings from a file
|
18 |
+
"""
|
19 |
+
raise NotImplementedError
|
20 |
+
|
21 |
+
def retrieve(self, queries: np.ndarray, n_neighbors: int = 5) -> List[str]:
|
22 |
+
"""Retrieve nearest neighbors indexes from queries
|
23 |
+
"""
|
24 |
+
dist, indexes = self.index.search(queries, n_neighbors)
|
25 |
+
return [[self.index_to_image[i] for i in index] for index in indexes]
|
26 |
+
|
requirements.txt
CHANGED
@@ -4,3 +4,6 @@ plotly
|
|
4 |
gradio
|
5 |
pydantic
|
6 |
Pillow
|
|
|
|
|
|
|
|
4 |
gradio
|
5 |
pydantic
|
6 |
Pillow
|
7 |
+
transformers
|
8 |
+
requests
|
9 |
+
faiss-cpu
|