|
|
|
|
|
import argparse |
|
import sys |
|
import os |
|
from train_agent import train_agent |
|
from test_agent import TestAgent, run_test_session |
|
from twisted.internet import reactor, task |
|
from lightbulb_custom import main as lightbulb_custom_main |
|
from distillation_pipeline import distill_model |
|
from transformers import logging |
|
|
|
|
|
logging.set_verbosity_error() |
|
|
|
def parse_main_args(): |
|
parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks") |
|
|
|
|
|
parser.add_argument('--task', type=str, choices=[ |
|
'train_llm_world', |
|
'train_agent', |
|
'test_agent', |
|
'inference_llm', |
|
'inference_world_model', |
|
'advanced_inference', |
|
'distill_full_model', |
|
'distill_domain_specific' |
|
], |
|
required=True, |
|
help='Choose task to execute: train_llm_world, train_agent, test_agent, inference_llm, inference_world_model, advanced_inference, distill_full_model, distill_domain_specific') |
|
|
|
|
|
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM') |
|
parser.add_argument('--student_model_name', type=str, default='distilgpt2', help='Name of the student model for distillation') |
|
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training') |
|
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name') |
|
parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training') |
|
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training') |
|
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training') |
|
parser.add_argument('--temperature', type=float, default=2.0, help='Distillation temperature') |
|
parser.add_argument('--learning_rate', type=float, default=5e-5, help='Learning rate') |
|
|
|
|
|
parser.add_argument('--save_path', type=str, default="./distilled_model", help="Path to save the distilled model") |
|
parser.add_argument('--log_dir', type=str, default="./logs", help="Directory for TensorBoard logs") |
|
parser.add_argument('--checkpoint_dir', type=str, default="./checkpoints", help="Directory to save checkpoints") |
|
parser.add_argument('--early_stopping_patience', type=int, default=3, help="Early stopping patience") |
|
|
|
|
|
parser.add_argument('--query', type=str, default='', help='Query for the test_agent or inference tasks') |
|
parser.add_argument('--inference_mode', type=str, choices=['without_world_model', 'world_model', 'world_model_tree_of_thought'], help='Inference mode') |
|
parser.add_argument('--beam_size', type=int, default=5, help='Beam size for beam search during inference') |
|
parser.add_argument('--n_tokens_predict', type=int, default=3, help='Number of tokens to predict at each step during inference') |
|
parser.add_argument('--mcts_iterations', type=int, default=10, help='Number of MCTS iterations during inference') |
|
parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Exploration constant for MCTS during inference') |
|
|
|
|
|
parser.add_argument('--distill_full_model', action="store_true", help="Whether to distill the full model or not") |
|
parser.add_argument('--query_terms', type=str, nargs="+", help="Query terms for domain-specific distillation") |
|
|
|
|
|
parser.add_argument('--load_model', type=str, help='Path to load the distilled model for inference') |
|
|
|
return parser.parse_args() |
|
|
|
def main(): |
|
|
|
args = parse_main_args() |
|
|
|
|
|
if args.task == 'train_llm_world': |
|
print("Starting LLM and World Model Training...") |
|
|
|
sys.argv = [ |
|
'lightbulb_custom.py', |
|
'--mode', 'train', |
|
'--model_name', args.model_name, |
|
'--dataset_name', args.dataset_name, |
|
'--dataset_config', args.dataset_config, |
|
'--batch_size', str(args.batch_size), |
|
'--num_epochs', str(args.num_epochs), |
|
'--max_length', str(args.max_length) |
|
] |
|
lightbulb_custom_main() |
|
|
|
elif args.task == 'train_agent': |
|
print("Starting Agent Training...") |
|
|
|
d = task.deferLater(reactor, 0, train_agent) |
|
d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True)) |
|
d.addBoth(lambda _: reactor.stop()) |
|
reactor.run() |
|
|
|
elif args.task == 'test_agent': |
|
print("Starting Test Agent...") |
|
test_agent = TestAgent() |
|
if args.query: |
|
|
|
result = test_agent.process_query(args.query) |
|
print("\nAgent's response:") |
|
print(result) |
|
else: |
|
|
|
reactor.callWhenRunning(run_test_session) |
|
reactor.run() |
|
|
|
elif args.task in ['inference_llm', 'inference_world_model', 'advanced_inference']: |
|
print("Starting Inference Task...") |
|
|
|
|
|
|
|
inference_mode_map = { |
|
'inference_llm': 'without_world_model', |
|
'inference_world_model': 'world_model', |
|
'advanced_inference': 'world_model_tree_of_thought' |
|
} |
|
|
|
selected_inference_mode = inference_mode_map.get(args.task, 'world_model_tree_of_thought') |
|
|
|
|
|
lightbulb_inf_args = [ |
|
'lightbulb_custom.py', |
|
'--mode', 'inference', |
|
'--model_name', args.model_name, |
|
'--query', args.query, |
|
'--max_length', str(args.max_length), |
|
'--inference_mode', selected_inference_mode, |
|
'--beam_size', str(args.beam_size), |
|
'--n_tokens_predict', str(args.n_tokens_predict), |
|
'--mcts_iterations', str(args.mcts_iterations), |
|
'--mcts_exploration_constant', str(args.mcts_exploration_constant) |
|
] |
|
|
|
|
|
if args.load_model: |
|
lightbulb_inf_args += ['--load_model', args.load_model] |
|
|
|
|
|
sys.argv = lightbulb_inf_args |
|
lightbulb_custom_main() |
|
|
|
elif args.task == 'distill_full_model': |
|
print("Starting Full Model Distillation...") |
|
distill_model( |
|
teacher_model_name=args.model_name, |
|
student_model_name=args.student_model_name, |
|
dataset_name=args.dataset_name, |
|
config=args.dataset_config, |
|
distill_full_model=True, |
|
query_terms=None, |
|
num_epochs=args.num_epochs, |
|
batch_size=args.batch_size, |
|
max_length=args.max_length, |
|
learning_rate=args.learning_rate, |
|
temperature=args.temperature, |
|
save_path=args.save_path, |
|
log_dir=args.log_dir, |
|
checkpoint_dir=args.checkpoint_dir, |
|
early_stopping_patience=args.early_stopping_patience |
|
) |
|
|
|
elif args.task == 'distill_domain_specific': |
|
print("Starting Domain-Specific Distillation...") |
|
if not args.query_terms: |
|
print("Error: --query_terms must be provided for domain-specific distillation.") |
|
sys.exit(1) |
|
distill_model( |
|
teacher_model_name=args.model_name, |
|
student_model_name=args.student_model_name, |
|
dataset_name=args.dataset_name, |
|
config=args.dataset_config, |
|
distill_full_model=False, |
|
query_terms=args.query_terms, |
|
num_epochs=args.num_epochs, |
|
batch_size=args.batch_size, |
|
max_length=args.max_length, |
|
learning_rate=args.learning_rate, |
|
temperature=args.temperature, |
|
save_path=args.save_path, |
|
log_dir=args.log_dir, |
|
checkpoint_dir=args.checkpoint_dir, |
|
early_stopping_patience=args.early_stopping_patience |
|
) |
|
|
|
else: |
|
print(f"Unknown task: {args.task}") |
|
sys.exit(1) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|