yunusserhat's picture
Create APP
894bc0c verified
import torch
from geoclip import GeoCLIP
from PIL import Image
import tempfile
from pathlib import Path
import gradio as gr
import spaces
from geopy.geocoders import Nominatim
from transformers import CLIPProcessor, CLIPModel
from torchvision import transforms
import reverse_geocoder as rg
from models.huggingface import Geolocalizer
import folium
import json
from geopy.exc import GeocoderTimedOut
if torch.cuda.is_available():
geoclip_model = GeoCLIP().to("cuda")
else:
geoclip_model = GeoCLIP()
geolocator = Nominatim(user_agent="predictGeolocforImage")
streetclip_model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
streetclip_processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']
IMAGE_SIZE = (224, 224)
GEOLOC_MODEL_NAME = "osv5m/baseline"
geoloc_model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME)
geoloc_model.eval()
def transform_image(image):
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform(image).unsqueeze(0)
def create_map(lat, lon):
m = folium.Map(location=[lat, lon], zoom_start=4)
folium.Marker([lat, lon]).add_to(m)
map_html = m._repr_html_()
return map_html
def get_country_coordinates(country_name):
try:
location = geolocator.geocode(country_name, timeout=10)
if location:
return location.latitude, location.longitude
except GeocoderTimedOut:
return None
return None
@spaces.GPU
def predict_geoclip(image):
with tempfile.TemporaryDirectory() as tmp_dir:
tmppath = Path(tmp_dir) / "tmp.jpg"
image.save(str(tmppath))
top_pred_gps, top_pred_prob = geoclip_model.predict(str(tmppath), top_k=50)
predictions = []
for i in range(1):
lat, lon = top_pred_gps[i]
probpercent = top_pred_prob[i] * 100
location = geolocator.reverse((lat, lon), exactly_one=True)
address = location.raw['address']
city = address.get('city', '')
country = address.get('country', '')
prediction = f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {country}"
predictions.append(prediction)
map_html = create_map(lat, lon)
return "\n".join(predictions), map_html
@spaces.GPU
def classify_streetclip(image):
inputs = streetclip_processor(text=labels, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = streetclip_model(**inputs)
logits_per_image = outputs.logits_per_image
prediction = logits_per_image.softmax(dim=1)
confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))}
sorted_confidences = sorted(confidences.items(), key=lambda item: item[1], reverse=True)
top_label, top_confidence = sorted_confidences[0]
coords = get_country_coordinates(top_label)
map_html = create_map(*coords) if coords else "Map not available"
return f"Country: {top_label}", map_html
def infer(image):
try:
img_tensor = transform_image(image)
gps_radians = geoloc_model(img_tensor)
gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist()
lat, lon = gps_degrees[0], gps_degrees[1]
location_query = rg.search((lat, lon))[0]
location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}"
map_html = create_map(lat, lon)
return f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {location_query['admin1']} - {location_query['cc']}", map_html
except Exception as e:
return f"Failed to predict the location: {e}", None
geoclip_interface = gr.Interface(
fn=predict_geoclip,
inputs=gr.Image(type="pil", label="Upload Image", elem_id="geoclip_image_input"),
outputs=[gr.Textbox(label="Prediction", elem_id="geoclip_output"), gr.HTML(label="Map", elem_id="geoclip_map_output")],
title="GeoCLIP"
)
streetclip_interface = gr.Interface(
fn=classify_streetclip,
inputs=gr.Image(type="pil", label="Upload Image", elem_id="streetclip_image_input"),
outputs=[gr.Textbox(label="Prediction", elem_id="streetclip_output"), gr.HTML(label="Map", elem_id="streetclip_map_output")],
title="StreetCLIP"
)
osv5m_interface = gr.Interface(
fn=infer,
inputs=gr.Image(label="Upload Image", type="pil", elem_id="osv5m_image_input"),
outputs=[gr.Textbox(label="Prediction", elem_id="result_text"), gr.HTML(label="Map", elem_id="map_output")],
title="OSV-5M Baseline"
)
demo = gr.TabbedInterface([geoclip_interface, streetclip_interface, osv5m_interface], tab_names=["GeoCLIP", "StreetCLIP", "OSV-5M Baseline"])
demo.launch()