lightbulb / README.md
RobbiePasquale's picture
Update README.md
f2f9590 verified
|
raw
history blame
29.9 kB
metadata
license: apache-2.0

license: apache-2.0

Model Card for LightBulb

Installation

To install the necessary dependencies, run:

pip install huggingface_hub torch transformers datasets argparse

Getting Started

Download the Repository

Use the huggingface_hub to download the repository:

from huggingface_hub import snapshot_download

# Download the repository
repo_path = snapshot_download("RobbiePasquale/lightbulb")

print(f"Repository downloaded to: {repo_path}")

Main Features

LightBulb provides six primary functionalities, each accessible via the main_menu.py script using command-line arguments.

1. Train a Web Search Agent

Description:
Trains an autonomous web search agent that navigates the web, gathers relevant content, and learns to summarize and generate responses based on user queries.

Overview of the AutonomousWebAgent

The AutonomousWebAgent is a multi-component search and retrieval agent designed to navigate the web, gather relevant content, and perform summarization and generation based on user queries. This agent integrates reinforcement learning (RL), Monte Carlo Tree Search (MCTS), a Retrieval Augmented Generation (RAG) Summarizer, and a Hierarchical Reinforcement Learning (HRL) architecture to select, execute, and optimize its actions based on past experiences.

Key Components

  1. Prioritized Experience Replay:

    • The agent uses a PrioritizedReplayMemory and a SumTree to prioritize and store experiences (transitions between states).
    • The SumTree structure maintains a binary tree where each parent node's value is the sum of its children, helping to efficiently store, update, and retrieve experiences based on priority.
    • These experiences are critical in training both high-level (manager) and low-level (worker) components through prioritized sampling during replay, allowing the model to focus on more significant transitions.
  2. Hierarchical Reinforcement Learning (HRL):

    • HRL is employed to allow a Manager (high-level) model to select options, which are then executed by a Worker (low-level) model. The ManagerModel selects tasks (such as searching, summarizing, or generating), while the WorkerModel determines specific actions to take.
    • The manager and worker use LSTM networks with fully connected layers, and each has its own replay memory and optimization process.
    • The Manager focuses on broad decisions and options, while the Worker operates on specific actions, enabling a layered approach to decision-making.
  3. RAGSummarizer:

    • The RAGSummarizer leverages a pre-trained language model (e.g., GPT-2) for summarizing, and a SentenceTransformer for embedding-based retrieval. This module breaks down the input content into chunks, retrieves relevant sections based on cosine similarity with the query, and generates a coherent summary.
    • Additionally, it implements a Least Recently Used (LRU) cache to avoid redundant computation and enhance efficiency, along with persistent storage for cache data.
    • Summarized results are stored, and this module contributes directly to the generation of LLM training data.
  4. WorldModel:

    • This module encapsulates an LSTM architecture with linear layers and a value_head to estimate state values, allowing the agent to anticipate the long-term value of its actions.
    • It is utilized in the HRL architecture, specifically by the Worker for evaluating actions and by the Manager in long-term decision-making.
  5. Knowledge Base:

    • The knowledge base acts as a repository for collected data, maintaining embeddings for efficient search and retrieval.
    • It supports saving and loading document embeddings, so the agent can retrieve relevant information for new queries from previously collected knowledge.
    • Adding and retrieving from the knowledge base enriches the agent鈥檚 context and allows it to store and use information from past experiences to inform current tasks.
  6. Monte Carlo Tree Search (MCTS):

    • The MCTS component guides the agent through complex decision trees to determine the most promising paths for query refinement.
    • Nodes in the tree represent states (possible query refinements), and child nodes represent possible expansions (e.g., related query variations).
    • MCTS utilizes a select, expand, simulate, and backpropagate strategy to iteratively refine queries, scoring them based on relevance and other metrics to converge on optimal searches.
    • It also integrates RL by backpropagating rewards based on the ranking score from retrieved results.
  7. Ranking Model:

    • The ranking model, built with a neural network and the SentenceTransformer, ranks search results based on various features such as cosine similarity with the query, content length, keyword overlap, and domain authority.
    • This model assigns scores to results, which are then used to guide the MCTS process by enhancing the combined reward with ranking scores.
  8. Tree of Thought (ToT) Search:

    • This module enhances the agent's capability to generate a series of interconnected thoughts, exploring different perspectives or angles on a given query.
    • ToTNode and ToTSearch classes enable the agent to generate thoughts, evaluate them, and navigate through them as a tree, considering various potential paths to best answer the query.
    • It combines MCTS and RAG to synthesize responses based on the generated thought paths.

