tronskel commited on
Commit
71026d8
·
verified ·
1 Parent(s): 5c3b2be
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -35
  2. app.py +409 -0
  3. configs/computer/a100.yaml +8 -0
  4. configs/computer/cluster-node-a100.yaml +8 -0
  5. configs/computer/cluster-node-v100.yaml +8 -0
  6. configs/computer/cpu.yaml +8 -0
  7. configs/computer/v100.yaml +8 -0
  8. configs/config.yaml +89 -0
  9. configs/dataset/baselines/im2gps.yaml +16 -0
  10. configs/dataset/baselines/im2gps3k.yaml +16 -0
  11. configs/dataset/baselines/yfcc4k.yaml +16 -0
  12. configs/dataset/osv5m.yaml +46 -0
  13. configs/dataset/osv5m_contrastive.yaml +34 -0
  14. configs/dataset/osv5m_contrastive_best.yaml +37 -0
  15. configs/dataset/osv5m_text_contrastive.yaml +34 -0
  16. configs/dataset/test_transform/center_crop.yaml +12 -0
  17. configs/dataset/test_transform/clip.yaml +2 -0
  18. configs/dataset/test_transform/fast_clip.yaml +12 -0
  19. configs/dataset/test_transform/fast_resnet.yaml +12 -0
  20. configs/dataset/test_transform/none.yaml +6 -0
  21. configs/dataset/train_transform/augmentation.yaml +85 -0
  22. configs/dataset/train_transform/center_crop.yaml +14 -0
  23. configs/dataset/train_transform/clip.yaml +2 -0
  24. configs/dataset/train_transform/fast_clip.yaml +12 -0
  25. configs/dataset/train_transform/fast_resnet.yaml +12 -0
  26. configs/dataset/train_transform/none.yaml +7 -0
  27. configs/exp/DinoV2.yaml +18 -0
  28. configs/exp/ResNet.yaml +21 -0
  29. configs/exp/base_model.yaml +19 -0
  30. configs/exp/best_model.yaml +25 -0
  31. configs/exp/classification_area.yaml +19 -0
  32. configs/exp/classification_cell.yaml +19 -0
  33. configs/exp/classification_cell_hier.yaml +20 -0
  34. configs/exp/classification_city.yaml +19 -0
  35. configs/exp/classification_city_hier.yaml +20 -0
  36. configs/exp/classification_country.yaml +19 -0
  37. configs/exp/classification_region copy.yaml +19 -0
  38. configs/exp/classification_region.yaml +19 -0
  39. configs/exp/clip_L_14_DataComp.yaml +18 -0
  40. configs/exp/clip_L_14_Laion.yaml +18 -0
  41. configs/exp/clip_L_14_OpenAI.yaml +18 -0
  42. configs/exp/clip_bigG_14_Laion.yaml +18 -0
  43. configs/exp/contrastive_area.yaml +20 -0
  44. configs/exp/contrastive_cell.yaml +20 -0
  45. configs/exp/contrastive_city.yaml +20 -0
  46. configs/exp/contrastive_country.yaml +20 -0
  47. configs/exp/contrastive_region.yaml +20 -0
  48. configs/exp/contrastive_text.yaml +22 -0
  49. configs/exp/eval_best_model.yaml +29 -0
  50. configs/exp/fine_tuning.yaml +20 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ from torchvision import transforms
