Spaces:
Sleeping
Sleeping
import inspect | |
import json | |
import logging | |
import os | |
from datetime import datetime | |
from typing import Any, Dict, Set | |
from uuid import UUID | |
logger = logging.getLogger(__name__) | |
class SafeLoaderUtils: | |
""" | |
Utility class for safely loading and saving object states while automatically | |
detecting and preserving class instances and complex objects. | |
""" | |
def is_class_instance(obj: Any) -> bool: | |
""" | |
Detect if an object is a class instance (excluding built-in types). | |
Args: | |
obj: Object to check | |
Returns: | |
bool: True if object is a class instance | |
""" | |
if obj is None: | |
return False | |
# Get the type of the object | |
obj_type = type(obj) | |
# Check if it's a class instance but not a built-in type | |
return ( | |
hasattr(obj, "__dict__") | |
and not isinstance(obj, type) | |
and obj_type.__module__ != "builtins" | |
) | |
def is_safe_type(value: Any) -> bool: | |
""" | |
Check if a value is of a safe, serializable type. | |
Args: | |
value: Value to check | |
Returns: | |
bool: True if the value is safe to serialize | |
""" | |
# Basic safe types | |
safe_types = ( | |
type(None), | |
bool, | |
int, | |
float, | |
str, | |
datetime, | |
UUID, | |
) | |
if isinstance(value, safe_types): | |
return True | |
# Check containers | |
if isinstance(value, (list, tuple)): | |
return all( | |
SafeLoaderUtils.is_safe_type(item) for item in value | |
) | |
if isinstance(value, dict): | |
return all( | |
isinstance(k, str) and SafeLoaderUtils.is_safe_type(v) | |
for k, v in value.items() | |
) | |
# Check for common serializable types | |
try: | |
json.dumps(value) | |
return True | |
except (TypeError, OverflowError, ValueError): | |
return False | |
def get_class_attributes(obj: Any) -> Set[str]: | |
""" | |
Get all attributes of a class, including inherited ones. | |
Args: | |
obj: Object to inspect | |
Returns: | |
Set[str]: Set of attribute names | |
""" | |
attributes = set() | |
# Get all attributes from class and parent classes | |
for cls in inspect.getmro(type(obj)): | |
attributes.update(cls.__dict__.keys()) | |
# Add instance attributes | |
attributes.update(obj.__dict__.keys()) | |
return attributes | |
def create_state_dict(obj: Any) -> Dict[str, Any]: | |
""" | |
Create a dictionary of safe values from an object's state. | |
Args: | |
obj: Object to create state dict from | |
Returns: | |
Dict[str, Any]: Dictionary of safe values | |
""" | |
state_dict = {} | |
for attr_name in SafeLoaderUtils.get_class_attributes(obj): | |
# Skip private attributes | |
if attr_name.startswith("_"): | |
continue | |
try: | |
value = getattr(obj, attr_name, None) | |
if SafeLoaderUtils.is_safe_type(value): | |
state_dict[attr_name] = value | |
except Exception as e: | |
logger.debug(f"Skipped attribute {attr_name}: {e}") | |
return state_dict | |
def preserve_instances(obj: Any) -> Dict[str, Any]: | |
""" | |
Automatically detect and preserve all class instances in an object. | |
Args: | |
obj: Object to preserve instances from | |
Returns: | |
Dict[str, Any]: Dictionary of preserved instances | |
""" | |
preserved = {} | |
for attr_name in SafeLoaderUtils.get_class_attributes(obj): | |
if attr_name.startswith("_"): | |
continue | |
try: | |
value = getattr(obj, attr_name, None) | |
if SafeLoaderUtils.is_class_instance(value): | |
preserved[attr_name] = value | |
except Exception as e: | |
logger.debug(f"Could not preserve {attr_name}: {e}") | |
return preserved | |
class SafeStateManager: | |
""" | |
Manages saving and loading object states while automatically handling | |
class instances and complex objects. | |
""" | |
def save_state(obj: Any, file_path: str) -> None: | |
""" | |
Save an object's state to a file, automatically handling complex objects. | |
Args: | |
obj: Object to save state from | |
file_path: Path to save state to | |
""" | |
try: | |
# Create state dict with only safe values | |
state_dict = SafeLoaderUtils.create_state_dict(obj) | |
# Ensure directory exists | |
os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
# Save to file | |
with open(file_path, "w") as f: | |
json.dump(state_dict, f, indent=4, default=str) | |
logger.info(f"Successfully saved state to: {file_path}") | |
except Exception as e: | |
logger.error(f"Error saving state: {e}") | |
raise | |
def load_state(obj: Any, file_path: str) -> None: | |
""" | |
Load state into an object while preserving class instances. | |
Args: | |
obj: Object to load state into | |
file_path: Path to load state from | |
""" | |
try: | |
# Verify file exists | |
if not os.path.exists(file_path): | |
raise FileNotFoundError( | |
f"State file not found: {file_path}" | |
) | |
# Preserve existing instances | |
preserved = SafeLoaderUtils.preserve_instances(obj) | |
# Load state | |
with open(file_path, "r") as f: | |
state_dict = json.load(f) | |
# Set safe values | |
for key, value in state_dict.items(): | |
if ( | |
not key.startswith("_") | |
and key not in preserved | |
and SafeLoaderUtils.is_safe_type(value) | |
): | |
setattr(obj, key, value) | |
# Restore preserved instances | |
for key, value in preserved.items(): | |
setattr(obj, key, value) | |
logger.info( | |
f"Successfully loaded state from: {file_path}" | |
) | |
except Exception as e: | |
logger.error(f"Error loading state: {e}") | |
raise | |
# # Example decorator for easy integration | |
# def safe_state_methods(cls: Type) -> Type: | |
# """ | |
# Class decorator to add safe state loading/saving methods to a class. | |
# Args: | |
# cls: Class to decorate | |
# Returns: | |
# Type: Decorated class | |
# """ | |
# def save(self, file_path: str) -> None: | |
# SafeStateManager.save_state(self, file_path) | |
# def load(self, file_path: str) -> None: | |
# SafeStateManager.load_state(self, file_path) | |
# cls.save = save | |
# cls.load = load | |
# return cls | |