waterdb / utils /data_loading.py
github-actions[bot]
Deploy from GitHub Actions
bc24113
from datetime import date
from typing import TypedDict
import pandas as pd
import streamlit as st
from config import AppConfig
from utils.date_utils import get_reporting_year
from utils.summary import (
create_multiindex_columns,
create_overall_summary,
create_summary_by_station_and_position,
)
from utils.timing import timer
class DatasetMetadata(TypedDict):
total_records: int
date_range: dict[str, date]
years: list[int]
stations: int
records_by_year: dict[int, int]
reporting_year_end_month: int
class DataManager:
def __init__(self, config: AppConfig):
self.config = config
self._data_cache = None
self._metadata: DatasetMetadata | None = None
self._all_sectors: list[str] | None = None
self._all_stations: list[str] | None = None
self._initialize_complete_lists()
def _initialize_complete_lists(self) -> None:
"""Initialize complete lists of sectors and stations from raw data"""
try:
raw_df = get_raw_data(self.config.DATA_FILE_PATH)
# Handle sectors
sectors = raw_df["Sector"].dropna().unique().tolist()
self._all_sectors = sorted(sectors)
# Handle stations - convert to float first to standardize numeric format
stations = raw_df["Station_Number"].dropna()
stations = stations.astype(float).astype(str).unique().tolist()
self._all_stations = sorted(stations, key=lambda x: float(x))
except Exception as e:
st.error(f"Failed to initialize complete lists: {str(e)}")
self._all_sectors = []
self._all_stations = []
@property
def all_sectors(self) -> list[str]:
"""Get complete list of all sectors in the dataset"""
if self._all_sectors is None:
self._initialize_complete_lists()
return self._all_sectors if self._all_sectors is not None else []
@property
def all_stations(self) -> list[str]:
"""Get complete list of all stations in the dataset"""
if self._all_stations is None:
self._initialize_complete_lists()
return self._all_stations if self._all_stations is not None else []
@property
def metadata(self) -> DatasetMetadata | None:
if self._metadata is None:
self._load_metadata()
return self._metadata
def _load_metadata(self) -> None:
try:
raw_df = get_raw_data(self.config.DATA_FILE_PATH)
self._metadata = get_dataset_metadata(
raw_df, self.config.DEFAULT_REPORTING_MONTH
)
except Exception as e:
st.error(f"Failed to load dataset metadata: {str(e)}")
self._metadata = None
def _load_data_internal(
self,
reporting_month: int,
start_date: date | None = None,
end_date: date | None = None,
) -> dict:
"""Internal method to load and process data"""
raw_df = get_raw_data(self.config.DATA_FILE_PATH)
raw_df = raw_df[raw_df["Station_Number"].notna()]
# Get full dataset date range for the date input controls
full_dataset_metadata = get_dataset_metadata(raw_df, reporting_month)
# Apply date filters if provided
if start_date and end_date:
raw_df = filter_data_by_dates(raw_df, start_date, end_date)
# Add reporting year based on provided reporting_month or default
if reporting_month is not None:
raw_df["Reporting_Year"] = raw_df["Activity_Start_Date_Time"].apply(
lambda x: get_reporting_year(x, reporting_month)
)
# Apply exclusion filters if they exist in session state
if (
"persistent_excluded_sectors" in st.session_state
and st.session_state.persistent_excluded_sectors
):
raw_df = raw_df[
~raw_df["Sector"].isin(st.session_state.persistent_excluded_sectors)
]
if (
"persistent_excluded_stations" in st.session_state
and st.session_state.persistent_excluded_stations
):
# Convert station numbers to standardized string format for comparison
df_stations = raw_df["Station_Number"].astype(float).astype(str)
excluded_stations = [
str(float(s)) for s in st.session_state.persistent_excluded_stations
]
raw_df = raw_df[~df_stations.isin(excluded_stations)]
downloads = prepare_downloads(raw_df)
return {
"raw_df": raw_df,
"downloads": downloads,
"full_dataset_metadata": full_dataset_metadata,
}
def _get_empty_data_structure(self) -> dict:
"""Return empty data structure for error cases"""
return {
"raw_df": pd.DataFrame(),
"downloads": {"summary": {}, "raw": {}},
"full_dataset_metadata": {
"total_records": 0,
"date_range": {"start": None, "end": None},
"years": [],
"stations": 0,
"records_by_year": {},
},
}
def load_data(
self,
start_date: date | None = None,
end_date: date | None = None,
reporting_month: int | None = None,
force_refresh: bool = False,
) -> dict:
"""Load data with improved error handling and caching"""
if force_refresh:
st.cache_data.clear()
try:
# Ensure we have the latest exclusions
excluded_sectors = st.session_state.get("persistent_excluded_sectors", [])
excluded_stations = st.session_state.get("persistent_excluded_stations", [])
# Update session state with current exclusions
st.session_state.persistent_excluded_sectors = excluded_sectors
st.session_state.persistent_excluded_stations = excluded_stations
return self._load_data_internal(
reporting_month=reporting_month
if reporting_month
else self.config.DEFAULT_REPORTING_MONTH,
start_date=start_date,
end_date=end_date,
)
except Exception as e:
st.error(f"Failed to load data: {str(e)}")
return self._get_empty_data_structure()
@timer(include_params=True)
def get_raw_data(file_path: str) -> pd.DataFrame:
"""Load raw data from parquet file"""
return pd.read_parquet(file_path)
@timer(include_params=False)
def get_dataset_metadata(df: pd.DataFrame, reporting_month: int) -> DatasetMetadata:
"""Generate metadata about the dataset"""
return {
"total_records": len(df),
"date_range": {
"start": df["Activity_Start_Date_Time"].min().date(),
"end": df["Activity_Start_Date_Time"].max().date(),
},
"years": sorted(df["Activity_Start_Date_Time"].dt.year.unique()),
"stations": df["Station_Number"].nunique(),
"records_by_year": (
df.groupby(df["Activity_Start_Date_Time"].dt.year).size().to_dict()
), # type: ignore
"reporting_year_end_month": reporting_month,
}
@timer(include_params=False)
def filter_data_by_dates(
df: pd.DataFrame, start_date: date, end_date: date
) -> pd.DataFrame:
"""Filter dataframe by date range"""
try:
df["Activity_Start_Date_Time"] = pd.to_datetime(df["Activity_Start_Date_Time"])
# Convert start_date to start of day and end_date to end of day
start_datetime = pd.Timestamp(start_date).normalize()
end_datetime = (
pd.Timestamp(end_date) + pd.Timedelta(days=1) - pd.Timedelta(microseconds=1)
)
filtered_df = df[
(df["Activity_Start_Date_Time"] >= start_datetime)
& (df["Activity_Start_Date_Time"] <= end_datetime)
]
if filtered_df.empty:
st.warning("No data found for the selected date range")
return df
return filtered_df
except Exception as e:
st.error(f"Error filtering data: {str(e)}")
return df
@st.cache_data
@timer(include_params=False)
def create_summaries(
raw_df: pd.DataFrame,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
summary_by_station = create_summary_by_station_and_position(raw_df)
overall_summary = create_overall_summary(raw_df)
multiindex_df = create_multiindex_columns(summary_by_station)
return summary_by_station, overall_summary, multiindex_df
@timer(include_params=False)
def prepare_downloads(raw_df):
return {
"raw": {
"CSV": (raw_df.to_csv(index=False), "csv", "text/csv"),
},
}
def add_lat_long(raw_df: pd.DataFrame, stations_df: pd.DataFrame) -> pd.DataFrame:
"""
Add latitude and longitude to raw data based on station number.
"""
raw_df["Number"] = raw_df["Station_Number"].astype(float)
raw_df = raw_df.merge(
stations_df[["Number", "Latitude", "Longitude"]],
left_on="Number",
right_on="Number",
how="left",
)
return raw_df.drop("Number", axis=1)
@timer(include_params=False)
def get_stations_data() -> pd.DataFrame:
"""
Return stations data as a dataframe with the most recent and earliest sample dates for each station.
"""
raw_df = st.session_state.data["raw_df"]
# Get date ranges for each station in one operation
sample_dates = (
raw_df.groupby("Station_Number")["Activity_Start_Date_Time"]
.agg(["min", "max", "count"])
.reset_index()
.rename(
columns={
"min": "Earliest_Sample",
"max": "Most_Recent_Sample",
"count": "Total_Samples",
}
)
.astype({"Station_Number": float, "Total_Samples": int})
)
# Merge with stations data and format dates
return (
pd.read_csv("data/Stations-Locations.csv")
.merge(sample_dates, left_on="Number", right_on="Station_Number", how="left")
.drop("Station_Number", axis=1)
.assign(
Most_Recent_Sample=lambda x: pd.to_datetime(x.Most_Recent_Sample).dt.date,
Earliest_Sample=lambda x: pd.to_datetime(x.Earliest_Sample).dt.date,
)
.dropna(subset=["Total_Samples"])
)
@timer(include_params=False)
def get_analyte_data_with_lat_long(df: pd.DataFrame, analyte: str) -> pd.DataFrame:
"""
Extract and transform data for a specific analyte, adding geographical coordinates.
This function processes raw water quality data by:
1. Adding latitude/longitude coordinates from stations data
2. Filtering for a specific analyte
3. Removing rows with missing values
4. Aggregating duplicate measurements using mean values
Args:
df (pd.DataFrame): Raw water quality data containing at minimum these columns:
- Station_Number
- Org_Analyte_Name
- Org_Result_Value
- Reporting_Year
analyte (str): Name of the analyte to filter for (e.g., "Temperature, Water")
Returns:
pd.DataFrame: Processed dataframe with columns:
- Activity_Start_Date_Time: Timestamp of measurement
- Station_Number: Monitoring station identifier
- Sector: Geographical sector
- WBID: Waterbody ID
- Sample_Position: Position of sample (e.g., "Surface", "Bottom")
- Activity_Depth: Depth of measurement
- Latitude: Station latitude
- Longitude: Station longitude
- Reporting_Year: Reporting year
- {analyte}: Measured value for the specified analyte
Note:
Duplicate measurements at the same location and time are averaged.
"""
return (
df.pipe(add_lat_long, get_stations_data())
.query(f"Org_Analyte_Name == '{analyte}'")
.dropna(subset=["Org_Result_Value"])
.pivot_table(
index=[
"Activity_Start_Date_Time",
"Station_Number",
"Sector",
"WBID",
"Sample_Position",
"Activity_Depth",
"Latitude",
"Longitude",
"Reporting_Year",
],
values="Org_Result_Value",
aggfunc="mean",
observed=True,
)
.reset_index()
.rename(columns={"Org_Result_Value": analyte})
)
@st.cache_data
@timer(include_params=False)
def load_seasonal_data(raw_df, analyte):
"""Load and prepare data for seasonal trends analysis"""
return get_analyte_data_with_lat_long(raw_df, analyte)