Usage:

python main_menu.py --task train_agent

Key Components:

  • Hierarchical Reinforcement Learning (HRL): Manages high-level (Manager) and low-level (Worker) decision-making.
  • Monte Carlo Tree Search (MCTS): Guides the agent through complex decision trees.
  • RAGSummarizer: Summarizes retrieved web content.
  • Knowledge Base: Stores and retrieves information to inform future queries.

2. Use a Web Search Agent (Inference)

Description:
Utilizes the trained web search agent to process queries, perform web searches, and generate summarized responses.

Usage:

python main_menu.py --task test_agent

Options:

  • Interactive Mode:
    python main_menu.py --task test_agent
    
  • Single Query Mode:
    python main_menu.py --task test_agent --query "Your query here"
    

3. Train Language Model

Description:
Trains a Language Model and World Model using datasets from Hugging Face, enabling the model to handle complex reasoning and long sequences.

Training Procedure

  • Data Loading: The data is tokenized and prepared with attention to padding and truncation. Text data is grouped into sequences of fixed length for efficient training.
  • Optimization: Training uses an AdamW optimizer with CosineAnnealingLR scheduler for learning rate adjustments. The Gradient Scaler helps prevent overflow when training with mixed precision.
  • Gradient Accumulation: Since the model can be computationally heavy, gradients are accumulated over several steps to reduce memory usage.

Usage:

python main_menu.py --task train_llm_world --model_name gpt2 --dataset_name wikitext --num_epochs 5 --batch_size 8 --max_length 256

Key Arguments:

  • --model_name: Pretrained model (e.g., gpt2, bert).
  • --dataset_name: Dataset from Hugging Face (e.g., wikitext).
  • --num_epochs: Number of training epochs.
  • --batch_size: Number of samples per batch.
  • --max_length: Maximum sequence length.

4. Inference Using Language Model

Description:
Generates responses using the trained language model, leveraging multi-token prediction, beam search, and MCTS for enhanced coherence and strategic reasoning.

Usage:

python main_menu.py --task inference_llm --query "Your query here"

Process:

  1. Multi-Token Prediction: Predicts multiple tokens at each step to improve generation speed.
  2. Beam Search: Maintains multiple candidate sequences to ensure diverse and high-quality outputs.
  3. MCTS Integration: Uses MCTS to evaluate and select the most promising token sequences based on policy and value estimates.

5. Train World Model

Description:
Develops a comprehensive World Model that encapsulates state representations, dynamics, and prediction networks to simulate and predict state transitions within the Tree of Thought framework.

Usage:

python main_menu.py --task train_world_model --additional_args

Key Components:

  • Representation Network: Encodes Transformer outputs into state representations.

  • Dynamics Network: Predicts next states based on current states and actions.

  • Prediction Network: Generates policy logits and value estimates.

  • Action Encoder: Encodes actions into embeddings for state transitions.

  • Loss Functions: The training process leverages a comprehensive set of custom loss functions:

1. InfoNCE Loss (Info Noise Contrastive Estimation Loss): Definition: This loss function is used for contrastive learning, encouraging similar samples to be close in the embedding space while pushing dissimilar samples apart.

Formula: L_InfoNCE = -log[ exp(sim(z_i, z_j) / 蟿) / 危_k exp(sim(z_i, z_k) / 蟿) ]

