Spaces:
Sleeping
Sleeping
shreyas2509
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from pathlib import Path
|
4 |
+
import gradio as gr
|
5 |
+
from transformers import CLIPProcessor, CLIPModel
|
6 |
+
from torchvision import transforms
|
7 |
+
import reverse_geocoder as rg
|
8 |
+
import folium
|
9 |
+
from geopy.exc import GeocoderTimedOut
|
10 |
+
from geopy.geocoders import Nominatim
|
11 |
+
|
12 |
+
# streetclip_model = CLIPModel.from_pretrained("E:/github projects/Country Classification/GeolocationCountryClassification/")
|
13 |
+
model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
|
14 |
+
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
|
15 |
+
labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']
|
16 |
+
|
17 |
+
def create_map(lat, lon):
|
18 |
+
m = folium.Map(location=[lat, lon], zoom_start=4)
|
19 |
+
folium.Marker([lat, lon]).add_to(m)
|
20 |
+
map_html = m._repr_html_()
|
21 |
+
return map_html
|
22 |
+
|
23 |
+
geolocator = Nominatim(user_agent="predictGeolocforImage")
|
24 |
+
|
25 |
+
def get_country_coordinates(country_name):
|
26 |
+
try:
|
27 |
+
location = geolocator.geocode(country_name, timeout=10)
|
28 |
+
if location:
|
29 |
+
return location.latitude, location.longitude
|
30 |
+
except GeocoderTimedOut:
|
31 |
+
return None
|
32 |
+
return None
|
33 |
+
|
34 |
+
|
35 |
+
def classify_streetclip(image):
|
36 |
+
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
|
37 |
+
with torch.no_grad():
|
38 |
+
outputs = model(**inputs)
|
39 |
+
logits_per_image = outputs.logits_per_image
|
40 |
+
prediction = logits_per_image.softmax(dim=1)
|
41 |
+
confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))}
|
42 |
+
|
43 |
+
sorted_confidences = sorted(confidences.items(), key=lambda item: item[1], reverse=True)
|
44 |
+
top_label, top_confidence = sorted_confidences[0]
|
45 |
+
coords = get_country_coordinates(top_label)
|
46 |
+
map_html = create_map(*coords) if coords else "Map not available"
|
47 |
+
return f"Country: {top_label}", map_html, confidences
|
48 |
+
|
49 |
+
text = '''
|
50 |
+
<b style="color: #F36912;">List of countries supported</b>: Albania, Andorra, Argentina, Australia, Austria, Bangladesh, Belgium, Bermuda, Bhutan, Bolivia, Botswana, Brazil, Bulgaria, Cambodia, Canada, Chile, China, Colombia, Croatia, Czech Republic, Denmark, Dominican Republic, Ecuador, Estonia, Finland, France, Germany, Ghana, Greece, Greenland, Guam, Guatemala, Hungary, Iceland, India, Indonesia, Ireland, Israel, Italy, Japan, Jordan, Kenya, Kyrgyzstan, Laos, Latvia, Lesotho, Lithuania, Luxembourg, Macedonia, Madagascar, Malaysia, Malta, Mexico, Monaco, Mongolia, Montenegro, Netherlands, New Zealand, Nigeria, Norway, Pakistan, Palestine, Peru, Philippines, Poland, Portugal, Puerto Rico, Romania, Russia, Rwanda, Senegal, Serbia, Singapore, Slovakia, Slovenia, South Africa, South Korea, Spain, Sri Lanka, Swaziland, Sweden, Switzerland, Taiwan, Thailand, Tunisia, Turkey, Uganda, Ukraine, United Arab Emirates, United Kingdom, United States, Uruguay
|
51 |
+
</p>
|
52 |
+
---<br>
|
53 |
+
<span style="color: #F24F13;">You may choose to use the images provided below, or feel free to upload your own images.</span>
|
54 |
+
'''
|
55 |
+
|
56 |
+
interface = gr.Interface(
|
57 |
+
fn=classify_streetclip,
|
58 |
+
inputs=gr.Image(type="pil", label="Upload Image", elem_id="image_input"),
|
59 |
+
outputs=[gr.Textbox(label="Prediction", elem_id="output"), gr.HTML(label="Map", elem_id="map_output"), gr.Label(num_top_classes=10,label="Top 10 countries")],
|
60 |
+
title="COUNTRY GUESSER",
|
61 |
+
description=text,
|
62 |
+
article="<span style='color: #F24F13;'>Model is not running on a GPU, so the interpretation takes some time. Thank you for your patience🙏🏻</span>",
|
63 |
+
examples=["taj.jpg","stockholm.jpeg","palace-square-saint-petersburg.jpg","monument.jpg"],
|
64 |
+
allow_flagging="never",
|
65 |
+
)
|
66 |
+
|
67 |
+
interface.launch()
|