Spaces:
Running
Running
import torch | |
from geoclip import GeoCLIP | |
from PIL import Image | |
import tempfile | |
from pathlib import Path | |
import gradio as gr | |
import spaces | |
from geopy.geocoders import Nominatim | |
from transformers import CLIPProcessor, CLIPModel | |
from torchvision import transforms | |
import reverse_geocoder as rg | |
from models.huggingface import Geolocalizer | |
import folium | |
import json | |
from geopy.exc import GeocoderTimedOut | |
if torch.cuda.is_available(): | |
geoclip_model = GeoCLIP().to("cuda") | |
else: | |
geoclip_model = GeoCLIP() | |
geolocator = Nominatim(user_agent="predictGeolocforImage") | |
streetclip_model = CLIPModel.from_pretrained("geolocal/StreetCLIP") | |
streetclip_processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") | |
labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay'] | |
IMAGE_SIZE = (224, 224) | |
GEOLOC_MODEL_NAME = "osv5m/baseline" | |
geoloc_model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME) | |
geoloc_model.eval() | |
def transform_image(image): | |
transform = transforms.Compose([ | |
transforms.Resize(IMAGE_SIZE), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
return transform(image).unsqueeze(0) | |
def create_map(lat, lon): | |
m = folium.Map(location=[lat, lon], zoom_start=4) | |
folium.Marker([lat, lon]).add_to(m) | |
map_html = m._repr_html_() | |
return map_html | |
def get_country_coordinates(country_name): | |
try: | |
location = geolocator.geocode(country_name, timeout=10) | |
if location: | |
return location.latitude, location.longitude | |
except GeocoderTimedOut: | |
return None | |
return None | |
def predict_geoclip(image): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
tmppath = Path(tmp_dir) / "tmp.jpg" | |
image.save(str(tmppath)) | |
top_pred_gps, top_pred_prob = geoclip_model.predict(str(tmppath), top_k=50) | |
predictions = [] | |
for i in range(1): | |
lat, lon = top_pred_gps[i] | |
probpercent = top_pred_prob[i] * 100 | |
location = geolocator.reverse((lat, lon), exactly_one=True) | |
address = location.raw['address'] | |
city = address.get('city', '') | |
country = address.get('country', '') | |
prediction = f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {country}" | |
predictions.append(prediction) | |
map_html = create_map(lat, lon) | |
return "\n".join(predictions), map_html | |
def classify_streetclip(image): | |
inputs = streetclip_processor(text=labels, images=image, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
outputs = streetclip_model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
prediction = logits_per_image.softmax(dim=1) | |
confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))} | |
sorted_confidences = sorted(confidences.items(), key=lambda item: item[1], reverse=True) | |
top_label, top_confidence = sorted_confidences[0] | |
coords = get_country_coordinates(top_label) | |
map_html = create_map(*coords) if coords else "Map not available" | |
return f"Country: {top_label}", map_html | |
def infer(image): | |
try: | |
img_tensor = transform_image(image) | |
gps_radians = geoloc_model(img_tensor) | |
gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist() | |
lat, lon = gps_degrees[0], gps_degrees[1] | |
location_query = rg.search((lat, lon))[0] | |
location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}" | |
map_html = create_map(lat, lon) | |
return f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {location_query['admin1']} - {location_query['cc']}", map_html | |
except Exception as e: | |
return f"Failed to predict the location: {e}", None | |
geoclip_interface = gr.Interface( | |
fn=predict_geoclip, | |
inputs=gr.Image(type="pil", label="Upload Image", elem_id="geoclip_image_input"), | |
outputs=[gr.Textbox(label="Prediction", elem_id="geoclip_output"), gr.HTML(label="Map", elem_id="geoclip_map_output")], | |
title="GeoCLIP" | |
) | |
streetclip_interface = gr.Interface( | |
fn=classify_streetclip, | |
inputs=gr.Image(type="pil", label="Upload Image", elem_id="streetclip_image_input"), | |
outputs=[gr.Textbox(label="Prediction", elem_id="streetclip_output"), gr.HTML(label="Map", elem_id="streetclip_map_output")], | |
title="StreetCLIP" | |
) | |
osv5m_interface = gr.Interface( | |
fn=infer, | |
inputs=gr.Image(label="Upload Image", type="pil", elem_id="osv5m_image_input"), | |
outputs=[gr.Textbox(label="Prediction", elem_id="result_text"), gr.HTML(label="Map", elem_id="map_output")], | |
title="OSV-5M Baseline" | |
) | |
demo = gr.TabbedInterface([geoclip_interface, streetclip_interface, osv5m_interface], tab_names=["GeoCLIP", "StreetCLIP", "OSV-5M Baseline"]) | |
demo.launch() | |