where sim() is the cosine similarity, 蟿 is the temperature parameter, z_i and z_j are paired samples, and the sum in the denominator is over all other samples in the batch.

2. Covariance Regularization: Definition: This regularization term encourages the learned representations to have uncorrelated dimensions, promoting more diverse and informative embeddings.

Formula: L_cov = 位 * (危_i,j (Cov(i,j)^2 - diag(Cov(i,j))^2))

where Cov is the covariance matrix of the embeddings, and 位 is a regularization coefficient.

3. Dynamics Performance Loss: Definition: This loss measures the accuracy of predicted next states while also encouraging diverse predictions.

Formula: L_dynamics = MSE(true_next_state, predicted_next_state) + 位 * Var(predicted_next_state)

where MSE is the mean squared error, Var is the variance, and 位 is a weighting factor.

4. Thought Consistency Loss: Definition: This loss encourages consistency between true next states and perturbed next states.

Formula: L_consistency = MSE(true_next_state, perturbed_next_state)

5. Policy Value Joint Loss: Definition: This loss combines policy and value losses for reinforcement learning tasks.

Formula: L_joint = CrossEntropy(policy_logits, true_policy) + 位 * MSE(value_pred, true_value)

where 位 is a weighting factor balancing policy and value losses.

6. Action Diversity Reward: Definition: This reward encourages diversity in action embeddings.

Formula: R_diversity = 位 * 危_i,j (cos_sim(a_i, a_j)^2)

where cos_sim is the cosine similarity between action embeddings, and 位 is a scaling factor.

7. Expected Thought Value Loss: Definition: This loss aims to maximize the expected value from Monte Carlo Tree Search.

Formula: L_ETV = -mean(mcts_best_values)

8. Exploration Regularization: Definition: This regularization encourages exploration by rewarding less-visited actions.

Formula: R_exploration = 位 * mean(危_a (1 / (visit_count(a) + 1)))

where 位 is a scaling factor.

9. KL Divergence Loss: Definition: This loss measures the difference between old and new policies in policy optimization.

Formula: L_KL = KL(new_policy || old_policy) = \sum_{i=1}^{n} old_policy_i \cdot \log\left(\frac{old_policy_i}{new_policy_i}\right)

where KL is the Kullback-Leibler divergence.

L_KL is the KL divergence loss old_policy and new_policy are probability distributions i represents each possible outcome or action n is the total number of possible outcomes or actions

6. Inference with Language World Model

Description:
Utilizes the trained World Model to perform advanced reasoning and generate responses based on structured thought processes and state simulations.

Usage:

python main_menu.py --task inference_world_model --query "Your query here"

Features:

  • Tree of Thought (ToT): Structures reasoning paths hierarchically.
  • Beam Search with MCTS: Enhances decision-making by balancing exploration and exploitation.
  • Integration with Knowledge Base: Leverages stored information for informed responses.

7. Inference with World Model, Tree of Thought, and Multi-Token Beam Search

Description:
Executes inference using the World Model integrated with ToT and multi-token beam search for highly coherent and contextually rich outputs.

Usage:

python main_menu.py --task advanced_inference --query "Your complex query here"

Process:

  1. State Initialization: Converts input queries into state representations.
  2. MCTS with Beam Search: Explores multiple reasoning paths simultaneously.
  3. Thought Sequence Generation: Produces a sequence of interconnected thoughts/actions.
  4. Final Response Generation: Synthesizes the best thought path into a coherent response.

Mode: With World Model and Tree of Thought

