ddecosmo's picture
Update app.py
f9e051f verified
# -*- coding: utf-8 -*-
"""Untitled2.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1UcsSFSmZqIdAQTsD_4_CmwwAcAzz0h60
"""
import numpy as np
import pandas as pd
import os
import io
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import folium
import matplotlib.colors
from scipy.stats import gaussian_kde
from PIL import Image
import gradio as gr
import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download, create_repo, file_exists, upload_file
import tempfile
import pathlib
import json
import uuid
import shutil
import zipfile
from datetime import datetime
# Import MultiModalPredictor for the model loading logic
# NOTE: This import assumes 'autogluon.multimodal' is installed in the environment.
try:
from autogluon.multimodal import MultiModalPredictor
AUTOGLUON_IMPORTED = True
except ImportError:
# Set flag to False if the complex dependency is missing
AUTOGLUON_IMPORTED = False
class MultiModalPredictor:
@staticmethod
def load(path):
raise ImportError("AutoGluon MultiModalPredictor is not installed or failed to import.")
# --- 1. CLASSIFICATION CONFIGURATION & MODEL LOADING ---
MODEL_REPO_ID = "ddecosmo/lanternfly_classifier"
ZIP_FILENAME = "autogluon_image_predictor_dir.zip"
MODEL_DIR_NAME = "autogluon_predictor_extracted"
CLASSIFICATION_LABELS = ["Lanternfly", "Other Insect", "Neither"]
PREDICTOR = None
MODEL_STATUS = "Attempting to load model..."
# Robust download and extraction of the AutoGluon model zip file
def _prepare_predictor_dir(repo_id, zip_filename, extract_dir_name) -> str:
"""Downloads the zipped model and extracts it to a clean directory."""
base_extract_dir = os.path.join(os.getcwd(), extract_dir_name)
try:
# 1. Download the zipped model file from Hugging Face Hub
zip_path = hf_hub_download(repo_id=repo_id, filename=zip_filename)
# 2. Prepare directories
if os.path.exists(base_extract_dir):
shutil.rmtree(base_extract_dir)
temp_extract_dir = os.path.join(os.getcwd(), "temp_ag_extract")
os.makedirs(temp_extract_dir, exist_ok=True)
# 3. Extract contents
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(temp_extract_dir)
# 4. Handle nested directory structure (common with zip creation)
extracted_contents = os.listdir(temp_extract_dir)
if len(extracted_contents) == 1 and os.path.isdir(os.path.join(temp_extract_dir, extracted_contents[0])):
final_model_dir = os.path.join(temp_extract_dir, extracted_contents[0])
shutil.move(final_model_dir, base_extract_dir)
shutil.rmtree(temp_extract_dir)
else:
os.rename(temp_extract_dir, base_extract_dir)
return base_extract_dir
except Exception as e:
print(f"Error during model prep: {e}")
return ""
# Global initialization on startup
if AUTOGLUON_IMPORTED:
try:
predictor_dir = _prepare_predictor_dir(MODEL_REPO_ID, ZIP_FILENAME, MODEL_DIR_NAME)
if predictor_dir:
PREDICTOR = MultiModalPredictor.load(predictor_dir)
MODEL_STATUS = f"βœ… Model Active: {MODEL_REPO_ID}"
else:
MODEL_STATUS = "❌ Initialization failed during extraction/download."
except Exception as e:
PREDICTOR = None
MODEL_STATUS = f"❌ Error loading model: {type(e).__name__} (Load Fail)"
else:
MODEL_STATUS = "❌ AutoGluon not imported. Classification tab is disabled."
# Core Lanternfly classification function
def classify_image(img: Image.Image):
"""Predicts the class of the input image using the loaded AutoGluon model."""
if PREDICTOR is None:
return "MODEL FAILED TO LOAD", 0.0, 0.0, 0.0
if img is None:
return "NO IMAGE PROVIDED", 0.0, 0.0, 0.0
final_output = [0.0] * len(CLASSIFICATION_LABELS)
final_result = "PREDICTION FAILED"
# Save image to a temporary path for AutoGluon to read
temp_dir = pathlib.Path(tempfile.mkdtemp())
img_path = temp_dir / "input.png"
img.save(img_path)
try:
df_path = pd.DataFrame({"image": [str(img_path)]})
proba_df = PREDICTOR.predict_proba(df_path, as_pandas=True)
scores_dict = proba_df.iloc[0].to_dict()
# Map scores to the expected order of CLASSIFICATION_LABELS
scores = [float(scores_dict.get(label, 0.0))
for label in CLASSIFICATION_LABELS]
predicted_class_label = max(scores_dict, key=scores_dict.get)
final_output = scores
final_result = f"Predicted Class: **{predicted_class_label}**"
except Exception as e:
final_result = f"CRITICAL PREDICTION FAILURE: {type(e).__name__} - Check AutoGluon dependencies."
finally:
shutil.rmtree(temp_dir)
return final_result, final_output[0], final_output[1], final_output[2]
# --- 2. GPS CAPTURE & SAVE CONFIGURATION & FUNCTIONS ---
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HF_TOKEN_SPACE")
DATASET_REPO = os.getenv("DATASET_REPO", "rlogh/lanternfly-data")
METADATA_PATH = "metadata/entries.jsonl"
api = None
if HF_TOKEN and DATASET_REPO:
api = HfApi(token=HF_TOKEN)
try:
# Ensure the dataset repository exists
create_repo(DATASET_REPO, repo_type="dataset", exist_ok=True, token=HF_TOKEN)
GPS_SAVE_STATUS = "βœ… Dataset saving enabled."
except Exception as e:
GPS_SAVE_STATUS = f"⚠️ Error creating dataset repo: {e}"
api = None
else:
GPS_SAVE_STATUS = "⚠️ Running in test mode - no HF credentials (dataset saving disabled)."
def get_gps_js():
"""JavaScript function to be injected into Gradio to capture GPS coordinates."""
return """
() => {
// Look for the hidden textbox element by its ID
const textarea = document.querySelector('#hidden_gps_input textarea');
if (!textarea) return;
if (!navigator.geolocation) {
textarea.value = JSON.stringify({error: "Geolocation not supported by this browser/device."});
textarea.dispatchEvent(new Event('input', { bubbles: true }));
return;
}
// Request current position
navigator.geolocation.getCurrentPosition(
function(position) {
const data = {
latitude: position.coords.latitude,
longitude: position.coords.longitude,
accuracy: position.coords.accuracy,
timestamp: position.timestamp
};
// Write JSON string to the hidden textbox and trigger a change event
textarea.value = JSON.stringify(data);
textarea.dispatchEvent(new Event('input', { bubbles: true }));
},
function(err) {
textarea.value = JSON.stringify({ error: err.message });
textarea.dispatchEvent(new Event('input', { bubbles: true }));
},
{ enableHighAccuracy: true, timeout: 10000 }
);
}
"""
def handle_gps_location(json_str):
"""Parses the GPS JSON string and updates the Gradio text boxes."""
try:
data = json.loads(json_str)
if 'error' in data:
status_msg = f"❌ **GPS Error**: {data['error']}"
return status_msg, "", "", "", ""
lat = str(data.get('latitude', ''))
lon = str(data.get('longitude', ''))
accuracy = str(data.get('accuracy', ''))
timestamp_ms = data.get('timestamp')
# Convert timestamp (milliseconds since epoch) to ISO string
device_ts = ""
if timestamp_ms and isinstance(timestamp_ms, (int, float)):
device_ts = datetime.fromtimestamp(timestamp_ms / 1000).isoformat()
status_msg = f"βœ… **GPS Captured**: {lat[:8]}, {lon[:8]} (accuracy: {accuracy}m)"
return status_msg, lat, lon, accuracy, device_ts
except Exception as e:
status_msg = f"❌ **Error parsing GPS data**: {str(e)}"
return status_msg, "", "", "", ""
def _save_image_to_repo(pil_img: Image.Image, dest_rel_path: str) -> None:
"""Uploads a PIL image into the dataset repo via a memory buffer."""
img_bytes = io.BytesIO()
pil_img.save(img_bytes, format="JPEG", quality=90)
img_bytes.seek(0)
upload_file(
path_or_fileobj=img_bytes, path_in_repo=dest_rel_path,
repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN,
commit_message=f"Upload image {dest_rel_path}",
)
def _append_jsonl_in_repo(new_row: dict) -> None:
"""Appends a new JSON line to the metadata file in the dataset repo."""
buf = io.BytesIO()
existing_lines = []
try:
# 1. Download existing metadata file if it exists
if file_exists(DATASET_REPO, METADATA_PATH, repo_type="dataset", token=HF_TOKEN):
local_path = hf_hub_download(
repo_id=DATASET_REPO, filename=METADATA_PATH,
repo_type="dataset", token=HF_TOKEN
)
with open(local_path, "r", encoding="utf-8") as f:
existing_lines = f.read().splitlines()
except Exception:
# Ignore download failure if the file doesn't exist yet
pass
# 2. Append the new line
existing_lines.append(json.dumps(new_row, ensure_ascii=False))
data = "\n".join(existing_lines).encode("utf-8")
buf.write(data); buf.seek(0)
# 3. Upload the updated file
upload_file(
path_or_fileobj=buf, path_in_repo=METADATA_PATH,
repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN,
commit_message=f"Append 1 entry at {datetime.now().isoformat()}Z",
)
def save_to_dataset(image, lat, lon, accuracy_m, device_ts):
"""Validates data and saves the image and metadata to the Hugging Face dataset."""
try:
if image is None:
return "❌ **Error**: No image captured.", ""
if not lat or not lon:
return "❌ **Error**: GPS coordinates missing.", ""
# Convert image to PIL if it's a numpy array (common in Gradio)
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'))
# --- Test Mode ---
if not api:
img_id = str(uuid.uuid4())
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
row = {"id": img_id, "image": f"test_{timestamp_str}_{img_id[:8]}.jpg",
"latitude": float(lat), "longitude": float(lon),
"mode": "test"}
status = f"πŸ” **Test Mode**: Data validated successfully! Sample {img_id[:8]}"
preview = json.dumps(row, indent=2)
return status, preview
# --- Production Mode ---
sample_id = str(uuid.uuid4())
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
image_rel_path = f"images/lanternfly_{timestamp_str}_{sample_id[:8]}.jpg"
# 1. Save image
_save_image_to_repo(image, image_rel_path)
server_ts_utc = datetime.now().isoformat() + "Z"
# 2. Prepare and save metadata
row = {
"id": sample_id, "image": image_rel_path,
"latitude": float(lat), "longitude": float(lon),
"accuracy_m": float(accuracy_m) if accuracy_m else None,
"device_timestamp": device_ts if device_ts else None,
"server_timestamp_utc": server_ts_utc,
}
_append_jsonl_in_repo(row)
status = f"βœ… **Success!** Saved to dataset! Image: `{image_rel_path}`"
preview = json.dumps(row, indent=2)
return status, preview
except Exception as e:
error_msg = f"❌ **Error during save**: {str(e)}"
return error_msg, ""
# --- 3. KDE CONFIGURATION & FUNCTIONS (UPDATED FOR LIVE DATA) ---
HUGGINGFACE_DATA_REPO = "rlogh/lanternfly-data"
METADATA_PATH = "metadata/entries.jsonl"
# Define the Pittsburgh coordinate range (used for visualization extent)
pittsburgh_lat_min, pittsburgh_lat_max = 40.3, 40.6
pittsburgh_lon_min, pittsburgh_lon_max = -80.2, -79.8
def load_lanternfly_data_from_hf():
"""Downloads the JSONL metadata file from HF and extracts latitude/longitude."""
try:
# Download the file
local_path = hf_hub_download(
repo_id=HUGGINGFACE_DATA_REPO,
filename=METADATA_PATH,
repo_type="dataset"
)
latitudes = []
longitudes = []
# Parse the JSONL file
with open(local_path, 'r', encoding='utf-8') as f:
for line in f:
try:
data = json.loads(line)
lat = data.get('latitude')
lon = data.get('longitude')
if isinstance(lat, (float, int)) and isinstance(lon, (float, int)):
# Filter points to be within the Pittsburgh area for relevance
if pittsburgh_lat_min <= lat <= pittsburgh_lat_max and \
pittsburgh_lon_min <= lon <= pittsburgh_lon_max:
latitudes.append(lat)
longitudes.append(lon)
except json.JSONDecodeError:
continue # Skip malformed lines
if not latitudes:
return None, None, "Error: Found no valid coordinates in the dataset."
return np.array(latitudes), np.array(longitudes), None
except Exception as e:
return None, None, f"Error downloading or parsing HF data: {type(e).__name__} - {e}"
def calculate_kde_and_points():
"""Loads data, calculates KDE, and prepares data for visualization."""
latitudes, longitudes, error = load_lanternfly_data_from_hf()
if error:
return None, None, None, error
try:
# Combine coordinates into a 2D array for KDE
coordinates = np.vstack([longitudes, latitudes])
# Compute the kernel density estimate
kde_object = gaussian_kde(coordinates)
return latitudes, longitudes, kde_object, None
except Exception as e:
return None, None, None, f"Error calculating KDE: {type(e).__name__} - {e}"
def plot_kde_and_points(min_lat, max_lat, min_lon, max_lon, original_latitudes, original_longitudes, kde_object):
"""Generates an interactive Folium map with points colored by KDE density."""
# --- Folium Interactive Map with Colored Points ---
# 1. Calculate density at each original point
original_coordinates = np.vstack([original_longitudes, original_latitudes])
density_at_original_points = kde_object(original_coordinates)
# Normalize density for coloring
density_normalized = (density_at_original_points - density_at_original_points.min()) / (density_at_original_points.max() - density_at_original_points.min() + 1e-9)
# 2. Setup map
colormap = cm.get_cmap('viridis')
map_center_lat = np.mean(original_latitudes)
map_center_lon = np.mean(original_longitudes)
m_colored_points = folium.Map(location=[map_center_lat, map_center_lon], zoom_start=12)
# 3. Add points to map
for lat, lon, density_norm in zip(original_latitudes, original_longitudes, density_normalized):
color = matplotlib.colors.rgb2hex(colormap(density_norm))
folium.CircleMarker(
location=[lat, lon], radius=5, color=color, fill=True, fill_color=color, fill_opacity=0.7,
tooltip=f"Lat: {lat:.5f}, Lon: {lon:.5f}"
).add_to(m_colored_points)
colored_points_map_html = m_colored_points._repr_html_()
# The original plot_kde_and_points also returned a Matplotlib image, but the Gradio tab was updated to remove it.
# We return None for the image output to match the function signature expected by Gradio.
return None, colored_points_map_html
def update_visualization_live():
"""Main visualization function for the Gradio interface."""
latitudes, longitudes, kde_object, error = calculate_kde_and_points()
if error:
# Return blank outputs and the error message
return None, f"<h1>{error}</h1>", f"Error: {error}"
# Use the predefined Pittsburgh coordinate bounds for the map extent
pil_image, colored_points_map_html = plot_kde_and_points(
pittsburgh_lat_min, pittsburgh_lat_max, pittsburgh_lon_min, pittsburgh_lon_max,
latitudes, longitudes, kde_object
)
# pil_image is None, but the function signature must match the output count
return pil_image, colored_points_map_html, ""
# --- 4. GRADIO INTERFACE (COMBINED) ---
with gr.Blocks(title="Lanternfly Tracking Tool") as app:
gr.Markdown("# Lanternfly Tracking Tool")
with gr.Tab("1. Field Capture & Classification"):
gr.Markdown(f"## πŸ“Έ Lanternfly Classification and GPS Data Capture")
gr.Markdown(f"**Model Status**: {MODEL_STATUS}")
gr.Markdown(f"**GPS Save Status**: {GPS_SAVE_STATUS}")
with gr.Row():
# --- Column 1: Image Input & Classification Output ---
with gr.Column(scale=1):
image_in = gr.Image(
type="pil", label="1. Upload or Capture Image",
value="https://placehold.co/224x224/ff6347/ffffff?text=Lanternfly",
sources=["upload", "webcam"]
)
# Disable classification button if model failed to load
run_classify_btn = gr.Button("πŸ” Run Classification", variant="primary", interactive=PREDICTOR is not None)
gr.Markdown("### Classification Result")
final_result_box = gr.Textbox(label="Prediction Result", interactive=False)
with gr.Row():
conf_0 = gr.Number(label=f"Confidence: {CLASSIFICATION_LABELS[0]}", interactive=False)
conf_1 = gr.Number(label=f"Confidence: {CLASSIFICATION_LABELS[1]}", interactive=False)
conf_2 = gr.Number(label=f"Confidence: {CLASSIFICATION_LABELS[2]}", interactive=False)
# --- Column 2: GPS Capture & Save ---
with gr.Column(scale=1):
gr.Markdown("## πŸ“ GPS Data Capture")
gps_btn = gr.Button("πŸ“ Get GPS", variant="primary")
# Hidden textbox to receive location data from JavaScript
hidden_gps_input = gr.Textbox(visible=False, elem_id="hidden_gps_input")
with gr.Row():
lat_box = gr.Textbox(label="Latitude", interactive=True)
lon_box = gr.Textbox(label="Longitude", interactive=True)
with gr.Row():
accuracy_box = gr.Textbox(label="Accuracy (m)", interactive=True)
device_ts_box = gr.Textbox(label="Device Timestamp", interactive=True)
# Disable save button if HF credentials are missing
save_btn = gr.Button("πŸ’Ύ Save Image & GPS to Dataset", variant="secondary", interactive=api is not None)
gr.Markdown("### Save Status & Preview")
gps_status = gr.Markdown("πŸ”„ **Ready for GPS capture and saving.**")
preview = gr.JSON(label="Preview JSON")
# Handlers for Classification
if PREDICTOR is not None:
run_classify_btn.click(
fn=classify_image,
inputs=[image_in],
outputs=[final_result_box, conf_0, conf_1, conf_2]
)
# Handlers for GPS
gps_btn.click(
fn=None, inputs=[], outputs=[], js=get_gps_js()
)
hidden_gps_input.change(
fn=handle_gps_location,
inputs=[hidden_gps_input],
outputs=[gps_status, lat_box, lon_box, accuracy_box, device_ts_box]
)
save_btn.click(
fn=save_to_dataset,
inputs=[image_in, lat_box, lon_box, accuracy_box, device_ts_box],
outputs=[gps_status, preview]
)
with gr.Tab("2. Spatial Data Visualization (KDE)"):
gr.Markdown("## πŸ—ΊοΈ Kernel Density Estimation of Lanternfly Sightings")
gr.Markdown(f"**Data Source**: {HUGGINGFACE_DATA_REPO} - Automatically loaded from `metadata/entries.jsonl`")
refresh_btn = gr.Button("πŸ”„ Refresh Map from Hugging Face Data", variant="primary")
kde_error_box = gr.Textbox(label="Error/Debug Message", visible=False)
with gr.Row():
interactive_map_out = gr.HTML(label="Interactive Points Map Colored by KDE (Folium)")
# Placeholder for the removed Matplotlib image output (to keep update_visualization_live signature intact)
matplotlib_placeholder = gr.State(value=None)
# Handler for Refresh Button
refresh_btn.click(
fn=update_visualization_live,
inputs=[],
outputs=[matplotlib_placeholder, interactive_map_out, kde_error_box]
)
# Trigger initial load
app.load(
fn=update_visualization_live,
inputs=[],
outputs=[matplotlib_placeholder, interactive_map_out, kde_error_box]
)
# Launch the combined app
if __name__ == "__main__":
app.launch()