5
+ import pydeck as pdk
6
+ from geopy.geocoders import Nominatim
7
+ import time
8
+ import requests
9
+ from io import BytesIO
10
+ import reverse_geocoder as rg
11
+ from bs4 import BeautifulSoup
12
+ from urllib.parse import urljoin
13
+ from models.huggingface import Geolocalizer
14
+ import spacy
15
+ from collections import Counter
16
+ from spacy.cli import download
17
+ from typing import Tuple, List, Optional, Union, Dict
18
+
19
+
20
+ def load_spacy_model(model_name: str = "en_core_web_md") -> spacy.Language:
21
+ """
22
+ Load the specified spaCy model.
23
+
24
+ Args:
25
+ model_name (str): Name of the spaCy model to load.
26
+
27
+ Returns:
28
+ spacy.Language: Loaded spaCy model.
29
+ """
30
+ try:
31
+ return spacy.load(model_name)
32
+ except IOError:
33
+ print(f"Model {model_name} not found, downloading...")
34
+ download(model_name)
35
+ return spacy.load(model_name)
36
+
37
+
38
+ nlp = load_spacy_model()
39
+
40
+ IMAGE_SIZE = (224, 224)
41
+ GEOLOC_MODEL_NAME = "osv5m/baseline"
42
+
43
+
44
+ @st.cache_resource(show_spinner=True)
45
+ def load_geoloc_model() -> Optional[Geolocalizer]:
46
+ """
47
+ Load the geolocation model.
48
+
49
+ Returns:
50
+ Optional[Geolocalizer]: Loaded geolocation model or None if loading fails.
51
+ """
52
+ with st.spinner('Loading model...'):
53
+ try:
54
+ model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME)
55
+ model.eval()
56
+ return model
57
+ except Exception as e:
58
+ st.error(f"Failed to load the model: {e}")
59
+ return None
60
+
61
+
62
+ def most_frequent_locations(text: str) -> Tuple[str, List[str]]:
63
+ """
64
+ Find the most frequent locations mentioned in the text.
65
+
66
+ Args:
67
+ text (str): Input text to analyze.
68
+
69
+ Returns:
70
+ Tuple[str, List[str]]: Description of the most mentioned locations and a list of those locations.
71
+ """
72
+ doc = nlp(text)
73
+ locations = []
74
+
75
+ for ent in doc.ents:
76
+ if ent.label_ in ['LOC', 'GPE']:
77
+ print(f"Entity: {ent.text} | Label: {ent.label_} | Sentence: {ent.sent}")
78
+ locations.append(ent.text)
79
+
80
+ if locations:
81
+ location_counts = Counter(locations)
82
+ most_common_locations = location_counts.most_common(2)
83
+ common_locations_str = ', '.join([f"{loc[0]} ({loc[1]} occurrences)" for loc in most_common_locations])
84
+ return f"Most Mentioned Locations: {common_locations_str}", [loc[0] for loc in most_common_locations]
85
+ else:
86
+ return "No locations found", []
87
+
88
+
89
+ def transform_image(image: Image) -> torch.Tensor:
90
+ """
91
+ Transform the input image for model prediction.
92
+
93
+ Args:
94
+ image (Image): Input image.
95
+
96
+ Returns:
97
+ torch.Tensor: Transformed image tensor.
98
+ """
99
+ transform = transforms.Compose([
100
+ transforms.Resize(IMAGE_SIZE),
101
+ transforms.ToTensor(),
102
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
103
+ ])
104
+ return transform(image).unsqueeze(0)
105
+
106
+
107
+ def check_location_match(location_query: dict, most_common_locations: List[str]) -> bool:
108
+ """
109
+ Check if the predicted location matches any of the most common locations.
110
+
111
+ Args:
112
+ location_query (dict): Predicted location details.
113
+ most_common_locations (List[str]): List of most common locations.
114
+
115
+ Returns:
116
+ bool: True if a match is found, False otherwise.
117
+ """
118
+ name = location_query['name']
119
+ admin1 = location_query['admin1']
120
+ cc = location_query['cc']
121
+
122
+ for loc in most_common_locations:
123
+ if name in loc and admin1 in loc and cc in loc:
124
+ return True
125
+ return False
126
+
127
+
128
+ def get_city_geojson(location_name: str) -> Optional[dict]:
129
+ """
130
+ Fetch the GeoJSON data for the specified city.
131
+
132
+ Args:
133
+ location_name (str): Name of the city.
134
+
135
+ Returns:
136
+ Optional[dict]: GeoJSON data of the city or None if fetching fails.
137
+ """
138
+ geolocator = Nominatim(user_agent="predictGeolocforImage")
139
+ try:
140
+ location = geolocator.geocode(location_name, geometry='geojson')
141
+ return location.raw['geojson'] if location else None
142
+ except Exception as e:
143
+ st.error(f"Failed to geocode location: {e}")
144
+ return None
145
+
146
+
147
+ def get_media(url: str) -> Optional[List[Tuple[str, str]]]:
148
+ """
149
+ Fetch media URLs and associated text from the specified URL.
150
+
151
+ Args:
152
+ url (str): URL to fetch media from.
153
+
154
+ Returns:
155
+ Optional[List[Tuple[str, str]]]: List of tuples containing media URLs and associated text or None if fetching fails.
156
+ """
157
+ try:
158
+ response = requests.get(url)
159
+ response.raise_for_status()
160
+ data = response.json()
161
+ return [(media['media_url'], entry['full_text'])
162
+ for entry in data for media in entry.get('media', []) if 'media_url' in media]
163
+ except requests.RequestException as e:
164
+ st.error(f"Failed to fetch media URL: {e}")
165
+ return None
166
+
167
+
168
+ def predict_location(image: Image, model: Geolocalizer) -> Optional[Tuple[List[float], dict, Optional[dict], float]]:
169
+ """
170
+ Predict the location from the input image using the specified model.
171
+
172
+ Args:
173
+ image (Image): Input image.
174
+ model (Geolocalizer): Geolocation model.
175
+
176
+ Returns:
177
+ Optional[Tuple[List[float], dict, Optional[dict], float]]: Predicted GPS coordinates, location query, city GeoJSON data, and processing time or None if prediction fails.
178
+ """
179
+ with st.spinner('Processing image and predicting location...'):
180
+ start_time = time.time()
181
+ try:
182
+ img_tensor = transform_image(image)
183
+ gps_radians = model(img_tensor)
184
+ gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist()
185
+ location_query = rg.search((gps_degrees[0], gps_degrees[1]))[0]
186
+ location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}"
187
+ city_geojson = get_city_geojson(location_name)
188
+ processing_time = time.time() - start_time
189
+ return gps_degrees, location_query, city_geojson, processing_time
190
+ except Exception as e:
191
+ st.error(f"Failed to predict the location: {e}")
192
+ return None
193
+
194
+
195
+ def display_map(city_geojson: dict, gps_degrees: List[float]) -> None:
196
+ """
197
+ Display a map with the specified city GeoJSON data and GPS coordinates.
198
+
199
+ Args:
200
+ city_geojson (dict): GeoJSON data of the city.
201
+ gps_degrees (List[float]): GPS coordinates.
202
+ """
203
+ map_view = pdk.Deck(
204
+ map_style='mapbox://styles/mapbox/light-v9',
205
+ initial_view_state=pdk.ViewState(
206
+ latitude=gps_degrees[0],
207
+ longitude=gps_degrees[1],
208
+ zoom=8,
209
+ pitch=0,
210
+ ),
211
+ layers=[
212
+ pdk.Layer(
213
+ 'GeoJsonLayer',
214
+ data=city_geojson,
215
+ get_fill_color=[255, 180, 0, 140],
216
+ pickable=True,
217
+ stroked=True,
218
+ filled=True,
219
+ extruded=False,
220
+ line_width_min_pixels=1,
221
+ ),
222
+ ],
223
+ )
224
+ st.pydeck_chart(map_view)
225
+
226
+
227
+ def display_image(image_url: str) -> None:
228
+ """
229
+ Display an image from the specified URL.
230
+
231
+ Args:
232
+ image_url (str): URL of the image.
233
+ """
234
+ try:
235
+ response = requests.get(image_url)
236
+ response.raise_for_status()
237
+ image_bytes = BytesIO(response.content)
238
+ st.image(image_bytes, caption=f'Image from URL: {image_url}', use_column_width=True)
239
+ except requests.RequestException as e:
240
+ st.error(f"Failed to fetch image at URL {image_url}: {e}")
241
+ except Exception as e:
242
+ st.error(f"An error occurred: {e}")
243
+
244
+
245
+ def scrape_webpage(url: str) -> Union[Tuple[Optional[str], Optional[List[str]]], Tuple[None, None]]:
246
+ """
247
+ Scrape the specified webpage for text and images.
248
+
249
+ Args:
250
+ url (str): URL of the webpage to scrape.
251
+
252
+ Returns:
253
+ Union[Tuple[Optional[str], Optional[List[str]]], Tuple[None, None]]: Extracted text and list of image URLs or None if scraping fails.
254
+ """
255
+ with st.spinner('Scraping web page...'):
256
+ try:
257
+ response = requests.get(url)
258
+ response.raise_for_status()
259
+ soup = BeautifulSoup(response.content, 'html.parser')
260
+ base_url = url # Adjust based on <base> tags or other HTML clues
261
+ text = ''.join(p.text for p in soup.find_all('p'))
262
+ images = [urljoin(base_url, img['src']) for img in soup.find_all('img') if 'src' in img.attrs]
263
+ return text, images
264
+ except requests.RequestException as e:
265
+ st.error(f"Failed to fetch and parse the URL: {e}")
266
+ return None, None
267
+
268
+
269
+ def main() -> None:
270
+ """
271
+ Main function to run the Streamlit app.
272
+ """
273
+ st.title('Welcome to Geolocation Guesstimation Demo 👋')
274
+
275
+ page = st.sidebar.selectbox(
276
+ "Choose your action:",
277
+ ("Home", "Images", "Social Media", "Web Pages"),
278
+ index=0
279
+ )
280
+
281
+ st.sidebar.success("Select a demo above.")
282
+ st.sidebar.info(
283
+ """
284
+ - Web App URL: <https://yunusserhat-guesstimatelocation.hf.space/>
285
+ """
286
+ )
287
+
288
+ st.sidebar.title("Contact")
289
+ st.sidebar.info(
290
+ """
291
+ Yunus Serhat Bıçakçı at [yunusserhat.com](https://yunusserhat.com) | [GitHub](https://github.com/yunusserhat) | [Twitter](https://twitter.com/yunusserhat) | [LinkedIn](https://www.linkedin.com/in/yunusserhat)
292
+ """
293
+ )
294
+
295
+ if page == "Home":
296
+ st.write("Welcome to the Geolocation Predictor. Please select an action from the sidebar dropdown.")
297
+
298
+ elif page == "Images":
299
+ upload_images_page()
300
+
301
+ elif page == "Social Media":
302
+ social_media_page()
303
+
304
+ elif page == "Web Pages":
305
+ web_page_url_page()
306
+
307
+
308
+ def upload_images_page() -> None:
309
+ """
310
+ Display the image upload page for geolocation prediction.
311
+ """
312
+ st.header("Image Upload for Geolocation Prediction")
313
+ uploaded_files = st.file_uploader("Choose images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
314
+ if uploaded_files:
315
+ for idx, file in enumerate(uploaded_files, start=1):
316
+ with st.spinner(f"Processing {file.name}..."):
317
+ image = Image.open(file).convert('RGB')
318
+ st.image(image, caption=f'Uploaded Image: {file.name}', use_column_width=True)
319
+ model = load_geoloc_model()
320
+ if model:
321
+ result = predict_location(image, model)
322
+ if result:
323
+ gps_degrees, location_query, city_geojson, processing_time = result
324
+ st.write(
325
+ f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}")
326
+ if city_geojson:
327
+ display_map(city_geojson, gps_degrees)
328
+ st.write(f"Processing Time (seconds): {processing_time}")
329
+
330
+
331
+ def social_media_page() -> None:
332
+ """
333
+ Display the social media analysis page.
334
+ """
335
+ st.header("Social Media Analyser")
336
+ social_media_url = st.text_input("Enter a social media URL to analyse:", key='social_media_url_input')
337
+ if social_media_url:
338
+ media_data = get_media(social_media_url)
339
+ if media_data:
340
+ full_text = media_data[0][1]
341
+ st.subheader("Full Text")
342
+ st.write(full_text)
343
+ most_used_location, most_common_locations = most_frequent_locations(full_text)
344
+ st.subheader("Most Frequent Location")
345
+ st.write(most_used_location)
346
+
347
+ for idx, (media_url, _) in enumerate(media_data, start=1):
348
+ st.subheader(f"Image {idx}")
349
+ response = requests.get(media_url)
350
+ if response.status_code == 200:
351
+ image = Image.open(BytesIO(response.content)).convert('RGB')
352
+ st.image(image, caption=f'Image from URL: {media_url}', use_column_width=True)
353
+ model = load_geoloc_model()
354
+ if model:
355
+ result = predict_location(image, model)
356
+ if result:
357
+ gps_degrees, location_query, city_geojson, processing_time = result
358
+ location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}"
359
+ st.write(
360
+ f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}")
361
+ if city_geojson:
362
+ display_map(city_geojson, gps_degrees)
363
+ st.write(f"Processing Time (seconds): {processing_time}")
364
+ if check_location_match(location_query, most_common_locations):
365
+ st.success(
366
+ f"The predicted location {location_name} matches one of the most frequently mentioned locations!")
367
+ else:
368
+ st.error(f"Failed to fetch image at URL {media_url}: HTTP {response.status_code}")
369
+
370
+
371
+ def web_page_url_page() -> None:
372
+ """
373
+ Display the web page URL analysis page.
374
+ """
375
+ st.header("Web Page Analyser")
376
+ web_page_url = st.text_input("Enter a web page URL to scrape:", key='web_page_url_input')
377
+ if web_page_url:
378
+ text, images = scrape_webpage(web_page_url)
379
+ if text:
380
+ st.subheader("Extracted Text First 500 Characters:")
381
+ st.write(text[:500])
382
+ most_used_location, most_common_locations = most_frequent_locations(text)
383
+ st.subheader("Most Frequent Location")
384
+ st.write(most_used_location)
385
+ if images:
386
+ selected_image_url = st.selectbox("Select an image to predict location:", images)
387
+ if selected_image_url:
388
+ response = requests.get(selected_image_url)
389
+ if response.status_code == 200:
390
+ image = Image.open(BytesIO(response.content)).convert('RGB')
391
+ st.image(image, caption=f'Selected Image from URL: {selected_image_url}', use_column_width=True)
392
+ model = load_geoloc_model()
393
+ if model:
394
+ result = predict_location(image, model)
395
+ if result:
396
+ gps_degrees, location_query, city_geojson, processing_time = result
397
+ location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}"
398
+ st.write(
399
+ f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}")
400
+ if city_geojson:
401
+ display_map(city_geojson, gps_degrees)
402
+ st.write(f"Processing Time (seconds): {processing_time}")
403
+ if check_location_match(location_query, most_common_locations):
404
+ st.success(
405
+ f"The predicted location {location_name} matches one of the most frequently mentioned locations!")
406
+
407
+
408
+ if __name__ == '__main__':
409
+ main()
configs/computer/a100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 1
2
+ progress_bar_refresh_rate: 2
3
+ num_workers: 8
4
+ sync_batchnorm: False
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: 1
configs/computer/cluster-node-a100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 8
2
+ num_workers: 8
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: True
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: ddp
8
+ num_nodes: 1
configs/computer/cluster-node-v100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 4
2
+ num_workers: 10
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: True
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: ddp
8
+ num_nodes: 1
configs/computer/cpu.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: null
2
+ num_workers: 0
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: False
5
+ accelerator: cpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: null
configs/computer/v100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 1
2
+ num_workers: 10
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: False
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: 1
configs/config.yaml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model: default
3
+ - computer: v100
4
+ - dataset: osv5m
5
+ - _self_
6
+ - exp: ???
7
+
8
+ model:
9
+ val_metrics:
10
+ _target_: metrics.distance_based.HaversineMetrics
11
+ acc_radiuses:
12
+ - 1
13
+ - 25
14
+ - 200
15
+ - 750
16
+ - 2500
17
+ acc_area: []
18
+ aux_data: ${aux_data}
19
+ test_metrics:
20
+ _target_: metrics.distance_based.HaversineMetrics
21
+ acc_radiuses:
22
+ - 1
23
+ - 25
24
+ - 200
25
+ - 750
26
+ - 2500
27
+ acc_area: ${areas}
28
+ aux_data: ${aux_data}
29
+
30
+ datamodule:
31
+ _target_: data.datamodule.ImageDataModule
32
+ train_dataset: ${dataset.train_dataset}
33
+ val_dataset: ${dataset.val_dataset}
34
+ test_dataset: ${dataset.test_dataset}
35
+ global_batch_size: ${dataset.global_batch_size}
36
+ num_workers: ${computer.num_workers}
37
+ num_nodes: ${computer.num_nodes}
38
+ num_devices: ${computer.devices}
39
+ val_proportion: 0.1
40
+
41
+ trainer:
42
+ _target_: pytorch_lightning.Trainer
43
+ devices: ${computer.devices}
44
+ accelerator: ${computer.accelerator}
45
+ strategy: ${computer.strategy}
46
+ num_nodes: ${computer.num_nodes}
47
+ precision: ${computer.precision}
48
+ max_epochs: ${max_epochs}
49
+
50
+ logger:
51
+ _target_: pytorch_lightning.loggers.WandbLogger
52
+ save_dir: ${root_dir}
53
+ name: ${experiment_name}
54
+ project: plonk
55
+ log_model: False
56
+ offline: False
57
+ entity: imaginelab
58
+
59
+ checkpoints:
60
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
61
+ dirpath: ${root_dir}/checkpoints/${experiment_name}
62
+ filename: 'epoch_{epoch}'
63
+ monitor: val/loss
64
+ save_last: True
65
+ save_top_k: 0
66
+ every_n_epochs: 1
67
+
68
+ progress_bar:
69
+ _target_: pytorch_lightning.callbacks.TQDMProgressBar
70
+ refresh_rate: ${computer.progress_bar_refresh_rate}
71
+
72
+ aux_data: []
73
+ max_epochs: 100
74
+ data_dir: ${root_dir}/datasets
75
+ root_dir: ${hydra:runtime.cwd}
76
+ experiment_name: ${dataset.name}__${model.name}
77
+ mode: train # change that to eval to do the testing
78
+ num_classes: 0
79
+ areas: ['country', 'region', 'sub-region', 'city']
80
+ class_name: null
81
+ streetclip: False
82
+ blur: False
83
+ text_tuning: False
84
+
85
+ hydra:
86
+ run:
87
+ dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name}
88
+ job:
89
+ chdir: true
configs/dataset/baselines/im2gps.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: im2gps
3
+ global_batch_size: 512
4
+ test_dataset:
5
+ _partial_: true
6
+ _target_: data.data.Baseline
7
+ path: ${data_dir}/baselines/im2gps
8
+ which: 'im2gps'
9
+ transforms: ${dataset.test_transform}
10
+ datamodule:
11
+ _target_: data.datamodule.BaselineDataModule
12
+ test_dataset: ${dataset.test_dataset}
13
+ global_batch_size: ${dataset.global_batch_size}
14
+ num_workers: ${computer.num_workers}
15
+ num_nodes: ${computer.num_nodes}
16
+ num_devices: ${computer.devices}
configs/dataset/baselines/im2gps3k.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: im2gps3k
3
+ global_batch_size: 512
4
+ test_dataset:
5
+ _partial_: true
6
+ _target_: data.data.Baseline
7
+ path: ${data_dir}/baselines/im2gps3k
8
+ which: 'im2gps3k'
9
+ transforms: ${dataset.test_transform}
10
+ datamodule:
11
+ _target_: data.datamodule.BaselineDataModule
12
+ test_dataset: ${dataset.test_dataset}
13
+ global_batch_size: ${dataset.global_batch_size}
14
+ num_workers: ${computer.num_workers}
15
+ num_nodes: ${computer.num_nodes}
16
+ num_devices: ${computer.devices}
configs/dataset/baselines/yfcc4k.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: yfcc4k
3
+ global_batch_size: 512
4
+ test_dataset:
5
+ _partial_: true
6
+ _target_: data.data.Baseline
7
+ path: ${data_dir}/baselines/yfcc4k
8
+ which: 'yfcc4k'
9
+ transforms: ${dataset.test_transform}
10
+ datamodule:
11
+ _target_: data.datamodule.BaselineDataModule
12
+ test_dataset: ${dataset.test_dataset}
13
+ global_batch_size: ${dataset.global_batch_size}
14
+ num_workers: ${computer.num_workers}
15
+ num_nodes: ${computer.num_nodes}
16
+ num_devices: ${computer.devices}
configs/dataset/osv5m.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: fast_clip
3
+ - test_transform: fast_clip
4
+ - _self_
5
+
6
+ name: osv5m
7
+ global_batch_size: 256
8
+
9
+ train_dataset:
10
+ _partial_: true
11
+ _target_: data.data.osv5m
12
+ path: ${data_dir}/osv5m/
13
+ split: train
14
+ class_name: ${class_name}
15
+ transforms: ${dataset.train_transform}
16
+ aux_data: ${aux_data}
17
+ is_baseline: ${is_baseline}
18
+ areas: ${areas}
19
+ streetclip: ${streetclip}
20
+ blur: ${blur}
21
+
22
+ val_dataset:
23
+ _partial_: true
24
+ _target_: data.data.osv5m
25
+ path: ${data_dir}/osv5m/
26
+ split: val
27
+ class_name: ${class_name}
28
+ transforms: ${dataset.test_transform}
29
+ aux_data: ${aux_data}
30
+ is_baseline: ${is_baseline}
31
+ areas: ${areas}
32
+ streetclip: ${streetclip}
33
+ blur: ${blur}
34
+
35
+ test_dataset:
36
+ _partial_: true
37
+ _target_: data.data.osv5m
38
+ path: ${data_dir}/osv5m/
39
+ split: test
40
+ class_name: ${class_name}
41
+ transforms: ${dataset.test_transform}
42
+ aux_data: ${aux_data}
43
+ is_baseline: ${is_baseline}
44
+ areas: ${areas}
45
+ streetclip: ${streetclip}
46
+ blur: ${blur}
configs/dataset/osv5m_contrastive.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: fast_clip
3
+ - test_transform: fast_clip
4
+ - _self_
5
+
6
+ name: osv5m
7
+ global_batch_size: 256
8
+
9
+ train_dataset:
10
+ _partial_: true
11
+ _target_: data.data.Contrastiveosv5m
12
+ path: ${data_dir}/osv5m/
13
+ split: train
14
+ class_name: ${class_name}
15
+ transforms: ${dataset.train_transform}
16
+ blur: ${blur}
17
+
18
+ val_dataset:
19
+ _partial_: true
20
+ _target_: data.data.Contrastiveosv5m
21
+ path: ${data_dir}/osv5m/
22
+ split: val
23
+ class_name: ${class_name}
24
+ transforms: ${dataset.test_transform}
25
+ blur: ${blur}
26
+
27
+ test_dataset:
28
+ _partial_: true
29
+ _target_: data.data.Contrastiveosv5m
30
+ path: ${data_dir}/osv5m/
31
+ split: test
32
+ class_name: ${class_name}
33
+ transforms: ${dataset.test_transform}
34
+ blur: ${blur}
configs/dataset/osv5m_contrastive_best.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: fast_clip
3
+ - test_transform: fast_clip
4
+ - _self_
5
+
6
+ name: osv5m
7
+ global_batch_size: 256
8
+
9
+ train_dataset:
10
+ _partial_: true
11
+ _target_: data.data.Contrastiveosv5m
12
+ path: ${data_dir}/osv5m/
13
+ split: train
14
+ class_name: ${class_name}
15
+ transforms: ${dataset.train_transform}
16
+ class_name2: 'unique_region'
17
+ blur: ${blur}
18
+
19
+ val_dataset:
20
+ _partial_: true
21
+ _target_: data.data.Contrastiveosv5m
22
+ path: ${data_dir}/osv5m/
23
+ split: val
24
+ class_name: ${class_name}
25
+ transforms: ${dataset.test_transform}
26
+ class_name2: 'unique_region'
27
+ blur: ${blur}
28
+
29
+ test_dataset:
30
+ _partial_: true
31
+ _target_: data.data.Contrastiveosv5m
32
+ path: ${data_dir}/osv5m/
33
+ split: test
34
+ class_name: ${class_name}
35
+ transforms: ${dataset.test_transform}
36
+ class_name2: 'unique_region'
37
+ blur: ${blur}
configs/dataset/osv5m_text_contrastive.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: fast_clip
3
+ - test_transform: fast_clip
4
+ - _self_
5
+
6
+ name: osv5m
7
+ global_batch_size: 256
8
+
9
+ train_dataset:
10
+ _partial_: true
11
+ _target_: data.data.TextContrastiveosv5m
12
+ path: ${data_dir}/osv5m/
13
+ split: train
14
+ class_name: ${class_name}
15
+ transforms: ${dataset.train_transform}
16
+ blur: ${blur}
17
+
18
+ val_dataset:
19
+ _partial_: true
20
+ _target_: data.data.TextContrastiveosv5m
21
+ path: ${data_dir}/osv5m/
22
+ split: val
23
+ class_name: ${class_name}
24
+ transforms: ${dataset.test_transform}
25
+ blur: ${blur}
26
+
27
+ test_dataset:
28
+ _partial_: true
29
+ _target_: data.data.TextContrastiveosv5m
30
+ path: ${data_dir}/osv5m/
31
+ split: test
32
+ class_name: ${class_name}
33
+ transforms: ${dataset.test_transform}
34
+ blur: ${blur}
configs/dataset/test_transform/center_crop.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.ToTensor
4
+ - _target_: utils.image_processing.CenterCrop
5
+ ratio: "1:1"
6
+ - _target_: torchvision.transforms.Resize
7
+ size: ${dataset.img_resolution}
8
+ interpolation: 3
9
+ antialias: true
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: 0.5
12
+ std: 0.5
configs/dataset/test_transform/clip.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: data.transforms.ClipTransform
2
+ split: val
configs/dataset/test_transform/fast_clip.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.48145466, 0.4578275, 0.40821073]
12
+ std: [0.26862954, 0.26130258, 0.27577711]
configs/dataset/test_transform/fast_resnet.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.485 ,0.456 ,0.406]
12
+ std: [0.229, 0.224, 0.225]
configs/dataset/test_transform/none.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.ToTensor
4
+ - _target_: torchvision.transforms.Normalize
5
+ mean: 0.5
6
+ std: 0.5
configs/dataset/train_transform/augmentation.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: data.augmentation.ImageAugmentation
2
+ names: "standard_augmentation,geometric_augmentation,clip_transform"
3
+
4
+ # always apply clip_transform at the end
5
+ clip_transform:
6
+ _target_: torchvision.transforms.Compose
7
+ transforms:
8
+ - _target_: torchvision.transforms.Resize
9
+ size: 224
10
+ interpolation: 3
11
+ antialias: true
12
+ - _target_: torchvision.transforms.CenterCrop
13
+ size: 224
14
+ - _target_: torchvision.transforms.ToTensor
15
+ - _target_: torchvision.transforms.Normalize
16
+ mean: [0.48145466, 0.4578275, 0.40821073]
17
+ std: [0.26862954, 0.26130258, 0.27577711]
18
+
19
+ standard_augmentation:
20
+ _target_: data.augmentation.StandardAugmentation
21
+ # by default, we all augmentation methods
22
+ names: "brightness,contrast,sharpness,color,blur,gaussian_noise"
23
+
24
+ # random PIL brigtness
25
+ brightness:
26
+ _target_: data.augmentation.PillowBrightness
27
+ p: 0.2
28
+ factor_interval: [0.5, 1.5]
29
+
30
+ # random PIL contrast
31
+ contrast:
32
+ _target_: data.augmentation.PillowContrast
33
+ p: 0.2
34
+ factor_interval: [0.3, 3]
35
+
36
+ # random PIL sharpness
37
+ sharpness:
38
+ _target_: data.augmentation.PillowSharpness
39
+ p: 0.2
40
+ factor_interval: [0.5, 30.0]
41
+
42
+ # random PIL color
43
+ color:
44
+ _target_: data.augmentation.PillowColor
45
+ p: 0.2
46
+ factor_interval: [0.0, 2.0]
47
+
48
+ # random PIL blur
49
+ blur:
50
+ _target_: data.augmentation.PillowBlur
51
+ p: 0.2
52
+ factor_interval: [1, 2]
53
+
54
+ # random numpy gaussian noise
55
+ gaussian_noise:
56
+ _target_: data.augmentation.NumpyGaussianNoise
57
+ p: 0.2
58
+ factor_interval: [0.1, 0.04]
59
+
60
+ geometric_augmentation:
61
+ _target_: data.augmentation.GeometricAugmentation
62
+ # by default, we all augmentation methods
63
+ names: "random_rotation,random_resized_crop,random_horizontal_flip"
64
+
65
+ # random rotation
66
+ random_rotation:
67
+ _target_: torchvision.transforms.RandomRotation
68
+ degrees: [-15, 15]
69
+
70
+ # random crop
71
+ random_resized_crop:
72
+ _target_: torchvision.transforms.RandomResizedCrop
73
+ scale: [0.5, 1.0]
74
+ ratio: [0.9, 1.1]
75
+ size: 224
76
+
77
+ # random horizontal flip
78
+ random_horizontal_flip:
79
+ _target_: torchvision.transforms.RandomHorizontalFlip
80
+ p: 0.5
81
+
82
+ # random vertical flip
83
+ random_vertical_flip:
84
+ _target_: torchvision.transforms.RandomVerticalFlip
85
+ p: 0.5
configs/dataset/train_transform/center_crop.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.ToTensor
4
+ - _target_: utils.image_processing.CenterCrop
5
+ ratio: "1:1"
6
+ - _target_: torchvision.transforms.Resize
7
+ size: ${dataset.img_resolution}
8
+ interpolation: 3
9
+ antialias: true
10
+ - _target_: torchvision.transforms.RandomHorizontalFlip
11
+ p: 0.5
12
+ - _target_: torchvision.transforms.Normalize
13
+ mean: 0.5
14
+ std: 0.5
configs/dataset/train_transform/clip.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: data.transforms.ClipTransform
2
+ split: val
configs/dataset/train_transform/fast_clip.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.48145466, 0.4578275, 0.40821073]
12
+ std: [0.26862954, 0.26130258, 0.27577711]
configs/dataset/train_transform/fast_resnet.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.485 ,0.456 ,0.406]
12
+ std: [0.229, 0.224, 0.225]
configs/dataset/train_transform/none.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.ToTensor
configs/exp/DinoV2.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /model/network/backbone: dinov2_vitl14
6
+ - _self_
7
+
8
+ model:
9
+ optimizer:
10
+ optim:
11
+ lr: 0.0002
12
+ weight_decay: 0.0001
13
+
14
+ is_baseline: false
15
+ max_epochs: 30
16
+
17
+ dataset:
18
+ global_batch_size: 2048
configs/exp/ResNet.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /dataset/test_transform: fast_resnet
6
+ - override /dataset/train_transform: fast_resnet
7
+ - override /model.network.mid: mlp_resnet
8
+ - override /model/network/backbone: ResNet50
9
+ - _self_
10
+
11
+ model:
12
+ optimizer:
13
+ optim:
14
+ lr: 0.0002
15
+ weight_decay: 0.0001
16
+
17
+ is_baseline: false
18
+ max_epochs: 30
19
+
20
+ dataset:
21
+ global_batch_size: 2048
configs/exp/base_model.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ model:
9
+ name: base_model
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 30
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/best_model.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive_best
5
+ - override /model: hybrid
6
+ - override /model/network: best_backbone
7
+ - override /model/network/backbone: clip_L_14_DataComp
8
+ - override /model/network/mid: mlp_hybrid
9
+ - override /model/loss: best_model
10
+ - _self_
11
+
12
+ class_name: 'quadtree_10_1000'
13
+ is_baseline: false
14
+ max_epochs: 30
15
+
16
+ model:
17
+ name: best_model
18
+ optimizer:
19
+ optim:
20
+ lr: 2e-4
21
+ weight_decay: 0.0001
22
+ backbone_lr: 2e-5
23
+
24
+ dataset:
25
+ global_batch_size: 2048
configs/exp/classification_area.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ class_name: 'area'
9
+ model:
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 15
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/classification_cell.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ class_name: quadtree_10_1000
9
+ model:
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 15
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/classification_cell_hier.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - override /model/loss: cls_hier_quad
7
+ - _self_
8
+
9
+ class_name: quadtree_10_1000
10
+ model:
11
+ optimizer:
12
+ optim:
13
+ lr: 0.0002
14
+ weight_decay: 0.0001
15
+
16
+ is_baseline: false
17
+ max_epochs: 15
18
+
19
+ dataset:
20
+ global_batch_size: 2048
configs/exp/classification_city.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ class_name: 'city'
9
+ model:
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 15
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/classification_city_hier.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - override /model/loss: cls_hier
7
+ - _self_
8
+
9
+ class_name: 'city'
10
+ model:
11
+ optimizer:
12
+ optim:
13
+ lr: 0.0002
14
+ weight_decay: 0.0001
15
+
16
+ is_baseline: false
17
+ max_epochs: 15
18
+
19
+ dataset:
20
+ global_batch_size: 2048
configs/exp/classification_country.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ class_name: 'country'
9
+ model:
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 15
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/classification_region copy.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ class_name: 'region'
9
+ model:
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 15
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/classification_region.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ class_name: 'region'
9
+ model:
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 15
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/clip_L_14_DataComp.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /model/network/backbone: clip_L_14_DataComp
6
+ - _self_
7
+
8
+ model:
9
+ optimizer:
10
+ optim:
11
+ lr: 0.0002
12
+ weight_decay: 0.0001
13
+
14
+ is_baseline: false
15
+ max_epochs: 30
16
+
17
+ dataset:
18
+ global_batch_size: 2048
configs/exp/clip_L_14_Laion.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /model/network/backbone: openclip_L_14
6
+ - _self_
7
+
8
+ model:
9
+ optimizer:
10
+ optim:
11
+ lr: 0.0002
12
+ weight_decay: 0.0001
13
+
14
+ is_baseline: false
15
+ max_epochs: 30
16
+
17
+ dataset:
18
+ global_batch_size: 2048
configs/exp/clip_L_14_OpenAI.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /model/network/backbone: clip_L_14
6
+ - _self_
7
+
8
+ model:
9
+ optimizer:
10
+ optim:
11
+ lr: 0.0002
12
+ weight_decay: 0.0001
13
+
14
+ is_baseline: false
15
+ max_epochs: 30
16
+
17
+ dataset:
18
+ global_batch_size: 2048
configs/exp/clip_bigG_14_Laion.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /model/network/backbone: openclip_bigG_14
6
+ - _self_
7
+
8
+ model:
9
+ optimizer:
10
+ optim:
11
+ lr: 0.0002
12
+ weight_decay: 0.0001
13
+
14
+ is_baseline: false
15
+ max_epochs: 30
16
+
17
+ dataset:
18
+ global_batch_size: 2048
configs/exp/contrastive_area.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive
5
+ - override /model: regression
6
+ - override /model/network: contrastive_unfrozen_backbone
7
+ - override /model/network/backbone: openclip_B_32
8
+ - override /model/loss: contrastive
9
+ - _self_
10
+
11
+ model:
12
+ optimizer:
13
+ optim:
14
+ lr: 2e-4
15
+ weight_decay: 0.0001
16
+ backbone_lr: 2e-5
17
+
18
+ class_name: area
19
+ is_baseline: false
20
+ max_epochs: 30
configs/exp/contrastive_cell.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive
5
+ - override /model: regression
6
+ - override /model/network: contrastive_unfrozen_backbone
7
+ - override /model/network/backbone: openclip_B_32
8
+ - override /model/loss: contrastive
9
+ - _self_
10
+
11
+ model:
12
+ optimizer:
13
+ optim:
14
+ lr: 2e-4
15
+ weight_decay: 0.0001
16
+ backbone_lr: 2e-5
17
+
18
+ class_name: quadtree_10_1000
19
+ is_baseline: false
20
+ max_epochs: 30
configs/exp/contrastive_city.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive
5
+ - override /model: regression
6
+ - override /model/network: contrastive_unfrozen_backbone
7
+ - override /model/network/backbone: openclip_B_32
8
+ - override /model/loss: contrastive
9
+ - _self_
10
+
11
+ model:
12
+ optimizer:
13
+ optim:
14
+ lr: 2e-4
15
+ weight_decay: 0.0001
16
+ backbone_lr: 2e-5
17
+
18
+ class_name: city
19
+ is_baseline: false
20
+ max_epochs: 30
configs/exp/contrastive_country.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive
5
+ - override /model: regression
6
+ - override /model/network: contrastive_unfrozen_backbone
7
+ - override /model/network/backbone: openclip_B_32
8
+ - override /model/loss: contrastive
9
+ - _self_
10
+
11
+ model:
12
+ optimizer:
13
+ optim:
14
+ lr: 2e-4
15
+ weight_decay: 0.0001
16
+ backbone_lr: 2e-5
17
+
18
+ class_name: country
19
+ is_baseline: false
20
+ max_epochs: 30
configs/exp/contrastive_region.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive
5
+ - override /model: regression
6
+ - override /model/network: contrastive_unfrozen_backbone
7
+ - override /model/network/backbone: openclip_B_32
8
+ - override /model/loss: contrastive
9
+ - _self_
10
+
11
+ model:
12
+ optimizer:
13
+ optim:
14
+ lr: 2e-4
15
+ weight_decay: 0.0001
16
+ backbone_lr: 2e-5
17
+
18
+ class_name: region
19
+ is_baseline: false
20
+ max_epochs: 30
configs/exp/contrastive_text.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_text_contrastive
5
+ - override /model: text_tuning
6
+ - override /model/network/backbone: openclip_B_32
7
+ - _self_
8
+
9
+ model:
10
+ network:
11
+ backbone:
12
+ instance:
13
+ _target_: models.networks.backbones.CLIPText
14
+ optimizer:
15
+ optim:
16
+ lr: 0.0002
17
+ weight_decay: 0.0001
18
+
19
+ is_baseline: false
20
+ class_name: city
21
+ text_tuning: True
22
+ max_epochs: 30
configs/exp/eval_best_model.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive_best
5
+ - override /model: hybrid
6
+ - override /model/network: best_backbone
7
+ - override /model/network/backbone: clip_L_14_DataComp
8
+ - override /model/network/mid: mlp_hybrid
9
+ - _self_
10
+
11
+ class_name: 'quadtree_10_1000'
12
+ is_baseline: false
13
+ max_epochs: 30
14
+ mode: 'eval'
15
+
16
+ model:
17
+ name: best_model
18
+ optimizer:
19
+ optim:
20
+ lr: 2e-4
21
+ weight_decay: 0.0001
22
+ backbone_lr: 2e-5
23
+ network:
24
+ head:
25
+ instance:
26
+ quadtree_path: ${root_dir}/quadtree_10_1000.csv
27
+
28
+ dataset:
29
+ global_batch_size: 2048
configs/exp/fine_tuning.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /model/network: unfrozen_backbone
6
+ - override /model/network/backbone: openclip_B_32
7
+ - _self_
8
+
9
+ model:
10
+ optimizer:
11
+ optim:
12
+ lr: 2e-4
13
+ weight_decay: 0.0001
14
+ backbone_lr: 2e-5
15
+
16
+ is_baseline: false
17
+ max_epochs: 30
18
+
19
+ dataset:
20
+ global_batch_size: 2048