Step 1: Input Tokenization and Encoding

  1. Tokenization:

    • The input query is converted into token IDs using the tokenizer. This numerical representation is essential for processing by the Transformer model.
    • Shape: The resulting tensor has a shape corresponding to the batch size and the sequence length of the input.
  2. Encoding via Transformer:

    • The tokenized input is passed through the Transformer model to generate contextual embeddings. These embeddings capture the semantic information of the input.
    • Shape: The output tensor includes the batch size, sequence length, and the dimensionality of the model.
  3. State Representation:

    • The RepresentationNetwork processes the transformer's output to create a condensed state representation. This state serves as the foundation for further reasoning steps.
    • Shape: The state representation tensor includes the batch size, a single sequence length (typically one), and the state dimensionality.
  4. State Initialization:

    • A State object is created, encapsulating the initial representation, dynamics network, action encoder, and the root node of the Tree of Thought. This object maintains the current context and facilitates state transitions as actions are applied.

Step 2: MCTS Initialization and Root Node Evaluation

  1. MCTS Instance Creation:

    • An instance of the MCTS class is initialized with the necessary networks and parameters. This instance will manage the search process through the Tree of Thought.
    • Key parameters include the prediction network, dynamics network, action encoder, number of iterations, and the exploration constant.
  2. Root Node Creation:

    • A MCTSNode representing the root of the search tree is created. This node is associated with the initial state and the root thought node from the Tree of Thought.
  3. Root Node Evaluation:

    • The root node is evaluated using the MCTS's evaluation function, which assesses the potential value of the current state based on the prediction network's output.
  4. Backpropagation:

    • The evaluation result (value estimate) is backpropagated through the tree, updating the visit counts and value sums of the nodes. This process informs future selections by providing aggregated value information.

Step 3: MCTS Iterations with Beam Search

  1. Beam Initialization:

    • The search beam is initialized with the root node. Each beam element includes the current node, a cumulative score, cumulative entropy, cumulative variance, and an empty action sequence.
  2. Iterative Expansion:

    • For each iteration up to the specified number of MCTS iterations:

      • Candidate Collection:

        • For each node in the current beam:

          • Leaf Evaluation:

            • If the node is a leaf (has no children), it is evaluated to estimate its value. The evaluation result is then backpropagated to update the node's statistics.
          • Child Selection:

            • If the node has children, the total number of visits to all its children is calculated. The children are then sorted based on their Upper Confidence Bound (UCB) scores, which consider exploration and exploitation factors. The top actions (up to the beam size) are selected for expansion.
      • Action Sequence Prediction:

        • For each selected action:

          • Initialization:

            • Set the current node to the selected child node.
            • Initialize the current action sequence with the selected action.
            • Initialize the current score, cumulative entropy, and cumulative variance to zero.
          • Multi-Token Prediction:

            • For each step in the number of tokens to predict:

              • Leaf Evaluation:

                • If the current node is a leaf, evaluate it to obtain a value estimate and backpropagate the result.
              • Child Check:

                • If the current node has no children, exit the multi-token prediction loop for this sequence.
              • Action Selection:

                • Calculate the total number of visits to all children of the current node.
                • Select the action with the highest UCB score from the children.
              • Score Update:

                • If the selected child node has been visited before, increment the current score by the average value estimate of that node.
                • If the child node has not been visited, the score remains unchanged or is updated with a default value.
              • Entropy and Variance Update:

                • Accumulate the entropy and variance metrics from the selected node to guide the search towards more confident and diverse actions.
              • Sequence Extension:

                • Append the selected action to the current action sequence.
                • Update the current node to the selected child node.
          • Candidate Aggregation:

            • Add the new candidate sequence, along with its updated score, entropy, variance, and action sequence, to the list of all candidates for this iteration.
      • Beam Pruning:

        • After collecting all candidates from the current beam, sort them based on a scoring function that balances the cumulative score, entropy, and variance.
        • Retain only the top sequences up to the specified beam size to form the new beam for the next iteration.
  3. Termination:

    • The iterative expansion process continues until the specified number of MCTS iterations is reached or there are no more candidates to explore.
  4. Result Extraction:

    • After completing the iterations, select the best action sequence from the final beam based on the accumulated scores and metrics.
    • Return this sequence as the generated series of actions (thoughts) in response to the input query.

Explanation of the Inference Process

