|
|
|
|
|
"""Untitled2.ipynb |
|
|
|
|
|
Automatically generated by Colab. |
|
|
|
|
|
Original file is located at |
|
|
https://colab.research.google.com/drive/1UcsSFSmZqIdAQTsD_4_CmwwAcAzz0h60 |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import os |
|
|
import io |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.cm as cm |
|
|
import folium |
|
|
import matplotlib.colors |
|
|
from scipy.stats import gaussian_kde |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
import huggingface_hub |
|
|
from huggingface_hub import HfApi, hf_hub_download, create_repo, file_exists, upload_file |
|
|
import tempfile |
|
|
import pathlib |
|
|
import json |
|
|
import uuid |
|
|
import shutil |
|
|
import zipfile |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from autogluon.multimodal import MultiModalPredictor |
|
|
AUTOGLUON_IMPORTED = True |
|
|
except ImportError: |
|
|
|
|
|
AUTOGLUON_IMPORTED = False |
|
|
class MultiModalPredictor: |
|
|
@staticmethod |
|
|
def load(path): |
|
|
raise ImportError("AutoGluon MultiModalPredictor is not installed or failed to import.") |
|
|
|
|
|
|
|
|
MODEL_REPO_ID = "ddecosmo/lanternfly_classifier" |
|
|
ZIP_FILENAME = "autogluon_image_predictor_dir.zip" |
|
|
MODEL_DIR_NAME = "autogluon_predictor_extracted" |
|
|
CLASSIFICATION_LABELS = ["Lanternfly", "Other Insect", "Neither"] |
|
|
|
|
|
PREDICTOR = None |
|
|
MODEL_STATUS = "Attempting to load model..." |
|
|
|
|
|
|
|
|
def _prepare_predictor_dir(repo_id, zip_filename, extract_dir_name) -> str: |
|
|
"""Downloads the zipped model and extracts it to a clean directory.""" |
|
|
base_extract_dir = os.path.join(os.getcwd(), extract_dir_name) |
|
|
try: |
|
|
|
|
|
zip_path = hf_hub_download(repo_id=repo_id, filename=zip_filename) |
|
|
|
|
|
|
|
|
if os.path.exists(base_extract_dir): |
|
|
shutil.rmtree(base_extract_dir) |
|
|
temp_extract_dir = os.path.join(os.getcwd(), "temp_ag_extract") |
|
|
os.makedirs(temp_extract_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
|
zip_ref.extractall(temp_extract_dir) |
|
|
|
|
|
|
|
|
extracted_contents = os.listdir(temp_extract_dir) |
|
|
if len(extracted_contents) == 1 and os.path.isdir(os.path.join(temp_extract_dir, extracted_contents[0])): |
|
|
final_model_dir = os.path.join(temp_extract_dir, extracted_contents[0]) |
|
|
shutil.move(final_model_dir, base_extract_dir) |
|
|
shutil.rmtree(temp_extract_dir) |
|
|
else: |
|
|
os.rename(temp_extract_dir, base_extract_dir) |
|
|
|
|
|
return base_extract_dir |
|
|
except Exception as e: |
|
|
print(f"Error during model prep: {e}") |
|
|
return "" |
|
|
|
|
|
|
|
|
if AUTOGLUON_IMPORTED: |
|
|
try: |
|
|
predictor_dir = _prepare_predictor_dir(MODEL_REPO_ID, ZIP_FILENAME, MODEL_DIR_NAME) |
|
|
if predictor_dir: |
|
|
PREDICTOR = MultiModalPredictor.load(predictor_dir) |
|
|
MODEL_STATUS = f"β
Model Active: {MODEL_REPO_ID}" |
|
|
else: |
|
|
MODEL_STATUS = "β Initialization failed during extraction/download." |
|
|
except Exception as e: |
|
|
PREDICTOR = None |
|
|
MODEL_STATUS = f"β Error loading model: {type(e).__name__} (Load Fail)" |
|
|
else: |
|
|
MODEL_STATUS = "β AutoGluon not imported. Classification tab is disabled." |
|
|
|
|
|
|
|
|
def classify_image(img: Image.Image): |
|
|
"""Predicts the class of the input image using the loaded AutoGluon model.""" |
|
|
if PREDICTOR is None: |
|
|
return "MODEL FAILED TO LOAD", 0.0, 0.0, 0.0 |
|
|
|
|
|
if img is None: |
|
|
return "NO IMAGE PROVIDED", 0.0, 0.0, 0.0 |
|
|
|
|
|
final_output = [0.0] * len(CLASSIFICATION_LABELS) |
|
|
final_result = "PREDICTION FAILED" |
|
|
|
|
|
|
|
|
temp_dir = pathlib.Path(tempfile.mkdtemp()) |
|
|
img_path = temp_dir / "input.png" |
|
|
img.save(img_path) |
|
|
|
|
|
try: |
|
|
df_path = pd.DataFrame({"image": [str(img_path)]}) |
|
|
proba_df = PREDICTOR.predict_proba(df_path, as_pandas=True) |
|
|
scores_dict = proba_df.iloc[0].to_dict() |
|
|
|
|
|
|
|
|
scores = [float(scores_dict.get(label, 0.0)) |
|
|
for label in CLASSIFICATION_LABELS] |
|
|
|
|
|
predicted_class_label = max(scores_dict, key=scores_dict.get) |
|
|
final_output = scores |
|
|
final_result = f"Predicted Class: **{predicted_class_label}**" |
|
|
|
|
|
except Exception as e: |
|
|
final_result = f"CRITICAL PREDICTION FAILURE: {type(e).__name__} - Check AutoGluon dependencies." |
|
|
finally: |
|
|
shutil.rmtree(temp_dir) |
|
|
|
|
|
return final_result, final_output[0], final_output[1], final_output[2] |
|
|
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HF_TOKEN_SPACE") |
|
|
DATASET_REPO = os.getenv("DATASET_REPO", "rlogh/lanternfly-data") |
|
|
METADATA_PATH = "metadata/entries.jsonl" |
|
|
api = None |
|
|
|
|
|
if HF_TOKEN and DATASET_REPO: |
|
|
api = HfApi(token=HF_TOKEN) |
|
|
try: |
|
|
|
|
|
create_repo(DATASET_REPO, repo_type="dataset", exist_ok=True, token=HF_TOKEN) |
|
|
GPS_SAVE_STATUS = "β
Dataset saving enabled." |
|
|
except Exception as e: |
|
|
GPS_SAVE_STATUS = f"β οΈ Error creating dataset repo: {e}" |
|
|
api = None |
|
|
else: |
|
|
GPS_SAVE_STATUS = "β οΈ Running in test mode - no HF credentials (dataset saving disabled)." |
|
|
|
|
|
|
|
|
def get_gps_js(): |
|
|
"""JavaScript function to be injected into Gradio to capture GPS coordinates.""" |
|
|
return """ |
|
|
() => { |
|
|
// Look for the hidden textbox element by its ID |
|
|
const textarea = document.querySelector('#hidden_gps_input textarea'); |
|
|
if (!textarea) return; |
|
|
|
|
|
if (!navigator.geolocation) { |
|
|
textarea.value = JSON.stringify({error: "Geolocation not supported by this browser/device."}); |
|
|
textarea.dispatchEvent(new Event('input', { bubbles: true })); |
|
|
return; |
|
|
} |
|
|
// Request current position |
|
|
navigator.geolocation.getCurrentPosition( |
|
|
function(position) { |
|
|
const data = { |
|
|
latitude: position.coords.latitude, |
|
|
longitude: position.coords.longitude, |
|
|
accuracy: position.coords.accuracy, |
|
|
timestamp: position.timestamp |
|
|
}; |
|
|
// Write JSON string to the hidden textbox and trigger a change event |
|
|
textarea.value = JSON.stringify(data); |
|
|
textarea.dispatchEvent(new Event('input', { bubbles: true })); |
|
|
}, |
|
|
function(err) { |
|
|
textarea.value = JSON.stringify({ error: err.message }); |
|
|
textarea.dispatchEvent(new Event('input', { bubbles: true })); |
|
|
}, |
|
|
{ enableHighAccuracy: true, timeout: 10000 } |
|
|
); |
|
|
} |
|
|
""" |
|
|
|
|
|
def handle_gps_location(json_str): |
|
|
"""Parses the GPS JSON string and updates the Gradio text boxes.""" |
|
|
try: |
|
|
data = json.loads(json_str) |
|
|
if 'error' in data: |
|
|
status_msg = f"β **GPS Error**: {data['error']}" |
|
|
return status_msg, "", "", "", "" |
|
|
|
|
|
lat = str(data.get('latitude', '')) |
|
|
lon = str(data.get('longitude', '')) |
|
|
accuracy = str(data.get('accuracy', '')) |
|
|
timestamp_ms = data.get('timestamp') |
|
|
|
|
|
|
|
|
device_ts = "" |
|
|
if timestamp_ms and isinstance(timestamp_ms, (int, float)): |
|
|
device_ts = datetime.fromtimestamp(timestamp_ms / 1000).isoformat() |
|
|
|
|
|
status_msg = f"β
**GPS Captured**: {lat[:8]}, {lon[:8]} (accuracy: {accuracy}m)" |
|
|
return status_msg, lat, lon, accuracy, device_ts |
|
|
|
|
|
except Exception as e: |
|
|
status_msg = f"β **Error parsing GPS data**: {str(e)}" |
|
|
return status_msg, "", "", "", "" |
|
|
|
|
|
|
|
|
def _save_image_to_repo(pil_img: Image.Image, dest_rel_path: str) -> None: |
|
|
"""Uploads a PIL image into the dataset repo via a memory buffer.""" |
|
|
img_bytes = io.BytesIO() |
|
|
pil_img.save(img_bytes, format="JPEG", quality=90) |
|
|
img_bytes.seek(0) |
|
|
upload_file( |
|
|
path_or_fileobj=img_bytes, path_in_repo=dest_rel_path, |
|
|
repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN, |
|
|
commit_message=f"Upload image {dest_rel_path}", |
|
|
) |
|
|
|
|
|
def _append_jsonl_in_repo(new_row: dict) -> None: |
|
|
"""Appends a new JSON line to the metadata file in the dataset repo.""" |
|
|
buf = io.BytesIO() |
|
|
existing_lines = [] |
|
|
|
|
|
try: |
|
|
|
|
|
if file_exists(DATASET_REPO, METADATA_PATH, repo_type="dataset", token=HF_TOKEN): |
|
|
local_path = hf_hub_download( |
|
|
repo_id=DATASET_REPO, filename=METADATA_PATH, |
|
|
repo_type="dataset", token=HF_TOKEN |
|
|
) |
|
|
with open(local_path, "r", encoding="utf-8") as f: |
|
|
existing_lines = f.read().splitlines() |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
existing_lines.append(json.dumps(new_row, ensure_ascii=False)) |
|
|
data = "\n".join(existing_lines).encode("utf-8") |
|
|
buf.write(data); buf.seek(0) |
|
|
|
|
|
|
|
|
upload_file( |
|
|
path_or_fileobj=buf, path_in_repo=METADATA_PATH, |
|
|
repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN, |
|
|
commit_message=f"Append 1 entry at {datetime.now().isoformat()}Z", |
|
|
) |
|
|
|
|
|
|
|
|
def save_to_dataset(image, lat, lon, accuracy_m, device_ts): |
|
|
"""Validates data and saves the image and metadata to the Hugging Face dataset.""" |
|
|
try: |
|
|
if image is None: |
|
|
return "β **Error**: No image captured.", "" |
|
|
if not lat or not lon: |
|
|
return "β **Error**: GPS coordinates missing.", "" |
|
|
|
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image.astype('uint8')) |
|
|
|
|
|
|
|
|
if not api: |
|
|
img_id = str(uuid.uuid4()) |
|
|
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
row = {"id": img_id, "image": f"test_{timestamp_str}_{img_id[:8]}.jpg", |
|
|
"latitude": float(lat), "longitude": float(lon), |
|
|
"mode": "test"} |
|
|
status = f"π **Test Mode**: Data validated successfully! Sample {img_id[:8]}" |
|
|
preview = json.dumps(row, indent=2) |
|
|
return status, preview |
|
|
|
|
|
|
|
|
sample_id = str(uuid.uuid4()) |
|
|
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
image_rel_path = f"images/lanternfly_{timestamp_str}_{sample_id[:8]}.jpg" |
|
|
|
|
|
|
|
|
_save_image_to_repo(image, image_rel_path) |
|
|
server_ts_utc = datetime.now().isoformat() + "Z" |
|
|
|
|
|
|
|
|
row = { |
|
|
"id": sample_id, "image": image_rel_path, |
|
|
"latitude": float(lat), "longitude": float(lon), |
|
|
"accuracy_m": float(accuracy_m) if accuracy_m else None, |
|
|
"device_timestamp": device_ts if device_ts else None, |
|
|
"server_timestamp_utc": server_ts_utc, |
|
|
} |
|
|
_append_jsonl_in_repo(row) |
|
|
|
|
|
status = f"β
**Success!** Saved to dataset! Image: `{image_rel_path}`" |
|
|
preview = json.dumps(row, indent=2) |
|
|
return status, preview |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"β **Error during save**: {str(e)}" |
|
|
return error_msg, "" |
|
|
|
|
|
|
|
|
HUGGINGFACE_DATA_REPO = "rlogh/lanternfly-data" |
|
|
METADATA_PATH = "metadata/entries.jsonl" |
|
|
|
|
|
|
|
|
pittsburgh_lat_min, pittsburgh_lat_max = 40.3, 40.6 |
|
|
pittsburgh_lon_min, pittsburgh_lon_max = -80.2, -79.8 |
|
|
|
|
|
|
|
|
def load_lanternfly_data_from_hf(): |
|
|
"""Downloads the JSONL metadata file from HF and extracts latitude/longitude.""" |
|
|
try: |
|
|
|
|
|
local_path = hf_hub_download( |
|
|
repo_id=HUGGINGFACE_DATA_REPO, |
|
|
filename=METADATA_PATH, |
|
|
repo_type="dataset" |
|
|
) |
|
|
|
|
|
latitudes = [] |
|
|
longitudes = [] |
|
|
|
|
|
|
|
|
with open(local_path, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
try: |
|
|
data = json.loads(line) |
|
|
lat = data.get('latitude') |
|
|
lon = data.get('longitude') |
|
|
|
|
|
if isinstance(lat, (float, int)) and isinstance(lon, (float, int)): |
|
|
|
|
|
if pittsburgh_lat_min <= lat <= pittsburgh_lat_max and \ |
|
|
pittsburgh_lon_min <= lon <= pittsburgh_lon_max: |
|
|
latitudes.append(lat) |
|
|
longitudes.append(lon) |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
|
|
|
if not latitudes: |
|
|
return None, None, "Error: Found no valid coordinates in the dataset." |
|
|
|
|
|
return np.array(latitudes), np.array(longitudes), None |
|
|
|
|
|
except Exception as e: |
|
|
return None, None, f"Error downloading or parsing HF data: {type(e).__name__} - {e}" |
|
|
|
|
|
|
|
|
def calculate_kde_and_points(): |
|
|
"""Loads data, calculates KDE, and prepares data for visualization.""" |
|
|
latitudes, longitudes, error = load_lanternfly_data_from_hf() |
|
|
|
|
|
if error: |
|
|
return None, None, None, error |
|
|
|
|
|
try: |
|
|
|
|
|
coordinates = np.vstack([longitudes, latitudes]) |
|
|
|
|
|
|
|
|
kde_object = gaussian_kde(coordinates) |
|
|
|
|
|
return latitudes, longitudes, kde_object, None |
|
|
|
|
|
except Exception as e: |
|
|
return None, None, None, f"Error calculating KDE: {type(e).__name__} - {e}" |
|
|
|
|
|
|
|
|
def plot_kde_and_points(min_lat, max_lat, min_lon, max_lon, original_latitudes, original_longitudes, kde_object): |
|
|
"""Generates an interactive Folium map with points colored by KDE density.""" |
|
|
|
|
|
|
|
|
|
|
|
original_coordinates = np.vstack([original_longitudes, original_latitudes]) |
|
|
density_at_original_points = kde_object(original_coordinates) |
|
|
|
|
|
density_normalized = (density_at_original_points - density_at_original_points.min()) / (density_at_original_points.max() - density_at_original_points.min() + 1e-9) |
|
|
|
|
|
|
|
|
colormap = cm.get_cmap('viridis') |
|
|
map_center_lat = np.mean(original_latitudes) |
|
|
map_center_lon = np.mean(original_longitudes) |
|
|
m_colored_points = folium.Map(location=[map_center_lat, map_center_lon], zoom_start=12) |
|
|
|
|
|
|
|
|
for lat, lon, density_norm in zip(original_latitudes, original_longitudes, density_normalized): |
|
|
color = matplotlib.colors.rgb2hex(colormap(density_norm)) |
|
|
folium.CircleMarker( |
|
|
location=[lat, lon], radius=5, color=color, fill=True, fill_color=color, fill_opacity=0.7, |
|
|
tooltip=f"Lat: {lat:.5f}, Lon: {lon:.5f}" |
|
|
).add_to(m_colored_points) |
|
|
|
|
|
colored_points_map_html = m_colored_points._repr_html_() |
|
|
|
|
|
|
|
|
|
|
|
return None, colored_points_map_html |
|
|
|
|
|
|
|
|
def update_visualization_live(): |
|
|
"""Main visualization function for the Gradio interface.""" |
|
|
latitudes, longitudes, kde_object, error = calculate_kde_and_points() |
|
|
|
|
|
if error: |
|
|
|
|
|
return None, f"<h1>{error}</h1>", f"Error: {error}" |
|
|
|
|
|
|
|
|
pil_image, colored_points_map_html = plot_kde_and_points( |
|
|
pittsburgh_lat_min, pittsburgh_lat_max, pittsburgh_lon_min, pittsburgh_lon_max, |
|
|
latitudes, longitudes, kde_object |
|
|
) |
|
|
|
|
|
|
|
|
return pil_image, colored_points_map_html, "" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Lanternfly Tracking Tool") as app: |
|
|
|
|
|
gr.Markdown("# Lanternfly Tracking Tool") |
|
|
|
|
|
with gr.Tab("1. Field Capture & Classification"): |
|
|
gr.Markdown(f"## πΈ Lanternfly Classification and GPS Data Capture") |
|
|
gr.Markdown(f"**Model Status**: {MODEL_STATUS}") |
|
|
gr.Markdown(f"**GPS Save Status**: {GPS_SAVE_STATUS}") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
image_in = gr.Image( |
|
|
type="pil", label="1. Upload or Capture Image", |
|
|
value="https://placehold.co/224x224/ff6347/ffffff?text=Lanternfly", |
|
|
sources=["upload", "webcam"] |
|
|
) |
|
|
|
|
|
run_classify_btn = gr.Button("π Run Classification", variant="primary", interactive=PREDICTOR is not None) |
|
|
|
|
|
gr.Markdown("### Classification Result") |
|
|
final_result_box = gr.Textbox(label="Prediction Result", interactive=False) |
|
|
with gr.Row(): |
|
|
conf_0 = gr.Number(label=f"Confidence: {CLASSIFICATION_LABELS[0]}", interactive=False) |
|
|
conf_1 = gr.Number(label=f"Confidence: {CLASSIFICATION_LABELS[1]}", interactive=False) |
|
|
conf_2 = gr.Number(label=f"Confidence: {CLASSIFICATION_LABELS[2]}", interactive=False) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("## π GPS Data Capture") |
|
|
gps_btn = gr.Button("π Get GPS", variant="primary") |
|
|
|
|
|
hidden_gps_input = gr.Textbox(visible=False, elem_id="hidden_gps_input") |
|
|
|
|
|
with gr.Row(): |
|
|
lat_box = gr.Textbox(label="Latitude", interactive=True) |
|
|
lon_box = gr.Textbox(label="Longitude", interactive=True) |
|
|
with gr.Row(): |
|
|
accuracy_box = gr.Textbox(label="Accuracy (m)", interactive=True) |
|
|
device_ts_box = gr.Textbox(label="Device Timestamp", interactive=True) |
|
|
|
|
|
|
|
|
save_btn = gr.Button("πΎ Save Image & GPS to Dataset", variant="secondary", interactive=api is not None) |
|
|
|
|
|
gr.Markdown("### Save Status & Preview") |
|
|
gps_status = gr.Markdown("π **Ready for GPS capture and saving.**") |
|
|
preview = gr.JSON(label="Preview JSON") |
|
|
|
|
|
|
|
|
if PREDICTOR is not None: |
|
|
run_classify_btn.click( |
|
|
fn=classify_image, |
|
|
inputs=[image_in], |
|
|
outputs=[final_result_box, conf_0, conf_1, conf_2] |
|
|
) |
|
|
|
|
|
|
|
|
gps_btn.click( |
|
|
fn=None, inputs=[], outputs=[], js=get_gps_js() |
|
|
) |
|
|
hidden_gps_input.change( |
|
|
fn=handle_gps_location, |
|
|
inputs=[hidden_gps_input], |
|
|
outputs=[gps_status, lat_box, lon_box, accuracy_box, device_ts_box] |
|
|
) |
|
|
save_btn.click( |
|
|
fn=save_to_dataset, |
|
|
inputs=[image_in, lat_box, lon_box, accuracy_box, device_ts_box], |
|
|
outputs=[gps_status, preview] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("2. Spatial Data Visualization (KDE)"): |
|
|
gr.Markdown("## πΊοΈ Kernel Density Estimation of Lanternfly Sightings") |
|
|
gr.Markdown(f"**Data Source**: {HUGGINGFACE_DATA_REPO} - Automatically loaded from `metadata/entries.jsonl`") |
|
|
|
|
|
refresh_btn = gr.Button("π Refresh Map from Hugging Face Data", variant="primary") |
|
|
kde_error_box = gr.Textbox(label="Error/Debug Message", visible=False) |
|
|
|
|
|
with gr.Row(): |
|
|
interactive_map_out = gr.HTML(label="Interactive Points Map Colored by KDE (Folium)") |
|
|
|
|
|
|
|
|
matplotlib_placeholder = gr.State(value=None) |
|
|
|
|
|
|
|
|
refresh_btn.click( |
|
|
fn=update_visualization_live, |
|
|
inputs=[], |
|
|
outputs=[matplotlib_placeholder, interactive_map_out, kde_error_box] |
|
|
) |
|
|
|
|
|
|
|
|
app.load( |
|
|
fn=update_visualization_live, |
|
|
inputs=[], |
|
|
outputs=[matplotlib_placeholder, interactive_map_out, kde_error_box] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch() |