Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| def plot_fc_matrices(original, reconstructed, generated): | |
| """Plot FC matrices comparison""" | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
| vmin, vmax = -1, 1 | |
| # Convert 1D arrays to 2D matrices if needed | |
| def vector_to_matrix(vector): | |
| """Convert upper triangular vector to full matrix""" | |
| if len(vector.shape) == 1: | |
| # Calculate the matrix size based on vector length | |
| # For a vector of length n, the matrix size is (-1 + sqrt(1 + 8*n))/2 | |
| n = len(vector) | |
| matrix_size = int((-1 + np.sqrt(1 + 8*n)) / 2) | |
| # Create empty matrix | |
| matrix = np.zeros((matrix_size, matrix_size)) | |
| # Fill upper triangle | |
| idx = 0 | |
| for i in range(matrix_size): | |
| for j in range(i+1, matrix_size): | |
| matrix[i, j] = vector[idx] | |
| idx += 1 | |
| # Make symmetric | |
| matrix = matrix + matrix.T | |
| return matrix | |
| return vector | |
| # Convert inputs to matrices if needed | |
| original_mat = vector_to_matrix(original) | |
| reconstructed_mat = vector_to_matrix(reconstructed) | |
| generated_mat = vector_to_matrix(generated) | |
| im1 = axes[0].imshow(original_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax) | |
| axes[0].set_title('Original FC') | |
| im2 = axes[1].imshow(reconstructed_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax) | |
| axes[1].set_title('Reconstructed FC') | |
| im3 = axes[2].imshow(generated_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax) | |
| axes[2].set_title('Generated FC') | |
| for ax, im in zip(axes, [im1, im2, im3]): | |
| plt.colorbar(im, ax=ax) | |
| plt.tight_layout() | |
| return fig | |
| def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke, prediction_std=None): | |
| """Plot predicted treatment trajectory""" | |
| fig = plt.figure(figsize=(10, 6)) | |
| # Plot current and predicted points | |
| plt.scatter([0], [current_score], label='Current Status', color='blue', s=100) | |
| plt.scatter([months_post_stroke], [predicted_score], | |
| label='Predicted Outcome', color='red', s=100) | |
| # Plot trajectory | |
| plt.plot([0, months_post_stroke], [current_score, predicted_score], | |
| 'g--', label='Predicted Trajectory') | |
| # Add prediction interval if available | |
| if prediction_std is not None: | |
| plt.fill_between([months_post_stroke], | |
| [predicted_score - 2*prediction_std], | |
| [predicted_score + 2*prediction_std], | |
| color='red', alpha=0.2, | |
| label='95% Prediction Interval') | |
| plt.xlabel('Months Post Treatment') | |
| plt.ylabel('WAB Score') | |
| plt.title('Predicted Treatment Trajectory') | |
| plt.legend() | |
| plt.grid(True) | |
| return fig | |
| def plot_learning_curves(train_losses, val_losses): | |
| """Plot VAE learning curves""" | |
| fig = plt.figure(figsize=(10, 6)) | |
| plt.plot(train_losses, label='Training Loss') | |
| plt.plot(val_losses, label='Validation Loss') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.title('VAE Learning Curves') | |
| plt.legend() | |
| plt.grid(True) | |
| return fig | |