1. Step 1: Input Tokenization and Encoding

  • Tokenization: The input query is converted into token IDs using the tokenizer. This numerical representation is essential for processing by the Transformer model.

  • Encoding via Transformer: The tokenized input is passed through the Transformer model to generate contextual embeddings. These embeddings capture the semantic information of the input.

  • State Representation: The RepresentationNetwork processes the transformer's output to create a condensed state representation. This state serves as the foundation for further reasoning steps.

  • State Initialization: A State object is created, encapsulating the initial representation, dynamics network, action encoder, and the root node of the Tree of Thought. This object maintains the current context and facilitates state transitions as actions are applied.

2. Step 2: MCTS Initialization and Root Node Evaluation

  • MCTS Instance Creation: An instance of the MCTS class is initialized with the necessary networks and parameters. This instance will manage the search process through the Tree of Thought.

  • Root Node Creation: A MCTSNode representing the root of the search tree is created, associated with the initial state and the root thought node.

  • Root Node Evaluation: The root node is evaluated using the MCTS's evaluation function, which assesses the potential value of the current state based on the prediction network's output.

  • Backpropagation: The evaluation result (value estimate) is backpropagated through the tree, updating the visit counts and value sums of the nodes. This process informs future selections by providing aggregated value information.

3. Step 3: MCTS Iterations with Beam Search

  • Beam Initialization: The search beam is initialized with the root node, starting with a score, cumulative entropy, variance, and an empty action sequence.

  • Iterative Expansion:

    • Candidate Collection:

      • Leaf Evaluation: For each node in the beam, if it's a leaf node (i.e., has no children), it's evaluated to estimate its value, and the results are backpropagated.

      • Child Selection: If a node has children, the total number of visits to all its children is calculated. The children are then sorted based on their Upper Confidence Bound (UCB) scores, and the top actions (up to beam_size) are selected for expansion.

    • Action Sequence Prediction:

      • For each selected action, a new sequence is initiated. The model predicts multiple tokens (n_tokens_predict) in each step, updating the sequence's score, entropy, and variance based on the predictions.

      • Score Update: The score is incremented by the value estimate of the next node, provided it has been visited before.

      • Entropy and Variance Update: Entropy and variance metrics are accumulated to guide the search towards more confident and diverse actions.

      • Sequence Extension: The predicted actions are appended to the current action sequence, and the current node is updated to the next node in the tree.

    • Candidate Aggregation: All potential candidates resulting from the action predictions are collected for the current iteration.

  • Beam Pruning: After collecting all candidates, the beam is pruned to retain only the top sequences based on a scoring function that balances score, entropy, and variance. This ensures that only the most promising action sequences are retained for further exploration.

  • Termination: The iterative process continues until the specified number of MCTS iterations is reached or all beams have been exhausted.

  • Result Extraction: The best action sequence is selected from the beam, representing the most promising path of reasoning or actions to address the input query.


Key Components and Concepts

1. Monte Carlo Tree Search (MCTS)

MCTS is a heuristic search algorithm used to make decisions in various domains, including game playing and strategic planning. It balances exploration (trying new actions) and exploitation (using known good actions) by simulating multiple potential action sequences and evaluating their outcomes.

  • Selection: Traverses the tree by selecting child nodes with the highest UCB scores, balancing exploration and exploitation.

  • Expansion: Adds new child nodes based on policy probabilities obtained from the Prediction Network.

  • Evaluation: Uses the Prediction Network to estimate the value of newly expanded nodes.

  • Backpropagation: Updates the visit counts and value sums of nodes based on the evaluations to inform future selections.

2. Beam Search with Multi-Token Prediction

Beam Search is an optimization technique that explores multiple potential sequences simultaneously, retaining only the top-performing ones based on a scoring function.

  • Multi-Token Prediction: Instead of predicting one token at a time, the model predicts multiple tokens (n_tokens_predict) in each step. This approach accelerates the generation process and ensures more coherent multi-token sequences.

  • Beam Size: Determines the number of candidate sequences retained at each step. A larger beam size allows for more extensive exploration but increases computational complexity.

