"""Dataset generation functions for testing BackpropNEAT."""

import numpy as np
import jax.numpy as jnp

def generate_xor_data(n_samples: int = 200, complexity: float = 1.0) -> tuple:
    """Generate complex XOR dataset with multiple clusters and rotations.
    
    Args:
        n_samples: Number of samples per quadrant
        complexity: Controls the complexity of the pattern (rotation and noise)
        
    Returns:
        Tuple of (features, labels)
    """
    points = []
    labels = []
    
    # Generate multiple clusters per quadrant
    n_clusters = 3
    samples_per_cluster = n_samples // n_clusters
    
    for cluster in range(n_clusters):
        # Add rotation to each subsequent cluster
        rotation = complexity * cluster * np.pi / 6  # 30 degree rotation per cluster
        
        # Define cluster centers with gaps
        centers = [
            # (x, y, radius, label)
            (-0.7 - 0.3*cluster, -0.7 - 0.3*cluster, 0.2, -1),  # Bottom-left
            (0.7 + 0.3*cluster, 0.7 + 0.3*cluster, 0.2, -1),   # Top-right
            (-0.7 - 0.3*cluster, 0.7 + 0.3*cluster, 0.2, 1),   # Top-left
            (0.7 + 0.3*cluster, -0.7 - 0.3*cluster, 0.2, 1),   # Bottom-right
        ]
        
        for cx, cy, radius, label in centers:
            # Generate points in a circle around center
            theta = np.random.uniform(0, 2*np.pi, samples_per_cluster)
            r = np.random.uniform(0, radius, samples_per_cluster)
            
            # Convert to cartesian coordinates
            x = r * np.cos(theta)
            y = r * np.sin(theta)
            
            # Apply rotation
            x_rot = x * np.cos(rotation) - y * np.sin(rotation)
            y_rot = x * np.sin(rotation) + y * np.cos(rotation)
            
            # Add cluster center and noise
            x = cx + x_rot + np.random.normal(0, 0.05, samples_per_cluster)
            y = cy + y_rot + np.random.normal(0, 0.05, samples_per_cluster)
            
            # Add points
            cluster_points = np.column_stack([x, y])
            points.append(cluster_points)
            labels.extend([label] * samples_per_cluster)
    
    # Convert to arrays
    X = np.vstack(points)
    y = np.array(labels, dtype=np.float32)
    
    # Add global rotation
    theta = complexity * np.pi / 4  # 45 degree global rotation
    rotation_matrix = np.array([
        [np.cos(theta), -np.sin(theta)],
        [np.sin(theta), np.cos(theta)]
    ])
    X = X @ rotation_matrix
    
    # Shuffle data
    perm = np.random.permutation(len(X))
    X = X[perm]
    y = y[perm]
    
    return jnp.array(X), jnp.array(y)

def generate_circle_data(n_samples: int = 1000, noise: float = 0.1) -> tuple:
    """Generate circle classification dataset.
    
    Args:
        n_samples: Number of samples per class
        noise: Standard deviation of Gaussian noise
        
    Returns:
        Tuple of (features, labels)
    """
    # Generate random angles
    theta = np.random.uniform(0, 2*np.pi, n_samples)
    
    # Inner circle (class -1)
    r_inner = 0.5 + np.random.normal(0, noise, n_samples)
    X_inner = np.column_stack([
        r_inner * np.cos(theta),
        r_inner * np.sin(theta)
    ])
    y_inner = np.full(n_samples, -1.0)
    
    # Outer circle (class 1)
    r_outer = 1.5 + np.random.normal(0, noise, n_samples)
    X_outer = np.column_stack([
        r_outer * np.cos(theta),
        r_outer * np.sin(theta)
    ])
    y_outer = np.full(n_samples, 1.0)
    
    # Combine and shuffle
    X = np.vstack([X_inner, X_outer])
    y = np.hstack([y_inner, y_outer])
    
    # Shuffle
    perm = np.random.permutation(len(X))
    return X[perm], y[perm]

