File size: 12,239 Bytes
fac3244
 
2f1f8ac
fac3244
2f1f8ac
 
 
 
c714447
2f1f8ac
 
fac3244
2f1f8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fac3244
2f1f8ac
fac3244
2f1f8ac
 
fac3244
2f1f8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fac3244
2f1f8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fac3244
2f1f8ac
 
fac3244
2f1f8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fac3244
 
2f1f8ac
 
 
 
 
 
 
 
 
 
 
 
 
fac3244
 
 
 
2f1f8ac
fac3244
2f1f8ac
fac3244
2f1f8ac
 
fac3244
 
2f1f8ac
 
 
 
 
fac3244
2f1f8ac
 
 
fac3244
2f1f8ac
 
 
 
 
 
 
 
 
fac3244
2f1f8ac
 
 
 
 
 
 
 
 
 
 
fac3244
2f1f8ac
 
 
fac3244
 
 
 
2f1f8ac
 
 
 
fac3244
2f1f8ac
 
 
 
 
 
 
 
 
 
5c11fb8
fac3244
2f1f8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fac3244
 
2f1f8ac
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import base64
import io
import os
import numpy as np
from pathlib import Path
from plonk.pipe import PlonkPipeline
import random

# Global variable to store the model
model = None

# Real PLONK predictions for production deployment  
MOCK_MODE = False  # Set to True for testing with mock data

def load_plonk_model():
    """
    Load the PLONK model.
    """
    global model
    if model is None:
        print("Loading PLONK_YFCC model...")
        model = PlonkPipeline(model_path="nicolas-dufour/PLONK_YFCC")
        print("Model loaded successfully!")
    return model

def mock_plonk_prediction():
    """
    Mock PLONK prediction - returns realistic coordinates
    Used only when MOCK_MODE = True
    """
    # Sample realistic coordinates from major cities/regions
    mock_locations = [
        (40.7128, -74.0060),   # New York
        (34.0522, -118.2437),  # Los Angeles  
        (51.5074, -0.1278),    # London
        (48.8566, 2.3522),     # Paris
        (35.6762, 139.6503),   # Tokyo
        (37.7749, -122.4194),  # San Francisco
        (41.8781, -87.6298),   # Chicago
        (25.7617, -80.1918),   # Miami
        (45.5017, -73.5673),   # Montreal
        (52.5200, 13.4050),    # Berlin
        (-33.8688, 151.2093),  # Sydney
        (19.4326, -99.1332),   # Mexico City
    ]
    
    # Add some randomness to make it more realistic
    base_lat, base_lon = random.choice(mock_locations)
    lat = base_lat + random.uniform(-2, 2)  # Add noise within ~200km
    lon = base_lon + random.uniform(-2, 2)
    
    return lat, lon

def real_plonk_prediction(image):
    """
    Real PLONK prediction using the diff-plonk package
    Now generates 32 samples for better uncertainty estimation
    """
    from plonk.pipe import PlonkPipeline
    import numpy as np
    
    # Load the model (do this once at startup, not per request)
    if not hasattr(gr, 'plonk_pipeline'):
        print("Loading PLONK model...")
        gr.plonk_pipeline = PlonkPipeline(model_path="nicolas-dufour/PLONK_YFCC")
        print("PLONK model loaded successfully!")
    
    # Get 32 predictions for uncertainty estimation
    predicted_gps = gr.plonk_pipeline(image, batch_size=32, cfg=2.0, num_steps=32)
    
    # Convert to numpy for easier processing
    predictions = predicted_gps.cpu().numpy()  # Shape: (32, 2)
    
    # Calculate statistics
    mean_lat = float(np.mean(predictions[:, 0]))
    mean_lon = float(np.mean(predictions[:, 1]))
    std_lat = float(np.std(predictions[:, 0]))
    std_lon = float(np.std(predictions[:, 1]))
    
    # Calculate uncertainty radius (approximate)
    uncertainty_km = np.sqrt(std_lat**2 + std_lon**2) * 111.32  # Rough conversion to km
    
    return mean_lat, mean_lon, uncertainty_km, len(predictions)

