|
|
import matplotlib.pyplot as plt
|
|
|
import numpy as np
|
|
|
import os
|
|
|
|
|
|
def plot_sigma_sequence(sigs, stopping_index, log_filename, save_directory="modules/sd_simple_kes_v2/image_generation_data", show_plot=False):
|
|
|
"""
|
|
|
Plot the sigma sequence and mark the early stopping point.
|
|
|
|
|
|
Parameters:
|
|
|
- sigs: The sigma tensor or numpy array (can be truncated if stopping early).
|
|
|
- stopping_index: The step index where early stopping was triggered.
|
|
|
- log_filename: The filename of the generation log (used to match the graph name).
|
|
|
- save_directory: The folder where the plot should be saved.
|
|
|
- show_plot: Set to True to display the plot interactively.
|
|
|
"""
|
|
|
|
|
|
|
|
|
base_filename = os.path.splitext(os.path.basename(log_filename))[0]
|
|
|
graph_filename = f"{base_filename}_sigma_plot.png"
|
|
|
graph_path = os.path.join(save_directory, graph_filename)
|
|
|
|
|
|
|
|
|
sigs_np = sigs.cpu().numpy() if hasattr(sigs, 'cpu') else np.array(sigs)
|
|
|
x = np.arange(len(sigs_np))
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
plt.plot(x, sigs_np, label='Sigma Sequence', marker='o')
|
|
|
plt.axvline(x=stopping_index, color='red', linestyle='--', label=f'Stopping Point: {stopping_index}')
|
|
|
plt.xlabel('Step Index')
|
|
|
plt.ylabel('Sigma Value')
|
|
|
plt.title('Sigma Sequence with Early Stopping Point')
|
|
|
plt.legend()
|
|
|
plt.grid(True)
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(graph_path)
|
|
|
|
|
|
if show_plot:
|
|
|
plt.show()
|
|
|
|
|
|
plt.close()
|
|
|
return graph_path |