def generate_spiral_dataset(n_points=1000, noise=0.1):
    """Generate a spiral dataset with rotation-invariant features."""
    # Generate theta values with more points near the center
    theta = np.sqrt(np.random.uniform(0, 1, n_points)) * 4 * np.pi
    
    # Generate two spirals
    data = []
    labels = []
    eps = 1e-8
    
    for i in range(n_points):
        # Base radius increases with theta
        r_base = theta[i] / (4 * np.pi)
        
        # Add noise that scales with radius
        noise_scale = noise * (1 - np.exp(-2 * r_base))
        
        for spiral_idx in range(2):
            # Rotate second spiral by pi
            angle = theta[i] + np.pi * spiral_idx
            
            # Add controlled noise to radius and angle
            r = r_base + np.random.normal(0, noise_scale)
            angle_noise = np.random.normal(0, noise_scale * 0.1)  # Less noise in angle
            angle += angle_noise
            
            # Calculate cartesian coordinates
            x = r * np.cos(angle)
            y = r * np.sin(angle)
            
            # Calculate polar coordinates
            r_point = np.sqrt(x*x + y*y)
            theta_point = np.arctan2(y, x)
            
            # Unwrap theta to handle multiple revolutions
            theta_unwrapped = theta_point + 2 * np.pi * (angle // (2 * np.pi))
            
            # Calculate spiral-specific features
            
            # 1. Local curvature (how much the spiral curves at this point)
            curvature = 1 / (r_point + eps)
            
            # 2. Spiral phase (position along spiral revolution)
            phase = theta_unwrapped % (2 * np.pi) / (2 * np.pi)
            
            # 3. Radial velocity (how fast radius changes with angle)
            dr_dtheta = 1 / (4 * np.pi)
            
            # 4. Normalized angular position (accounts for multiple revolutions)
            angular_pos = theta_unwrapped / (4 * np.pi)
            
            # 5. Spiral tightness (local measure of how tight the spiral is)
            tightness = r_point / (theta_unwrapped + eps)
            
            # 6. Relative position features (help distinguish between spirals)
            # Distance to other spiral
            other_angle = angle + np.pi
            other_x = r * np.cos(other_angle)
            other_y = r * np.sin(other_angle)
            dist_to_other = np.sqrt((x - other_x)**2 + (y - other_y)**2)
            
            # 7. Rotation-invariant features
            sin_phase = np.sin(phase * 2 * np.pi)
            cos_phase = np.cos(phase * 2 * np.pi)
            
            # Combine features with careful normalization
            features = np.array([
                x / 2.0,  # Normalize coordinates
                y / 2.0,
                r_point / 2.0,  # Normalize radius
                sin_phase,  # Already normalized
                cos_phase,  # Already normalized
                np.tanh(curvature * 2),  # Normalize curvature
                angular_pos / 2.0,  # Normalize angular position
                np.tanh(tightness),  # Normalize tightness
                np.tanh(dr_dtheta * 10),  # Normalize radial velocity
                dist_to_other / 4.0  # Normalize distance to other spiral
            ])
            
            data.append(features)
            labels.append(spiral_idx * 2 - 1)  # Convert to [-1, 1]
    
    return np.array(data), np.array(labels)

def generate_checkerboard_data(n_samples: int = 200) -> tuple:
    """Generate checkerboard dataset.
    
    Args:
        n_samples: Number of samples per class
        
    Returns:
        Tuple of (features, labels)
    """
    # Generate random points
    X = np.random.uniform(-2, 2, (n_samples * 2, 2))
    
    # Assign labels based on checkerboard pattern
    y = np.zeros(n_samples * 2)
    for i in range(len(X)):
        x1, x2 = X[i]
        y[i] = 1 if (int(np.floor(x1)) + int(np.floor(x2))) % 2 == 0 else 0
    
    return jnp.array(X), jnp.array(y)

# Export dataset functions
__all__ = ['generate_xor_data', 'generate_circle_data', 'generate_spiral_dataset', 
           'generate_checkerboard_data']