import torch from geoclip import GeoCLIP from PIL import Image import tempfile from pathlib import Path import gradio as gr import spaces from geopy.geocoders import Nominatim from transformers import CLIPProcessor, CLIPModel from torchvision import transforms import reverse_geocoder as rg from models.huggingface import Geolocalizer import folium import json from geopy.exc import GeocoderTimedOut if torch.cuda.is_available(): geoclip_model = GeoCLIP().to("cuda") else: geoclip_model = GeoCLIP() geolocator = Nominatim(user_agent="predictGeolocforImage") streetclip_model = CLIPModel.from_pretrained("geolocal/StreetCLIP") streetclip_processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay'] IMAGE_SIZE = (224, 224) GEOLOC_MODEL_NAME = "osv5m/baseline" geoloc_model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME) geoloc_model.eval() def transform_image(image): transform = transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) return transform(image).unsqueeze(0) def create_map(lat, lon): m = folium.Map(location=[lat, lon], zoom_start=4) folium.Marker([lat, lon]).add_to(m) map_html = m._repr_html_() return map_html def get_country_coordinates(country_name): try: location = geolocator.geocode(country_name, timeout=10) if location: return location.latitude, location.longitude except GeocoderTimedOut: return None return None @spaces.GPU def predict_geoclip(image): with tempfile.TemporaryDirectory() as tmp_dir: tmppath = Path(tmp_dir) / "tmp.jpg" image.save(str(tmppath)) top_pred_gps, top_pred_prob = geoclip_model.predict(str(tmppath), top_k=50) predictions = [] for i in range(1): lat, lon = top_pred_gps[i] probpercent = top_pred_prob[i] * 100 location = geolocator.reverse((lat, lon), exactly_one=True) address = location.raw['address'] city = address.get('city', '') country = address.get('country', '') prediction = f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {country}" predictions.append(prediction) map_html = create_map(lat, lon) return "\n".join(predictions), map_html @spaces.GPU def classify_streetclip(image): inputs = streetclip_processor(text=labels, images=image, return_tensors="pt", padding=True) with torch.no_grad(): outputs = streetclip_model(**inputs) logits_per_image = outputs.logits_per_image prediction = logits_per_image.softmax(dim=1) confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))} sorted_confidences = sorted(confidences.items(), key=lambda item: item[1], reverse=True) top_label, top_confidence = sorted_confidences[0] coords = get_country_coordinates(top_label) map_html = create_map(*coords) if coords else "Map not available" return f"Country: {top_label}", map_html def infer(image): try: img_tensor = transform_image(image) gps_radians = geoloc_model(img_tensor) gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist() lat, lon = gps_degrees[0], gps_degrees[1] location_query = rg.search((lat, lon))[0] location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}" map_html = create_map(lat, lon) return f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {location_query['admin1']} - {location_query['cc']}", map_html except Exception as e: return f"Failed to predict the location: {e}", None geoclip_interface = gr.Interface( fn=predict_geoclip, inputs=gr.Image(type="pil", label="Upload Image", elem_id="geoclip_image_input"), outputs=[gr.Textbox(label="Prediction", elem_id="geoclip_output"), gr.HTML(label="Map", elem_id="geoclip_map_output")], title="GeoCLIP" ) streetclip_interface = gr.Interface( fn=classify_streetclip, inputs=gr.Image(type="pil", label="Upload Image", elem_id="streetclip_image_input"), outputs=[gr.Textbox(label="Prediction", elem_id="streetclip_output"), gr.HTML(label="Map", elem_id="streetclip_map_output")], title="StreetCLIP" ) osv5m_interface = gr.Interface( fn=infer, inputs=gr.Image(label="Upload Image", type="pil", elem_id="osv5m_image_input"), outputs=[gr.Textbox(label="Prediction", elem_id="result_text"), gr.HTML(label="Map", elem_id="map_output")], title="OSV-5M Baseline" ) demo = gr.TabbedInterface([geoclip_interface, streetclip_interface, osv5m_interface], tab_names=["GeoCLIP", "StreetCLIP", "OSV-5M Baseline"]) demo.launch()