def predict_location(image):
    """
    Main prediction function for Gradio interface
    """
    try:
        if image is None:
            return "Please upload an image."
        
        # Ensure RGB format
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Get prediction (mock or real)
        if MOCK_MODE:
            lat, lon = mock_plonk_prediction()
            confidence = "mock"
            uncertainty_km = None
            num_samples = 1
            note = " (Mock prediction for testing)"
        else:
            lat, lon, uncertainty_km, num_samples = real_plonk_prediction(image)
            confidence = "high"
            note = f" (Real PLONK prediction, {num_samples} samples)"
        
        # Format the result
        uncertainty_text = f"\n**Uncertainty:** ยฑ{uncertainty_km:.1f} km" if uncertainty_km is not None else ""
        
        result = f"""๐Ÿ—บ๏ธ **Predicted Location**{note}

**Latitude:** {lat:.6f}
**Longitude:** {lon:.6f}{uncertainty_text}

**Confidence:** {confidence}
**Samples:** {num_samples}
**Mode:** {'๐Ÿงช Mock Testing' if MOCK_MODE else '๐Ÿš€ Production'}

๐ŸŒ *This prediction estimates where the image was taken based on visual content.*
"""
        
        return result
        
    except Exception as e:
        return f"โŒ Error processing image: {str(e)}"

def predict_location_json(image):
    """
    JSON API function for programmatic access
    Returns structured data instead of formatted text
    """
    try:
        if image is None:
            return {
                "error": "No image provided",
                "status": "error"
            }
        
        # Ensure RGB format
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Get prediction (mock or real)
        if MOCK_MODE:
            lat, lon = mock_plonk_prediction()
            confidence = "mock"
            uncertainty_km = None
            num_samples = 1
        else:
            lat, lon, uncertainty_km, num_samples = real_plonk_prediction(image)
            confidence = "high"
        
        result = {
            "status": "success",
            "mode": "mock" if MOCK_MODE else "production",
            "predicted_location": {
                "latitude": round(lat, 6),
                "longitude": round(lon, 6)
            },
            "confidence": confidence,
            "samples": num_samples,
            "note": "This is a mock prediction for testing" if MOCK_MODE else f"Real PLONK prediction using {num_samples} samples"
        }
        
        # Add uncertainty info if available
        if uncertainty_km is not None:
            result["uncertainty_km"] = round(uncertainty_km, 1)
        
        return result
        
    except Exception as e:
        return {
            "error": str(e),
            "status": "error"
        }

