Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Dataset utilities for Trackio experiment data management | |
Provides functions for safe dataset operations with data preservation | |
""" | |
import json | |
import logging | |
from datetime import datetime | |
from typing import Dict, Any, List, Optional, Union | |
from datasets import Dataset, load_dataset | |
logger = logging.getLogger(__name__) | |
class TrackioDatasetManager: | |
""" | |
Manager class for Trackio experiment datasets with data preservation. | |
This class ensures that existing experiment data is always preserved | |
when adding new experiments or updating existing ones. | |
""" | |
def __init__(self, dataset_repo: str, hf_token: str): | |
""" | |
Initialize the dataset manager. | |
Args: | |
dataset_repo (str): HF dataset repository ID (e.g., "username/dataset-name") | |
hf_token (str): Hugging Face token for authentication | |
""" | |
self.dataset_repo = dataset_repo | |
self.hf_token = hf_token | |
self._validate_repo_format() | |
def _validate_repo_format(self): | |
"""Validate dataset repository format""" | |
if not self.dataset_repo or '/' not in self.dataset_repo: | |
raise ValueError(f"Invalid dataset repository format: {self.dataset_repo}") | |
def check_dataset_exists(self) -> bool: | |
""" | |
Check if the dataset repository exists and is accessible. | |
Returns: | |
bool: True if dataset exists and is accessible, False otherwise | |
""" | |
try: | |
# Try standard load first | |
load_dataset(self.dataset_repo, token=self.hf_token) | |
logger.info(f"β Dataset {self.dataset_repo} exists and is accessible") | |
return True | |
except Exception as e: | |
# Some hubs raise a split-metadata mismatch; retry with relaxed verification | |
try: | |
logger.info(f"π Standard load failed: {e}. Retrying with relaxed verification...") | |
load_dataset( | |
self.dataset_repo, | |
token=self.hf_token, | |
verification_mode="no_checks" # type: ignore[arg-type] | |
) | |
logger.info(f"β Dataset {self.dataset_repo} accessible with relaxed verification") | |
return True | |
except Exception as e2: | |
logger.info(f"π Dataset {self.dataset_repo} doesn't exist or isn't accessible: {e2}") | |
return False | |
def load_existing_experiments(self) -> List[Dict[str, Any]]: | |
""" | |
Load all existing experiments from the dataset. | |
Returns: | |
List[Dict[str, Any]]: List of existing experiment dictionaries | |
""" | |
try: | |
if not self.check_dataset_exists(): | |
logger.info("π No existing dataset found, returning empty list") | |
return [] | |
# Load with relaxed verification to avoid split-metadata mismatches blocking reads | |
try: | |
dataset = load_dataset(self.dataset_repo, token=self.hf_token) | |
except Exception: | |
dataset = load_dataset(self.dataset_repo, token=self.hf_token, verification_mode="no_checks") # type: ignore[arg-type] | |
if 'train' not in dataset: | |
logger.info("π No 'train' split found in dataset") | |
return [] | |
experiments = list(dataset['train']) | |
logger.info(f"π Loaded {len(experiments)} existing experiments") | |
# Validate experiment structure | |
valid_experiments = [] | |
for exp in experiments: | |
if self._validate_experiment_structure(exp): | |
valid_experiments.append(exp) | |
else: | |
logger.warning(f"β οΈ Skipping invalid experiment: {exp.get('experiment_id', 'unknown')}") | |
logger.info(f"π {len(valid_experiments)} valid experiments loaded") | |
return valid_experiments | |
except Exception as e: | |
logger.error(f"β Failed to load existing experiments: {e}") | |
return [] | |
def _validate_experiment_structure(self, experiment: Dict[str, Any]) -> bool: | |
""" | |
Validate and SANITIZE an experiment structure. | |
This function is intentionally lenient to avoid dropping any | |
existing rows from the remote dataset during union-merge saves. | |
Rules: | |
- 'experiment_id' must exist; otherwise the row is skipped | |
- All other required fields are auto-filled with safe defaults | |
- JSON fields are normalized to valid JSON strings | |
Args: | |
experiment (Dict[str, Any]): Experiment dictionary to validate/sanitize | |
Returns: | |
bool: True if experiment has (or was sanitized to) a valid structure. | |
""" | |
# Hard requirement: experiment_id must be present | |
if not experiment.get('experiment_id'): | |
logger.warning("β οΈ Missing required field 'experiment_id' in experiment; skipping row") | |
return False | |
# Fill defaults for non-JSON scalar fields | |
defaults = { | |
'name': '', | |
'description': '', | |
'created_at': datetime.now().isoformat(), | |
'status': 'running', | |
} | |
for key, default_value in defaults.items(): | |
if experiment.get(key) in (None, ''): | |
experiment[key] = default_value | |
# Normalize JSON fields to valid JSON strings | |
def _ensure_json_string(field_name: str, default_value: Any): | |
raw_value = experiment.get(field_name) | |
try: | |
if isinstance(raw_value, str): | |
# Validate JSON string; if empty use default | |
if raw_value.strip() == '': | |
experiment[field_name] = json.dumps(default_value, default=str) | |
else: | |
json.loads(raw_value) | |
# keep as-is if it's valid JSON | |
else: | |
# Convert object to JSON string | |
experiment[field_name] = json.dumps( | |
raw_value if raw_value is not None else default_value, | |
default=str | |
) | |
except Exception: | |
# On any error, fall back to default JSON | |
experiment[field_name] = json.dumps(default_value, default=str) | |
for json_field, default in (('metrics', []), ('parameters', {}), ('artifacts', []), ('logs', [])): | |
_ensure_json_string(json_field, default) | |
return True | |
def save_experiments(self, experiments: List[Dict[str, Any]], commit_message: Optional[str] = None) -> bool: | |
""" | |
Save a list of experiments to the dataset using a non-destructive union merge. | |
- Loads existing experiments (if any) and builds a union by `experiment_id`. | |
- For overlapping IDs, merges JSON fields: | |
- metrics: concatenates lists and de-duplicates by (step, timestamp) for nested entries | |
- parameters: dict-update (new values override) | |
- artifacts: union with de-dup | |
- logs: concatenation with de-dup | |
- Non-JSON scalar fields from incoming experiments take precedence. | |
Args: | |
experiments (List[Dict[str, Any]]): List of experiment dictionaries | |
commit_message (Optional[str]): Custom commit message | |
Returns: | |
bool: True if save was successful, False otherwise | |
""" | |
try: | |
if not experiments: | |
logger.warning("β οΈ No experiments to save") | |
return False | |
# Helpers | |
def _parse_json_field(value, default): | |
try: | |
if value is None: | |
return default | |
if isinstance(value, str): | |
return json.loads(value) if value else default | |
return value | |
except Exception: | |
return default | |
def _metrics_key(entry: Dict[str, Any]): | |
if isinstance(entry, dict): | |
return (entry.get('step'), entry.get('timestamp')) | |
return (None, json.dumps(entry, sort_keys=True)) | |
# Load existing experiments for union merge | |
existing = {} | |
dataset_exists = self.check_dataset_exists() | |
try: | |
existing_list = self.load_existing_experiments() | |
for row in existing_list: | |
exp_id = row.get('experiment_id') | |
if exp_id: | |
existing[exp_id] = row | |
except Exception: | |
existing = {} | |
# Safety guard: avoid destructive overwrite if dataset exists but | |
# we failed to read any existing records (e.g., transient HF issue) | |
if dataset_exists and len(existing) == 0 and len(experiments) <= 3: | |
logger.error( | |
"β Refusing to overwrite dataset: existing records could not be loaded " | |
"but repository exists. Skipping save to prevent data loss." | |
) | |
return False | |
# Validate and merge | |
merged_map: Dict[str, Dict[str, Any]] = {} | |
# Seed with existing | |
for exp_id, row in existing.items(): | |
merged_map[exp_id] = row | |
# Apply incoming | |
for exp in experiments: | |
if not self._validate_experiment_structure(exp): | |
logger.error(f"β Invalid experiment structure: {exp.get('experiment_id', 'unknown')}") | |
return False | |
exp_id = exp['experiment_id'] | |
incoming = exp | |
if exp_id not in merged_map: | |
incoming['last_updated'] = incoming.get('last_updated') or datetime.now().isoformat() | |
merged_map[exp_id] = incoming | |
continue | |
# Merge with existing | |
base = merged_map[exp_id] | |
# Parse JSON fields | |
base_metrics = _parse_json_field(base.get('metrics'), []) | |
base_params = _parse_json_field(base.get('parameters'), {}) | |
base_artifacts = _parse_json_field(base.get('artifacts'), []) | |
base_logs = _parse_json_field(base.get('logs'), []) | |
inc_metrics = _parse_json_field(incoming.get('metrics'), []) | |
inc_params = _parse_json_field(incoming.get('parameters'), {}) | |
inc_artifacts = _parse_json_field(incoming.get('artifacts'), []) | |
inc_logs = _parse_json_field(incoming.get('logs'), []) | |
# Merge metrics with de-dup (by step+timestamp) then collapse per step | |
merged_metrics = [] | |
seen = set() | |
for entry in base_metrics + inc_metrics: | |
try: | |
key = _metrics_key(entry) | |
except Exception: | |
key = (None, None) | |
if key not in seen: | |
seen.add(key) | |
merged_metrics.append(entry) | |
# Collapse duplicate steps by merging their metric dicts and keeping the latest timestamp | |
try: | |
step_to_entry: Dict[Any, Dict[str, Any]] = {} | |
for e in merged_metrics: | |
if not isinstance(e, dict): | |
continue | |
# Ensure nested structure {timestamp, step, metrics} | |
if 'metrics' not in e: | |
e = { | |
'timestamp': e.get('timestamp'), | |
'step': e.get('step'), | |
'metrics': {k: v for k, v in e.items() if k not in ('step', 'timestamp')} | |
} | |
step_val = e.get('step') | |
if step_val in step_to_entry: | |
existing_e = step_to_entry[step_val] | |
try: | |
existing_metrics_dict = existing_e.get('metrics', {}) | |
if isinstance(existing_metrics_dict, dict): | |
existing_metrics_dict.update(e.get('metrics', {})) | |
else: | |
existing_e['metrics'] = e.get('metrics', {}) | |
except Exception: | |
existing_e['metrics'] = e.get('metrics', {}) | |
try: | |
if str(e.get('timestamp', '')) > str(existing_e.get('timestamp', '')): | |
existing_e['timestamp'] = e.get('timestamp') | |
except Exception: | |
pass | |
else: | |
step_to_entry[step_val] = dict(e) | |
def _step_key(x: Dict[str, Any]): | |
try: | |
return float(x.get('step')) | |
except Exception: | |
return -1.0 | |
merged_metrics = sorted(step_to_entry.values(), key=_step_key) | |
except Exception: | |
# On any error, keep the de-duplicated list | |
pass | |
# Merge params | |
merged_params = {} | |
if isinstance(base_params, dict): | |
merged_params.update(base_params) | |
if isinstance(inc_params, dict): | |
merged_params.update(inc_params) | |
# Merge artifacts and logs with de-dup | |
def _dedup_list(lst): | |
out = [] | |
seen_local = set() | |
for item in lst: | |
key = json.dumps(item, sort_keys=True, default=str) if not isinstance(item, str) else item | |
if key not in seen_local: | |
seen_local.add(key) | |
out.append(item) | |
return out | |
merged_artifacts = _dedup_list(list(base_artifacts) + list(inc_artifacts)) | |
merged_logs = _dedup_list(list(base_logs) + list(inc_logs)) | |
# Rebuild merged record preferring incoming scalars | |
merged_rec = dict(base) | |
merged_rec.update({k: v for k, v in incoming.items() if k not in ('metrics', 'parameters', 'artifacts', 'logs')}) | |
merged_rec['metrics'] = json.dumps(merged_metrics, default=str) | |
merged_rec['parameters'] = json.dumps(merged_params, default=str) | |
merged_rec['artifacts'] = json.dumps(merged_artifacts, default=str) | |
merged_rec['logs'] = json.dumps(merged_logs, default=str) | |
merged_rec['last_updated'] = datetime.now().isoformat() | |
merged_map[exp_id] = merged_rec | |
# Prepare final list | |
valid_experiments = list(merged_map.values()) | |
# Ensure all have mandatory fields encoded | |
normalized = [] | |
for rec in valid_experiments: | |
# Normalize json fields to strings | |
for f, default in (('metrics', []), ('parameters', {}), ('artifacts', []), ('logs', [])): | |
val = rec.get(f) | |
if not isinstance(val, str): | |
rec[f] = json.dumps(val if val is not None else default, default=str) | |
if 'last_updated' not in rec: | |
rec['last_updated'] = datetime.now().isoformat() | |
normalized.append(rec) | |
dataset = Dataset.from_list(normalized) | |
# Generate commit message if not provided | |
if not commit_message: | |
commit_message = f"Union-merge update with {len(normalized)} experiments ({datetime.now().isoformat()})" | |
# Push to hub | |
dataset.push_to_hub( | |
self.dataset_repo, | |
token=self.hf_token, | |
private=True, | |
commit_message=commit_message | |
) | |
logger.info(f"β Successfully saved {len(normalized)} experiments (union-merged) to {self.dataset_repo}") | |
return True | |
except Exception as e: | |
logger.error(f"β Failed to save experiments to dataset: {e}") | |
return False | |
def upsert_experiment(self, experiment: Dict[str, Any]) -> bool: | |
""" | |
Insert a new experiment or update an existing one, preserving all other data. | |
Args: | |
experiment (Dict[str, Any]): Experiment dictionary to upsert | |
Returns: | |
bool: True if operation was successful, False otherwise | |
""" | |
try: | |
# Validate the experiment structure | |
if not self._validate_experiment_structure(experiment): | |
logger.error(f"β Invalid experiment structure for {experiment.get('experiment_id', 'unknown')}") | |
return False | |
# Load existing experiments | |
existing_experiments = self.load_existing_experiments() | |
# Find if experiment already exists | |
experiment_id = experiment['experiment_id'] | |
experiment_found = False | |
updated_experiments = [] | |
for existing_exp in existing_experiments: | |
if existing_exp.get('experiment_id') == experiment_id: | |
# Update existing experiment | |
logger.info(f"π Updating existing experiment: {experiment_id}") | |
experiment['last_updated'] = datetime.now().isoformat() | |
updated_experiments.append(experiment) | |
experiment_found = True | |
else: | |
# Preserve existing experiment | |
updated_experiments.append(existing_exp) | |
# If experiment doesn't exist, add it | |
if not experiment_found: | |
logger.info(f"β Adding new experiment: {experiment_id}") | |
experiment['last_updated'] = datetime.now().isoformat() | |
updated_experiments.append(experiment) | |
# Save all experiments | |
commit_message = f"{'Update' if experiment_found else 'Add'} experiment {experiment_id} (preserving {len(existing_experiments)} existing experiments)" | |
return self.save_experiments(updated_experiments, commit_message) | |
except Exception as e: | |
logger.error(f"β Failed to upsert experiment: {e}") | |
return False | |
def get_experiment_by_id(self, experiment_id: str) -> Optional[Dict[str, Any]]: | |
""" | |
Retrieve a specific experiment by its ID. | |
Args: | |
experiment_id (str): The experiment ID to search for | |
Returns: | |
Optional[Dict[str, Any]]: The experiment dictionary if found, None otherwise | |
""" | |
try: | |
experiments = self.load_existing_experiments() | |
for exp in experiments: | |
if exp.get('experiment_id') == experiment_id: | |
logger.info(f"β Found experiment: {experiment_id}") | |
return exp | |
logger.info(f"π Experiment not found: {experiment_id}") | |
return None | |
except Exception as e: | |
logger.error(f"β Failed to get experiment {experiment_id}: {e}") | |
return None | |
def list_experiments(self, status_filter: Optional[str] = None) -> List[Dict[str, Any]]: | |
""" | |
List all experiments, optionally filtered by status. | |
Args: | |
status_filter (Optional[str]): Filter by experiment status (running, completed, failed, paused) | |
Returns: | |
List[Dict[str, Any]]: List of experiments matching the filter | |
""" | |
try: | |
experiments = self.load_existing_experiments() | |
if status_filter: | |
filtered_experiments = [exp for exp in experiments if exp.get('status') == status_filter] | |
logger.info(f"π Found {len(filtered_experiments)} experiments with status '{status_filter}'") | |
return filtered_experiments | |
logger.info(f"π Found {len(experiments)} total experiments") | |
return experiments | |
except Exception as e: | |
logger.error(f"β Failed to list experiments: {e}") | |
return [] | |
def backup_dataset(self, backup_suffix: Optional[str] = None) -> str: | |
""" | |
Create a backup of the current dataset. | |
Args: | |
backup_suffix (Optional[str]): Optional suffix for backup repo name | |
Returns: | |
str: Backup repository name if successful, empty string otherwise | |
""" | |
try: | |
if not backup_suffix: | |
backup_suffix = datetime.now().strftime('%Y%m%d_%H%M%S') | |
backup_repo = f"{self.dataset_repo}-backup-{backup_suffix}" | |
# Load current experiments | |
experiments = self.load_existing_experiments() | |
if not experiments: | |
logger.warning("β οΈ No experiments to backup") | |
return "" | |
# Create backup dataset manager | |
backup_manager = TrackioDatasetManager(backup_repo, self.hf_token) | |
# Save to backup | |
success = backup_manager.save_experiments( | |
experiments, | |
f"Backup of {self.dataset_repo} created on {datetime.now().isoformat()}" | |
) | |
if success: | |
logger.info(f"β Backup created: {backup_repo}") | |
return backup_repo | |
else: | |
logger.error("β Failed to create backup") | |
return "" | |
except Exception as e: | |
logger.error(f"β Failed to create backup: {e}") | |
return "" | |
def create_dataset_manager(dataset_repo: str, hf_token: str) -> TrackioDatasetManager: | |
""" | |
Factory function to create a TrackioDatasetManager instance. | |
Args: | |
dataset_repo (str): HF dataset repository ID | |
hf_token (str): Hugging Face token | |
Returns: | |
TrackioDatasetManager: Configured dataset manager instance | |
""" | |
return TrackioDatasetManager(dataset_repo, hf_token) | |