Gemini
feat: add detailed logging
01d5a5d
import os
import re
import time
import psutil
from typing import Optional, Dict
from lpm_kernel.L1.utils import save_true_topics
from lpm_kernel.L1.serializers import NotesStorage
from lpm_kernel.kernel.note_service import NoteService
from lpm_kernel.L2.l2_generator import L2Generator
from lpm_kernel.L2.utils import save_hf_model
from lpm_kernel.api.common.responses import APIResponse
from lpm_kernel.api.domains.loads.services import LoadService
from lpm_kernel.kernel.chunk_service import ChunkService
from lpm_kernel.kernel.l1.l1_manager import (
extract_notes_from_documents,
document_service,
get_latest_status_bio,
get_latest_global_bio,
)
from lpm_kernel.api.common.script_executor import ScriptExecutor
from lpm_kernel.configs.config import Config
from lpm_kernel.file_data.chunker import DocumentChunker
from lpm_kernel.kernel.l1.l1_manager import generate_l1_from_l0
import threading
from lpm_kernel.api.domains.trainprocess.progress_enum import Status
from lpm_kernel.api.domains.trainprocess.process_step import ProcessStep
from lpm_kernel.api.domains.trainprocess.progress_holder import TrainProgressHolder
from lpm_kernel.api.domains.trainprocess.training_params_manager import TrainingParamsManager
from lpm_kernel.models.l1 import L1Bio, L1Shade
from lpm_kernel.common.repository.database_session import DatabaseSession
from lpm_kernel.api.domains.kernel.routes import store_l1_data
from lpm_kernel.api.domains.trainprocess.L1_exposure_manager import output_files, query_l1_version_data, read_file_content
import gc
import subprocess
from lpm_kernel.configs.logging import get_train_process_logger, TRAIN_LOG_FILE
logger = get_train_process_logger()
class TrainProcessService:
"""Training process service (singleton pattern)"""
_instance = None
_initialized = False
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, current_model_name: str):
if current_model_name is None:
raise ValueError("current_model_name cannot be None")
if not self._initialized:
# Generate a unique progress file name based on model name
self.progress = TrainProgressHolder(current_model_name)
self.model_name = current_model_name # Set model name directly
self._initialized = True
# Initialize stop flag
self.is_stopped = False
self.current_step = None
# Initialize L2 data dictionary
self.l2_data = {
"notes": None,
"basic_info": None,
"data_output_base_dir": None,
"topics_path": None,
"entitys_path": None,
"graph_path": None,
"config_path": None
}
self.l2_data_prepared = False
# Update model name and progress instance if model name changes
if current_model_name != self.model_name:
self.model_name = current_model_name
# Create new progress instance with updated progress file name
self.progress = TrainProgressHolder(current_model_name)
@classmethod
def get_instance(cls, current_model_name: str = None):
"""Get the current instance of TrainProcessService
Args:
current_model_name: Optional model name to update the instance with
Returns:
TrainProcessService: The singleton instance
"""
if cls._instance is None:
if current_model_name is None:
logger.warning("current_model_name must be provided when creating a new instance")
return None
return cls(current_model_name)
if current_model_name is not None:
# Update the existing instance with new model name
cls._instance.model_name = current_model_name
cls._instance.progress = TrainProgressHolder(current_model_name)
return cls._instance
def list_documents(self):
"""List all documents"""
try:
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.LIST_DOCUMENTS, Status.IN_PROGRESS)
# Directly call document service instead of API
documents = document_service.list_documents()
# Mark step as completed if we found documents
self.progress.mark_step_status(ProcessStep.LIST_DOCUMENTS, Status.COMPLETED)
return [doc.to_dict() for doc in documents]
except Exception as e:
logger.error(f"List documents failed: {str(e)}")
self.progress.mark_step_status(ProcessStep.LIST_DOCUMENTS, Status.FAILED)
return []
def generate_document_embeddings(self) -> bool:
"""Process embeddings for all documents"""
try:
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.GENERATE_DOCUMENT_EMBEDDINGS, Status.IN_PROGRESS)
documents = self.list_documents()
for doc in documents:
doc_id = doc.get("id")
# Directly call document service instead of API
embedding = document_service.process_document_embedding(doc_id)
if embedding is None:
logger.error(
f"Generate document embeddings failed for doc_id: {doc_id}"
)
self.progress.mark_step_status(ProcessStep.GENERATE_DOCUMENT_EMBEDDINGS, Status.FAILED)
return False
self.progress.mark_step_status(ProcessStep.GENERATE_DOCUMENT_EMBEDDINGS, Status.COMPLETED)
logger.info(f"Successfully generated embedding for document {doc_id}")
return True
except Exception as e:
logger.error(f"Generate document embeddings failed: {str(e)}")
self.progress.mark_step_status(ProcessStep.GENERATE_DOCUMENT_EMBEDDINGS, Status.FAILED)
return False
def process_chunks(self) -> bool:
"""Process document chunks"""
try:
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.CHUNK_DOCUMENT, Status.IN_PROGRESS)
config = Config.from_env()
chunker = DocumentChunker(
chunk_size=int(config.get("DOCUMENT_CHUNK_SIZE")),
overlap=int(config.get("DOCUMENT_CHUNK_OVERLAP")),
)
documents = document_service.list_documents()
processed, failed = 0, 0
chunk_service = ChunkService()
for doc in documents:
try:
if not doc.raw_content:
logger.warning(f"Document {doc.id} has no content, skipping...")
failed += 1
continue
# Split into chunks and save
chunks = chunker.split(doc.raw_content)
for chunk in chunks:
chunk.document_id = doc.id
chunk_service.save_chunk(chunk)
processed += 1
logger.info(
f"Document {doc.id} processed: {len(chunks)} chunks created"
)
except Exception as e:
logger.error(f"Failed to process document {doc.id}: {str(e)}")
failed += 1
self.progress.mark_step_status(ProcessStep.CHUNK_DOCUMENT, Status.COMPLETED)
return True
except Exception as e:
logger.error(f"Process chunks failed: {str(e)}")
self.progress.mark_step_status(ProcessStep.CHUNK_DOCUMENT, Status.FAILED)
return False
def chunk_embedding(self) -> bool:
"""Process embeddings for all document chunks"""
try:
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.CHUNK_EMBEDDING, Status.IN_PROGRESS)
documents = self.list_documents()
for doc in documents:
doc_id = doc.get("id")
try:
# Directly call document service to generate chunk embeddings
processed_chunks = document_service.generate_document_chunk_embeddings(doc_id)
if not processed_chunks:
logger.warning(f"No chunks to process for document: {doc_id}")
continue
except Exception as e:
logger.error(
f"Generate chunk embeddings failed for doc_id: {doc_id}: {str(e)}"
)
self.progress.mark_step_status(ProcessStep.CHUNK_EMBEDDING, Status.FAILED)
return False
# All documents' chunks processed successfully
self.progress.mark_step_status(ProcessStep.CHUNK_EMBEDDING, Status.COMPLETED)
return True
except Exception as e:
logger.error(f"Generate chunk embeddings failed: {str(e)}")
self.progress.mark_step_status(ProcessStep.CHUNK_EMBEDDING, Status.FAILED)
return False
def extract_dimensional_topics(self) -> bool:
"""Extract dimensional topics (L0)"""
try:
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.EXTRACT_DIMENSIONAL_TOPICS, Status.IN_PROGRESS)
logger.info("Starting dimensional topics extraction (L0)...")
# Generate L0 - Call document_service to analyze all documents
logger.info("Generating L0 data...")
analyzed_docs = document_service.analyze_all_documents()
logger.info(f"Successfully analyzed {len(analyzed_docs)} documents for L0")
# Mark step as completed
self.progress.mark_step_status(ProcessStep.EXTRACT_DIMENSIONAL_TOPICS, Status.COMPLETED)
logger.info("Dimensional topics extraction (L0) completed successfully")
return True
except Exception as e:
logger.error(f"Extract dimensional topics (L0) failed: {str(e)}")
self.progress.mark_step_status(ProcessStep.EXTRACT_DIMENSIONAL_TOPICS, Status.FAILED)
return False
def generate_biography(self) -> bool:
"""Generate biography using L1 data"""
try:
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.GENERATE_BIOGRAPHY, Status.IN_PROGRESS)
logger.info("Starting biography generation...")
# Generate L1 data and biography
logger.info("Generating L1 data and biography...")
l1_data = generate_l1_from_l0()
logger.info("Successfully generated L1 data and biography")
# Store L1 data
with DatabaseSession.session() as session:
store_l1_data(session, l1_data)
# Mark step as completed
self.progress.mark_step_status(ProcessStep.GENERATE_BIOGRAPHY, Status.COMPLETED)
logger.info("Biography generation completed successfully")
return True
except Exception as e:
logger.error(f"Biography generation failed: {str(e)}")
self.progress.mark_step_status(ProcessStep.GENERATE_BIOGRAPHY, Status.FAILED)
return False
def model_download(self) -> bool:
"""Download model"""
try:
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.MODEL_DOWNLOAD, Status.IN_PROGRESS)
# Directly call save_hf_model function to download model
logger.info(f"Starting model download: {self.model_name}")
# Start monitoring the download progress in a separate thread
monitor_thread = threading.Thread(target=self._monitor_model_download)
monitor_thread.daemon = True
monitor_thread.start()
# Start the actual download
model_path = save_hf_model(self.model_name)
if model_path and os.path.exists(model_path):
logger.info(f"Model downloaded successfully to {model_path}")
self.progress.mark_step_status(ProcessStep.MODEL_DOWNLOAD, Status.COMPLETED)
return True
else:
logger.error(f"Model path does not exist after download: {model_path}")
self.progress.mark_step_status(ProcessStep.MODEL_DOWNLOAD, Status.FAILED)
return False
except Exception as e:
logger.error(f"Download model failed: {str(e)}")
self.progress.mark_step_status(ProcessStep.MODEL_DOWNLOAD, Status.FAILED)
return False
def map_your_entity_network(self)->bool:
"""Map entity network using notes and basic info"""
try:
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.MAP_ENTITY_NETWORK, Status.IN_PROGRESS)
logger.info("Starting entity network mapping...")
# Get or prepare L2 data
self._prepare_l2_data()
l2_generator = L2Generator(
data_path=os.path.join(os.getcwd(), "resources")
)
l2_generator.data_preprocess(self.l2_data["notes"], self.l2_data["basic_info"])
self.progress.mark_step_status(ProcessStep.MAP_ENTITY_NETWORK, Status.COMPLETED)
logger.info("Entity network mapping completed successfully")
return True
except Exception as e:
logger.error(f"Map entity network failed: {str(e)}")
self.progress.mark_step_status(ProcessStep.MAP_ENTITY_NETWORK, Status.FAILED)
self._cleanup_resources()
return False
def decode_preference_patterns(self)->bool:
"""Decode preference patterns using notes and related data"""
try:
params_manager = TrainingParamsManager()
training_params = params_manager.get_latest_training_params()
concurrency_threads = training_params.get("concurrency_threads")
data_synthesis_mode = training_params.get("data_synthesis_mode")
os.environ["CONCURRENCY_THREADS"] = str(concurrency_threads)
os.environ["DATA_SYNTHESIS_MODE"] = data_synthesis_mode
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.DECODE_PREFERENCE_PATTERNS, Status.IN_PROGRESS)
logger.info("Starting preference patterns decoding...")
# Get or prepare L2 data
self._prepare_l2_data()
# Use data from l2_data dictionary
training_params = TrainingParamsManager.get_latest_training_params()
L2Generator(is_cot=training_params.get("is_cot", False)).gen_preference_data(
self.l2_data["notes"],
self.l2_data["basic_info"],
self.l2_data["data_output_base_dir"],
self.l2_data["topics_path"],
self.l2_data["entitys_path"],
self.l2_data["graph_path"],
self.l2_data["config_path"]
)
self.progress.mark_step_status(ProcessStep.DECODE_PREFERENCE_PATTERNS, Status.COMPLETED)
logger.info("Preference patterns decoding completed successfully")
return True
except Exception as e:
logger.error(f"Decode preference patterns failed: {str(e)}")
self.progress.mark_step_status(ProcessStep.DECODE_PREFERENCE_PATTERNS, Status.FAILED)
return False
def reinforce_identity(self)->bool:
"""Reinforce identity using notes and related data"""
try:
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.REINFORCE_IDENTITY, Status.IN_PROGRESS)
logger.info("Starting identity reinforcement...")
# Get or prepare L2 data
self._prepare_l2_data()
# Get training parameters
training_params = TrainingParamsManager.get_latest_training_params()
# Use data from l2_data dictionary
l2_generator = L2Generator(
data_path=os.path.join(os.getcwd(), "resources"), is_cot=training_params.get("is_cot", False)
)
l2_generator.gen_selfqa_data(
self.l2_data["notes"],
self.l2_data["basic_info"],
self.l2_data["data_output_base_dir"],
self.l2_data["topics_path"],
self.l2_data["entitys_path"],
self.l2_data["graph_path"],
self.l2_data["config_path"]
)
self.progress.mark_step_status(ProcessStep.REINFORCE_IDENTITY, Status.COMPLETED)
logger.info("Identity reinforcement completed successfully")
return True
except Exception as e:
logger.error(f"Reinforce identity failed: {str(e)}")
self.progress.mark_step_status(ProcessStep.REINFORCE_IDENTITY, Status.FAILED)
return False
def _cleanup_resources(self):
"""Clean up resources to prevent memory leaks"""
logger.info("Cleaning up resources to prevent memory leaks")
# Clean up large data structures in l2_data dictionary
for key in self.l2_data:
self.l2_data[key] = None
self.l2_data_prepared = False
# Force garbage collection
gc.collect()
# Log memory usage after cleanup
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
logger.info(f"Memory usage after cleanup: {memory_info.rss / 1024 / 1024:.2f} MB")
def augment_content_retention(self) -> bool:
"""Augment content retention using notes, basic info and graph data"""
try:
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.AUGMENT_CONTENT_RETENTION, Status.IN_PROGRESS)
logger.info("Starting content retention augmentation...")
# Get or prepare L2 data
self._prepare_l2_data()
# Get training parameters
training_params = TrainingParamsManager.get_latest_training_params()
# Use data from l2_data dictionary
l2_generator = L2Generator(data_path=os.path.join(os.getcwd(), "resources"), is_cot=training_params.get("is_cot", False))
l2_generator.gen_diversity_data(
self.l2_data["notes"],
self.l2_data["basic_info"],
self.l2_data["data_output_base_dir"],
self.l2_data["topics_path"],
self.l2_data["entitys_path"],
self.l2_data["graph_path"],
self.l2_data["config_path"]
)
l2_generator.merge_json_files(self.l2_data["data_output_base_dir"])
# Mark step as completed
logger.info("Content retention augmentation completed successfully")
self.progress.mark_step_status(ProcessStep.AUGMENT_CONTENT_RETENTION, Status.COMPLETED)
# Clean up resources after completion
self._cleanup_resources()
return True
except Exception as e:
logger.error(f"Failed to augment content retention: {str(e)}")
self.progress.mark_step_status(ProcessStep.AUGMENT_CONTENT_RETENTION, Status.FAILED)
# Clean up resources even if there was an error
self._cleanup_resources()
return False
def _prepare_l2_data(self) -> dict:
"""Prepare common data needed for L2 generation tasks using lazy loading
Returns:
Dictionary containing all L2 data:
- notes: List of prepared notes
- basic_info: Dict containing user information
- data_output_base_dir: Path to output directory
- topics_path: Path to topics data
- entitys_path: Path to entity mapping file
- graph_path: Path to graph data
- config_path: Path to config file
"""
# If data is already prepared, return cached data directly
if self.l2_data_prepared and all(self.l2_data.values()):
logger.info("Using cached L2 data")
return self.l2_data
logger.info("Preparing L2 data...")
# Setup directories and paths
config = Config.from_env()
base_dir = os.path.join(
os.getcwd(), config.get("USER_DATA_PIPELINE_DIR") + "/raw_data"
)
os.makedirs(base_dir, exist_ok=True)
# get topic
topics_path = os.path.join(base_dir, "topics.json")
self.l2_data["topics_path"] = topics_path
logger.info("Topics data not found, generating it...")
chunk_service = ChunkService()
topics_data = chunk_service.query_topics_data()
save_true_topics(topics_data, topics_path)
# Initialize storage
storage = NotesStorage()
logger.info("Notes not found, preparing them...")
documents = document_service.list_documents_with_l0()
logger.info(f"list_documents_with_l0 len: {len(documents)}")
notes_list, _ = extract_notes_from_documents(documents)
logger.info(f"extract_notes_from_documents len: {len(notes_list)}")
note_service = NoteService()
note_service.prepareNotes(notes_list)
storage.save_notes(notes_list)
self.l2_data["notes"] = storage.load_notes()
# Get paths
self.l2_data["config_path"] = os.path.join(
os.getcwd(),
"resources/L2/data_pipeline/data_prep/subjective/config/config.json",
)
self.l2_data["entitys_path"] = os.path.join(
os.getcwd(),
"resources/L2/data_pipeline/raw_data/id_entity_mapping_subjective_v2.json",
)
self.l2_data["graph_path"] = os.path.join(
os.getcwd(),
"resources/L1/graphrag_indexing_output/subjective/entities.parquet",
)
self.l2_data["data_output_base_dir"] = os.path.join(os.getcwd(), "resources/L2/data")
# Lazy load user information
logger.info("Loading user information...")
status_bio = get_latest_status_bio()
global_bio = get_latest_global_bio()
self.l2_data["basic_info"] = {
"username": LoadService.get_current_upload_name(),
"aboutMe": LoadService.get_current_upload_description(),
"statusBio": status_bio.content if status_bio else "Currently working on an AI project.",
"globalBio": global_bio.content_third_view if global_bio
else "The User is a software engineer who loves programming and learning new technologies.",
"lang": "English",
}
# Mark data as prepared
self.l2_data_prepared = True
return self.l2_data
def train(self) -> bool:
"""Start model training"""
try:
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.TRAIN, Status.IN_PROGRESS)
# Get paths for the model
paths = self._get_model_paths(self.model_name)
# Check if the model directory exists and has the necessary files
config_file = os.path.join(paths["base_path"], "config.json")
if not os.path.exists(paths["base_path"]) or not os.path.exists(config_file):
logger.info(f"Model '{self.model_name}' needs to be downloaded or is missing config.json")
# Call model_download to download the model
download_success = self.model_download()
if not download_success:
logger.error(f"Failed to download model '{self.model_name}'")
self.progress.mark_step_status(ProcessStep.MODEL_DOWNLOAD, Status.FAILED)
return False
# Prepare log directory and file
log_dir = os.path.join(os.getcwd(), "logs")
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(log_dir, "train", "train.log")
logger.info(f"Log file path: {log_path}")
# Ensure output directory exists
os.makedirs(paths["personal_dir"], exist_ok=True)
# Set USER_NAME environment variable
os.environ["USER_NAME"] = LoadService.get_current_upload_name()
logger.info(f"USER_NAME environment variable set: {os.environ['USER_NAME']}")
script_path = os.path.join(os.getcwd(), "lpm_kernel/L2/train_for_user.sh")
# First start monitoring progress in a separate thread
logger.info("Starting monitoring thread first...")
monitor_thread = threading.Thread(
target=self._monitor_training_progress,
args=(log_path,),
daemon=True
)
monitor_thread.start()
# Allow a moment for the monitoring thread to initialize
time.sleep(1)
# Then directly execute training process (blocking)
logger.info("Now starting training process (blocking)...")
training_result = self._start_training(script_path, log_path)
if not training_result:
logger.error("Training process failed to start")
self.progress.mark_step_status(ProcessStep.TRAIN, Status.FAILED)
return False
# Wait for the monitoring thread to finish
logger.info("Training process completed, waiting for monitoring to finish...")
monitor_thread.join(timeout=10) # Wait up to 10 seconds for monitor to finish
# Check if the training was successful by checking the returncode
if hasattr(self, 'training_result') and self.training_result:
if self.training_result.get('returncode', 1) != 0:
error_msg = f"Training failed: {self.training_result.get('error', 'Unknown error')}"
logger.error(error_msg)
self.progress.mark_step_status(ProcessStep.TRAIN, Status.FAILED)
return False
return True
except Exception as e:
logger.error(f"Failed to start training: {str(e)}")
self.progress.mark_step_status(ProcessStep.TRAIN, Status.FAILED)
return False
def _get_model_paths(self, model_name):
"""Get all relevant paths for a model and set environment variables
Args:
model_name: Model name
Returns:
Dictionary containing all related paths:
- base_path: Base model path
- personal_dir: Personal trained model output directory
- merged_dir: Merged model output directory
- gguf_dir: GGUF model output directory
"""
base_dir = os.getcwd()
paths = {
"base_path": os.path.join(base_dir, "resources/L2/base_models", model_name),
"personal_dir": os.path.join(base_dir, "resources/model/output/personal_model", model_name),
"merged_dir": os.path.join(base_dir, "resources/model/output/merged_model", model_name),
"gguf_dir": os.path.join(base_dir, "resources/model/output/gguf", model_name)
}
# Ensure all directories exist
for path in paths.values():
os.makedirs(path, exist_ok=True)
# Set environment variables
os.environ["MODEL_BASE_PATH"] = paths["base_path"]
os.environ["MODEL_PERSONAL_DIR"] = paths["personal_dir"]
os.environ["MODEL_MERGED_DIR"] = paths["merged_dir"]
os.environ["MODEL_GGUF_DIR"] = paths["gguf_dir"]
# Log environment variables
logger.info("Set environment variables:")
logger.info(f"MODEL_BASE_PATH: {paths['base_path']}")
logger.info(f"MODEL_PERSONAL_DIR: {paths['personal_dir']}")
logger.info(f"MODEL_MERGED_DIR: {paths['merged_dir']}")
logger.info(f"MODEL_GGUF_DIR: {paths['gguf_dir']}")
return paths
def _start_training(self, script_path, log_path):
"""Start training process
Args:
script_path: Path to training script
log_path: Path to log file
Returns:
bool: True if the training process started successfully, False otherwise
"""
try:
# Reset stop flag before starting
self.is_stopped = False
# Get the latest training parameters from the class
params_manager = TrainingParamsManager()
training_params = params_manager.get_latest_training_params()
learning_rate = training_params.get("learning_rate")
num_train_epochs = training_params.get("number_of_epochs")
concurrency_threads = training_params.get("concurrency_threads")
data_synthesis_mode = training_params.get("data_synthesis_mode")
use_cuda = training_params.get("use_cuda", False)
is_cot = training_params.get("is_cot", False)
# Log training parameters
logger.info("Training parameters from latest settings:")
logger.info(f" Learning rate: {learning_rate}")
logger.info(f" Number of epochs: {num_train_epochs}")
logger.info(f" Concurrency threads: {concurrency_threads}")
logger.info(f" Data synthesis mode: {data_synthesis_mode}")
logger.info(f" Use CUDA: {use_cuda}")
logger.info(f" Is CoT: {is_cot}")
# Prepare arguments for the script
# Build command line arguments, need to include script path as the first parameter
cmd = [
script_path,
"--lr", str(learning_rate),
"--epochs", str(num_train_epochs),
"--threads", str(concurrency_threads),
"--mode", str(data_synthesis_mode),
"--cuda", str(use_cuda),
"--is_cot", str(is_cot)
]
# Ensure log directory exists
os.makedirs(os.path.dirname(log_path), exist_ok=True)
# Set environment variables to improve tqdm output
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1" # Force Python to be unbuffered
env["FORCE_COLOR"] = "1" # Force colored output
env["TQDM_FORCE_TTY"] = "1" # Force tqdm to use TTY features
# Ensure log directory exists
log_dir = os.path.dirname(log_path)
os.makedirs(log_dir, exist_ok=True)
# Open log file
log_file = open(log_path, "ab")
# Use subprocess.Popen to directly execute the training script, redirecting output to file
process = subprocess.Popen(
cmd,
env=env,
stdout=log_file,
stderr=subprocess.STDOUT,
bufsize=0, # Unbuffered
)
self.process = process
self.current_pid = process.pid
logger.info(f"Training process started with PID: {self.current_pid}")
# Wait for process to finish directly (blocking)
logger.info("Waiting for training process to complete...")
return_code = process.wait()
# Close log file
log_file.close()
# Save results for train method to check
self.training_result = {
"returncode": return_code,
"error": f"Execution failed, return code: {return_code}" if return_code != 0 else None
}
if return_code != 0:
logger.error(f"Command execution failed, return code: {return_code}")
return False
else:
logger.info(f"Command execution successful, return code: {return_code}")
return True
except Exception as e:
logger.error(f"Failed to start training process: {str(e)}")
return False
def _monitor_training_progress(self, log_file) -> bool:
"""Monitor training progress"""
try:
# Initialize last_position to the end of file to only process new content
try:
with open(log_file, 'r') as f:
f.seek(0, 2) # Move to the end of file
last_position = f.tell()
except FileNotFoundError:
# If file doesn't exist yet, start from beginning when it's created
last_position = 0
# variable to track training status
total_steps = None
current_step = 0
last_update_time = time.time()
training_started = False
while True:
try:
# read new log content
with open(log_file, 'r') as f:
f.seek(last_position)
new_lines = f.readlines()
last_position = f.tell()
for line in new_lines:
line = line.strip()
# Check if training has started
if not training_started:
if "***** Running training *****" in line:
training_started = True
logger.info("Training started")
continue # Skip progress matching until training starts
progress_match = re.search(r"(\d+)%\|[^|]+\| (\d+)/(\d+)", line)
if progress_match and len(progress_match.groups()) == 3:
percentage = int(progress_match.group(1))
current_step = int(progress_match.group(2))
total_steps = int(progress_match.group(3))
# Update progress at most once per second
current_time = time.time()
if current_time - last_update_time >= 1.0:
# logger.info(f"Training progress: {percentage}% ({current_step}/{total_steps})")
if percentage == 100.0:
self.progress.mark_step_status(ProcessStep.TRAIN, Status.COMPLETED)
return True
self._update_progress("training_to_create_second_me", "train", percentage, f"Current step: {current_step}/{total_steps}")
last_update_time = current_time
# Check if we have exited the training record interval
if "=== Training Ended ===" in line:
# in_training_section = False # Exit training record interval
logger.info("Exited training record interval")
# Briefly pause to avoid excessive CPU usage
time.sleep(0.1)
except IOError as e:
logger.error(f"Failed to read log file: {str(e)}")
time.sleep(0.1)
continue
except Exception as e:
logger.error(f"Failed to monitor training progress: {str(e)}")
self.progress.mark_step_status(ProcessStep.TRAIN, Status.FAILED)
return False
def _update_progress(self, stage: str, step: str, percentage: float, message: str):
"""Update progress for any stage and step"""
try:
self.progress.progress.update_progress(
stage, # stage
step, # step
Status.IN_PROGRESS,
percentage
)
logger.info(f"Progress updated: {percentage}% - {message}")
except Exception as e:
logger.error(f"Progress callback error: {str(e)}")
def _monitor_model_download(self) -> bool:
"""Monitor model download progress"""
try:
# log_dir = os.path.join(os.getcwd(), "logs")
# log_file = os.path.join(log_dir, "model_download.log")
log_file = TRAIN_LOG_FILE
# Initialize last_position to the end of file to only process new content
try:
with open(log_file, 'r') as f:
f.seek(0, 2) # Move to the end of file
last_position = f.tell()
except FileNotFoundError:
# If file doesn't exist yet, start from beginning when it's created
last_position = 0
# Variables to track download status
current_file = ""
file_size = 0
total_size = 0 # Total size of all files
file_sizes = {} # Dictionary to store file sizes
last_update_time = time.time()
while True:
try:
# Read new log content
with open(log_file, 'r') as f:
f.seek(last_position)
new_lines = f.readlines()
last_position = f.tell()
for line in new_lines:
line = line.strip()
# Check for download start
if "Starting download of model:" in line:
logger.info("Model download started")
continue
# Get file size information when a download starts
if "Starting download of file:" in line:
match = re.search(r"Starting download of file: (.+) \(Size: ([\d\.]+) MB\)", line)
if match:
current_file = match.group(1)
file_size = float(match.group(2))
file_sizes[current_file] = file_size
total_size = sum(file_sizes.values())
# logger.info(f"Starting download of {current_file} ({file_size} MB)")
# Track file download progress
if "Downloaded" in line and "MB /" in line:
match = re.search(r"File (.+): Downloaded ([\d\.]+) MB / ([\d\.]+) MB \(([\d\.]+)%\)", line)
if match:
file_name = match.group(1)
downloaded_mb = float(match.group(2))
total_mb = float(match.group(3))
percentage = float(match.group(4))
# Update file size if it was updated (especially for model.safetensors)
if total_mb > file_sizes.get(file_name, 0):
file_sizes[file_name] = total_mb
total_size = sum(file_sizes.values())
# Calculate overall progress
if total_size > 0:
# Sum up all downloaded data
completed_files_size = sum([file_sizes.get(f, 0) for f in file_sizes if f != file_name])
current_file_downloaded = (percentage / 100.0) * total_mb
overall_downloaded = completed_files_size + current_file_downloaded
current_progress = (overall_downloaded / total_size) * 100
current_progress = min(99.0, current_progress) # Cap at 99% until fully complete
# Update progress at most once per second
current_time = time.time()
if current_time - last_update_time >= 3.0:
self._update_progress(
"downloading_the_base_model",
"model_download",
current_progress,
f"Overall: {current_progress:.1f}% - Downloading {file_name}: {percentage}% ({downloaded_mb:.1f}/{total_mb:.1f} MB)"
)
last_update_time = current_time
if "Model downloaded successfully" in line:
self.progress.mark_step_status(ProcessStep.MODEL_DOWNLOAD, Status.COMPLETED)
logger.info("Model download completed")
return True
# Briefly pause to avoid excessive CPU usage
time.sleep(0.1)
except IOError as e:
logger.error(f"Failed to read log file: {str(e)}")
time.sleep(0.1)
continue
except Exception as e:
logger.error(f"Failed to monitor model download progress: {str(e)}")
return False
def merge_weights(self) -> bool:
"""Merge weights"""
try:
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.MERGE_WEIGHTS, Status.IN_PROGRESS)
paths = self._get_model_paths(self.model_name)
# Check if model exists
if not os.path.exists(paths["base_path"]):
logger.error(f"Model '{self.model_name}' does not exist, please download first")
self.progress.mark_step_status(ProcessStep.MERGE_WEIGHTS, Status.FAILED)
return False
# Check if training output exists
if not os.path.exists(paths["personal_dir"]):
return jsonify(APIResponse.error(
message=f"Model '{model_name}' training output does not exist, please train model first",
code=400
))
# Ensure merged output directory exists
os.makedirs(paths["merged_dir"], exist_ok=True)
script_path = os.path.join(
os.getcwd(), "lpm_kernel/L2/merge_weights_for_user.sh"
)
log_path = os.path.join(os.getcwd(), "logs", f"merge_weights_{self.model_name}.log")
# Ensure log directory exists
os.makedirs(os.path.dirname(log_path), exist_ok=True)
# Use script executor to execute merge script
script_executor = ScriptExecutor()
result = script_executor.execute(
script_path=script_path, script_type="merge_weights", log_file=log_path
)
logger.info(f"Weight merge task result: {result}")
# Check if script execution was successful
if result.get('returncode', 1) != 0:
error_msg = f"Merge weights failed: {result.get('error', 'Unknown error')}"
logger.error(error_msg)
self.progress.mark_step_status(ProcessStep.MERGE_WEIGHTS, Status.FAILED)
return False
# Check if merged model files exist
config_path = os.path.join(paths["merged_dir"], "config.json")
if not os.path.exists(config_path):
error_msg = f"Merged model files not found in {paths['merged_dir']}"
logger.error(error_msg)
self.progress.mark_step_status(ProcessStep.MERGE_WEIGHTS, Status.FAILED)
return False
logger.info("Weight merge completed successfully")
self.progress.mark_step_status(ProcessStep.MERGE_WEIGHTS, Status.COMPLETED)
return True
except Exception as e:
self.progress.mark_step_status(ProcessStep.MERGE_WEIGHTS, Status.FAILED)
logger.error(f"Merge weights failed: {str(e)}")
return False
def convert_model(self) -> bool:
"""Convert model to GGUF format"""
try:
# Mark step as in progress
self.progress.mark_step_status(ProcessStep.CONVERT_MODEL, Status.IN_PROGRESS)
# Get paths for the model
paths = self._get_model_paths(self.model_name)
# Check if merged model exists
merged_model_dir = paths["merged_dir"]
logger.info(f"Merged model path: {merged_model_dir}")
if not os.path.exists(merged_model_dir):
logger.error(f"Model '{self.model_name}' merged output does not exist, please merge model first")
self.progress.mark_step_status(ProcessStep.CONVERT_MODEL, Status.FAILED)
return False
# Get GGUF output directory
gguf_dir = paths["gguf_dir"]
logger.info(f"GGUF output directory: {gguf_dir}")
script_path = os.path.join(os.getcwd(), "lpm_kernel/L2/convert_hf_to_gguf.py")
gguf_path = os.path.join(gguf_dir, "model.gguf")
logger.info(f"GGUF output path: {gguf_path}")
# Build parameters
args = [
merged_model_dir,
"--outfile",
gguf_path,
"--outtype",
"f16",
]
logger.info(f"Parameters: {args}")
# Ensure GGUF output directory exists
os.makedirs(os.path.dirname(gguf_path), exist_ok=True)
# Use script executor to execute conversion script
script_executor = ScriptExecutor()
result = script_executor.execute(
script_path=script_path,
script_type="convert_model",
args=args
)
logger.info(f"Model conversion result: {result}")
# Check if script execution was successful
if result.get('returncode', 1) != 0:
error_msg = f"Model conversion failed: {result.get('error', 'Unknown error')}"
logger.error(error_msg)
self.progress.mark_step_status(ProcessStep.CONVERT_MODEL, Status.FAILED)
return False
# Check if GGUF model file exists
if not os.path.exists(gguf_path):
error_msg = f"GGUF model file not found at {gguf_path}"
logger.error(error_msg)
self.progress.mark_step_status(ProcessStep.CONVERT_MODEL, Status.FAILED)
return False
logger.info("Model conversion completed successfully")
self.progress.mark_step_status(ProcessStep.CONVERT_MODEL, Status.COMPLETED)
return True
except Exception as e:
self.progress.mark_step_status(ProcessStep.CONVERT_MODEL, Status.FAILED)
logger.error(f"Convert model failed: {str(e)}")
return False
def check_training_condition(self) -> bool:
"""
Check if the conditions for training are met
Returns:
bool: True if conditions are met, False otherwise
"""
try:
# Check if there are any documents that need embedding
if document_service.check_all_documents_embeding_status():
logger.warning("Cannot start training: There are documents that need embedding process first")
return False
return True
except Exception as e:
logger.error(f"Error checking training conditions: {str(e)}", exc_info=True)
if self.progress.progress.current_stage:
current_step = self.progress.progress.stages[self.progress.progress.current_stage].current_step
if current_step:
step = ProcessStep(current_step)
self.progress.mark_step_status(step, Status.FAILED)
return False
def start_process(self) -> bool:
"""Start training process"""
try:
self.is_stopped = False
# Store the current process PID
self.current_pid = os.getpid() # Store the PID
logger.info(f"Training process started with PID: {self.current_pid}")
# Get the ordered list of all steps
ordered_steps = ProcessStep.get_ordered_steps()
# Get the last successfully completed step
last_successful_step = self.progress.get_last_successful_step()
start_index = 0
if last_successful_step:
start_index = ordered_steps.index(last_successful_step) + 1
# Start executing from the step after the last successful one
for step in ordered_steps[start_index:]:
self.current_step = step
if self.is_stopped:
logger.info("Training process aborted during step")
self.progress.mark_step_status(step, Status.SUSPENDED)
break # If stop is requested, exit the loop
logger.info(f"Starting step: {step.value}")
# Execute the corresponding method
method_name = step.get_method_name()
if not hasattr(self, method_name):
logger.error(f"Method {method_name} not found")
self.progress.mark_step_status(step, Status.FAILED)
return False
method = getattr(self, method_name)
success = method()
if not success:
logger.error(f"Step {step.value} failed")
logger.info(f'Marking step as failed: stage={step.value}, step={step.value}')
self.progress.mark_step_status(step, Status.FAILED)
return False
logger.info(f"Step {step.value} completed successfully")
# self.progress.mark_step_status(step, Status.COMPLETED)
if self.is_stopped:
logger.info("Training process was stopped during a step")
else:
logger.info("Training process completed...")
return True
except Exception as e:
logger.error(f"Exception occurred: {str(e)}", exc_info=True)
if self.current_step:
self.progress.mark_step_status(self.current_step, Status.FAILED)
return False
def reset_progress(self):
"""Save current progress
This method saves the current progress to the progress file.
"""
try:
self.progress.reset_progress()
logger.info("Progress saved successfully")
except Exception as e:
logger.error(f"Failed to save progress: {str(e)}", exc_info=True)
def get_step_output_content(self, step_name: str = None) -> Optional[Dict]:
"""Get content of output file for a specific training step
Args:
step_name: Name of the step to get content for. Required parameter.
Returns:
Optional[Dict]: Content of the output file for the specified step, or None if not found
"""
try:
if step_name == "generate_biography":
logger.info("Querying L1 version data for biography")
return query_l1_version_data(1)
# If step_name is not provided or invalid, return None
if not step_name or step_name not in output_files:
return None
# Get file path for the requested step
file_path = output_files[step_name]
if not os.path.exists(file_path):
return None
# Read and return file content
return read_file_content(file_path)
except Exception as e:
logger.error(f"Error getting step output content: {str(e)}")
return None
def stop_process(self):
"""Stop training process
Returns:
bool: True if the process was stopped successfully, False otherwise
"""
try:
# Set the stop flag
self.is_stopped = True
logger.info("Training process has been requested to stop")
# mark train stop
if self.current_step == ProcessStep.TRAIN:
self.progress.mark_step_status(ProcessStep.TRAIN, Status.SUSPENDED)
# First check if we have the current process PID
if not hasattr(self, 'current_pid') or not self.current_pid:
logger.info("No active process PID found")
if self.progress.progress.data["current_stage"]:
current_stage_name = self.progress.progress.data["current_stage"]
current_stage = next((s for s in self.progress.progress.data["stages"] if s["name"] == current_stage_name), None)
if current_stage and current_stage["current_step"]:
step = ProcessStep(current_stage["current_step"].lower().replace(" ", "_"))
self.progress.mark_step_status(step, Status.SUSPENDED)
return True
try:
logger.info(f"Attempting to terminate process with PID: {self.current_pid}")
# Check if the process exists
if psutil.pid_exists(self.current_pid):
# Get the process object
process = psutil.Process(self.current_pid)
# Get all child processes
children = process.children(recursive=True)
# Terminate all child processes first
for child in children:
logger.info(f"Terminating child process with PID: {child.pid}")
try:
child.terminate()
except psutil.NoSuchProcess:
pass
# Wait for children to terminate
gone, still_alive = psutil.wait_procs(children, timeout=3)
# Kill any remaining children
for child in still_alive:
logger.info(f"Killing child process with PID: {child.pid}")
try:
child.kill()
except psutil.NoSuchProcess:
pass
# Note: We don't terminate the main process as it's this process
logger.info(f"All child processes of {self.current_pid} have been terminated")
gc.collect()
return True
else:
logger.warning(f"Process with PID {self.current_pid} no longer exists")
return True
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess) as e:
logger.error(f"Failed to terminate process: {str(e)}", exc_info=True)
except Exception as e:
logger.error(f"Error stopping training process: {str(e)}", exc_info=True)
return False