# Create the Gradio interface
with gr.Blocks(
    theme=gr.themes.Soft(),
    title="๐Ÿ—บ๏ธ PLONK: Around the World in 80 Timesteps"
) as demo:
    
    # Header
    gr.Markdown("""
    # ๐Ÿ—บ๏ธ PLONK: Around the World in 80 Timesteps
    
    A generative approach to global visual geolocation. Upload an image and PLONK will predict where it was taken!
    
    This uses the PLONK model concept from the paper: *"Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation"*
    
    **Current Mode:** {'๐Ÿงช Mock Testing' if MOCK_MODE else '๐Ÿš€ Production'} - Real PLONK model predictions with 32 samples for uncertainty estimation.
    **Configuration:** Guidance Scale = 2.0, Samples = 32, Steps = 32
    """)
    
    with gr.Tab("๐Ÿ–ผ๏ธ Image Upload"):
        with gr.Row():
            with gr.Column(scale=1):
                image_input = gr.Image(
                    label="Upload an image",
                    type="pil",
                    sources=["upload", "webcam", "clipboard"]
                )
                
                predict_btn = gr.Button(
                    "๐Ÿ” Predict Location",
                    variant="primary",
                    size="lg"
                )
                
                clear_btn = gr.ClearButton(
                    components=[image_input],
                    value="๐Ÿ—‘๏ธ Clear"
                )
            
            with gr.Column(scale=1):
                output_text = gr.Markdown(
                    label="Prediction Result",
                    value="Upload an image and click 'Predict Location' to see results."
                )
    
    with gr.Tab("๐Ÿ“ก API Information"):
        gr.Markdown(f"""
        ## ๐Ÿ”— API Access
        
        This Space provides both web interface and programmatic API access:
        
        ### **REST API Endpoint**
        ```
        POST https://kylanoconnor-plonk-geolocation.hf.space/api/predict
        ```
        
        ### **Python Example**
        ```python
        import requests
        
        # For API access
        response = requests.post(
            "https://kylanoconnor-plonk-geolocation.hf.space/api/predict",
            files={{"file": open("image.jpg", "rb")}}
        )
        result = response.json()
        print(f"Location: {{result['data']['latitude']}}, {{result['data']['longitude']}}")
        ```
        
        ### **cURL Example**
        ```bash
        curl -X POST \\
          -F "data=@image.jpg" \\
          "https://kylanoconnor-plonk-geolocation.hf.space/api/predict"
        ```
        
        ### **Gradio Client (Python)**
        ```python
        from gradio_client import Client
        
        client = Client("kylanoconnor/plonk-geolocation")
        result = client.predict("path/to/image.jpg", api_name="/predict")
        print(result)
        ```
        
        ### **JavaScript/Node.js**
        ```javascript
        const formData = new FormData();
        formData.append('data', imageFile);
        
        const response = await fetch(
            'https://kylanoconnor-plonk-geolocation.hf.space/api/predict',
            {{
                method: 'POST',
                body: formData
            }}
        );
        
        const result = await response.json();
        console.log('Location:', result.data);
        ```
        
        **Current Status:** {'๐Ÿงช Mock Mode - Returns realistic test coordinates' if MOCK_MODE else '๐Ÿš€ Production Mode - Real PLONK predictions with 32 samples'}
        
        **Response Format:** 
        - Latitude/Longitude coordinates
        - Uncertainty estimation (ยฑkm radius)
        - Number of samples used (32 for production)
        - Prediction confidence metrics
        
        **Rate Limits:** Standard Hugging Face Spaces limits apply
        
        **CORS:** Enabled for web integration
        """)
    
    with gr.Tab("โ„น๏ธ About"):
        gr.Markdown(f"""
        ## About PLONK
        
        PLONK is a generative approach to global visual geolocation that uses diffusion models to predict where images were taken.
        
        **Paper:** [Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation](https://arxiv.org/abs/2412.06781)
        
        **Authors:** Nicolas Dufour, David Picard, Vicky Kalogeiton, Loic Landrieu
        
        **Original Code:** https://github.com/nicolas-dufour/plonk
        
        ### Current Deployment
        - **Mode:** {'Mock Testing' if MOCK_MODE else 'Production'}
        - **Model:** {'Simulated predictions for API testing' if MOCK_MODE else 'Real PLONK model inference'}
        - **Response Format:** Structured JSON + formatted text
        - **API:** Fully functional REST endpoints
        
        ### Production Deployment
        This Space is running with the real PLONK model using:
        - **Model:** nicolas-dufour/PLONK_YFCC
        - **Dataset:** YFCC-100M
        - **Inference:** CFG=2.0, 32 samples, 32 timesteps for high quality predictions
        - **Uncertainty:** Statistical analysis across 32 predictions for reliability estimation
        
        ### Available Models
        - `nicolas-dufour/PLONK_YFCC` - YFCC-100M dataset
        - `nicolas-dufour/PLONK_iNaturalist` - iNaturalist dataset
        - `nicolas-dufour/PLONK_OSV_5M` - OpenStreetView-5M dataset
        """)
    
    # Event handlers
    predict_btn.click(
        fn=predict_location,
        inputs=[image_input],
        outputs=[output_text],
        api_name="predict"  # This enables API access at /api/predict
    )
    
    # Hidden API function for JSON responses
    predict_json = gr.Interface(
        fn=predict_location_json,
        inputs=gr.Image(type="pil"),
        outputs=gr.JSON(),
        api_name="predict_json"  # Available at /api/predict_json
    )
    
    # Add examples if available
    try:
        examples = [
            ["demo/examples/condor.jpg"],
            ["demo/examples/Kilimanjaro.jpg"],
            ["demo/examples/pigeon.png"]
        ]
        gr.Examples(
            examples=examples,
            inputs=image_input,
            outputs=output_text,
            fn=predict_location,
            cache_examples=True
        )
    except:
        pass  # Examples not available, skip

if __name__ == "__main__":
    # For local testing
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_api=True
    )