Spaces:
Sleeping
Sleeping
| from enum import Enum | |
| import json | |
| import os | |
| from typing import Dict, List, Optional | |
| from lpm_kernel.api.domains.trainprocess.progress_enum import Status | |
| from lpm_kernel.api.domains.trainprocess.train_progress import TrainProgress | |
| from lpm_kernel.api.domains.trainprocess.process_step import ProcessStep | |
| from lpm_kernel.configs.logging import get_train_process_logger | |
| logger = get_train_process_logger() | |
| class TrainProgressHolder: | |
| """Progress management class""" | |
| def __init__(self, model_name: str = None): | |
| progress_dir = os.path.join(os.getcwd(), "data", "progress") | |
| if not os.path.exists(progress_dir): | |
| os.makedirs(progress_dir) | |
| # Generate progress file name based on model name | |
| progress_file = "trainprocess_progress.json" # Default name | |
| if model_name: | |
| progress_file = f"trainprocess_progress_{model_name}.json" | |
| self.progress_file = os.path.normpath(os.path.join(progress_dir, progress_file)) | |
| if not self.progress_file.startswith(progress_dir): | |
| raise ValueError("Invalid progress file path") | |
| self.progress = TrainProgress() | |
| # Stage mapping for process steps | |
| self._stage_mapping = { | |
| ProcessStep.MODEL_DOWNLOAD: "downloading_the_base_model", | |
| ProcessStep.LIST_DOCUMENTS: "activating_the_memory_matrix", | |
| ProcessStep.GENERATE_DOCUMENT_EMBEDDINGS: "activating_the_memory_matrix", | |
| ProcessStep.CHUNK_DOCUMENT: "activating_the_memory_matrix", | |
| ProcessStep.CHUNK_EMBEDDING: "activating_the_memory_matrix", | |
| ProcessStep.EXTRACT_DIMENSIONAL_TOPICS: "synthesize_your_life_narrative", | |
| ProcessStep.GENERATE_BIOGRAPHY: "synthesize_your_life_narrative", | |
| ProcessStep.MAP_ENTITY_NETWORK: "synthesize_your_life_narrative", | |
| ProcessStep.DECODE_PREFERENCE_PATTERNS: "prepare_training_data_for_deep_comprehension", | |
| ProcessStep.REINFORCE_IDENTITY: "prepare_training_data_for_deep_comprehension", | |
| ProcessStep.AUGMENT_CONTENT_RETENTION: "prepare_training_data_for_deep_comprehension", | |
| ProcessStep.TRAIN: "training_to_create_second_me", | |
| ProcessStep.MERGE_WEIGHTS: "training_to_create_second_me", | |
| ProcessStep.CONVERT_MODEL: "training_to_create_second_me", | |
| } | |
| self._load_progress() | |
| def _load_progress(self): | |
| """Load progress file""" | |
| if os.path.exists(self.progress_file): | |
| try: | |
| with open(self.progress_file, "r") as f: | |
| saved_progress = json.load(f) | |
| self.progress.data = saved_progress | |
| self.progress.stage_map = {} | |
| for stage in self.progress.data["stages"]: | |
| stage_name = stage["name"].lower().replace(" ", "_") | |
| self.progress.stage_map[stage_name] = stage | |
| self.progress.steps_map = {} | |
| for stage_name, stage in self.progress.stage_map.items(): | |
| self.progress.steps_map[stage_name] = {} | |
| for step in stage["steps"]: | |
| step_name = step["name"].lower().replace(" ", "_") | |
| self.progress.steps_map[stage_name][step_name] = step | |
| # Check and reset any in_progress status to failed | |
| self._reset_in_progress_status() | |
| except Exception as e: | |
| logger.error(f"Error loading progress: {str(e)}") | |
| # Reset progress on any error | |
| self.progress = TrainProgress() | |
| def _reset_in_progress_status(self): | |
| """Reset any in_progress status to failed after loading from file""" | |
| need_save = False | |
| # Check overall status | |
| if self.progress.data["status"] == "in_progress": | |
| self.progress.data["status"] = "failed" | |
| need_save = True | |
| logger.info("Reset overall in_progress status to failed") | |
| # Check each stage | |
| for stage in self.progress.data["stages"]: | |
| if stage["status"] == "in_progress": | |
| stage["status"] = "failed" | |
| need_save = True | |
| logger.info(f"Reset stage '{stage['name']}' in_progress status to failed") | |
| # Check each step in the stage | |
| for step in stage["steps"]: | |
| if step["status"] == "in_progress": | |
| step["status"] = "failed" | |
| step["completed"] = False | |
| need_save = True | |
| logger.info(f"Reset step '{step['name']}' in_progress status to failed") | |
| # Save changes if any were made | |
| if need_save: | |
| progress_dict = self.progress.to_dict() | |
| with open(self.progress_file, "w") as f: | |
| json.dump(progress_dict, f, indent=2) | |
| logger.info("Saved progress after resetting in_progress statuses") | |
| def _save_progress(self): | |
| """Save progress""" | |
| progress_dict = self.progress.to_dict() | |
| with open(self.progress_file, "w") as f: | |
| json.dump(progress_dict, f, indent=2) | |
| def is_step_completed(self, step: ProcessStep) -> bool: | |
| """Check if a step is completed""" | |
| stage_name = self._stage_mapping[step] | |
| step_name = step.value | |
| step_info = self.progress.steps_map[stage_name][step_name] | |
| return step_info.get("completed", False) | |
| def mark_step_status(self, step: ProcessStep, status: Status): | |
| """Mark a step with the specified status | |
| Args: | |
| step: The process step to mark | |
| status: The status to set for the step | |
| """ | |
| stage_name = self._stage_mapping[step] | |
| step_name = step.value | |
| self.progress.update_progress(stage_name, step_name, status) | |
| self._save_progress() | |
| def reset_progress(self): | |
| """Reset all progress""" | |
| self.progress = TrainProgress() | |
| self._save_progress() | |
| def get_last_successful_step(self) -> Optional[ProcessStep]: | |
| """Get the last successfully completed step""" | |
| ordered_steps = ProcessStep.get_ordered_steps() | |
| for step in reversed(ordered_steps): | |
| if self.is_step_completed(step): | |
| return step | |
| return None | |