3. Entropy and Variance Metrics

These statistical measures guide the search process by quantifying uncertainty and diversity in action choices.

  • Entropy:

    • Definition: Measures the uncertainty or randomness in the action probability distribution.
    • Purpose: Lower entropy indicates more confident predictions, steering the search towards more decisive actions.
  • Variance:

    • Definition: Measures the dispersion or diversity in the action probability distribution.
    • Purpose: Higher variance encourages the exploration of diverse actions, preventing the search from converging prematurely on suboptimal paths.

4. Tree of Thought (ToT)

The ToT is a hierarchical structure representing various reasoning paths or actions the model can take to address a query. Each node corresponds to a specific thought or action, and the edges represent transitions between thoughts.

  • Thought Nodes: Represent individual reasoning steps or actions.

  • Hierarchical Structure: Allows the model to navigate through complex reasoning processes systematically.


Visual Representation

For enhanced understanding, here's a flowchart illustrating the advanced inference process:

graph TD
    A[Start Inference] --> B[Receive Query Input]
    B --> C[Tokenize and Encode Query]
    C --> D[Generate Initial State Representation]
    D --> E{Check Inference Mode}
    E -->|Without World Model| F[Perform Beam Search]
    E -->|With World Model| G{Check Inference Sub-Mode}
    G -->|World Model + ToT| H[Initialize MCTS]
    G -->|World Model Only| I[Perform Beam Search with World Model]
    H --> J[MCTS Iterations Loop]
    J --> K[Selection -> Expansion -> Evaluation -> Backpropagation]
    K --> L[Integrate Beam Search]
    L --> M[Compute Entropy and Variance]
    M --> N[Select Best Action Sequence]
    N --> O[Generate Thought Sequence]
    O --> P[Return Thought Sequence]
    F --> Q[Generate and Return Text]
    I --> R[Generate and Return Text]
    P --> Q
    R --> Q

General Arguments

Argument Required Description Default
--task Yes Specifies the task to run (train_llm_world, train_agent, test_agent, etc.). None
--model_name No Pretrained model name for LLM (gpt2, bert, etc.) or a custom model path. gpt2
--dataset_name No Name of the dataset from Hugging Face for training the LLM and World Model (e.g., wikitext). wikitext
--dataset_config No Configuration name for the dataset. wikitext-2-raw-v1
--batch_size No Number of samples per batch during training. 4
--num_epochs No Number of training epochs. 3
--max_length No Maximum sequence length for training/inference. 128
--mode No Mode for LLM and World Model (train, inference). train
--query No Query input for test_agent when running a single query. '' (empty)

Requirements

  • Python: 3.7+
  • Libraries:
    • torch>=1.7.1
    • transformers
    • datasets
    • argparse
    • huggingface_hub

Training the World Model

python main_menu.py --task train_llm_world --model_name gpt2 --dataset_name wikitext --num_epochs 5 --batch_size 8 --max_length 256

Training the Web Search Agent

python main_menu.py --task train_agent

Use the Web Search Agent in Interactive Mode

python main_menu.py --task test_agent

Use the Web Search Agent with a Single Query

python main_menu.py --task test_agent --query "What are the impacts of renewable energy on global sustainability?"

Inference with World Model and Tree of Thought

python main_menu.py --task advanced_inference --query "Analyze the economic effects of artificial intelligence in the next decade."

Citation

If you use LightBulb in your research, please cite the author:

@misc{RobbiePasquale_lightbulb,
  author       = {Robbie Pasquale},
  title        = {LightBulb: An Autonomous Web Search and Language Model Framework},
  year         = {2024},
  publisher    = {Huggingface},
  howpublished = {\url{https://huggingface.co/RobbiePasquale/lightbulb}},
}

License

This project is licensed under the Apache 2.0 License.


For more detailed information on each component and advanced configurations, please refer to the documentation.