diff --git a/app.py b/app.py index b7be78d3040e66b8c463f966cb9775000d826e6c..adfb6aaacdc8d14fd545d61ab65bb96084f35714 100644 --- a/app.py +++ b/app.py @@ -3,13 +3,8 @@ Chat Interface App """ import gradio as gr - -# from axolotl.cli import print_axolotl_text_art -# print_axolotl_text_art() - import sys -sys.path.append('axolotl/src/axolotl/common/cli.py') -from cli import test_func +from .axolotl import test_func def generate(instruction): return "Hello " + instruction + "!" + test_func() diff --git a/axolotl.py b/axolotl.py new file mode 100644 index 0000000000000000000000000000000000000000..4d6fbc449c00e18ef16d95b20e7b38e9ff3ca071 --- /dev/null +++ b/axolotl.py @@ -0,0 +1,3 @@ + +def test_func(): + return "Hello Everybodyyyy!" diff --git a/axolotl/.bandit b/axolotl/.bandit deleted file mode 100644 index 2d81286aee751afae8630a21986eb04f36e7e6d7..0000000000000000000000000000000000000000 --- a/axolotl/.bandit +++ /dev/null @@ -1,3 +0,0 @@ -[bandit] -exclude = tests -skips = B101 diff --git a/axolotl/.editorconfig b/axolotl/.editorconfig deleted file mode 100644 index df80c53a12dbe60f3e3f60c486cedad044e81ecd..0000000000000000000000000000000000000000 --- a/axolotl/.editorconfig +++ /dev/null @@ -1,14 +0,0 @@ -root = true - -[*] -end_of_line = lf -insert_final_newline = true -trim_trailing_whitespace = true - -[*.py] -indent_style = space -indent_size = 4 - -[**.yml] -indent_style = space -indent_size = 2 diff --git a/axolotl/.flake8 b/axolotl/.flake8 deleted file mode 100644 index fd69af775696d49c5397839d5bd6320fe6a73f4b..0000000000000000000000000000000000000000 --- a/axolotl/.flake8 +++ /dev/null @@ -1,5 +0,0 @@ -[flake8] -max-line-length = 88 - -select = C,E,F,W,B,B950 -extend-ignore = E203, E501, W503 diff --git a/axolotl/.gitattributes b/axolotl/.gitattributes deleted file mode 100644 index 7b52c8631c166406a5a80fafc25a4c566eae9ebc..0000000000000000000000000000000000000000 --- a/axolotl/.gitattributes +++ /dev/null @@ -1 +0,0 @@ -data/*.jsonl filter=lfs diff=lfs merge=lfs -text diff --git a/axolotl/.github/CODE_OF_CONDUCT.md b/axolotl/.github/CODE_OF_CONDUCT.md deleted file mode 100644 index 865b95d2a747f258c211edf28a96e9e32fa5179d..0000000000000000000000000000000000000000 --- a/axolotl/.github/CODE_OF_CONDUCT.md +++ /dev/null @@ -1,129 +0,0 @@ -# Contributor Covenant Code of Conduct - -## Our Pledge - -We as members, contributors, and leaders pledge to make participation in our -community a harassment-free experience for everyone, regardless of age, body -size, visible or invisible disability, ethnicity, sex characteristics, gender -identity and expression, level of experience, education, socio-economic status, -nationality, personal appearance, race, religion, or sexual identity -and orientation. - -We pledge to act and interact in ways that contribute to an open, welcoming, -diverse, inclusive, and healthy community. - -## Our Standards - -Examples of behavior that contributes to a positive environment for our -community include: - -* Demonstrating empathy and kindness toward other people -* Being respectful of differing opinions, viewpoints, and experiences -* Giving and gracefully accepting constructive feedback -* Accepting responsibility and apologizing to those affected by our mistakes, - and learning from the experience -* Focusing on what is best not just for us as individuals, but for the - overall community - -Examples of unacceptable behavior include: - -* The use of sexualized language or imagery, and sexual attention or - advances of any kind -* Trolling, insulting or derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or email - address, without their explicit permission -* Other conduct which could reasonably be considered inappropriate in a - professional setting - -## Enforcement Responsibilities - -Community leaders are responsible for clarifying and enforcing our standards of -acceptable behavior and will take appropriate and fair corrective action in -response to any behavior that they deem inappropriate, threatening, offensive, -or harmful. - -Community leaders have the right and responsibility to remove, edit, or reject -comments, commits, code, wiki edits, issues, and other contributions that are -not aligned to this Code of Conduct, and will communicate reasons for moderation -decisions when appropriate. - -## Scope - -This Code of Conduct applies within all community spaces, and also applies when -an individual is officially representing the community in public spaces. -Examples of representing our community include using an official e-mail address, -posting via an official social media account, or acting as an appointed -representative at an online or offline event. - -## Enforcement - -Instances of abusive, harassing, or otherwise unacceptable behavior may be -reported to the community leaders responsible for enforcement on Discord -at https://discord.gg/QYF8QrtEUm - -All complaints will be reviewed and investigated promptly and fairly. - -All community leaders are obligated to respect the privacy and security of the -reporter of any incident. - -## Enforcement Guidelines - -Community leaders will follow these Community Impact Guidelines in determining -the consequences for any action they deem in violation of this Code of Conduct: - -### 1. Correction - -**Community Impact**: Use of inappropriate language or other behavior deemed -unprofessional or unwelcome in the community. - -**Consequence**: A private, written warning from community leaders, providing -clarity around the nature of the violation and an explanation of why the -behavior was inappropriate. A public apology may be requested. - -### 2. Warning - -**Community Impact**: A violation through a single incident or series -of actions. - -**Consequence**: A warning with consequences for continued behavior. No -interaction with the people involved, including unsolicited interaction with -those enforcing the Code of Conduct, for a specified period of time. This -includes avoiding interactions in community spaces as well as external channels -like social media. Violating these terms may lead to a temporary or -permanent ban. - -### 3. Temporary Ban - -**Community Impact**: A serious violation of community standards, including -sustained inappropriate behavior. - -**Consequence**: A temporary ban from any sort of interaction or public -communication with the community for a specified period of time. No public or -private interaction with the people involved, including unsolicited interaction -with those enforcing the Code of Conduct, is allowed during this period. -Violating these terms may lead to a permanent ban. - -### 4. Permanent Ban - -**Community Impact**: Demonstrating a pattern of violation of community -standards, including sustained inappropriate behavior, harassment of an -individual, or aggression toward or disparagement of classes of individuals. - -**Consequence**: A permanent ban from any sort of public interaction within -the community. - -## Attribution - -This Code of Conduct is adapted from the [Contributor Covenant][homepage], -version 2.0, available at -https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. - -Community Impact Guidelines were inspired by [Mozilla's code of conduct -enforcement ladder](https://github.com/mozilla/diversity). - -[homepage]: https://www.contributor-covenant.org - -For answers to common questions about this code of conduct, see the FAQ at -https://www.contributor-covenant.org/faq. Translations are available at -https://www.contributor-covenant.org/translations. diff --git a/axolotl/.github/CONTRIBUTING.md b/axolotl/.github/CONTRIBUTING.md deleted file mode 100644 index 9eec23e1a3c484063f62e8b825af901dc5f03f8d..0000000000000000000000000000000000000000 --- a/axolotl/.github/CONTRIBUTING.md +++ /dev/null @@ -1,76 +0,0 @@ -# Contributing to axolotl - -First of all, thank you for your interest in contributing to axolotl! We appreciate the time and effort you're willing to invest in making our project better. This document provides guidelines and information to make the contribution process as smooth as possible. - -## Table of Contents - -- [Code of Conduct](#code-of-conduct) -- [Getting Started](#getting-started) -- [How to Contribute](#how-to-contribute) - - [Reporting Bugs](#reporting-bugs) - - [Suggesting Enhancements](#suggesting-enhancements) - - [Submitting Pull Requests](#submitting-pull-requests) -- [Style Guidelines](#style-guidelines) - - [Code Style](#code-style) - - [Commit Messages](#commit-messages) -- [Additional Resources](#additional-resources) - -## Code of Conductcode - -All contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT.md). Please read it before participating in the axolotl community. - -## Getting Started - -Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new). - -PRs are **greatly welcome**! - -1. Fork the repository and clone it to your local machine. -2. Set up the development environment by following the instructions in the [README.md](https://github.com/OpenAccess-AI-Collective/axolotl/tree/main/README.md) file. -3. Explore the codebase, run tests, and verify that everything works as expected. - -Please run below to setup env -```bash -pip3 install -r requirements-dev.txt -r requirements-tests.txt -pre-commit install - -# test -pytest tests/ -``` - -## How to Contribute - -### Reporting Bugs - -If you encounter a bug or issue while using axolotl, please open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Provide a clear and concise description of the problem, steps to reproduce it, and any relevant error messages or logs. - -### Suggesting Enhancements - -We welcome ideas for improvements and new features. To suggest an enhancement, open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Describe the enhancement in detail, explain the use case, and outline the benefits it would bring to the project. - -### Submitting Pull Requests - -1. Create a new branch for your feature or bugfix. Use a descriptive name like `feature/your-feature-name` or `fix/your-bugfix-name`. -2. Make your changes, following the [Style Guidelines](#style-guidelines) below. -3. Test your changes and ensure that they don't introduce new issues or break existing functionality. -4. Commit your changes, following the [commit message guidelines](#commit-messages). -5. Push your branch to your fork on GitHub. -6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues. - -## Style Guidelines - -### Code Style - -axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines. - -### Commit Messages - -Write clear and concise commit messages that briefly describe the changes made in each commit. Use the imperative mood and start with a capitalized verb, e.g., "Add new feature" or "Fix bug in function". - -## Additional Resources - -- [GitHub Help](https://help.github.com/) -- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests) -- [{codestyle}]({URLofCodestyle}) - -Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together! diff --git a/axolotl/.github/FUNDING.yml b/axolotl/.github/FUNDING.yml deleted file mode 100644 index 30ddf31da3efc4194c8cb336d70c49a388989202..0000000000000000000000000000000000000000 --- a/axolotl/.github/FUNDING.yml +++ /dev/null @@ -1,13 +0,0 @@ -# These are supported funding model platforms - -github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] -patreon: # Replace with a single Patreon username -open_collective: # Replace with a single Open Collective username -ko_fi: # Replace with a single Ko-fi username -tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel -community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry -liberapay: # Replace with a single Liberapay username -issuehunt: # Replace with a single IssueHunt username -otechie: # Replace with a single Otechie username -lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry -custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/axolotl/.github/ISSUE_TEMPLATE/bug-report.yaml b/axolotl/.github/ISSUE_TEMPLATE/bug-report.yaml deleted file mode 100644 index 5cf332587ab829ffa5d3d512a7e60e74f8864384..0000000000000000000000000000000000000000 --- a/axolotl/.github/ISSUE_TEMPLATE/bug-report.yaml +++ /dev/null @@ -1,112 +0,0 @@ -name: Bug Report -description: File a bug report -labels: ["bug", "needs triage"] -body: - - type: markdown - attributes: - value: | - ## Before you start - Please **make sure you are on the latest version.** - If you encountered the issue after you installed, updated, or reloaded, **please try restarting before reporting the bug**. - - - type: checkboxes - id: no-duplicate-issues - attributes: - label: "Please check that this issue hasn't been reported before." - description: "The **Label filters** may help make your search more focussed." - options: - - label: "I searched previous [Bug Reports](https://github.com/OpenAccess-AI-Collective/axolotl/labels/bug) didn't find any similar reports." - required: true - - - type: textarea - id: expected - attributes: - label: Expected Behavior - description: Tell us what **should** happen. - validations: - required: true - - - type: textarea - id: what-happened - attributes: - label: Current behaviour - description: | - Tell us what happens instead of the expected behavior. - Provide stacktrace and/or screenshots. - validations: - required: true - - - type: textarea - id: reproduce - attributes: - label: Steps to reproduce - description: | - Which exact steps can a developer take to reproduce the issue? - The more detail you provide, the easier it will be to narrow down and fix the bug. - Please paste in tasks and/or queries **as text, not screenshots**. - placeholder: | - Example of the level of detail needed to reproduce any bugs efficiently and reliably. - 1. Go to the '...' page. - 2. Click on the '...' button. - 3. Scroll down to '...'. - 4. Observe the error. - validations: - required: true - - - type: textarea - id: config - attributes: - label: Config yaml - description: | - Please attach the config yaml! - - - type: textarea - id: possible-solution - attributes: - label: Possible solution - description: | - Not obligatory, but please suggest a fix or reason for the bug, if you have an idea. - - - - type: checkboxes - id: operating-systems - attributes: - label: Which Operating Systems are you using? - description: You may select more than one. - options: - - label: Linux - - label: macOS - - label: Windows - - - type: input - id: Python-version - attributes: - label: Python Version - description: Which {Programming} version are you using? - placeholder: 3.10 / please change accordingly - validations: - required: true - - - type: input - id: axolotl-branch-commit - attributes: - label: axolotl branch-commit - description: On which branch/commit are you? - placeholder: main/4d6490b - validations: - required: true - - - type: checkboxes - id: acknowledgements - attributes: - label: 'Acknowledgements' - description: 'Please confirm the following:' - options: - - label: 'My issue title is concise, descriptive, and in title casing.' - required: true - - label: 'I have searched the existing issues to make sure this bug has not been reported yet.' - required: true - - label: 'I am using the latest version of axolotl.' - required: true - - label: 'I have provided enough information for the maintainers to reproduce and diagnose the issue.' - required: true diff --git a/axolotl/.github/ISSUE_TEMPLATE/config.yml b/axolotl/.github/ISSUE_TEMPLATE/config.yml deleted file mode 100644 index e0c5754fb9b395438ff2e7aa04ce8b5c7d93dc0d..0000000000000000000000000000000000000000 --- a/axolotl/.github/ISSUE_TEMPLATE/config.yml +++ /dev/null @@ -1,7 +0,0 @@ -blank_issues_enabled: false -contact_links: - - name: Ask a question - url: https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/q-a - about: Ask questions and discuss with other community members - - name: Discuss the Project in Discord - url: https://discord.gg/HhrNrHJPRb diff --git a/axolotl/.github/ISSUE_TEMPLATE/docs.yml b/axolotl/.github/ISSUE_TEMPLATE/docs.yml deleted file mode 100644 index 2c27af1aa07c7c053413a5d78a3df0e9bc7e1f85..0000000000000000000000000000000000000000 --- a/axolotl/.github/ISSUE_TEMPLATE/docs.yml +++ /dev/null @@ -1,46 +0,0 @@ -name: Documentation Improvement / Clarity -description: Make a suggestion to improve the project documentation. -labels: ['needs triage', 'docs'] -body: - - type: markdown - attributes: - value: '## :book: Documentation :book:' - - type: markdown - attributes: - value: | - * Ask questions in [Discord](https://discord.gg/HhrNrHJPRb). - * Before you file an issue read the [Contributing guide](./CONTRIBUTING.md). - * Check to make sure someone hasn't already opened a [similar issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues). - - type: textarea - attributes: - label: What piece of documentation is affected? - description: Please link to the article you'd like to see updated. - validations: - required: true - - type: textarea - attributes: - label: What part(s) of the article would you like to see updated? - description: | - - Give as much detail as you can to help us understand the change you want to see. - - Why should the docs be changed? What use cases does it support? - - What is the expected outcome? - validations: - required: true - - type: textarea - attributes: - label: Additional Information - description: Add any other context or screenshots about the feature request here. - validations: - required: false - - type: checkboxes - id: acknowledgements - attributes: - label: 'Acknowledgements' - description: 'Please confirm the following:' - options: - - label: 'My issue title is concise, descriptive, and in title casing.' - required: true - - label: 'I have searched the existing issues to make sure this feature has not been requested yet.' - required: true - - label: 'I have provided enough information for the maintainers to understand and evaluate this request.' - required: true diff --git a/axolotl/.github/ISSUE_TEMPLATE/feature-request.yaml b/axolotl/.github/ISSUE_TEMPLATE/feature-request.yaml deleted file mode 100644 index 39b6cb74e1c3b7876806dcbcd6c8a34df8e64e22..0000000000000000000000000000000000000000 --- a/axolotl/.github/ISSUE_TEMPLATE/feature-request.yaml +++ /dev/null @@ -1,63 +0,0 @@ -name: Feature Request / Enhancement -description: Suggest a new feature or feature enhancement for the project -labels: ["enhancement", "needs triage"] -body: - - type: checkboxes - id: no-duplicate-issues - attributes: - label: "⚠️ Please check that this feature request hasn't been suggested before." - description: "There are two locations for previous feature requests. Please search in both. Thank you. The **Label filters** may help make your search more focussed." - options: - - label: "I searched previous [Ideas in Discussions](https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/ideas) didn't find any similar feature requests." - required: true - - label: "I searched previous [Issues](https://github.com/OpenAccess-AI-Collective/axolotl/labels/enhancement) didn't find any similar feature requests." - required: true - - - type: textarea - id: feature-description - validations: - required: true - attributes: - label: "πŸ”– Feature description" - description: "A clear and concise description of what the feature request is." - placeholder: "You should add ..." - - - type: textarea - id: solution - validations: - required: true - attributes: - label: "βœ”οΈ Solution" - description: "A clear and concise description of what you want to happen, and why." - placeholder: "In my use-case, ..." - - - type: textarea - id: alternatives - validations: - required: false - attributes: - label: "❓ Alternatives" - description: "A clear and concise description of any alternative solutions or features you've considered." - placeholder: "I have considered ..." - - - type: textarea - id: additional-context - validations: - required: false - attributes: - label: "πŸ“ Additional Context" - description: "Add any other context or screenshots about the feature request here." - placeholder: "..." - - - type: checkboxes - id: acknowledgements - attributes: - label: 'Acknowledgements' - description: 'Please confirm the following:' - options: - - label: 'My issue title is concise, descriptive, and in title casing.' - required: true - - label: 'I have searched the existing issues to make sure this feature has not been requested yet.' - required: true - - label: 'I have provided enough information for the maintainers to understand and evaluate this request.' - required: true diff --git a/axolotl/.github/PULL_REQUEST_TEMPLATE/pull_request_template_simple.md b/axolotl/.github/PULL_REQUEST_TEMPLATE/pull_request_template_simple.md deleted file mode 100644 index 2f1fe20bd43e9b4c170975a638397ca4a1e953c8..0000000000000000000000000000000000000000 --- a/axolotl/.github/PULL_REQUEST_TEMPLATE/pull_request_template_simple.md +++ /dev/null @@ -1,22 +0,0 @@ - - -# Description - - - -## Motivation and Context - - - - -## How has this been tested? - - - - - -## Screenshots (if appropriate) - -## Types of changes - - diff --git a/axolotl/.github/SECURITY.md b/axolotl/.github/SECURITY.md deleted file mode 100644 index aceb0d1a2e26eb5ad73f56692501b7682cb847b3..0000000000000000000000000000000000000000 --- a/axolotl/.github/SECURITY.md +++ /dev/null @@ -1,9 +0,0 @@ -# Security Policy - -## Supported Versions - -Due to the nature of the fast development that is happening in this project, only the latest released version can be supported. - -## Reporting a Vulnerability - -If you find a vulnerability, please contact us on [Discord](https://discord.gg/xcu3ECkH9a) rather than creating a GitHub issue to allow us some time to fix it before it is a known vulnerability to others. diff --git a/axolotl/.github/SUPPORT.md b/axolotl/.github/SUPPORT.md deleted file mode 100644 index e625ec6d1067cab5cf1aa06db49c6562ff16f2d4..0000000000000000000000000000000000000000 --- a/axolotl/.github/SUPPORT.md +++ /dev/null @@ -1,10 +0,0 @@ -# Support - -If you need help with this project or have questions, please: - -1. Check the documentation. -2. Search the existing issues and pull requests. -3. Create a new issue if your question is not answered or your problem is not solved. -4. Have a look in the [Discord server](https://discord.gg/HhrNrHJPRb) - -Please note that this project is maintained by volunteers who have limited availability. We'll do our best to address your questions and concerns in a timely manner. diff --git a/axolotl/.github/release-drafter.yml b/axolotl/.github/release-drafter.yml deleted file mode 100644 index 348a573a1c58e8112dac16c6772de5fe90bb15bd..0000000000000000000000000000000000000000 --- a/axolotl/.github/release-drafter.yml +++ /dev/null @@ -1,31 +0,0 @@ -name-template: 'v$RESOLVED_VERSION' -tag-template: 'v$RESOLVED_VERSION' -categories: - - title: 'πŸš€ Features' - labels: - - 'feature' - - 'enhancement' - - title: 'πŸ› Bug Fixes' - labels: - - 'fix' - - 'bugfix' - - 'bug' - - title: '🧰 Maintenance' - label: 'chore' -change-template: '- $TITLE @$AUTHOR (#$NUMBER)' -change-title-escapes: '\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks. -version-resolver: - major: - labels: - - 'major' - minor: - labels: - - 'minor' - patch: - labels: - - 'patch' - default: patch -template: | - ## What’s Changed - - $CHANGES diff --git a/axolotl/.github/workflows/base.yml b/axolotl/.github/workflows/base.yml deleted file mode 100644 index 6b90d1b501fbb5389267d198073ec66aa7095108..0000000000000000000000000000000000000000 --- a/axolotl/.github/workflows/base.yml +++ /dev/null @@ -1,66 +0,0 @@ -name: ci-cd-base - -on: - push: - branches: - - "main-base" - - "dev-base" - -jobs: - build-base: - if: github.repository_owner == 'OpenAccess-AI-Collective' - # this job needs to be run on self-hosted GPU runners... - runs-on: self-hosted - strategy: - fail-fast: false - matrix: - include: - - cuda: "118" - cuda_version: 11.8.0 - python_version: "3.9" - pytorch: 2.0.1 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" - - cuda: "118" - cuda_version: 11.8.0 - python_version: "3.10" - pytorch: 2.0.1 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" - - cuda: "118" - cuda_version: 11.8.0 - python_version: "3.10" - pytorch: 2.1.1 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" - - cuda: "121" - cuda_version: 12.1.0 - python_version: "3.10" - pytorch: 2.1.1 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" - steps: - - name: Checkout - uses: actions/checkout@v3 - - name: Docker metadata - id: metadata - uses: docker/metadata-action@v3 - with: - images: winglian/axolotl-base - - name: Login to Docker Hub - uses: docker/login-action@v2 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - - name: Build - uses: docker/build-push-action@v4 - with: - context: . - file: ./docker/Dockerfile-base - push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} - labels: ${{ steps.metadata.outputs.labels }} - build-args: | - CUDA_VERSION=${{ matrix.cuda_version }} - CUDA=${{ matrix.cuda }} - PYTHON_VERSION=${{ matrix.python_version }} - PYTORCH_VERSION=${{ matrix.pytorch }} - TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }} diff --git a/axolotl/.github/workflows/main.yml b/axolotl/.github/workflows/main.yml deleted file mode 100644 index 2f0b0745016ecf92c66796343207f66c4dc92844..0000000000000000000000000000000000000000 --- a/axolotl/.github/workflows/main.yml +++ /dev/null @@ -1,136 +0,0 @@ -name: ci-cd - -on: - push: - branches: - - "main" - -jobs: - build-axolotl: - if: github.repository_owner == 'OpenAccess-AI-Collective' - # this job needs to be run on self-hosted GPU runners... - strategy: - fail-fast: false - matrix: - include: - - cuda: 118 - cuda_version: 11.8.0 - python_version: "3.9" - pytorch: 2.0.1 - axolotl_extras: - - cuda: 118 - cuda_version: 11.8.0 - python_version: "3.10" - pytorch: 2.0.1 - axolotl_extras: - is_latest: true - - cuda: 118 - cuda_version: 11.8.0 - python_version: "3.10" - pytorch: 2.1.1 - axolotl_extras: - - cuda: 121 - cuda_version: 12.1.0 - python_version: "3.10" - pytorch: 2.1.1 - axolotl_extras: - runs-on: [self-hosted, gpu, docker] - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Docker metadata - id: metadata - uses: docker/metadata-action@v5 - with: - images: winglian/axolotl - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - name: Login to Docker Hub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - # guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/ - - name: Build and export to Docker - uses: docker/build-push-action@v5 - with: - context: . - load: true - build-args: | - BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} - CUDA=${{ matrix.cuda }} - PYTORCH_VERSION=${{ matrix.pytorch }} - file: ./docker/Dockerfile - tags: | - ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} - ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} - labels: ${{ steps.metadata.outputs.labels }} - - name: Unit Tests - run: | - docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ - - name: Push to Docker Hub - if: github.event_name != 'pull_request' - run: | - docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} - latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} - if [ -n "$latest_tag" ]; then - docker push "$latest_tag" - fi - - build-axolotl-runpod: - needs: build-axolotl - if: github.repository_owner == 'OpenAccess-AI-Collective' - # this job needs to be run on self-hosted GPU runners... - strategy: - matrix: - include: - - cuda: 118 - cuda_version: 11.8.0 - python_version: "3.9" - pytorch: 2.0.1 - axolotl_extras: - - cuda: 118 - cuda_version: 11.8.0 - python_version: "3.10" - pytorch: 2.0.1 - axolotl_extras: - is_latest: true - - cuda: 118 - cuda_version: 11.8.0 - python_version: "3.10" - pytorch: 2.1.1 - axolotl_extras: - - cuda: 121 - cuda_version: 12.1.0 - python_version: "3.10" - pytorch: 2.1.1 - axolotl_extras: - runs-on: [self-hosted, gpu, docker] - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Docker metadata - id: metadata - uses: docker/metadata-action@v5 - with: - images: winglian/axolotl-runpod - - name: Login to Docker Hub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - - name: Build - uses: docker/build-push-action@v5 - with: - context: . - build-args: | - BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} - CUDA=${{ matrix.cuda }} - file: ./docker/Dockerfile-runpod - push: ${{ github.event_name != 'pull_request' }} - tags: | - ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} - ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} - labels: ${{ steps.metadata.outputs.labels }} diff --git a/axolotl/.github/workflows/pypi.yml b/axolotl/.github/workflows/pypi.yml deleted file mode 100644 index 27b1cb8d69d282b6913002cd5542be18af13d50a..0000000000000000000000000000000000000000 --- a/axolotl/.github/workflows/pypi.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: publish pypi - -on: - push: - tags: - - '*' - -jobs: - pypi-publish: - name: Upload release to PyPI - runs-on: ubuntu-latest - environment: - name: pypi - url: https://pypi.org/p/axolotl - permissions: - id-token: write # IMPORTANT: this permission is mandatory for trusted publishing - steps: - - name: Check out repository code - uses: actions/checkout@v3 - - - name: Setup Python - uses: actions/setup-python@v4 - with: - python-version: "3.10" - - - name: Install dependencies - run: | - pip3 install wheel - pip3 install -e . - pip3 install -r requirements-tests.txt - - - name: Extract tag name - id: tag - run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3) - - - name: Update version in setup.py - run: >- - sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py - - - name: Build a binary wheel - run: >- - python setup.py sdist bdist_wheel - - - name: Publish package distributions to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/axolotl/.github/workflows/tests.yml b/axolotl/.github/workflows/tests.yml deleted file mode 100644 index ad2cb428b06a442b40f0f07229686902bd529aa2..0000000000000000000000000000000000000000 --- a/axolotl/.github/workflows/tests.yml +++ /dev/null @@ -1,81 +0,0 @@ -name: Tests -on: - # check on push/merge to main, PRs, and manual triggers - push: - branches: - - "main" - paths: - - '**.py' - - 'requirements.txt' - pull_request: - paths: - - '**.py' - - 'requirements.txt' - workflow_dispatch: - -jobs: - pre-commit: - name: pre-commit - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: "3.9" - cache: 'pip' # caching pip dependencies - - uses: pre-commit/action@v3.0.0 - - pytest: - name: PyTest - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python_version: ["3.9", "3.10"] - timeout-minutes: 10 - - steps: - - name: Check out repository code - uses: actions/checkout@v3 - - - name: Setup Python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python_version }} - cache: 'pip' # caching pip dependencies - - - name: Install dependencies - run: | - pip3 install -U -e . - pip3 install -r requirements-tests.txt - - - name: Run tests - run: | - pytest --ignore=tests/e2e/ tests/ - - e2e-test: - name: E2E Tests - runs-on: [self-hosted, gpu] - timeout-minutes: 20 - needs: [pre-commit, pytest] - - steps: - - name: Check out repository code - uses: actions/checkout@v3 - - - name: Setup Python - uses: actions/setup-python@v4 - with: - python-version: "3.10" -# cache: 'pip' # caching pip dependencies - - - name: Install dependencies - run: | - pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1 - pip3 uninstall -y transformers accelerate - pip3 install -U -e .[flash-attn,mamba-ssm] - pip3 install -r requirements-tests.txt - - - name: Run e2e tests - run: | - pytest tests/e2e/ diff --git a/axolotl/.gitignore b/axolotl/.gitignore deleted file mode 100644 index f46b6808f5d156081145f99423a6d83ccc037293..0000000000000000000000000000000000000000 --- a/axolotl/.gitignore +++ /dev/null @@ -1,167 +0,0 @@ -**/axolotl.egg-info -configs - -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -.idea/ - -# WandB -# wandb creates a folder to store logs for training runs -wandb diff --git a/axolotl/.isort.cfg b/axolotl/.isort.cfg deleted file mode 100644 index 79067a7c91c9364b6a26abeaa8420725f25050c4..0000000000000000000000000000000000000000 --- a/axolotl/.isort.cfg +++ /dev/null @@ -1,3 +0,0 @@ -[settings] -profile=black -known_third_party=wandb diff --git a/axolotl/.mypy.ini b/axolotl/.mypy.ini deleted file mode 100644 index 1bbe04d2c7819ebbed70c1c0b8ce9a6b52ace3ab..0000000000000000000000000000000000000000 --- a/axolotl/.mypy.ini +++ /dev/null @@ -1,48 +0,0 @@ -[mypy] - -exclude = venv - -[mypy-alpaca_lora_4bit.*] -ignore_missing_imports = True - -[mypy-axolotl.monkeypatch.*] -ignore_errors = True - -[mypy-axolotl.models.mixtral.*] -ignore_errors = True - -[mypy-axolotl.models.phi.*] -ignore_errors = True - -[mypy-flash_attn.*] -ignore_missing_imports = True - -[mypy-huggingface_hub] -ignore_missing_imports = True - -[mypy-transformers.*] -ignore_missing_imports = True - -[mypy-peft] -ignore_missing_imports = True - -[mypy-wandb] -ignore_missing_imports = True - -[mypy-bitsandbytes] -ignore_missing_imports = True - -[mypy-datasets] -ignore_missing_imports = True - -[mypy-fire] -ignore_missing_imports = True - -[mypy-setuptools] -ignore_missing_imports = True - -[mypy-addict] -ignore_missing_imports = True - -[mypy-xformers.*] -ignore_missing_imports = True diff --git a/axolotl/.pre-commit-config.yaml b/axolotl/.pre-commit-config.yaml deleted file mode 100644 index c811a6eb30906f71846a95d451e93639fe6cf50f..0000000000000000000000000000000000000000 --- a/axolotl/.pre-commit-config.yaml +++ /dev/null @@ -1,42 +0,0 @@ -default_language_version: - python: python3 - -repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: - - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace -- repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black -- repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort -- repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 - hooks: - - id: flake8 -- repo: https://github.com/PyCQA/pylint - rev: v2.17.4 - hooks: - - id: pylint -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.3.0 - hooks: - - id: mypy - additional_dependencies: - [ - 'types-PyYAML', - ] -- repo: https://github.com/PyCQA/bandit - rev: 1.7.5 - hooks: - - id: bandit - args: [ - '--ini', - '.bandit', - ] diff --git a/axolotl/.pylintrc b/axolotl/.pylintrc deleted file mode 100644 index ed973d2859b1faa54ec53706a2e4ac15b70556a4..0000000000000000000000000000000000000000 --- a/axolotl/.pylintrc +++ /dev/null @@ -1,14 +0,0 @@ -[MASTER] -init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))" - -[TYPECHECK] - -# List of members which are set dynamically and missed by Pylint inference -# system, and so shouldn't trigger E1101 when accessed. -generated-members=numpy.*, torch.* - - -[pylint.messages_control] -disable=missing-function-docstring, line-too-long, import-error, - too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, - too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, diff --git a/axolotl/FAQS.md b/axolotl/FAQS.md deleted file mode 100644 index f3c9dd52536aa7e9579f8058961b437e80cc8162..0000000000000000000000000000000000000000 --- a/axolotl/FAQS.md +++ /dev/null @@ -1,7 +0,0 @@ -# FAQs - -- Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874) -- Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases -- `Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c` -`/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598: arrow::fs::FinalizeS3 was not called even though S3 was initialized.` -This could lead to a segmentation fault at exit. Try reinstalling bitsandbytes and transformers from source. diff --git a/axolotl/LICENSE b/axolotl/LICENSE deleted file mode 100644 index d645695673349e3947e8e5ae42332d0ac3164cd7..0000000000000000000000000000000000000000 --- a/axolotl/LICENSE +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/axolotl/README.md b/axolotl/README.md deleted file mode 100644 index 2e0a07855c28b9b4b5fc4db09943eeb68f93cf7d..0000000000000000000000000000000000000000 --- a/axolotl/README.md +++ /dev/null @@ -1,1132 +0,0 @@ -# Axolotl - -Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures. - -Features: -- Train various Huggingface models such as llama, pythia, falcon, mpt -- Supports fullfinetune, lora, qlora, relora, and gptq -- Customize configurations using a simple yaml file or CLI overwrite -- Load different dataset formats, use custom formats, or bring your own tokenized datasets -- Integrated with xformer, flash attention, rope scaling, and multipacking -- Works with single GPU or multiple GPUs via FSDP or Deepspeed -- Easily run with Docker locally or on the cloud -- Log results and optionally checkpoints to wandb -- And more! - - - - - - - -
- -## Table of Contents -- [Introduction](#axolotl) -- [Supported Features](#axolotl-supports) -- [Quickstart](#quickstart-) -- [Installation](#installation) - - [Docker](#docker) - - [Conda/Pip venv](#condapip-venv) - - [Runpod](#runpod) - - [LambdaLabs](#lambdalabs) - - [Windows](#windows) - - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot) -- [Dataset](#dataset) - - [How to Add Custom Prompts](#how-to-add-custom-prompts) - - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset) -- [Config](#config) - - [Train](#train) - - [Inference](#inference) - - [Merge LORA to Base](#merge-lora-to-base) - - [Special Tokens](#special-tokens) -- [Common Errors](#common-errors-) - - [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training) -- [Need Help?](#need-help-) -- [Badge](#badge-) -- [Community Showcase](#community-showcase) -- [Contributing](#contributing-) - - - -
- axolotl -
-

- Axolotl provides a unified repository for fine-tuning
a variety of AI models with ease
-

-

- Go ahead and Axolotl questions!! -

- pre-commit - PyTest Status -
-
- -
- -## Axolotl supports - -| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn | -|-------------|:----------|:-----|-------|------|-------------------|------------|--------------| -| llama | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | -| Mistral | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | -| Mixtral-MoE | βœ… | βœ… | βœ… | ❓ | ❓ | ❓ | ❓ | -| Pythia | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❓ | -| cerebras | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❓ | -| btlm | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❓ | -| mpt | βœ… | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ | -| falcon | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❓ | -| gpt-j | βœ… | βœ… | βœ… | ❌ | ❌ | ❓ | ❓ | -| XGen | βœ… | ❓ | βœ… | ❓ | ❓ | ❓ | βœ… | -| phi | βœ… | βœ… | βœ… | ❓ | ❓ | ❓ | ❓ | -| RWKV | βœ… | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | -| Qwen | βœ… | βœ… | βœ… | ❓ | ❓ | ❓ | ❓ | - - -## Quickstart ⚑ - -Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task. - -**Requirements**: Python >=3.9 and Pytorch >=2.0. - -`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"` - -### For developers -```bash -git clone https://github.com/OpenAccess-AI-Collective/axolotl -cd axolotl - -pip3 install packaging -pip3 install -e '.[flash-attn,deepspeed]' -``` - -### Usage -```bash -# finetune lora -accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml - -# inference -accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ - --lora_model_dir="./lora-out" - -# gradio -accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ - --lora_model_dir="./lora-out" --gradio -``` - -## Installation - -### Environment - -#### Docker - ```bash - docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1 - ``` - - Or run on the current files for development: - - ```sh - docker compose up -d - ``` - -
- - Docker advanced - - A more powerful Docker command to run would be this: - - ```bash - docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1 - ``` - - It additionally: - * Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args. - * Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args. - * The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal. - * The `--privileged` flag gives all capabilities to the container. - * The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed. - - [More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem) - -
- -#### Conda/Pip venv - 1. Install python >=**3.9** - - 2. Install pytorch stable https://pytorch.org/get-started/locally/ - - 3. Install Axolotl along with python dependencies - ```bash - pip3 install packaging - pip3 install -e '.[flash-attn,deepspeed]' - ``` - 4. (Optional) Login to Huggingface to use gated models/datasets. - ```bash - huggingface-cli login - ``` - Get the token at huggingface.co/settings/tokens - -#### Runpod - -Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz) - -#### LambdaLabs -
- - Click to Expand - - 1. Install python - ```bash - sudo apt update - sudo apt install -y python3.9 - - sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1 - sudo update-alternatives --config python # pick 3.9 if given option - python -V # should be 3.9 - - ``` - - 2. Install pip - ```bash - wget https://bootstrap.pypa.io/get-pip.py - python get-pip.py - ``` - - 3. Install torch - ```bash - pip3 install -U torch --index-url https://download.pytorch.org/whl/cu118 - ``` - - 4. Axolotl - ```bash - git clone https://github.com/OpenAccess-AI-Collective/axolotl - cd axolotl - - pip3 install packaging - pip3 install -e '.[flash-attn,deepspeed]' - pip3 install protobuf==3.20.3 - pip3 install -U --ignore-installed requests Pillow psutil scipy - ``` - - 5. Set path - ```bash - export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH - ``` -
- -#### Windows -Please use WSL or Docker! - - -#### Launching on public clouds via SkyPilot -To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html): -```bash -pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds -sky check -``` -Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`: -``` -git clone https://github.com/skypilot-org/skypilot.git -cd skypilot/llm/axolotl -``` -Use one command to launch: -```bash -# On-demand -HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN - -# Managed spot (auto-recovery on preemption) -HF_TOKEN=xx BUCKET= sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET -``` - - -### Dataset - -Axolotl supports a variety of dataset formats. Below are some of the formats you can use. -Have dataset(s) in one of the following format (JSONL recommended): - -- `alpaca`: instruction; input(optional) - ```json - {"instruction": "...", "input": "...", "output": "..."} - ``` -- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt) - ```json - {"conversations": [{"from": "...", "value": "..."}]} - ``` -- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details) - ```yml - datasets: - - path: - type: sharegpt - conversation: llama-2 - ``` -- `completion`: raw corpus - ```json - {"text": "..."} - ``` - -
- -See other formats - -- `jeopardy`: question and answer - ```json - {"question": "...", "category": "...", "answer": "..."} - ``` -- `oasst`: instruction - ```json - {"INSTRUCTION": "...", "RESPONSE": "..."} - ``` -- `gpteacher`: instruction; input(optional) - ```json - {"instruction": "...", "input": "...", "response": "..."} - ``` -- `reflection`: instruction with reflect; input(optional) - ```json - {"instruction": "...", "input": "...", "output": "...", "reflection": "...", "corrected": "..."} - ``` -- `explainchoice`: question, choices, (solution OR explanation) - ```json - {"question": "...", "choices": ["..."], "solution": "...", "explanation": "..."} - ``` -- `concisechoice`: question, choices, (solution OR explanation) - ```json - {"question": "...", "choices": ["..."], "solution": "...", "explanation": "..."} - ``` -- `summarizetldr`: article and summary - ```json - {"article": "...", "summary": "..."} - ``` -- `alpaca_chat`: basic instruct for alpaca chat - ```json - {"instruction": "...", "input": "...", "response": "..."} - ``` -- `alpaca_chat.load_qa`: question and answer for alpaca chat - ```json - {"question": "...", "answer": "..."} - ``` -- `alpaca_chat.load_concise`: question and answer for alpaca chat, for concise answers - ```json - {"instruction": "...", "input": "...", "response": "..."} - ``` -- `alpaca_chat.load_camel_ai`: question and answer for alpaca chat, for load_camel_ai - ```json - {"message_1": "...", "message_2": "..."} - ``` -- `alpaca_w_system.load_open_orca`: support for open orca datasets with included system prompts, instruct - ```json - {"system_prompt": "...", "question": "...", "response": "..."} - ``` -- `context_qa`: in context question answering from an article - ```json - {"article": "...", "question": "...", "answer": "..."} - ``` -- `context_qa.load_v2`: in context question answering (alternate) - ```json - {"context": "...", "question": "...", "answer": "..."} - ``` -- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context - ```json - {"article": "...", "unanswerable_question": "..."} - ``` -- `creative_acr.load_answer`: instruction and revision - ```json - {"instruction": "...", "revision": "..."} - ``` -- `creative_acr.load_critique`: critique - ```json - {"scores": "...", "critiques": "...", "instruction": "...", "answer": "..."} - ``` -- `creative_acr.load_revise`: critique and revise - ```json - {"scores": "...", "critiques": "...", "instruction": "...", "answer": "...", "revision": "..."} - ``` -- `pygmalion`: pygmalion - ```json - {"conversations": [{"role": "...", "value": "..."}]} - ``` -- `metharme`: instruction, adds additional eos tokens - ```json - {"prompt": "...", "generation": "..."} - ``` -- `sharegpt.load_role`: conversations where `role` is used instead of `from` - ```json - {"conversations": [{"role": "...", "value": "..."}]} - ``` -- `sharegpt.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt - ```json - {"conversations": [{"from": "...", "value": "..."}]} - ``` -- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny - ```json - {"conversations": [{"title": "...", "text": "...", "explanation": "..."}]} - ``` - -
- -#### How to add custom prompts - -For a dataset that is preprocessed for instruction purposes: - -```json -{"instruction": "...", "output": "..."} -``` - -You can use this example in your YAML config: - -```yaml -datasets: - - path: repo - type: - system_prompt: "" - field_system: system - format: "[INST] {instruction} [/INST]" - no_input_format: "[INST] {instruction} [/INST]" -``` - -#### How to use your custom pretokenized dataset - -- Do not pass a `type:` -- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels` - - -### Config - -See [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are: - -- model - ```yaml - base_model: ./llama-7b-hf # local or huggingface repo - ``` - Note: The code will load the right architecture. - -- dataset - ```yaml - sequence_len: 2048 # max token length for prompt - - # huggingface repo - datasets: - - path: vicgalle/alpaca-gpt4 - type: alpaca # format from earlier - - # huggingface repo with specific configuration/subset - datasets: - - path: EleutherAI/pile - name: enron_emails - type: completion # format from earlier - field: text # Optional[str] default: text, field to use for completion data - - # huggingface repo with multiple named configurations/subsets - datasets: - - path: bigcode/commitpackft - name: - - ruby - - python - - typescript - type: ... # unimplemented custom format - - # fastchat conversation - # See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py - datasets: - - path: ... - type: sharegpt - conversation: chatml - - # local - datasets: - - path: data.jsonl # or json - ds_type: json # see other options below - type: alpaca - - # dataset with splits, but no train split - dataset: - - path: knowrohit07/know_sql - type: context_qa.load_v2 - train_on_split: validation - - # loading from s3 or gcs - # s3 creds will be loaded from the system default and gcs only supports public access - dataset: - - path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs. - ... - ``` - -- loading - ```yaml - load_in_4bit: true - load_in_8bit: true - bf16: true # require >=ampere - fp16: true - tf32: true # require >=ampere - bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision) - float16: true # use instead of fp16 when you don't want AMP - ``` - Note: Repo does not do 4-bit quantization. - -- lora - ```yaml - adapter: lora # qlora or leave blank for full finetune - lora_r: 8 - lora_alpha: 16 - lora_dropout: 0.05 - lora_target_modules: - - q_proj - - v_proj - ``` - -
- -All yaml options (click me) - -```yaml -# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files -# This can also be a relative path to a model on disk -base_model: ./llama-7b-hf -# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc) -base_model_ignore_patterns: -# If the base_model repo on hf hub doesn't include configuration .json files, -# You can set that here, or leave this empty to default to base_model -base_model_config: ./llama-7b-hf -# You can specify to choose a specific model revision from huggingface hub -model_revision: -# Optional tokenizer configuration override in case you want to use a different tokenizer -# than the one defined in the base model -tokenizer_config: -# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too -model_type: AutoModelForCausalLM -# Corresponding tokenizer for the model AutoTokenizer is a good choice -tokenizer_type: AutoTokenizer -# Trust remote code for untrusted source -trust_remote_code: -# use_fast option for tokenizer loading from_pretrained, default to True -tokenizer_use_fast: -# Whether to use the legacy tokenizer setting, defaults to True -tokenizer_legacy: -# Resize the model embeddings when new tokens are added to multiples of 32 -# This is reported to improve training speed on some models -resize_token_embeddings_to_32x: - -# Used to identify which the model is based on -is_falcon_derived_model: -is_llama_derived_model: -# Please note that if you set this to true, `padding_side` will be set to "left" by default -is_mistral_derived_model: -is_qwen_derived_model: - -# optional overrides to the base model configuration -model_config: - # RoPE Scaling https://github.com/huggingface/transformers/pull/24653 - rope_scaling: - type: # linear | dynamic - factor: # float - - -# Whether you are training a 4-bit GPTQ quantized model -gptq: true -gptq_groupsize: 128 # group size -gptq_model_v1: false # v1 or v2 - -# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer -load_in_8bit: true -# Use bitsandbytes 4 bit -load_in_4bit: - -# Use CUDA bf16 -bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere -# Use CUDA fp16 -fp16: true -# Use CUDA tf32 -tf32: true # require >=ampere - -# No AMP (automatic mixed precision) -bfloat16: true # require >=ampere -float16: true - -# A list of one or more datasets to finetune the model with -datasets: - # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files - - path: vicgalle/alpaca-gpt4 - # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] - type: alpaca # format | format: (chat/instruct) | .load_ - ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file - data_files: # Optional[str] path to source data files - shards: # Optional[int] number of shards to split data into - name: # Optional[str] name of dataset configuration to load - train_on_split: train # Optional[str] name of dataset split to load from - - # Optional[str] fastchat conversation type, only used with type: sharegpt - conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py - field_human: # Optional[str]. Human key to use for conversation. - field_model: # Optional[str]. Assistant key to use for conversation. - - # Custom user prompt - - path: repo - type: - # The below are defaults. only set what's needed. - system_prompt: "" - system_format: "{system}" - field_system: system - field_instruction: instruction - field_input: input - field_output: output - - # Customizable to be single line or multi-line - # 'format' can include {input} - format: |- - User: {instruction} {input} - Assistant: - # 'no_input_format' cannot include {input} - no_input_format: "{instruction} " - - # For `completion` datsets only, uses the provided field instead of `text` column - field: - -# Axolotl attempts to save the dataset as an arrow after packing the data together so -# subsequent training attempts load faster, relative path -dataset_prepared_path: data/last_run_prepared -# Push prepared dataset to hub -push_dataset_to_hub: # repo path -# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` -# if not set. -dataset_processes: # defaults to os.cpu_count() if not set -# push checkpoints to hub -hub_model_id: # repo path to push finetuned model -# how to push checkpoints to hub -# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy -hub_strategy: -# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets -# Required to be true when used in combination with `push_dataset_to_hub` -hf_use_auth_token: # boolean -# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval. -val_set_size: 0.04 -# Num shards for whole dataset -dataset_shard_num: -# Index of shard to use for whole dataset -dataset_shard_idx: - -# The maximum length of an input to train with, this should typically be less than 2048 -# as most models have a token/context limit of 2048 -sequence_len: 2048 -# Pad inputs so each step uses constant sized buffers -# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently -pad_to_sequence_len: -# Max sequence length to concatenate training samples together up to -# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning -# FutureWarning: This will soon be DEPRECATED -max_packed_sequence_len: 1024 -# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true' -sample_packing: -# Set to 'false' if getting errors during eval with sample_packing on. -eval_sample_packing: -# You can set these packing optimizations AFTER starting a training at least once. -# The trainer will provide recommended values for these values. -sample_packing_eff_est: -total_num_tokens: - -# Passed through to transformers when loading the model when launched without accelerate -# Use `sequential` when training w/ model parallelism to limit memory -device_map: -# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model. -max_memory: - -# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model -adapter: lora -# If you already have a lora model trained that you want to load, put that here. -# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`. -lora_model_dir: - -# LoRA hyperparameters -# For more details about the following options, see: -# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2 -lora_r: 8 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: - - q_proj - - v_proj -# - k_proj -# - o_proj -# - gate_proj -# - down_proj -# - up_proj -lora_target_linear: # If true, will target all linear layers - -# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. -# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. -# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities. -# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994 -lora_modules_to_save: -# - embed_tokens -# - lm_head - -# Once you complete training, the model will be saved to the following directory. -# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory. -# Make sure `lora_model_dir` points to this directory if you want to use the trained model. -lora_out_dir: -lora_fan_in_fan_out: false - -# ReLoRA configuration -# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed -relora_steps: # Number of steps per ReLoRA restart -relora_warmup_steps: # Number of per-restart warmup steps -relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings - -# wandb configuration if you're using it -# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`. -wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb -wandb_project: # Your wandb project name -wandb_entity: # A wandb Team name if using a Team -wandb_watch: -wandb_name: # Set the name of your wandb run -wandb_run_id: # Set the ID of your wandb run -wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training - -# Where to save the full-finetuned model to -output_dir: ./completed-model - -# Whether to use torch.compile and which backend to use -torch_compile: # bool -torch_compile_backend: # Optional[str] - -# Training hyperparameters - -# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps. -gradient_accumulation_steps: 1 -# The number of samples to include in each batch. This is the number of samples sent to each GPU. -micro_batch_size: 2 -eval_batch_size: -num_epochs: 4 -warmup_steps: 100 # cannot use with warmup_ratio -warmup_ratio: 0.05 # cannot use with warmup_steps -learning_rate: 0.00003 -lr_quadratic_warmup: -logging_steps: -eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps -evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps -save_strategy: # Set to `no` to skip checkpoint saves -save_steps: # Leave empty to save at each epoch -saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps -save_total_limit: # Checkpoints saved at a time -# Maximum number of iterations to train for. It precedes num_epochs which means that -# if both are set, num_epochs will not be guaranteed. -# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps -max_steps: - -eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 -eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 - -loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) -loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) - -# Save model as safetensors (require safetensors package) -save_safetensors: - -# Whether to mask out or include the human's prompt from the training labels -train_on_inputs: false -# Group similarly sized data to minimize padding. -# May be slower to start, as it must download and sort the entire dataset. -# Note that training loss may have an oscillating pattern with this enabled. -group_by_length: false - -# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing -gradient_checkpointing: false - -# Stop training after this many evaluation losses have increased in a row -# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback -early_stopping_patience: 3 - -# Specify a scheduler and kwargs to use with the optimizer -lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine -lr_scheduler_kwargs: - -# For one_cycle optim -lr_div_factor: # Learning rate div factor - -# For log_sweep optim -log_sweep_min_lr: -log_sweep_max_lr: - -# Specify optimizer -# Valid values are driven by the Transformers OptimizerNames class, see: -# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134 -# -# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of -# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used -# in the examples/ for your model and fine-tuning use case. -# -# Valid values for 'optimizer' include: -# - adamw_hf -# - adamw_torch -# - adamw_torch_fused -# - adamw_torch_xla -# - adamw_apex_fused -# - adafactor -# - adamw_anyprecision -# - sgd -# - adagrad -# - adamw_bnb_8bit -# - lion_8bit -# - lion_32bit -# - paged_adamw_32bit -# - paged_adamw_8bit -# - paged_lion_32bit -# - paged_lion_8bit -optimizer: -# Specify weight decay -weight_decay: -# adamw hyperparams -adam_beta1: -adam_beta2: -adam_epsilon: -# Gradient clipping max norm -max_grad_norm: - -# Augmentation techniques -# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings -# currently only supported on Llama and Mistral -neftune_noise_alpha: - -# Whether to bettertransformers -flash_optimum: -# Whether to use xformers attention patch https://github.com/facebookresearch/xformers: -xformers_attention: -# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention: -flash_attention: -flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only -flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only -flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation -flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation -# Whether to use scaled-dot-product attention -# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html -sdp_attention: - -# Resume from a specific checkpoint dir -resume_from_checkpoint: -# If resume_from_checkpoint isn't set and you simply want it to start where it left off. -# Be careful with this being turned on between different models. -auto_resume_from_checkpoints: false - -# Don't mess with this, it's here for accelerate and torchrun -local_rank: - -# Add or change special tokens. -# If you add tokens here, you don't need to add them to the `tokens` list. -special_tokens: - # bos_token: "" - # eos_token: "" - # unk_token: "" - -# Add extra tokens. -tokens: - -# FSDP -fsdp: -fsdp_config: - -# Deepspeed config path. e.g., deepspeed/zero3.json -deepspeed: - -# Advanced DDP Arguments -ddp_timeout: -ddp_bucket_cap_mb: -ddp_broadcast_buffers: - -# Path to torch distx for optim 'adamw_anyprecision' -torchdistx_path: - -# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize -pretraining_dataset: - -# Debug mode -debug: - -# Seed -seed: - -# Allow overwrite yml config using from cli -strict: -``` - -
- -
- Understanding of batch size and gradient accumulation steps -
-Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning. - -This method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why: - -1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption. - -2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch. - -**Example 1:** -Micro batch size: 3 -Gradient accumulation steps: 2 -Number of GPUs: 3 -Total batch size = 3 * 2 * 3 = 18 - -``` -| GPU 1 | GPU 2 | GPU 3 | -|----------------|----------------|----------------| -| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 | -| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 | -|----------------|----------------|----------------| -| β†’ (accumulate) | β†’ (accumulate) | β†’ (accumulate) | -|----------------|----------------|----------------| -| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 | -| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 | -|----------------|----------------|----------------| -| β†’ (apply) | β†’ (apply) | β†’ (apply) | - -Accumulated gradient for the weight w1 after the second iteration (considering all GPUs): -Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18 - -Weight update for w1: -w1_new = w1_old - learning rate x (Total gradient for w1 / 18) -``` - -**Example 2:** -Micro batch size: 2 -Gradient accumulation steps: 1 -Number of GPUs: 3 -Total batch size = 2 * 1 * 3 = 6 - -``` -| GPU 1 | GPU 2 | GPU 3 | -|-----------|-----------|-----------| -| S1, S2 | S3, S4 | S5, S6 | -| e1, e2 | e3, e4 | e5, e6 | -|-----------|-----------|-----------| -| β†’ (apply) | β†’ (apply) | β†’ (apply) | - -Accumulated gradient for the weight w1 (considering all GPUs): -Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 - -Weight update for w1: -w1_new = w1_old - learning rate Γ— (Total gradient for w1 / 6) -``` - -
- -### Train - -Run -```bash -accelerate launch -m axolotl.cli.train your_config.yml -``` - -#### Preprocess dataset - -You can optionally pre-tokenize dataset with the following before finetuning. -This is recommended for large datasets. - -- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface. -- Use `--debug` to see preprocessed examples. - -```bash -python -m axolotl.cli.preprocess your_config.yml -``` - -#### Multi-GPU - -Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed -is the recommended multi-GPU option currently because FSDP may experience -[loss instability](https://github.com/huggingface/transformers/issues/26498). - -##### DeepSpeed - -Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you -might typically be able to fit into your GPU's VRAM. More information about the various optimization types -for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated - -We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3. - -```yaml -deepspeed: deepspeed/zero1.json -``` - -```shell -accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json -``` - -##### FSDP - -- llama FSDP -```yaml -fsdp: - - full_shard - - auto_wrap -fsdp_config: - fsdp_offload_params: true - fsdp_state_dict_type: FULL_STATE_DICT - fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer -``` - -##### Weights & Biases Logging - -Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`. - -- wandb options -```yaml -wandb_mode: -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -``` - -##### Special Tokens - -It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this: - -```yml -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" -tokens: # these are delimiters - - "<|im_start|>" - - "<|im_end|>" -``` - -When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary. - -### Inference Playground - -Axolotl allows you to load your model in an interactive terminal playground for quick experimentation. -The config file is the same config file used for training. - -Pass the appropriate flag to the inference command, depending upon what kind of model was trained: - -- Pretrained LORA: - ```bash - python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir" - ``` -- Full weights finetune: - ```bash - python -m axolotl.cli.inference examples/your_config.yml --base_model="./completed-model" - ``` -- Full weights finetune w/ a prompt from a text file: - ```bash - cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \ - --base_model="./completed-model" --prompter=None --load_in_8bit=True - ``` --- With gradio hosting - ```bash - python -m axolotl.cli.inference examples/your_config.yml --gradio - ``` - -Please use `--sample_packing False` if you have it on and receive the error similar to below: - -> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1 - -### Merge LORA to base - -Add below flag to train command above - -```bash -python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False -``` - -If you run out of CUDA memory, you can try to merge in system RAM with - -```bash -CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ... -``` - -## Common Errors 🧰 - -See also the [FAQ's](./docs/faq.md). - -> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it: - -Please reduce any below - - `micro_batch_size` - - `eval_batch_size` - - `gradient_accumulation_steps` - - `sequence_len` - -If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command. - -Using adamw_bnb_8bit might also save you some memory. - -> `failed (exitcode: -9)` - -Usually means your system has run out of system memory. -Similarly, you should consider reducing the same settings as when you run out of VRAM. -Additionally, look into upgrading your system RAM which should be simpler than GPU upgrades. - -> RuntimeError: expected scalar type Float but found Half - -Try set `fp16: true` - -> NotImplementedError: No operator found for `memory_efficient_attention_forward` ... - -Try to turn off xformers. - -> accelerate config missing - -It's safe to ignore it. - -> NCCL Timeouts during training - -See the [NCCL](docs/nccl.md) guide. - - -### Tokenization Mismatch b/w Inference & Training - -For many formats, Axolotl constructs prompts by concatenating token ids _after_ tokenizing strings. The reason for concatenating token ids rather than operating on strings is to maintain precise accounting for attention masks. - -If you decode a prompt constructed by axolotl, you might see spaces between tokens (or lack thereof) that you do not expect, especially around delimiters and special tokens. When you are starting out with a new format, you should always do the following: - -1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer. -2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string. -3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly. -4. As an additional troubleshooting step, you can look look at the token ids between 1 and 2 to make sure they are identical. - -Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example. - -## Need help? πŸ™‹β™‚οΈ - -Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you - -## Badge ❀🏷️ - -Building something cool with Axolotl? Consider adding a badge to your model card. - -```markdown -[Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) -``` - -[Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) - -## Community Showcase - -Check out some of the projects and models that have been built using Axolotl! Have a model you'd like to add to our Community Showcase? Open a PR with your model. - -Open Access AI Collective -- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b) -- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b) -- [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat) - -PocketDoc Labs -- [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA) - -## Contributing 🀝 - -Please read the [contributing guide](./.github/CONTRIBUTING.md) - -Bugs? Please check the [open issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues/bug) else create a new Issue. - -PRs are **greatly welcome**! - -Please run below to setup env -```bash -pip3 install -r requirements-dev.txt -r requirements-tests.txt -pre-commit install - -# test -pytest tests/ -``` diff --git a/axolotl/TODO.md b/axolotl/TODO.md deleted file mode 100644 index 2002bbbaf1b8781e9f632856c032702151895bb2..0000000000000000000000000000000000000000 --- a/axolotl/TODO.md +++ /dev/null @@ -1,10 +0,0 @@ -# todo list - -- [] Validation of parameters for combinations that won't work - - - -## things that are known not to work - -- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203 -- adamw_bnb_8bit doesn't play well with FSDP offload diff --git a/axolotl/deepspeed/zero1.json b/axolotl/deepspeed/zero1.json deleted file mode 100644 index c76a20637f3498986f38bcdaf7aaeefac1245e67..0000000000000000000000000000000000000000 --- a/axolotl/deepspeed/zero1.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "zero_optimization": { - "stage": 1, - "overlap_comm": true - }, - "bf16": { - "enabled": "auto" - }, - "fp16": { - "enabled": "auto", - "auto_cast": false, - "loss_scale": 0, - "initial_scale_power": 32, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, - "gradient_accumulation_steps": "auto", - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false -} diff --git a/axolotl/deepspeed/zero2.json b/axolotl/deepspeed/zero2.json deleted file mode 100644 index 3f3baa3ece9e5f134c1a826ef387bba3abeb5717..0000000000000000000000000000000000000000 --- a/axolotl/deepspeed/zero2.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "zero_optimization": { - "stage": 2, - "offload_optimizer": { - "device": "cpu" - }, - "contiguous_gradients": true, - "overlap_comm": true - }, - "bf16": { - "enabled": "auto" - }, - "fp16": { - "enabled": "auto", - "auto_cast": false, - "loss_scale": 0, - "initial_scale_power": 32, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, - "gradient_accumulation_steps": "auto", - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false -} diff --git a/axolotl/deepspeed/zero3.json b/axolotl/deepspeed/zero3.json deleted file mode 100644 index cf64e83aca49ce5c7610c49e8563ddbd7f7b4082..0000000000000000000000000000000000000000 --- a/axolotl/deepspeed/zero3.json +++ /dev/null @@ -1,39 +0,0 @@ -{ - "zero_optimization": { - "stage": 3, - "overlap_comm": true, - "contiguous_gradients": true, - "sub_group_size": 0, - "reduce_bucket_size": "auto", - "stage3_prefetch_bucket_size": "auto", - "stage3_param_persistence_threshold": "auto", - "stage3_max_live_parameters": 0, - "stage3_max_reuse_distance": 0, - "stage3_gather_16bit_weights_on_model_save": true - }, - "bf16": { - "enabled": "auto" - }, - "fp16": { - "enabled": "auto", - "auto_cast": false, - "loss_scale": 0, - "initial_scale_power": 32, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, - "gradient_accumulation_steps": "auto", - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false -} diff --git a/axolotl/deepspeed/zero3_bf16.json b/axolotl/deepspeed/zero3_bf16.json deleted file mode 100644 index 42d10b6bd88f5425b39a573b47c7c79671b4c948..0000000000000000000000000000000000000000 --- a/axolotl/deepspeed/zero3_bf16.json +++ /dev/null @@ -1,39 +0,0 @@ -{ - "zero_optimization": { - "stage": 3, - "overlap_comm": true, - "contiguous_gradients": true, - "sub_group_size": 0, - "reduce_bucket_size": "auto", - "stage3_prefetch_bucket_size": "auto", - "stage3_param_persistence_threshold": "auto", - "stage3_max_live_parameters": 0, - "stage3_max_reuse_distance": 0, - "stage3_gather_16bit_weights_on_model_save": true - }, - "bf16": { - "enabled": true - }, - "fp16": { - "enabled": "auto", - "auto_cast": false, - "loss_scale": 0, - "initial_scale_power": 32, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, - "gradient_accumulation_steps": "auto", - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false -} diff --git a/axolotl/docker-compose.yaml b/axolotl/docker-compose.yaml deleted file mode 100644 index a16be726cfaa8c85e78f5abccdd33faba59028ba..0000000000000000000000000000000000000000 --- a/axolotl/docker-compose.yaml +++ /dev/null @@ -1,25 +0,0 @@ -# version: '3.8' -services: - axolotl: - build: - context: . - dockerfile: ./docker/Dockerfile - volumes: - - .:/workspace/axolotl - - ~/.cache/huggingface/:/root/.cache/huggingface/ - # set environment variables - environment: - # Set environment variables - - GIT_AUTHOR_NAME=${GIT_AUTHOR_NAME} - - GIT_AUTHOR_EMAIL=${GIT_AUTHOR_EMAIL} - - GIT_COMMITTER_NAME=${GIT_COMMITTER_NAME} - - GIT_COMMITTER_EMAIL=${GIT_COMMITTER_EMAIL} - - WANDB_API_KEY=${WANDB_API_KEY} - deploy: - resources: - reservations: - devices: - - driver: nvidia - # count: 1 - capabilities: [gpu] - command: tail -f /dev/null diff --git a/axolotl/docker/Dockerfile b/axolotl/docker/Dockerfile deleted file mode 100644 index f8e05285629046bdb0fcdf0af7d2f5826db74f68..0000000000000000000000000000000000000000 --- a/axolotl/docker/Dockerfile +++ /dev/null @@ -1,36 +0,0 @@ -ARG BASE_TAG=main-base -FROM winglian/axolotl-base:$BASE_TAG - -ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" -ARG AXOLOTL_EXTRAS="" -ARG CUDA="118" -ENV BNB_CUDA_VERSION=$CUDA -ARG PYTORCH_VERSION="2.0.1" - -ENV PYTORCH_VERSION=$PYTORCH_VERSION - -RUN apt-get update && \ - apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev - -WORKDIR /workspace - -RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git - -WORKDIR /workspace/axolotl - -# If AXOLOTL_EXTRAS is set, append it in brackets -RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \ - else \ - pip install -e .[deepspeed,flash-attn]; \ - fi - -# So we can test the Docker image -RUN pip install pytest - -# fix so that git fetch/pull from remote works -RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ - git config --get remote.origin.fetch - -# helper for huggingface-login cli -RUN git config --global credential.helper store diff --git a/axolotl/docker/Dockerfile-base b/axolotl/docker/Dockerfile-base deleted file mode 100644 index ca43c96521d3983067a699e60b928bb1c1ddd0f6..0000000000000000000000000000000000000000 --- a/axolotl/docker/Dockerfile-base +++ /dev/null @@ -1,37 +0,0 @@ -ARG CUDA_VERSION="11.8.0" -ARG CUDNN_VERSION="8" -ARG UBUNTU_VERSION="22.04" -ARG MAX_JOBS=4 - -FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder - -ENV PATH="/root/miniconda3/bin:${PATH}" - -ARG PYTHON_VERSION="3.9" -ARG PYTORCH_VERSION="2.0.1" -ARG CUDA="118" -ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" - -ENV PYTHON_VERSION=$PYTHON_VERSION -ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST - -RUN apt-get update \ - && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \ - && wget \ - https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ - && mkdir /root/.conda \ - && bash Miniconda3-latest-Linux-x86_64.sh -b \ - && rm -f Miniconda3-latest-Linux-x86_64.sh \ - && conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}" - -ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}" - -WORKDIR /workspace - -RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ - python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} deepspeed-kernels --extra-index-url https://download.pytorch.org/whl/cu$CUDA - -RUN git lfs install --skip-repo && \ - pip3 install awscli && \ - # The base image ships with `pydantic==1.8.2` which is not working - pip3 install -U --no-cache-dir pydantic==1.10.10 diff --git a/axolotl/docker/Dockerfile-runpod b/axolotl/docker/Dockerfile-runpod deleted file mode 100644 index eaf470b5ea511fa161ee7925304546daf49ad6f3..0000000000000000000000000000000000000000 --- a/axolotl/docker/Dockerfile-runpod +++ /dev/null @@ -1,19 +0,0 @@ -ARG BASE_TAG=main -FROM winglian/axolotl:$BASE_TAG - -ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" -ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" -ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub" -ENV HF_HOME="/workspace/data/huggingface-cache/hub" - -COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh - -RUN apt install --yes --no-install-recommends openssh-server tmux && \ - mkdir -p ~/.ssh && \ - chmod 700 ~/.ssh && \ - printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ - chmod +x /workspace/axolotl/scripts/runpod-entrypoint.sh && \ - chmod +x /root/runpod-entrypoint.sh - -ENTRYPOINT ["/root/runpod-entrypoint.sh"] -CMD ["sleep", "infinity"] diff --git a/axolotl/docs/faq.md b/axolotl/docs/faq.md deleted file mode 100644 index 65423065384cab8375fd49ba5e0dfdf652ebeee9..0000000000000000000000000000000000000000 --- a/axolotl/docs/faq.md +++ /dev/null @@ -1,18 +0,0 @@ -# Axolotl FAQ's - - -> The trainer stopped and hasn't progressed in several minutes. - -Usually an issue with the GPU's communicating with each other. See the [NCCL doc](../docs/nccl.md) - -> Exitcode -9 - -This usually happens when you run out of system RAM. - -> Exitcode -7 while using deepspeed - -Try upgrading deepspeed w: `pip install -U deepspeed` - -> AttributeError: 'DummyOptim' object has no attribute 'step' - -You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli. diff --git a/axolotl/docs/multi-node.md b/axolotl/docs/multi-node.md deleted file mode 100644 index 6806159690c9bad8e24f4f33e88f05b3ac1880dc..0000000000000000000000000000000000000000 --- a/axolotl/docs/multi-node.md +++ /dev/null @@ -1,45 +0,0 @@ -# Multi Node - -You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below: - -~/.cache/huggingface/accelerate/default_config.yaml -```yaml -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: FSDP -downcast_bf16: 'no' -machine_rank: 0 # Set to 0 for the main machine, increment by one for other machines -main_process_ip: 10.0.0.4 # Set to main machine's IP -main_process_port: 5000 -main_training_function: main -mixed_precision: bf16 -num_machines: 2 # Change to the number of machines -num_processes: 4 # That's the total number of GPUs, (for example: if you have 2 machines with 4 GPU, put 8) -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false -``` - -Configure your model to use FSDP with for example: -```yaml -fsdp: - - full_shard - - auto_wrap -fsdp_config: - fsdp_offload_params: true - fsdp_state_dict_type: FULL_STATE_DICT - fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer -``` - -## Machine configuration - -On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility. - -You will also need to have the same configuration file for your model on each machine. - -On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines. - -All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine. diff --git a/axolotl/docs/multipack.md b/axolotl/docs/multipack.md deleted file mode 100644 index 2a55148b2a20bce50ad54c2a022791dc1285c3e5..0000000000000000000000000000000000000000 --- a/axolotl/docs/multipack.md +++ /dev/null @@ -1,51 +0,0 @@ -# Multipack - -4k context, bsz =4, -each character represents 256 tokens -X represents a padding token - -``` - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -[[ A A A A A A A A A A A ] - B B B B B B ] - C C C C C C C ] - D D D D ]] - -[[ E E E E E E E E ] - [ F F F F ] - [ G G G ] - [ H H H H ]] - -[[ I I I ] - [ J J J ] - [ K K K K K] - [ L L L ]] -``` - -after padding to longest input in each step -``` - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -[[ A A A A A A A A A A A ] - B B B B B B X X X X X X ] - C C C C C C C X X X X ] - D D D D X X X X X X X ]] - -[[ E E E E E E E E ] - [ F F F F X X X X ] - [ G G G X X X X X ] - [ H H H H X X X X ]] - -[[ I I I X X ] - [ J J J X X ] - [ K K K K K ] - [ L L L X X ]] -``` - -w packing ( note it's the same effective number of tokens per step, but a true bsz of 1) -``` - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -[[ A A A A A A A A A A A B B B B B - B C C C C C C C D D D D E E E E - E E E E F F F F F G G G H H H H - I I I J J J J K K K K K L L L X ]] -``` diff --git a/axolotl/docs/nccl.md b/axolotl/docs/nccl.md deleted file mode 100644 index 4a7ff5d5d6c1a9fee50a6d182728df3c851ee6c0..0000000000000000000000000000000000000000 --- a/axolotl/docs/nccl.md +++ /dev/null @@ -1,46 +0,0 @@ -# NCCL - -NVIDIA NCCL is a library to facilitate and optimize multi-GPU communication operations, such as broadcast, all-gather, reduce, all-reduce, etc. Broadly, NCCL configuration is highly environment-specific and is configured via several [environment variables](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html). A common NCCL-related problem occurs when a long-running operation times out causing the training process to abort: - -```text -Watchdog caught collective operation timeout: WorkNCCL(SeqNum=42, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1806948 milliseconds before timing out. -``` - -Often, this timeout will happen after 30 minutes (the default setting) and is accompanied by below-average power consumption with near 100% GPU utilization before the error is raised. Nvidia recommends [disabling PCI access control services (ACS)](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#pci-access-control-services-acs) as a possible solution if this is available to you. - -Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command: - -```shell -nvidia-smi nvlink --status -``` - -To force NCCL to use NVLink, simply set this in the environment: - -```shell -export NCCL_P2P_LEVEL=NVL -``` - -If NVLink is not available in your environment there are other options for ``NCCL_P2P_LEVEL`` in the table below: - -| NCCL_P2P_LEVEL | Description | -| -------------- | ----------- | -| PIX | P2P data transfers through no more than a single PCIe bridge. Faster data transfer rates vs to paths involving multiple bridges, but slower compared to direct GPU-to-GPU communication. | -| PXB | P2P data transfers through multiple PCIe bridges but not going through the PCIe Host Bridge; this path involves a complex routing process, potentially incurring a moderate level of latency. | -| PHB | P2P data transfers occur over the PCIe and through a PCIe Host Bridge, typically involving the CPU, which can facilitate direct memory access but might introduce additional latency compared to more direct paths (ex PIX, NVL) | - -To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example: - -```shell -./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3 -``` - -It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL: - -```shell -export NCCL_DEBUG=INFO -export NCCL_DEBUG_SUBSYS=ALL -export TORCH_DISTRIBUTED_DEBUG=INFO -export TORCHELASTIC_ERROR_FILE=/PATH/TO/torcherror.log -``` - -Finally, if you believe your training job needs more time you can increase the timeout past 30 minutes by setting the ``ddp_timeout`` value in the Axolotl configuration. See [PyTorch init_process_group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for documentation on this value. diff --git a/axolotl/examples/cerebras/btlm-ft.yml b/axolotl/examples/cerebras/btlm-ft.yml deleted file mode 100644 index d0975214b26d95610a7f0883d5f35f91517703ef..0000000000000000000000000000000000000000 --- a/axolotl/examples/cerebras/btlm-ft.yml +++ /dev/null @@ -1,89 +0,0 @@ -base_model: cerebras/btlm-3b-8k-base -model_type: AutoModelForCausalLM -tokenizer_type: GPT2Tokenizer -trust_remote_code: true -tokenizer_use_fast: true -tokenizer_legacy: true - -load_in_8bit: false -load_in_4bit: false -strict: false -push_dataset_to_hub: -hf_use_auth_token: true -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: last_prepared_run -val_set_size: 0.05 - -adapter: -lora_model_dir: -sequence_len: 2048 -max_packed_sequence_len: -sample_packing: false -sample_packing_eff_est: -sample_packing_seq_len_multiplier: -total_num_tokens: - -lora_r: -lora_alpha: -lora_dropout: -lora_target_modules: -lora_target_linear: -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -output_dir: btlm-out -gradient_accumulation_steps: 1 -micro_batch_size: 1 -num_epochs: 1 -optimizer: adamw_torch -adam_beta2: 0.95 -adam_eps: 0.000000001 -max_grad_norm: 1.0 - -torchdistx_path: -lr_scheduler: cosine -lr_quadratic_warmup: true -learning_rate: 0.000085 -train_on_inputs: true -group_by_length: false -bf16: true -fp16: false -tf32: true - -gradient_checkpointing: false -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 - -xformers_attention: -flash_attention: true -sdp_attention: -flash_optimum: - -gptq_groupsize: -gptq_model_v1: - -warmup_steps: 32 -evals_per_epoch: 4 -saves_per_epoch: 1 -save_total_limit: - -debug: -deepspeed: -weight_decay: 0.1 -special_tokens: - pad_token: "<|endoftext|>" -fsdp: -# - full_shard -# - auto_wrap -fsdp_config: -# fsdp_state_dict_type: FULL_STATE_DICT -# fsdp_transformer_layer_cls_to_wrap: BTLMBlock diff --git a/axolotl/examples/cerebras/qlora.yml b/axolotl/examples/cerebras/qlora.yml deleted file mode 100644 index 03155c6c234352b916af4f8d8d9f59bcbd1ef752..0000000000000000000000000000000000000000 --- a/axolotl/examples/cerebras/qlora.yml +++ /dev/null @@ -1,60 +0,0 @@ -base_model: cerebras/Cerebras-GPT-1.3B -load_in_8bit: false -load_in_4bit: true -strict: false -push_dataset_to_hub: -datasets: - - path: teknium/GPT4-LLM-Cleaned - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -adapter: qlora -lora_model_dir: -sequence_len: 2048 -max_packed_sequence_len: 2048 -lora_r: 16 -lora_alpha: 32 -lora_dropout: 0.05 -lora_target_modules: - - c_fc - - c_attn - - c_proj -lora_target_linear: -lora_fan_in_fan_out: -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./qlora-out -batch_size: 4 -micro_batch_size: 4 -num_epochs: 2 -optimizer: paged_adamw_8bit -torchdistx_path: -lr_scheduler: cosine -learning_rate: 0.0002 -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: true -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: true -flash_attention: -gptq_groupsize: -gptq_model_v1: -warmup_steps: 10 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.1 -fsdp: -fsdp_config: -special_tokens: - pad_token: "<|endoftext|>" diff --git a/axolotl/examples/code-llama/13b/lora.yml b/axolotl/examples/code-llama/13b/lora.yml deleted file mode 100644 index fc43ad14e2647acbe5da776d4fe0f0be4b545f61..0000000000000000000000000000000000000000 --- a/axolotl/examples/code-llama/13b/lora.yml +++ /dev/null @@ -1,67 +0,0 @@ -base_model: codellama/CodeLlama-13b-hf -model_type: LlamaForCausalLM -tokenizer_type: CodeLlamaTokenizer -is_llama_derived_model: true - -load_in_8bit: true -load_in_4bit: false -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./lora-out - -sequence_len: 4096 -sample_packing: true -pad_to_sequence_len: true - -adapter: lora -lora_model_dir: -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 2 -num_epochs: 4 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -warmup_steps: 10 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/code-llama/13b/qlora.yml b/axolotl/examples/code-llama/13b/qlora.yml deleted file mode 100644 index 06b9ac72f38b0465abbfcb84111b79bf788da811..0000000000000000000000000000000000000000 --- a/axolotl/examples/code-llama/13b/qlora.yml +++ /dev/null @@ -1,69 +0,0 @@ -base_model: codellama/CodeLlama-13b-hf -model_type: LlamaForCausalLM -tokenizer_type: CodeLlamaTokenizer -is_llama_derived_model: true - -load_in_8bit: false -load_in_4bit: true -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./qlora-out - -adapter: qlora -lora_model_dir: - -sequence_len: 4096 -sample_packing: true -pad_to_sequence_len: true - -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 2 -num_epochs: 4 -optimizer: paged_adamw_32bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -warmup_steps: 10 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/code-llama/34b/lora.yml b/axolotl/examples/code-llama/34b/lora.yml deleted file mode 100644 index c2f1d5ce150852bfcce57afbd614ef93fe9a3f73..0000000000000000000000000000000000000000 --- a/axolotl/examples/code-llama/34b/lora.yml +++ /dev/null @@ -1,67 +0,0 @@ -base_model: codellama/CodeLlama-34b-hf -model_type: LlamaForCausalLM -tokenizer_type: CodeLlamaTokenizer -is_llama_derived_model: true - -load_in_8bit: true -load_in_4bit: false -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./lora-out - -sequence_len: 4096 -sample_packing: true -pad_to_sequence_len: true - -adapter: lora -lora_model_dir: -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 2 -num_epochs: 4 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -warmup_steps: 10 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/code-llama/34b/qlora.yml b/axolotl/examples/code-llama/34b/qlora.yml deleted file mode 100644 index ad1e21675e9cab32f01e1b871e038ca74d1c7549..0000000000000000000000000000000000000000 --- a/axolotl/examples/code-llama/34b/qlora.yml +++ /dev/null @@ -1,69 +0,0 @@ -base_model: codellama/CodeLlama-34b-hf -model_type: LlamaForCausalLM -tokenizer_type: CodeLlamaTokenizer -is_llama_derived_model: true - -load_in_8bit: false -load_in_4bit: true -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./qlora-out - -adapter: qlora -lora_model_dir: - -sequence_len: 4096 -sample_packing: true -pad_to_sequence_len: true - -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 2 -num_epochs: 4 -optimizer: paged_adamw_32bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -warmup_steps: 10 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/code-llama/7b/lora.yml b/axolotl/examples/code-llama/7b/lora.yml deleted file mode 100644 index 630c8da6fcc9b8ccaad54949997805c262518338..0000000000000000000000000000000000000000 --- a/axolotl/examples/code-llama/7b/lora.yml +++ /dev/null @@ -1,67 +0,0 @@ -base_model: codellama/CodeLlama-7b-hf -model_type: LlamaForCausalLM -tokenizer_type: CodeLlamaTokenizer -is_llama_derived_model: true - -load_in_8bit: true -load_in_4bit: false -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./lora-out - -sequence_len: 4096 -sample_packing: true -pad_to_sequence_len: true - -adapter: lora -lora_model_dir: -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 2 -num_epochs: 4 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -warmup_steps: 10 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/code-llama/7b/qlora.yml b/axolotl/examples/code-llama/7b/qlora.yml deleted file mode 100644 index 12462dcb76fcd8d6211906bf2c613a132dba6bd0..0000000000000000000000000000000000000000 --- a/axolotl/examples/code-llama/7b/qlora.yml +++ /dev/null @@ -1,69 +0,0 @@ -base_model: codellama/CodeLlama-7b-hf -model_type: LlamaForCausalLM -tokenizer_type: CodeLlamaTokenizer -is_llama_derived_model: true - -load_in_8bit: false -load_in_4bit: true -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./qlora-out - -adapter: qlora -lora_model_dir: - -sequence_len: 4096 -sample_packing: true -pad_to_sequence_len: true - -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 2 -num_epochs: 4 -optimizer: paged_adamw_32bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -warmup_steps: 10 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/code-llama/README.md b/axolotl/examples/code-llama/README.md deleted file mode 100644 index a5011e347283ab67b2f276de85e7fba4344bf09f..0000000000000000000000000000000000000000 --- a/axolotl/examples/code-llama/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# Overview - -This is an example of CodeLLaMA configuration for 7b, 13b and 34b. - -The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes. - -The 13b variant will fit if you change these settings to these values: -gradient_accumulation_steps: 2 -micro_batch_size: 1 - -The 34b variant does not fit on 24GB of VRAM - you will need something with +40 gb VRAM that also supports flash attention v2 - A6000 or A100 are good choices. - -```shell -accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/qlora.yml - -``` -or - -```shell -accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/lora.yml - -``` diff --git a/axolotl/examples/falcon/config-7b-lora.yml b/axolotl/examples/falcon/config-7b-lora.yml deleted file mode 100644 index 13bad9425a16ea28e64ca020349983a6da3b84d2..0000000000000000000000000000000000000000 --- a/axolotl/examples/falcon/config-7b-lora.yml +++ /dev/null @@ -1,64 +0,0 @@ -base_model: tiiuae/falcon-7b -trust_remote_code: true -model_type: AutoModelForCausalLM -tokenizer_type: AutoTokenizer -is_falcon_derived_model: true -load_in_8bit: true -load_in_4bit: false -gptq: false -strict: false -push_dataset_to_hub: -datasets: - - path: teknium/GPT4-LLM-Cleaned - type: alpaca:chat -dataset_prepared_path: -val_set_size: 0.05 -adapter: lora -lora_model_dir: -sequence_len: 2048 -max_packed_sequence_len: -lora_r: 16 -lora_alpha: 32 -lora_dropout: 0.0 -lora_target_modules: -lora_target_linear: true -lora_fan_in_fan_out: -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./falcon-7b -batch_size: 2 -micro_batch_size: 1 -num_epochs: 4 -optimizer: adamw_bnb_8bit -torchdistx_path: -lr_scheduler: cosine -learning_rate: 0.00003 -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: true -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: true -flash_attention: -gptq_groupsize: -gptq_model_v1: -warmup_steps: 40 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - pad_token: "<|endoftext|>" - bos_token: ">>ABSTRACT<<" - eos_token: "<|endoftext|>" diff --git a/axolotl/examples/falcon/config-7b-qlora.yml b/axolotl/examples/falcon/config-7b-qlora.yml deleted file mode 100644 index a89124bb8c3e132de629ed61c60a8d54a083f537..0000000000000000000000000000000000000000 --- a/axolotl/examples/falcon/config-7b-qlora.yml +++ /dev/null @@ -1,93 +0,0 @@ -# 1b: tiiuae/falcon-rw-1b -# 40b: tiiuae/falcon-40b -base_model: tiiuae/falcon-7b -# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main -trust_remote_code: true -model_type: AutoModelForCausalLM -tokenizer_type: AutoTokenizer -is_falcon_derived_model: true -load_in_8bit: false -# enable 4bit for QLoRA -load_in_4bit: true -gptq: false -strict: false -push_dataset_to_hub: -datasets: - - path: QingyiSi/Alpaca-CoT - data_files: - - Chain-of-Thought/formatted_cot_data/gsm8k_train.json - type: "alpaca:chat" -dataset_prepared_path: -val_set_size: 0.05 -# enable QLoRA -adapter: qlora -lora_model_dir: -sequence_len: 2048 -max_packed_sequence_len: - -# hyperparameters from QLoRA paper Appendix B.2 -# "We find hyperparameters to be largely robust across datasets" -lora_r: 64 -lora_alpha: 16 -# 0.1 for models up to 13B -# 0.05 for 33B and 65B models -lora_dropout: 0.05 -# add LoRA modules on all linear layers of the base model -lora_target_modules: -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./qlora-out - -# QLoRA paper Table 9 -# - 16 for 7b & 13b -# - 32 for 33b, 64 for 64b -# Max size tested on A6000 -# - 7b: 40 -# - 40b: 4 -# decrease if OOM, increase for max VRAM utilization -micro_batch_size: 1 -gradient_accumulation_steps: 2 -num_epochs: 4 -# Optimizer for QLoRA -optimizer: paged_adamw_32bit -torchdistx_path: -lr_scheduler: cosine -# QLoRA paper Table 9 -# - 2e-4 for 7b & 13b -# - 1e-4 for 33b & 64b -learning_rate: 0.0002 -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: true -gradient_checkpointing: true -# stop training after this many evaluation losses have increased in a row -# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback -early_stopping_patience: 3 -resume_from_checkpoint: -auto_resume_from_checkpoints: true -local_rank: -logging_steps: 1 -xformers_attention: true -flash_attention: -gptq_groupsize: -gptq_model_v1: -warmup_steps: 10 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.000001 -fsdp: -fsdp_config: -special_tokens: - pad_token: "<|endoftext|>" - bos_token: ">>ABSTRACT<<" - eos_token: "<|endoftext|>" diff --git a/axolotl/examples/falcon/config-7b.yml b/axolotl/examples/falcon/config-7b.yml deleted file mode 100644 index ff37dcf8516a6f7b01857aff19966a87e0140baa..0000000000000000000000000000000000000000 --- a/axolotl/examples/falcon/config-7b.yml +++ /dev/null @@ -1,64 +0,0 @@ -base_model: tiiuae/falcon-7b -trust_remote_code: true -model_type: AutoModelForCausalLM -tokenizer_type: AutoTokenizer -is_falcon_derived_model: true -load_in_8bit: false -load_in_4bit: false -gptq: false -strict: false -push_dataset_to_hub: -datasets: - - path: teknium/GPT4-LLM-Cleaned - type: alpaca:chat -dataset_prepared_path: -val_set_size: 0.05 -adapter: -lora_model_dir: -sequence_len: 2048 -max_packed_sequence_len: -lora_r: 64 -lora_alpha: 32 -lora_dropout: 0.0 -lora_target_modules: -lora_target_linear: true -lora_fan_in_fan_out: -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./falcon-7b -batch_size: 2 -micro_batch_size: 1 -num_epochs: 4 -optimizer: adamw_bnb_8bit -torchdistx_path: -lr_scheduler: cosine -learning_rate: 0.00003 -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: true -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: true -flash_attention: -gptq_groupsize: -gptq_model_v1: -warmup_steps: 40 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - pad_token: "<|endoftext|>" - bos_token: ">>ABSTRACT<<" - eos_token: "<|endoftext|>" diff --git a/axolotl/examples/gptj/qlora.yml b/axolotl/examples/gptj/qlora.yml deleted file mode 100644 index 700d10e67ed2016207af09eda6b00b2ea97beccb..0000000000000000000000000000000000000000 --- a/axolotl/examples/gptj/qlora.yml +++ /dev/null @@ -1,57 +0,0 @@ -base_model: EleutherAI/gpt-j-6b -load_in_8bit: false -load_in_4bit: true -strict: false -push_dataset_to_hub: -datasets: - - path: teknium/GPT4-LLM-Cleaned - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -adapter: qlora -lora_model_dir: -sequence_len: 2048 -max_packed_sequence_len: -lora_r: 8 -lora_alpha: 32 -lora_dropout: 0.05 -lora_target_modules: -lora_target_linear: true -lora_fan_in_fan_out: -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./qlora-out -gradient_accumulation_steps: 2 -micro_batch_size: 2 -num_epochs: 2 -optimizer: paged_adamw_8bit -torchdistx_path: -lr_scheduler: cosine -learning_rate: 0.0001 -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: true -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: true -flash_attention: -gptq_groupsize: -gptq_model_v1: -warmup_steps: 10 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.1 -fsdp: -fsdp_config: -special_tokens: - pad_token: "<|endoftext|>" diff --git a/axolotl/examples/jeopardy-bot/config.yml b/axolotl/examples/jeopardy-bot/config.yml deleted file mode 100644 index ac8814b0bea1e9731e3bb488f1fd3b011b3c909d..0000000000000000000000000000000000000000 --- a/axolotl/examples/jeopardy-bot/config.yml +++ /dev/null @@ -1,55 +0,0 @@ -base_model: huggyllama/llama-7b -model_type: LlamaForCausalLM -tokenizer_type: LlamaTokenizer -load_in_8bit: false -datasets: - - path: openaccess-ai-collective/jeopardy - type: jeopardy -dataset_prepared_path: -val_set_size: 0.02 -adapter: -lora_model_dir: -sequence_len: 512 -max_packed_sequence_len: -lora_r: -lora_alpha: -lora_dropout: -lora_target_modules: -lora_fan_in_fan_out: false -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./jeopardy-bot-7b -gradient_accumulation_steps: 1 -micro_batch_size: 1 -num_epochs: 4 -optimizer: adamw_bnb_8bit -torchdistx_path: -lr_scheduler: cosine -learning_rate: 0.00003 -train_on_inputs: false -group_by_length: false -bf16: true -tf32: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 5 -xformers_attention: true -flash_attention: -gptq_groupsize: -gptq_model_v1: -warmup_steps: 20 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.1 -fsdp: -fsdp_config: -tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/llama-2/README.md b/axolotl/examples/llama-2/README.md deleted file mode 100644 index 2ddd711e29a34b72f34efb5745e241ca8e015675..0000000000000000000000000000000000000000 --- a/axolotl/examples/llama-2/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# Overview - -This is an example of a llama-2 configuration for 7b and 13b. The yaml file contains configuration for the 7b variant, but you can just aswell use the same settings for 13b. - -The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes. - -The 13b variant will fit if you change these settings to these values: -gradient_accumulation_steps: 2 -micro_batch_size: 1 - -```shell -accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml -``` -or - -```shell -accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml -``` - -To launch a full finetuning with 16-bit precision: - -```shell -accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml -``` diff --git a/axolotl/examples/llama-2/fft_optimized.yml b/axolotl/examples/llama-2/fft_optimized.yml deleted file mode 100644 index 5530283bfa40e70f49fbe110906eea1c9b4e0d79..0000000000000000000000000000000000000000 --- a/axolotl/examples/llama-2/fft_optimized.yml +++ /dev/null @@ -1,72 +0,0 @@ -base_model: NousResearch/Llama-2-7b-hf -model_type: LlamaForCausalLM -tokenizer_type: LlamaTokenizer -is_llama_derived_model: true - -load_in_8bit: false -load_in_4bit: false -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: last_run_prepared -val_set_size: 0.05 -output_dir: ./out - -sequence_len: 4096 -sample_packing: true -pad_to_sequence_len: true - -adapter: -lora_model_dir: -lora_r: -lora_alpha: -lora_dropout: -lora_target_linear: -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 1 -micro_batch_size: 1 -num_epochs: 1 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true -flash_attn_cross_entropy: false -flash_attn_rms_norm: true -flash_attn_fuse_qkv: false -flash_attn_fuse_mlp: true - -warmup_steps: 100 -evals_per_epoch: 4 -eval_table_size: -saves_per_epoch: 1 -debug: -deepspeed: #deepspeed/zero2.json # multi-gpu only -weight_decay: 0.1 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/llama-2/gptq-lora.yml b/axolotl/examples/llama-2/gptq-lora.yml deleted file mode 100644 index a3235c1fba437fe3c485da4715b4d11d8e617f34..0000000000000000000000000000000000000000 --- a/axolotl/examples/llama-2/gptq-lora.yml +++ /dev/null @@ -1,73 +0,0 @@ -base_model: TheBloke/Llama-2-7B-GPTQ -is_llama_derived_model: false -gptq: true -gptq_disable_exllama: true -model_type: AutoModelForCausalLM -tokenizer_type: LlamaTokenizer -tokenizer_use_fast: true -tokenizer_legacy: true -load_in_8bit: false -load_in_4bit: false -strict: false -push_dataset_to_hub: -hf_use_auth_token: true -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -adapter: lora -lora_model_dir: -sequence_len: 4096 -sample_packing: -lora_r: 8 -lora_alpha: 32 -lora_dropout: 0.05 -lora_target_modules: - - k_proj - - o_proj - - q_proj - - v_proj -lora_target_linear: -lora_fan_in_fan_out: -wandb_project: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./model-out -gradient_accumulation_steps: 1 -micro_batch_size: 1 -num_epochs: 4 -optimizer: adamw_torch -adam_beta2: 0.95 -adam_eps: 0.00001 -max_grad_norm: 1.0 -torchdistx_path: -lr_scheduler: cosine -lr_quadratic_warmup: true -learning_rate: 0.000017 -train_on_inputs: false -group_by_length: false -bf16: false -fp16: false -float16: true -tf32: true -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: -sdp_attention: -flash_optimum: -warmup_steps: 100 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.1 -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/llama-2/lora.yml b/axolotl/examples/llama-2/lora.yml deleted file mode 100644 index afb7dcd06f8c38cce4853e04f5a05f1de983c7d4..0000000000000000000000000000000000000000 --- a/axolotl/examples/llama-2/lora.yml +++ /dev/null @@ -1,69 +0,0 @@ -base_model: NousResearch/Llama-2-7b-hf -model_type: LlamaForCausalLM -tokenizer_type: LlamaTokenizer -is_llama_derived_model: true - -load_in_8bit: true -load_in_4bit: false -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./lora-out - -sequence_len: 4096 -sample_packing: true -pad_to_sequence_len: true - -adapter: lora -lora_model_dir: -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 2 -num_epochs: 4 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -warmup_steps: 10 -evals_per_epoch: 4 -eval_table_size: -eval_table_max_new_tokens: 128 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/llama-2/qlora.yml b/axolotl/examples/llama-2/qlora.yml deleted file mode 100644 index d68882d6a5d790db040c92043669bb6ac8169f1a..0000000000000000000000000000000000000000 --- a/axolotl/examples/llama-2/qlora.yml +++ /dev/null @@ -1,70 +0,0 @@ -base_model: NousResearch/Llama-2-7b-hf -model_type: LlamaForCausalLM -tokenizer_type: LlamaTokenizer -is_llama_derived_model: true - -load_in_8bit: false -load_in_4bit: true -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./qlora-out - -adapter: qlora -lora_model_dir: - -sequence_len: 4096 -sample_packing: true -pad_to_sequence_len: true - -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 2 -num_epochs: 4 -optimizer: paged_adamw_32bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -warmup_steps: 10 -evals_per_epoch: 4 -eval_table_size: -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/llama-2/relora.yml b/axolotl/examples/llama-2/relora.yml deleted file mode 100644 index ff76ddbea778bcdeadbeff773a3bdfd4fdc63a97..0000000000000000000000000000000000000000 --- a/axolotl/examples/llama-2/relora.yml +++ /dev/null @@ -1,73 +0,0 @@ -base_model: NousResearch/Llama-2-7b-hf -model_type: LlamaForCausalLM -tokenizer_type: LlamaTokenizer -is_llama_derived_model: true - -load_in_8bit: false -load_in_4bit: true -strict: false - -datasets: - - path: teknium/GPT4-LLM-Cleaned - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./relora-out - -adapter: qlora -lora_model_dir: - -sequence_len: 4096 -sample_packing: true -pad_to_sequence_len: true - -lora_r: 8 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: -lora_target_linear: true -lora_fan_in_fan_out: - -relora_steps: 150 -relora_warmup_steps: 10 -relora_cpu_offload: false - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 4 -num_epochs: 4 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -warmup_steps: 10 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/llama-2/tiny-llama.yml b/axolotl/examples/llama-2/tiny-llama.yml deleted file mode 100644 index c72db4e5b2c58f0d3269bd68bb4a5e247c28b5d9..0000000000000000000000000000000000000000 --- a/axolotl/examples/llama-2/tiny-llama.yml +++ /dev/null @@ -1,68 +0,0 @@ -base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T - -model_type: LlamaForCausalLM -tokenizer_type: LlamaTokenizer -is_llama_derived_model: true - -load_in_8bit: true -load_in_4bit: false -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./lora-out - -sequence_len: 4096 -sample_packing: true - -adapter: lora -lora_model_dir: -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 2 -num_epochs: 4 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -warmup_steps: 10 -evals_per_epoch: 4 -eval_table_size: -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/mamba/config.yml b/axolotl/examples/mamba/config.yml deleted file mode 100644 index 946bbe731f0b9c3dce868ee9f16f460654cf2227..0000000000000000000000000000000000000000 --- a/axolotl/examples/mamba/config.yml +++ /dev/null @@ -1,61 +0,0 @@ -base_model: state-spaces/mamba-2.8b -model_type: MambaLMHeadModel -tokenizer_type: AutoTokenizer -tokenizer_config: EleutherAI/gpt-neox-20b - -load_in_8bit: false -load_in_4bit: false -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.0 -output_dir: ./out - -sequence_len: 2048 -sample_packing: false -pad_to_sequence_len: false - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 1 -num_epochs: 2 -optimizer: paged_adamw_8bit -lr_scheduler: cosine -learning_rate: 5e-5 - -train_on_inputs: false -group_by_length: true - -bf16: true -fp16: false -tf32: true - -gradient_checkpointing: false -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: - -warmup_steps: 10 -evals_per_epoch: 4 -eval_table_size: -eval_table_max_new_tokens: 128 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: -tokens: -save_safetensors: False diff --git a/axolotl/examples/mistral/README.md b/axolotl/examples/mistral/README.md deleted file mode 100644 index d1efb2cab374afad2fb478d9ce8c21f8281fd847..0000000000000000000000000000000000000000 --- a/axolotl/examples/mistral/README.md +++ /dev/null @@ -1,12 +0,0 @@ -**Mistral 7B** is a language model with a total of 7.3 billion parameters, showcasing a notable performance across a variety of benchmarks. - -Fine Tune: -```shell -accelerate launch -m axolotl.cli.train examples/mistral/config.yml - -``` - -If you run into CUDA OOM, use deepspeed with config zero2.json: -```shell -accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json -``` diff --git a/axolotl/examples/mistral/config.yml b/axolotl/examples/mistral/config.yml deleted file mode 100644 index ea62e9ebfec8eddebb1004ce628b7d1b7831d820..0000000000000000000000000000000000000000 --- a/axolotl/examples/mistral/config.yml +++ /dev/null @@ -1,62 +0,0 @@ -base_model: mistralai/Mistral-7B-v0.1 -model_type: MistralForCausalLM -tokenizer_type: LlamaTokenizer -is_mistral_derived_model: true - -load_in_8bit: false -load_in_4bit: false -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./out - -sequence_len: 8192 -sample_packing: true -pad_to_sequence_len: true -eval_sample_packing: false - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 2 -num_epochs: 4 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -learning_rate: 0.000005 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -warmup_steps: 10 -evals_per_epoch: 4 -eval_table_size: -eval_table_max_new_tokens: 128 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/mistral/mixtral.yml b/axolotl/examples/mistral/mixtral.yml deleted file mode 100644 index 11c842d4eee842cab1014ab2dbca5a466a35d2bf..0000000000000000000000000000000000000000 --- a/axolotl/examples/mistral/mixtral.yml +++ /dev/null @@ -1,91 +0,0 @@ -base_model: mistralai/Mixtral-8x7B-v0.1 -model_type: AutoModelForCausalLM -tokenizer_type: LlamaTokenizer -trust_remote_code: true - -load_in_8bit: false -load_in_4bit: true -strict: false - -datasets: - - path: tatsu-lab/alpaca - type: alpaca -dataset_prepared_path: last_run_prepared -val_set_size: 0.0 -output_dir: ./qlora-out - -## You can optionally freeze the entire model and unfreeze a subset of parameters -unfrozen_parameters: -# - lm_head.* -# - model.embed_tokens.* -# - model.layers.2[0-9]+.block_sparse_moe.gate.* -# - model.layers.2[0-9]+.block_sparse_moe.experts.* -# - model.layers.3[0-9]+.block_sparse_moe.gate.* -# - model.layers.3[0-9]+.block_sparse_moe.experts.* - -model_config: - output_router_logits: true - -adapter: qlora -lora_model_dir: - -sequence_len: 4096 -sample_packing: true -pad_to_sequence_len: true - -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_linear: true -lora_fan_in_fan_out: -#lora_target_modules: -# - gate -# - q_proj -# - k_proj -# - v_proj -# - o_proj -# - w1 -# - w2 -# - w3 - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 2 -micro_batch_size: 1 -num_epochs: 1 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -loss_watchdog_threshold: 5.0 -loss_watchdog_patience: 3 - -warmup_steps: 10 -evals_per_epoch: 4 -eval_table_size: -eval_table_max_new_tokens: 128 -saves_per_epoch: 1 -debug: -deepspeed: deepspeed/zero2.json -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: diff --git a/axolotl/examples/mistral/qlora.yml b/axolotl/examples/mistral/qlora.yml deleted file mode 100644 index 35c79ebf4ee21f11a78565bf326c036abdd1ae3a..0000000000000000000000000000000000000000 --- a/axolotl/examples/mistral/qlora.yml +++ /dev/null @@ -1,81 +0,0 @@ -base_model: mistralai/Mistral-7B-v0.1 -model_type: MistralForCausalLM -tokenizer_type: LlamaTokenizer -is_mistral_derived_model: true - -load_in_8bit: false -load_in_4bit: true -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: last_run_prepared -val_set_size: 0.1 -output_dir: ./qlora-out - -adapter: qlora -lora_model_dir: - -sequence_len: 8192 -sample_packing: true -pad_to_sequence_len: true - -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_linear: true -lora_fan_in_fan_out: -lora_target_modules: - - gate_proj - - down_proj - - up_proj - - q_proj - - v_proj - - k_proj - - o_proj - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 2 -num_epochs: 1 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -loss_watchdog_threshold: 5.0 -loss_watchdog_patience: 3 - -warmup_steps: 10 -evals_per_epoch: 4 -eval_table_size: -eval_table_max_new_tokens: 128 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/mpt-7b/README.md b/axolotl/examples/mpt-7b/README.md deleted file mode 100644 index 12c55688d2aefc559e8acde1a6fc354ce863fe19..0000000000000000000000000000000000000000 --- a/axolotl/examples/mpt-7b/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# MPT-7B - -```shell -accelerate launch scripts/finetune.py examples/mpt-7b/config.yml - -``` diff --git a/axolotl/examples/mpt-7b/config.yml b/axolotl/examples/mpt-7b/config.yml deleted file mode 100644 index bc36b1c60935de7674616c451dbb212c882d0142..0000000000000000000000000000000000000000 --- a/axolotl/examples/mpt-7b/config.yml +++ /dev/null @@ -1,58 +0,0 @@ -base_model: mosaicml/mpt-7b -tokenizer_type: AutoTokenizer -trust_remote_code: true # required for mpt as their model class is not merged into transformers yet -load_in_8bit: false -datasets: - - path: vicgalle/alpaca-gpt4 - type: alpaca -dataset_prepared_path: -val_set_size: 0.02 -adapter: -lora_model_dir: -sequence_len: 2048 -max_packed_sequence_len: -lora_r: 8 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: - - q_proj - - v_proj -lora_fan_in_fan_out: false -wandb_project: mpt-alpaca-7b -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./mpt-alpaca-7b -gradient_accumulation_steps: 1 -micro_batch_size: 1 -num_epochs: 4 -optimizer: adamw_bnb_8bit -torchdistx_path: -lr_scheduler: cosine -learning_rate: 0.0000002 -train_on_inputs: false -group_by_length: false -bf16: true -tf32: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 5 -xformers_attention: -flash_attention: -gptq_groupsize: -gptq_model_v1: -warmup_steps: 20 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0001 -fsdp: -fsdp_config: -tokens: - pad_token: "<|padding|>" - bos_token: "<|endoftext|>" - eos_token: "<|endoftext|>" - unk_token: "<|endoftext|>" diff --git a/axolotl/examples/openllama-3b/README.md b/axolotl/examples/openllama-3b/README.md deleted file mode 100644 index 3e9501a54f61577ab8f7301b6caebc3d4802000c..0000000000000000000000000000000000000000 --- a/axolotl/examples/openllama-3b/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# openllama-3b - -Basic full tune -```shell -accelerate launch scripts/finetune.py examples/openllama-3b/config.yml -``` - -LoRA -```shell -accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml -``` - -QLoRA -```shell -accelerate launch scripts/finetune.py examples/openllama-3b/qlora.yml -``` diff --git a/axolotl/examples/openllama-3b/config.yml b/axolotl/examples/openllama-3b/config.yml deleted file mode 100644 index 0a404c79d85114359412622dbc642117a5fab7f7..0000000000000000000000000000000000000000 --- a/axolotl/examples/openllama-3b/config.yml +++ /dev/null @@ -1,62 +0,0 @@ -base_model: openlm-research/open_llama_3b_v2 -model_type: LlamaForCausalLM -tokenizer_type: LlamaTokenizer -load_in_8bit: false -load_in_4bit: false -strict: false -push_dataset_to_hub: -datasets: - - path: teknium/GPT4-LLM-Cleaned - type: alpaca -dataset_prepared_path: -val_set_size: 0.02 -adapter: -lora_model_dir: -sequence_len: 1024 -sample_packing: true -lora_r: -lora_alpha: -lora_dropout: -lora_target_modules: -lora_target_linear: -lora_fan_in_fan_out: -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./openllama-out -gradient_accumulation_steps: 1 -micro_batch_size: 1 -num_epochs: 4 -optimizer: adamw_bnb_8bit -torchdistx_path: -lr_scheduler: cosine -learning_rate: 0.000003 -train_on_inputs: false -group_by_length: false -float16: true -bf16: false -fp16: false -tf32: false -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true -gptq_groupsize: -gptq_model_v1: -warmup_steps: 20 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.1 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/openllama-3b/lora.yml b/axolotl/examples/openllama-3b/lora.yml deleted file mode 100644 index 4fbb634f9226750eb4e2e8a4d0195ca41cfc0ecc..0000000000000000000000000000000000000000 --- a/axolotl/examples/openllama-3b/lora.yml +++ /dev/null @@ -1,67 +0,0 @@ -base_model: openlm-research/open_llama_3b_v2 -model_type: LlamaForCausalLM -tokenizer_type: LlamaTokenizer -load_in_8bit: true -load_in_4bit: false -strict: false -push_dataset_to_hub: -datasets: - - path: teknium/GPT4-LLM-Cleaned - type: alpaca -dataset_prepared_path: -val_set_size: 0.02 -adapter: lora -lora_model_dir: -sequence_len: 1024 -sample_packing: true -lora_r: 8 -lora_alpha: 16 -lora_dropout: 0.0 -lora_target_modules: - - gate_proj - - down_proj - - up_proj - - q_proj - - v_proj - - k_proj - - o_proj -lora_fan_in_fan_out: -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./lora-out -gradient_accumulation_steps: 1 -micro_batch_size: 2 -num_epochs: 4 -optimizer: adamw_bnb_8bit -torchdistx_path: -lr_scheduler: cosine -learning_rate: 0.0002 -train_on_inputs: false -group_by_length: false -bf16: false -fp16: true -tf32: false -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true -gptq_groupsize: -gptq_model_v1: -warmup_steps: 20 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.1 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/openllama-3b/qlora.yml b/axolotl/examples/openllama-3b/qlora.yml deleted file mode 100644 index 3d6218b308a3133b35a1507a04360f8ce094a204..0000000000000000000000000000000000000000 --- a/axolotl/examples/openllama-3b/qlora.yml +++ /dev/null @@ -1,61 +0,0 @@ -base_model: openlm-research/open_llama_3b_v2 -model_type: LlamaForCausalLM -tokenizer_type: LlamaTokenizer -load_in_8bit: false -load_in_4bit: true -strict: false -push_dataset_to_hub: -datasets: - - path: teknium/GPT4-LLM-Cleaned - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -adapter: qlora -lora_model_dir: -sequence_len: 1024 -sample_packing: true -lora_r: 8 -lora_alpha: 32 -lora_dropout: 0.05 -lora_target_modules: -lora_target_linear: true -lora_fan_in_fan_out: -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./qlora-out -gradient_accumulation_steps: 1 -micro_batch_size: 2 -num_epochs: 4 -optimizer: paged_adamw_32bit -torchdistx_path: -lr_scheduler: cosine -learning_rate: 0.0002 -train_on_inputs: false -group_by_length: false -bf16: false -fp16: true -tf32: false -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true -gptq_groupsize: -gptq_model_v1: -warmup_steps: 20 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.1 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/axolotl/examples/phi/README.md b/axolotl/examples/phi/README.md deleted file mode 100644 index 1109db0b5042c2097c48b0cd65e2cd70cccd26a3..0000000000000000000000000000000000000000 --- a/axolotl/examples/phi/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Phi - -Due to some nuances with the phi code, please use deepspeed when training phi for full finetune. - -```shell -accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed/zero1.json - -# OR - -python -m axolotl.cli.train examples/phi/phi-qlora.yml -``` diff --git a/axolotl/examples/phi/phi-ft.yml b/axolotl/examples/phi/phi-ft.yml deleted file mode 100644 index eaebd21efd122350745355ef7b6307e03f0a5db3..0000000000000000000000000000000000000000 --- a/axolotl/examples/phi/phi-ft.yml +++ /dev/null @@ -1,74 +0,0 @@ -base_model: microsoft/phi-1_5 -model_type: PhiForCausalLM -tokenizer_type: AutoTokenizer -is_llama_derived_model: false -trust_remote_code: true - -load_in_8bit: false -load_in_4bit: false -strict: false - -datasets: - - path: garage-bAInd/Open-Platypus - type: alpaca - -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./phi-sft-out - -sequence_len: 2048 -sample_packing: true -pad_to_sequence_len: - -adapter: -lora_model_dir: -lora_r: -lora_alpha: -lora_dropout: -lora_target_linear: -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 1 -micro_batch_size: 1 -num_epochs: 4 -optimizer: adamw_torch -adam_beta2: 0.95 -adam_epsilon: 0.00001 -max_grad_norm: 1.0 -lr_scheduler: cosine -learning_rate: 0.000003 - -train_on_inputs: false -group_by_length: true -bf16: true -fp16: false -tf32: true - -gradient_checkpointing: -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: - -warmup_steps: 100 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.1 -fsdp: -fsdp_config: -resize_token_embeddings_to_32x: true -special_tokens: - bos_token: "<|endoftext|>" - eos_token: "<|endoftext|>" - unk_token: "<|endoftext|>" - pad_token: "<|endoftext|>" diff --git a/axolotl/examples/phi/phi-qlora.yml b/axolotl/examples/phi/phi-qlora.yml deleted file mode 100644 index 691a83509bca9dcfcd560789e75a80f6e16df48d..0000000000000000000000000000000000000000 --- a/axolotl/examples/phi/phi-qlora.yml +++ /dev/null @@ -1,74 +0,0 @@ -base_model: microsoft/phi-1_5 -model_type: AutoModelForCausalLM -tokenizer_type: AutoTokenizer -is_llama_derived_model: false -trust_remote_code: true - -load_in_8bit: false -load_in_4bit: true -strict: false - -datasets: - - path: garage-bAInd/Open-Platypus - type: alpaca - -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./phi-sft-out - -sequence_len: 1024 -sample_packing: false # not CURRENTLY compatible with LoRAs -pad_to_sequence_len: - -adapter: qlora -lora_model_dir: -lora_r: 64 -lora_alpha: 32 -lora_dropout: 0.05 -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 1 -micro_batch_size: 1 -num_epochs: 4 -optimizer: adamw_torch -adam_beta2: 0.95 -adam_epsilon: 0.00001 -max_grad_norm: 1.0 -lr_scheduler: cosine -learning_rate: 0.000003 - -train_on_inputs: false -group_by_length: true -bf16: true -fp16: false -tf32: true - -gradient_checkpointing: -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: - -warmup_steps: 100 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.1 -fsdp: -fsdp_config: -resize_token_embeddings_to_32x: true -special_tokens: - bos_token: "<|endoftext|>" - eos_token: "<|endoftext|>" - unk_token: "<|endoftext|>" - pad_token: "<|endoftext|>" diff --git a/axolotl/examples/pythia-12b/README.md b/axolotl/examples/pythia-12b/README.md deleted file mode 100644 index 123ffa7109a439845c88039cdcb3968da0b26a16..0000000000000000000000000000000000000000 --- a/axolotl/examples/pythia-12b/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# Pythia 12B - -- Single-GPU A100 only (?) - -```shell -python scripts/finetune.py examples/pythia-12b/config.yml -``` - -⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️ diff --git a/axolotl/examples/pythia-12b/config.yml b/axolotl/examples/pythia-12b/config.yml deleted file mode 100644 index e44bba74512ae7be47acbe1fef3b47f3ad2e8af4..0000000000000000000000000000000000000000 --- a/axolotl/examples/pythia-12b/config.yml +++ /dev/null @@ -1,48 +0,0 @@ -base_model: EleutherAI/pythia-12b-deduped -base_model_ignore_patterns: pytorch* # prefer safetensors -model_type: GPTNeoXForCausalLM -tokenizer_type: AutoTokenizer -load_in_8bit: false -load_in_4bit: false -gptq: false -device_map: auto -datasets: - - path: vicgalle/alpaca-gpt4 - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -adapter: -lora_model_dir: -sequence_len: 2048 -max_packed_sequence_len: 2048 -lora_r: 64 -lora_alpha: 32 -lora_dropout: 0.0 -lora_target_modules: -lora_target_linear: true -lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./pythia-12b -gradient_accumulation_steps: 1 -micro_batch_size: 1 -num_epochs: 5 -learning_rate: 0.00003 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -train_on_inputs: false -group_by_length: false -bf16: false -fp16: false -float16: true -tf32: true -flash_optimum: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -gradient_checkpointing: true -fsdp: -fsdp_config: diff --git a/axolotl/examples/pythia/lora.yml b/axolotl/examples/pythia/lora.yml deleted file mode 100644 index 10c76c97359f12c86f00347be1e0bfb56b0a22f3..0000000000000000000000000000000000000000 --- a/axolotl/examples/pythia/lora.yml +++ /dev/null @@ -1,37 +0,0 @@ -base_model: EleutherAI/pythia-1.4b-deduped -load_in_8bit: true -datasets: - - path: teknium/GPT4-LLM-Cleaned - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -adapter: lora -lora_model_dir: -sequence_len: 512 -lora_r: 16 -lora_alpha: 32 -lora_dropout: 0.05 -lora_target_modules: - - query_key_value -lora_target_linear: -lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./lora-alpaca-pythia -gradient_accumulation_steps: 1 -micro_batch_size: 4 -num_epochs: 4 -learning_rate: 0.00001 -train_on_inputs: false -group_by_length: false -bf16: true -tf32: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -weight_decay: 0.1 -evals_per_epoch: 4 -logging_steps: 1 diff --git a/axolotl/examples/qwen/lora.yml b/axolotl/examples/qwen/lora.yml deleted file mode 100644 index 0ad9fc0f1d00d6250c45ae0a32570da6bad95fc9..0000000000000000000000000000000000000000 --- a/axolotl/examples/qwen/lora.yml +++ /dev/null @@ -1,68 +0,0 @@ -base_model: Qwen/Qwen-7B -model_type: AutoModelForCausalLM -tokenizer_type: AutoTokenizer - -is_qwen_derived_model: true -trust_remote_code: true - -load_in_8bit: true -load_in_4bit: false -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./lora-out - -sequence_len: 2048 # supports up to 8192 -sample_packing: false -pad_to_sequence_len: - -adapter: lora -lora_model_dir: -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 2 -num_epochs: 4 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: false -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: - -warmup_steps: 10 -evals_per_epoch: 4 -eval_table_size: -eval_table_max_new_tokens: 128 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: diff --git a/axolotl/examples/qwen/qlora.yml b/axolotl/examples/qwen/qlora.yml deleted file mode 100644 index 1ce0cbdc0f6f216a38bdcf4bb15d0e2e5ecd7f94..0000000000000000000000000000000000000000 --- a/axolotl/examples/qwen/qlora.yml +++ /dev/null @@ -1,68 +0,0 @@ -base_model: Qwen/Qwen-7B -model_type: AutoModelForCausalLM -tokenizer_type: AutoTokenizer - -is_qwen_derived_model: true -trust_remote_code: true - -load_in_8bit: false -load_in_4bit: true -strict: false - -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -output_dir: ./lora-out - -sequence_len: 2048 # supports up to 8192 -sample_packing: false -pad_to_sequence_len: - -adapter: qlora -lora_model_dir: -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: - -gradient_accumulation_steps: 4 -micro_batch_size: 2 -num_epochs: 4 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: false -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: - -warmup_steps: 10 -evals_per_epoch: 4 -eval_table_size: -eval_table_max_new_tokens: 128 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: diff --git a/axolotl/examples/redpajama/README.md b/axolotl/examples/redpajama/README.md deleted file mode 100644 index bbd9a93ecc1a548de071250edd14209c8bfedc8f..0000000000000000000000000000000000000000 --- a/axolotl/examples/redpajama/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# RedPajama 3B preview release - -```shell -accelerate launch scripts/finetune.py examples/redpajama/config-3b.yml - -``` diff --git a/axolotl/examples/redpajama/config-3b.yml b/axolotl/examples/redpajama/config-3b.yml deleted file mode 100644 index a369b6cef9783943c3dc0e02ea4d6b1157928216..0000000000000000000000000000000000000000 --- a/axolotl/examples/redpajama/config-3b.yml +++ /dev/null @@ -1,59 +0,0 @@ -base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1 -model_type: GPTNeoXForCausalLM -tokenizer_type: AutoTokenizer -trust_remote_code: -load_in_8bit: false -datasets: - - path: vicgalle/alpaca-gpt4 - type: alpaca -dataset_prepared_path: -val_set_size: 0.02 -adapter: -lora_model_dir: -sequence_len: 2048 -max_packed_sequence_len: -lora_r: 8 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: - - q_proj - - v_proj -lora_fan_in_fan_out: false -wandb_project: redpajama-alpaca-3b -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./redpajama-alpaca-3b -batch_size: 4 -micro_batch_size: 1 -num_epochs: 4 -optimizer: adamw_bnb_8bit -torchdistx_path: -lr_scheduler: cosine -learning_rate: 0.0000002 -train_on_inputs: false -group_by_length: false -bf16: true -tf32: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 5 -xformers_attention: -flash_attention: -gptq_groupsize: -gptq_model_v1: -warmup_steps: 20 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0001 -fsdp: -fsdp_config: -tokens: - pad_token: "<|padding|>" - bos_token: "<|endoftext|>" - eos_token: "<|endoftext|>" - unk_token: "<|endoftext|>" diff --git a/axolotl/examples/replit-3b/config-lora.yml b/axolotl/examples/replit-3b/config-lora.yml deleted file mode 100644 index 01314acc1c51e947b3dc160c5b6fbbe8fcbba3f9..0000000000000000000000000000000000000000 --- a/axolotl/examples/replit-3b/config-lora.yml +++ /dev/null @@ -1,55 +0,0 @@ -base_model: replit/replit-code-v1-3b -trust_remote_code: true -load_in_8bit: false -datasets: - - path: vicgalle/alpaca-gpt4 - type: alpaca -dataset_prepared_path: -val_set_size: 0.05 -adapter: lora -lora_model_dir: -sequence_len: 2048 -max_packed_sequence_len: -lora_r: 8 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: - - Wqkv - - mlp_up - - mlp_down -lora_fan_in_fan_out: -wandb_project: lora-replit -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./lora-replit -batch_size: 8 -micro_batch_size: 1 -num_epochs: 4 -optimizer: -torchdistx_path: -lr_scheduler: -learning_rate: 0.00001 -train_on_inputs: false -group_by_length: false -bf16: true -tf32: true -gradient_checkpointing: -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: -gptq_groupsize: -gptq_model_v1: -warmup_steps: 20 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0 -fsdp: -fsdp_config: -#special_tokens: diff --git a/axolotl/examples/xgen-7b/xgen-7b-8k-qlora.yml b/axolotl/examples/xgen-7b/xgen-7b-8k-qlora.yml deleted file mode 100644 index 48924e5f7ad38f50d52268fe4bbb6d6c14192907..0000000000000000000000000000000000000000 --- a/axolotl/examples/xgen-7b/xgen-7b-8k-qlora.yml +++ /dev/null @@ -1,90 +0,0 @@ -# An example finetuning Saleforce's XGen-7b model with 8k context using qlora -# on Tim Dettmer's Guanaco dataset. -base_model: Salesforce/xgen-7b-8k-base -trust_remote_code: true -model_type: AutoModelForCausalLM -tokenizer_type: AutoTokenizer -load_in_8bit: false -# enable 4bit for QLoRA -load_in_4bit: true -gptq: false -strict: false -push_dataset_to_hub: -datasets: - - path: timdettmers/openassistant-guanaco - data_files: - - openassistant_best_replies_train.jsonl - type: "completion" -dataset_prepared_path: -val_set_size: 0.05 -# enable QLoRA -adapter: qlora -lora_model_dir: -sequence_len: 8192 -max_packed_sequence_len: - -# hyperparameters from QLoRA paper Appendix B.2 -# "We find hyperparameters to be largely robust across datasets" -lora_r: 64 -lora_alpha: 16 -# 0.1 for models up to 13B -# 0.05 for 33B and 65B models -lora_dropout: 0.05 -# add LoRA modules on all linear layers of the base model -lora_target_modules: -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: -output_dir: ./qlora-out - -# QLoRA paper Table 9 -# - 16 for 7b & 13b -# - 32 for 33b, 64 for 64b -# Max size tested on A6000 -# - 7b: 40 -# - 40b: 4 -# decrease if OOM, increase for max VRAM utilization -micro_batch_size: 1 -gradient_accumulation_steps: 1 -num_epochs: 4 -# Optimizer for QLoRA -optimizer: paged_adamw_32bit -torchdistx_path: -lr_scheduler: cosine -# QLoRA paper Table 9 -# - 2e-4 for 7b & 13b -# - 1e-4 for 33b & 64b -learning_rate: 0.00002 -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false -gradient_checkpointing: true -# stop training after this many evaluation losses have increased in a row -# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback -early_stopping_patience: 3 -resume_from_checkpoint: -auto_resume_from_checkpoints: true -local_rank: -logging_steps: 1 -xformers_attention: true -flash_attention: -gptq_groupsize: -gptq_model_v1: -warmup_steps: 10 -evals_per_epoch: 4 -saves_per_epoch: 1 -debug: -deepspeed: -weight_decay: 0.0 -special_tokens: - eos_token: "<|endoftext|>" - bos_token: "<|endoftext|>" - unk_token: "<|endoftext|>" - pad_token: "<|endoftext|>" diff --git a/axolotl/examples/yi-34B-chat/README.md b/axolotl/examples/yi-34B-chat/README.md deleted file mode 100644 index 07078850fbb9454b270253765e631bfcd78cf988..0000000000000000000000000000000000000000 --- a/axolotl/examples/yi-34B-chat/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Overview - -This is an example of a Yi-34B-Chat configuration. It demonstrates that it is possible to finetune a 34B model on a GPU with 24GB of VRAM. - -Tested on an RTX 4090 with `python -m axolotl.cli.train examples/mistral/qlora.yml`, a single epoch of finetuning on the alpaca dataset using qlora runs in 47 mins, using 97% of available memory. diff --git a/axolotl/examples/yi-34B-chat/qlora.yml b/axolotl/examples/yi-34B-chat/qlora.yml deleted file mode 100644 index 0c1a4b788936767ef36b7b44d4fa3529f61398e2..0000000000000000000000000000000000000000 --- a/axolotl/examples/yi-34B-chat/qlora.yml +++ /dev/null @@ -1,76 +0,0 @@ -base_model: 01-ai/Yi-34B-Chat -model_type: LlamaForCausalLM -tokenizer_type: LlamaTokenizer -is_mistral_derived_model: false -is_llama_derived_model: true -load_in_8bit: false -load_in_4bit: true -strict: false -sequence_len: 1024 -bf16: true -fp16: false -tf32: false -flash_attention: true -special_tokens: - bos_token: "<|startoftext|>" - eos_token: "<|endoftext|>" - unk_token: "" - -# Data -datasets: - - path: mhenrichsen/alpaca_2k_test - type: alpaca -warmup_steps: 10 - -# Iterations -num_epochs: 1 - -# Evaluation -val_set_size: 0.1 -evals_per_epoch: 5 -eval_table_size: -eval_table_max_new_tokens: 128 -eval_sample_packing: false -eval_batch_size: 1 - -# LoRA -output_dir: ./qlora-out -adapter: qlora -lora_model_dir: -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_linear: true -lora_fan_in_fan_out: -lora_target_modules: - -# Sampling -sample_packing: false -pad_to_sequence_len: false - -# Batching -gradient_accumulation_steps: 4 -micro_batch_size: 1 -gradient_checkpointing: true - -# wandb -wandb_project: - -# Optimizer -optimizer: paged_adamw_8bit -lr_scheduler: cosine -learning_rate: 0.0002 - -# Misc -train_on_inputs: false -group_by_length: false -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -debug: -deepspeed: -weight_decay: 0 -fsdp: -fsdp_config: diff --git a/axolotl/image/axolotl-badge-web.png b/axolotl/image/axolotl-badge-web.png deleted file mode 100644 index 42217dca31ca758d822a97655350793fa4e4dc59..0000000000000000000000000000000000000000 Binary files a/axolotl/image/axolotl-badge-web.png and /dev/null differ diff --git a/axolotl/image/axolotl.png b/axolotl/image/axolotl.png deleted file mode 100644 index 21c27db852b0c6e8e00d28a2ebe666f3663c5784..0000000000000000000000000000000000000000 Binary files a/axolotl/image/axolotl.png and /dev/null differ diff --git a/axolotl/image/sticker_fixed.png b/axolotl/image/sticker_fixed.png deleted file mode 100644 index 0f5c67a1922120993bd350eb1213e2d331478b76..0000000000000000000000000000000000000000 Binary files a/axolotl/image/sticker_fixed.png and /dev/null differ diff --git a/axolotl/requirements-dev.txt b/axolotl/requirements-dev.txt deleted file mode 100644 index df7e312cb14411825603f85e6dfa624213e14a58..0000000000000000000000000000000000000000 --- a/axolotl/requirements-dev.txt +++ /dev/null @@ -1,3 +0,0 @@ -pre-commit -black -mypy diff --git a/axolotl/requirements-tests.txt b/axolotl/requirements-tests.txt deleted file mode 100644 index e079f8a6038dd2dc8512967540f96ee0de172067..0000000000000000000000000000000000000000 --- a/axolotl/requirements-tests.txt +++ /dev/null @@ -1 +0,0 @@ -pytest diff --git a/axolotl/requirements.txt b/axolotl/requirements.txt deleted file mode 100644 index c1c1cbc1324501c5559a0e1d9221ac9ac10c8a27..0000000000000000000000000000000000000000 --- a/axolotl/requirements.txt +++ /dev/null @@ -1,39 +0,0 @@ ---extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ -auto-gptq==0.5.1 -packaging -peft==0.6.0 -transformers==4.36.2 -tokenizers==0.15.0 -bitsandbytes>=0.41.1 -accelerate==0.24.1 -deepspeed -addict -fire -PyYAML>=6.0 -datasets>=2.15.0 -flash-attn==2.3.3 -sentencepiece -wandb -einops -xformers==0.0.22 -optimum==1.13.2 -hf_transfer -colorama -numba -numpy>=1.24.4 -# qlora things -bert-score==0.3.13 -evaluate==0.4.0 -rouge-score==0.1.2 -scipy -scikit-learn==1.2.2 -pynvml -art -fschat==0.2.34 -gradio==3.50.2 -tensorboard - -# remote filesystems -s3fs -gcsfs -# adlfs diff --git a/axolotl/scripts/finetune.py b/axolotl/scripts/finetune.py deleted file mode 100644 index d5bbcaf8f019b3b86733a8d70cdd6cff24b5d3a7..0000000000000000000000000000000000000000 --- a/axolotl/scripts/finetune.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" -import logging -from pathlib import Path - -import fire -import transformers - -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - do_inference, - do_merge_lora, - load_cfg, - load_datasets, - print_axolotl_text_art, -) -from axolotl.cli.shard import shard -from axolotl.common.cli import TrainerCliArgs -from axolotl.train import train - -LOG = logging.getLogger("axolotl.scripts.finetune") - - -def do_cli(config: Path = Path("examples/"), **kwargs): - print_axolotl_text_art() - LOG.warning( - str( - PendingDeprecationWarning( - "scripts/finetune.py will be replaced with calling axolotl.cli.train" - ) - ) - ) - parsed_cfg = load_cfg(config, **kwargs) - check_accelerate_default_config() - check_user_token() - parser = transformers.HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - if parsed_cli_args.inference: - do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) - elif parsed_cli_args.merge_lora: - do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) - elif parsed_cli_args.shard: - shard(cfg=parsed_cfg, cli_args=parsed_cli_args) - else: - dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) - train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) - - -if __name__ == "__main__": - fire.Fire(do_cli) diff --git a/axolotl/scripts/runpod-entrypoint.sh b/axolotl/scripts/runpod-entrypoint.sh deleted file mode 100755 index b7cef80b4099b3fe678085fa08161a756b1f616c..0000000000000000000000000000000000000000 --- a/axolotl/scripts/runpod-entrypoint.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -# Export specific ENV variables to /etc/rp_environment -echo "Exporting environment variables..." -printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment -echo 'source /etc/rp_environment' >> ~/.bashrc - -if [[ $PUBLIC_KEY ]] -then - mkdir -p ~/.ssh - chmod 700 ~/.ssh - echo $PUBLIC_KEY >> ~/.ssh/authorized_keys - chmod 700 -R ~/.ssh - # Start the SSH service in the background - service ssh start -else - echo "No PUBLIC_KEY ENV variable provided, not starting openSSH daemon" -fi - -# Execute the passed arguments (CMD) -exec "$@" diff --git a/axolotl/setup.py b/axolotl/setup.py deleted file mode 100644 index fe4d2cfad8e053e6292df1080fd61681f7b162e0..0000000000000000000000000000000000000000 --- a/axolotl/setup.py +++ /dev/null @@ -1,61 +0,0 @@ -"""setup.py for axolotl""" - -from importlib.metadata import PackageNotFoundError, version - -from setuptools import find_packages, setup - - -def parse_requirements(): - _install_requires = [] - _dependency_links = [] - with open("./requirements.txt", encoding="utf-8") as requirements_file: - lines = [r.strip() for r in requirements_file.readlines()] - for line in lines: - if line.startswith("--extra-index-url"): - # Handle custom index URLs - _, url = line.split() - _dependency_links.append(url) - elif ( - "flash-attn" not in line - and "deepspeed" not in line - and line - and line[0] != "#" - ): - # Handle standard packages - _install_requires.append(line) - - try: - torch_version = version("torch") - if torch_version.startswith("2.1.1"): - _install_requires.pop(_install_requires.index("xformers==0.0.22")) - _install_requires.append("xformers==0.0.23") - except PackageNotFoundError: - pass - - return _install_requires, _dependency_links - - -install_requires, dependency_links = parse_requirements() - - -setup( - name="axolotl", - version="0.3.0", - description="LLM Trainer", - long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.", - package_dir={"": "src"}, - packages=find_packages(), - install_requires=install_requires, - dependency_links=dependency_links, - extras_require={ - "flash-attn": [ - "flash-attn==2.3.3", - ], - "deepspeed": [ - "deepspeed", - ], - "mamba-ssm": [ - "mamba-ssm==1.0.1", - ], - }, -) diff --git a/axolotl/src/axolotl/__init__.py b/axolotl/src/axolotl/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/axolotl/src/axolotl/cli/__init__.py b/axolotl/src/axolotl/cli/__init__.py deleted file mode 100644 index e6537ad05247382db424d86b305d05b7835f5160..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/cli/__init__.py +++ /dev/null @@ -1,345 +0,0 @@ -"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" - -import importlib -import logging -import os -import random -import sys -from pathlib import Path -from threading import Thread -from typing import Any, Dict, List, Optional, Union - -import gradio as gr -import torch -import yaml - -# add src to the pythonpath so we don't need to pip install this -from accelerate.commands.config import config_args -from art import text2art -from huggingface_hub import HfApi -from huggingface_hub.utils import LocalTokenNotFoundError -from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer - -from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer -from axolotl.logging_config import configure_logging -from axolotl.train import TrainDatasetMeta -from axolotl.utils.config import normalize_config, validate_config -from axolotl.utils.data import prepare_dataset -from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process -from axolotl.utils.models import load_tokenizer -from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.trainer import prepare_optim_env -from axolotl.utils.wandb_ import setup_wandb_env_vars - -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -src_dir = os.path.join(project_root, "src") -sys.path.insert(0, src_dir) - -configure_logging() -LOG = logging.getLogger("axolotl.scripts") - -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" - - -def print_axolotl_text_art(suffix=None): - font = "nancyj" - ascii_text = " axolotl" - if suffix: - ascii_text += f" x {suffix}" - ascii_art = text2art(ascii_text, font=font) - - if is_main_process(): - print(ascii_art) - - -def get_multi_line_input() -> Optional[str]: - print("Give me an instruction (Ctrl + D to submit): ") - instruction = "" - for line in sys.stdin: - instruction += line # pylint: disable=consider-using-join - # instruction = pathlib.Path("/proc/self/fd/0").read_text() - return instruction - - -def do_merge_lora( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - safe_serialization = cfg.save_safetensors is True - - LOG.info("running merge of LoRA with base model") - model = model.merge_and_unload() - model.to(dtype=cfg.torch_dtype) - - if cfg.local_rank == 0: - LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}") - model.save_pretrained( - str(Path(cfg.output_dir) / "merged"), - safe_serialization=safe_serialization, - ) - tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) - - -def do_inference( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - prompter = cli_args.prompter - default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} - - for token, symbol in default_tokens.items(): - # If the token isn't already specified in the config, add it - if not (cfg.special_tokens and token in cfg.special_tokens): - tokenizer.add_special_tokens({token: symbol}) - - prompter_module = None - if prompter: - prompter_module = getattr( - importlib.import_module("axolotl.prompters"), prompter - ) - - model = model.to(cfg.device) - - while True: - print("=" * 80) - # support for multiline inputs - instruction = get_multi_line_input() - if not instruction: - return - if prompter_module: - prompt: str = next( - prompter_module().build_prompt(instruction=instruction.strip("\n")) - ) - else: - prompt = instruction.strip() - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - - print("=" * 40) - model.eval() - with torch.no_grad(): - generation_config = GenerationConfig( - repetition_penalty=1.1, - max_new_tokens=1024, - temperature=0.9, - top_p=0.95, - top_k=40, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=True, - use_cache=True, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - streamer = TextStreamer(tokenizer) - generated = model.generate( - inputs=batch["input_ids"].to(cfg.device), - generation_config=generation_config, - streamer=streamer, - ) - print("=" * 40) - print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) - - -def do_inference_gradio( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - prompter = cli_args.prompter - default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} - - for token, symbol in default_tokens.items(): - # If the token isn't already specified in the config, add it - if not (cfg.special_tokens and token in cfg.special_tokens): - tokenizer.add_special_tokens({token: symbol}) - - prompter_module = None - if prompter: - prompter_module = getattr( - importlib.import_module("axolotl.prompters"), prompter - ) - - model = model.to(cfg.device) - - def generate(instruction): - if not instruction: - return - if prompter_module: - # pylint: disable=stop-iteration-return - prompt: str = next( - prompter_module().build_prompt(instruction=instruction.strip("\n")) - ) - else: - prompt = instruction.strip() - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - - model.eval() - with torch.no_grad(): - generation_config = GenerationConfig( - repetition_penalty=1.1, - max_new_tokens=1024, - temperature=0.9, - top_p=0.95, - top_k=40, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=True, - use_cache=True, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - streamer = TextIteratorStreamer(tokenizer) - generation_kwargs = { - "inputs": batch["input_ids"].to(cfg.device), - "generation_config": generation_config, - "streamer": streamer, - } - - thread = Thread(target=model.generate, kwargs=generation_kwargs) - thread.start() - - all_text = "" - - for new_text in streamer: - all_text += new_text - yield all_text - - demo = gr.Interface( - fn=generate, - inputs="textbox", - outputs="text", - title=cfg.get("gradio_title", "Axolotl Gradio Interface"), - ) - demo.queue().launch(show_api=False, share=True) - - -def choose_config(path: Path): - yaml_files = list(path.glob("*.yml")) - - if not yaml_files: - raise ValueError( - "No YAML config files found in the specified directory. Are you using a .yml extension?" - ) - - if len(yaml_files) == 1: - print(f"Using default YAML file '{yaml_files[0]}'") - return yaml_files[0] - - print("Choose a YAML file:") - for idx, file in enumerate(yaml_files): - print(f"{idx + 1}. {file}") - - chosen_file = None - while chosen_file is None: - try: - choice = int(input("Enter the number of your choice: ")) - if 1 <= choice <= len(yaml_files): - chosen_file = yaml_files[choice - 1] - else: - print("Invalid choice. Please choose a number from the list.") - except ValueError: - print("Invalid input. Please enter a number.") - - return chosen_file - - -def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool: - return not any(el in list2 for el in list1) - - -def load_cfg(config: Path = Path("examples/"), **kwargs): - if Path(config).is_dir(): - config = choose_config(config) - - # load the config from the yaml file - with open(config, encoding="utf-8") as file: - cfg: DictDefault = DictDefault(yaml.safe_load(file)) - cfg.axolotl_config_path = config - # if there are any options passed in the cli, if it is something that seems valid from the yaml, - # then overwrite the value - cfg_keys = cfg.keys() - for k, _ in kwargs.items(): - # if not strict, allow writing to cfg even if it's not in the yml already - if k in cfg_keys or not cfg.strict: - # handle booleans - if isinstance(cfg[k], bool): - cfg[k] = bool(kwargs[k]) - else: - cfg[k] = kwargs[k] - - validate_config(cfg) - - prepare_optim_env(cfg) - - normalize_config(cfg) - - setup_wandb_env_vars(cfg) - return cfg - - -def load_datasets( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -) -> TrainDatasetMeta: - tokenizer = load_tokenizer(cfg) - - train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( - cfg, tokenizer - ) - - if cli_args.debug or cfg.debug: - LOG.info("check_dataset_labels...") - check_dataset_labels( - train_dataset.select( - [ - random.randrange(0, len(train_dataset) - 1) # nosec - for _ in range(cli_args.debug_num_examples) - ] - ), - tokenizer, - num_examples=cli_args.debug_num_examples, - text_only=cli_args.debug_text_only, - ) - - LOG.info("printing prompters...") - for prompter in prompters: - LOG.info(prompter) - - return TrainDatasetMeta( - train_dataset=train_dataset, - eval_dataset=eval_dataset, - total_num_steps=total_num_steps, - ) - - -def check_accelerate_default_config(): - if Path(config_args.default_yaml_config_file).exists(): - LOG.warning( - f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" - ) - - -def check_user_token(): - # Verify if token is valid - api = HfApi() - try: - user_info = api.whoami() - return bool(user_info) - except LocalTokenNotFoundError: - LOG.warning( - "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets." - ) - return False diff --git a/axolotl/src/axolotl/cli/inference.py b/axolotl/src/axolotl/cli/inference.py deleted file mode 100644 index 86ad8409ff386ef857214207f460c22843af6be3..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/cli/inference.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -CLI to run inference on a trained model -""" -from pathlib import Path - -import fire -import transformers - -from axolotl.cli import ( - do_inference, - do_inference_gradio, - load_cfg, - print_axolotl_text_art, -) -from axolotl.common.cli import TrainerCliArgs - - -def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs): - # pylint: disable=duplicate-code - print_axolotl_text_art() - parsed_cfg = load_cfg(config, **kwargs) - parsed_cfg.sample_packing = False - parser = transformers.HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - parsed_cli_args.inference = True - - if gradio: - do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args) - else: - do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) - - -if __name__ == "__main__": - fire.Fire(do_cli) diff --git a/axolotl/src/axolotl/cli/merge_lora.py b/axolotl/src/axolotl/cli/merge_lora.py deleted file mode 100644 index 0caee4c28b3402905301c29add18105ffbc06803..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/cli/merge_lora.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -CLI to run merge a trained LoRA into a base model -""" -from pathlib import Path - -import fire -import transformers - -from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art -from axolotl.common.cli import TrainerCliArgs - - -def do_cli(config: Path = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code - print_axolotl_text_art() - parser = transformers.HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - parsed_cli_args.merge_lora = True - parsed_cfg = load_cfg(config, merge_lora=True, **kwargs) - - do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) - - -if __name__ == "__main__": - fire.Fire(do_cli) diff --git a/axolotl/src/axolotl/cli/preprocess.py b/axolotl/src/axolotl/cli/preprocess.py deleted file mode 100644 index e0eeea6b346d213787206295958292bd6a983cb9..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/cli/preprocess.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -CLI to run training on a model -""" -import logging -from pathlib import Path - -import fire -import transformers -from colorama import Fore - -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - load_cfg, - load_datasets, - print_axolotl_text_art, -) -from axolotl.common.cli import PreprocessCliArgs -from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH - -LOG = logging.getLogger("axolotl.cli.preprocess") - - -def do_cli(config: Path = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code - print_axolotl_text_art() - parsed_cfg = load_cfg(config, **kwargs) - check_accelerate_default_config() - check_user_token() - parser = transformers.HfArgumentParser((PreprocessCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - if not parsed_cfg.dataset_prepared_path: - msg = ( - Fore.RED - + "preprocess CLI called without dataset_prepared_path set, " - + f"using default path: {DEFAULT_DATASET_PREPARED_PATH}" - + Fore.RESET - ) - LOG.warning(msg) - parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH - - _ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) - LOG.info( - Fore.GREEN - + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`" - + Fore.RESET - ) - - -if __name__ == "__main__": - fire.Fire(do_cli) diff --git a/axolotl/src/axolotl/cli/shard.py b/axolotl/src/axolotl/cli/shard.py deleted file mode 100644 index 85901b0f2aa20f626fd556c879ad0fdac1bc5688..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/cli/shard.py +++ /dev/null @@ -1,42 +0,0 @@ -""" -CLI to shard a trained model into 10GiB chunks -""" -import logging -from pathlib import Path - -import fire -import transformers - -from axolotl.cli import load_cfg, print_axolotl_text_art -from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer -from axolotl.utils.dict import DictDefault - -LOG = logging.getLogger("axolotl.scripts") - - -def shard( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - safe_serialization = cfg.save_safetensors is True - LOG.debug("Re-saving model w/ sharding") - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) - - -def do_cli(config: Path = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code - print_axolotl_text_art() - parsed_cfg = load_cfg(config, **kwargs) - parser = transformers.HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - parsed_cli_args.shard = True - - shard(cfg=parsed_cfg, cli_args=parsed_cli_args) - - -if __name__ == "__main__": - fire.Fire(do_cli) diff --git a/axolotl/src/axolotl/cli/train.py b/axolotl/src/axolotl/cli/train.py deleted file mode 100644 index 81307b6b92bf36939f6a9c43d0ca4b1004a43450..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/cli/train.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -CLI to run training on a model -""" -import logging -from pathlib import Path - -import fire -import transformers - -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - load_cfg, - load_datasets, - print_axolotl_text_art, -) -from axolotl.common.cli import TrainerCliArgs -from axolotl.train import train - -LOG = logging.getLogger("axolotl.cli.train") - - -def do_cli(config: Path = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code - parsed_cfg = load_cfg(config, **kwargs) - print_axolotl_text_art() - check_accelerate_default_config() - check_user_token() - parser = transformers.HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) - train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) - - -if __name__ == "__main__": - fire.Fire(do_cli) diff --git a/axolotl/src/axolotl/common/__init__.py b/axolotl/src/axolotl/common/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/axolotl/src/axolotl/common/cli.py b/axolotl/src/axolotl/common/cli.py deleted file mode 100644 index 74e66a3b6243ee4b3fdb8b8c11aa2512591bb946..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/common/cli.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -shared module for cli specific things -""" - -import logging -from dataclasses import dataclass, field -from typing import Optional - -from axolotl.logging_config import configure_logging -from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer - -configure_logging() -LOG = logging.getLogger("axolotl.common.cli") - -def test_func(): - return "Hello World to All!" - -@dataclass -class TrainerCliArgs: - """ - dataclass representing the various non-training arguments - """ - - debug: bool = field(default=False) - debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=5) - inference: bool = field(default=False) - merge_lora: bool = field(default=False) - prompter: Optional[str] = field(default=None) - shard: bool = field(default=False) - - -@dataclass -class PreprocessCliArgs: - """ - dataclass representing arguments for preprocessing only - """ - - debug: bool = field(default=False) - debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=1) - prompter: Optional[str] = field(default=None) - - -def load_model_and_tokenizer( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") - tokenizer = load_tokenizer(cfg) - LOG.info("loading model and (optionally) peft_config...") - model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) - - return model, tokenizer diff --git a/axolotl/src/axolotl/common/const.py b/axolotl/src/axolotl/common/const.py deleted file mode 100644 index fd34ad4694f082fd8a1a0f135fba162e977b868c..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/common/const.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -Various shared constants -""" - -DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" diff --git a/axolotl/src/axolotl/convert.py b/axolotl/src/axolotl/convert.py deleted file mode 100644 index 357e0ec50e1a055f28f1a26e39ca3f2c961a6954..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/convert.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes""" - - -import json -import sys - - -class FileReader: - """ - Reads a file and returns its contents as a string - """ - - def read(self, file_path): - with open(file_path, encoding="utf-8") as file: - return file.read() - - -class FileWriter: - """ - Writes a string to a file - """ - - def __init__(self, file_path): - self.file_path = file_path - - def write(self, content): - with open(self.file_path, "w", encoding="utf-8") as file: - file.write(content) - - -class StdoutWriter: - """ - Writes a string to stdout - """ - - def write(self, content): - sys.stdout.write(content) - sys.stdout.write("\n") - - -class JsonParser: - """ - Parses a string as JSON and returns the result - """ - - def parse(self, content): - return json.loads(content) - - -class JsonlSerializer: - """ - Serializes a list of JSON objects into a JSONL string - """ - - def serialize(self, data): - lines = [json.dumps(item) for item in data] - return "\n".join(lines) - - -class JsonToJsonlConverter: - """ - Converts a JSON file to JSONL - """ - - def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer): - self.file_reader = file_reader - self.file_writer = file_writer - self.json_parser = json_parser - self.jsonl_serializer = jsonl_serializer - - def convert( - self, input_file_path, output_file_path - ): # pylint: disable=unused-argument - content = self.file_reader.read(input_file_path) - data = self.json_parser.parse(content) - # data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations - jsonl_content = self.jsonl_serializer.serialize(data) - self.file_writer.write(jsonl_content) diff --git a/axolotl/src/axolotl/core/__init__.py b/axolotl/src/axolotl/core/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/axolotl/src/axolotl/core/trainer_builder.py b/axolotl/src/axolotl/core/trainer_builder.py deleted file mode 100644 index fed26de4644c9e8da42678f1b72f00007cb1b0cc..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/core/trainer_builder.py +++ /dev/null @@ -1,821 +0,0 @@ -""" -Builder for the training args and trainer -""" - -import abc -import importlib -import logging -import math -import sys -from abc import abstractmethod -from dataclasses import dataclass, field -from functools import wraps -from pathlib import Path -from typing import Optional - -import torch -import transformers -from datasets import Dataset -from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler -from transformers import EarlyStoppingCallback, Trainer, TrainingArguments -from transformers.trainer_utils import seed_worker - -from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler -from axolotl.utils.callbacks import ( - EvalFirstStepCallback, - GPUStatsCallback, - LossWatchDogCallback, - SaveAxolotlConfigtoWandBCallback, - SaveBetterTransformerModelCallback, - bench_eval_callback_factory, - log_prediction_callback_factory, -) -from axolotl.utils.collators import ( - BatchSamplerDataCollatorForSeq2Seq, - MambaDataCollator, -) -from axolotl.utils.samplers import MultipackBatchSampler -from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup - -try: - import torch._dynamo # pylint: disable=ungrouped-imports -except ImportError: - pass - -LOG = logging.getLogger("axolotl.core.trainer_builder") - - -@dataclass -class AxolotlTrainingArguments(TrainingArguments): - """ - Extend the base TrainingArguments for axolotl helpers - """ - - model_type: Optional[str] = field( - default=None, metadata={"help": "HF model configuration model_type."} - ) - lr_quadratic_warmup: bool = field( - default=False, - metadata={"help": "Use quadratic warmup for cosine scheduling."}, - ) - sample_packing: bool = field( - default=False, - metadata={"help": "Use sample packing for efficient training."}, - ) - eval_sample_packing: Optional[bool] = field( - default=None, - metadata={"help": "Use sample packing for efficient evals."}, - ) - sample_packing_efficiency: float = field( - default=1.0, - metadata={"help": "Sample packing efficiency for calculating batch length."}, - ) - max_seq_length: int = field( - default=2048, - metadata={"help": "The maximum sequence length the model can handle"}, - ) - sample_packing_seq_len_multiplier: int = field( - default=1, - metadata={"help": "the multiplier for the max len for packed sequences"}, - ) - relora_steps: Optional[int] = field( - default=None, - metadata={"help": "how often to reset for ReLoRA"}, - ) - relora_warmup_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - bench_split: Optional[str] = field( - default="eval", metadata={"help": "The benchmark split to run on"} - ) - bench_dataset: Optional[str] = field( - default="pharaouk/dharma-1/dharma_1_mini.json", - metadata={ - "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" - }, - ) - do_bench_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Benchmark evaluation."} - ) - max_bench_samples: Optional[int] = field( - default=None, - metadata={ - "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." - }, - ) - bench_source_max_len: int = field( - default=2048, metadata={"help": "Maximum source sequence length for bench."} - ) - dataloader_prefetch_factor: Optional[int] = field( - default=None, - metadata={"help": "prefetch_factor argument to the dataloader"}, - ) - - -class AxolotlTrainer(Trainer): - """ - Extend the base Trainer for axolotl helpers - """ - - args = None # type: AxolotlTrainingArguments - tag_names = ["axolotl"] - - def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs): - self.num_epochs = num_epochs - self.bench_data_collator = bench_data_collator - super().__init__(*args, **kwargs) - - def create_scheduler( - self, num_training_steps: int, optimizer: torch.optim.Optimizer = None - ): - """ - Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or - passed as an argument. - - Args: - num_training_steps (int): The number of training steps to do. - optimizer (torch.optim.Optimizer): The training optimizer - """ - - # fmt: off - if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition - # fmt: on - if ( - self.args.lr_scheduler_type == "cosine" - and self.args.lr_quadratic_warmup is True - ): - self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - ) - else: - return super().create_scheduler(num_training_steps, optimizer) - return self.lr_scheduler - - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing: - return MultipackBatchSampler( - RandomSampler(self.train_dataset), - self.args.train_batch_size, - drop_last=True, - batch_max_len=self._train_batch_size * self.args.max_seq_length, - lengths=( - self.train_dataset.data.column("position_ids") - .to_pandas() - .apply(lambda x: x[-1] + 1) - .values - ), - packing_efficiency_estimate=self.args.sample_packing_efficiency, - ) - return super()._get_train_sampler() - - def _get_eval_sampler( - self, eval_dataset: Dataset - ) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and self.args.eval_sample_packing is not False: - return MultipackBatchSampler( - SequentialSampler(eval_dataset), - self.args.per_device_eval_batch_size, - drop_last=True, - batch_max_len=self.args.eval_batch_size * self.args.max_seq_length, - lengths=( - eval_dataset.data.column("position_ids") - .to_pandas() - .apply(lambda x: x[-1] + 1) - .values - ), - packing_efficiency_estimate=self.args.sample_packing_efficiency, - ) - return super()._get_eval_sampler(eval_dataset) - - def get_train_dataloader(self) -> DataLoader: - if self.args.sample_packing: - train_dataset = self.train_dataset - train_dataset = train_dataset.remove_columns(["length"]) - data_collator = self.data_collator - dataloader_params = { - "batch_size": self._train_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params[ - "prefetch_factor" - ] = self.args.dataloader_prefetch_factor - - sampler = self._get_train_sampler() - if isinstance(sampler, BatchSampler): - dataloader_params["batch_sampler"] = sampler - del dataloader_params["batch_size"] - else: - dataloader_params["sampler"] = sampler - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = seed_worker - - self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader( - DataLoader(train_dataset, **dataloader_params) - ) - return super().get_train_dataloader() - - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: - if self.args.sample_packing and self.args.eval_sample_packing is not False: - eval_dataset = ( - eval_dataset if eval_dataset is not None else self.eval_dataset - ) - - eval_sampler = self._get_eval_sampler(eval_dataset) - eval_dataset = eval_dataset.remove_columns(["length"]) - data_collator = self.data_collator - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params[ - "prefetch_factor" - ] = self.args.dataloader_prefetch_factor - - if isinstance(eval_sampler, BatchSampler): - dataloader_params["batch_sampler"] = eval_sampler - del dataloader_params["batch_size"] - else: - dataloader_params["sampler"] = eval_sampler - dataloader_params["drop_last"] = self.args.dataloader_drop_last - - self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader( - DataLoader(eval_dataset, **dataloader_params) - ) - return super().get_eval_dataloader(eval_dataset) - - def _get_bench_sampler( - self, bench_dataset: Dataset - ) -> Optional[torch.utils.data.Sampler]: - if self.args.world_size <= 1: - return SequentialSampler(bench_dataset) - return None - - def get_bench_dataloader( - self, - bench_dataset: Dataset, - ) -> DataLoader: - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": self.bench_data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - if not isinstance(bench_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) - dataloader_params["drop_last"] = self.args.dataloader_drop_last - - return DataLoader(bench_dataset, **dataloader_params) - # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) - - def compute_loss(self, model, inputs, return_outputs=False): - # use one's weighted cross entropy loss calc - # if self.args.sample_packing: - # labels = inputs.pop("labels") - # outputs = model(**inputs) - # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) - # return (loss, outputs) if return_outputs else loss - return super().compute_loss(model, inputs, return_outputs=return_outputs) - - def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None): - if isinstance(tag_names, str): - tag_names = [tag_names] - - if kwargs is not None: - if "tags" not in kwargs: - kwargs["tags"] = tag_names - elif "tags" in kwargs and isinstance(kwargs["tags"], list): - kwargs["tags"].extend(tag_names) - elif "tags" in kwargs and isinstance(kwargs["tags"], str): - tag_names.append(kwargs["tags"]) - kwargs["tags"] = tag_names - - return kwargs - - @wraps(Trainer.push_to_hub) - def push_to_hub(self, *args, **kwargs) -> str: - """ - Overwrite the `push_to_hub` method in order to force-add the tags when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - """ - kwargs = self._sanitize_kwargs_for_tagging( - tag_names=self.tag_names, kwargs=kwargs - ) - - return super().push_to_hub(*args, **kwargs) - - -class AxolotlMambaTrainer(AxolotlTrainer): - """ - Mamba specific trainer to handle loss calculation - """ - - tag_names = ["axolotl", "mamba"] - - def compute_loss( - self, - model, - inputs, - return_outputs=False, # pylint: disable=unused-argument - ): - input_ids = inputs.pop("input_ids") - lm_logits = model(input_ids).logits - - labels = input_ids.to(lm_logits.device) - shift_logits = lm_logits[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - - loss_fct = torch.nn.CrossEntropyLoss() - lm_loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) - ) - - return lm_loss - - -class OneCycleLRSchedulerTrainer(AxolotlTrainer): - """ - Trainer subclass that uses the OneCycleLR scheduler - """ - - tag_names = ["axolotl", "onecycle"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lr_scheduler = None - - def create_scheduler( - self, - num_training_steps: int, - optimizer: Optional[torch.optim.Optimizer] = None, - ): - optimizer = self.optimizer if optimizer is None else optimizer - num_warmup_steps = self.args.get_warmup_steps(num_training_steps) - pct_start = num_warmup_steps / num_training_steps - - self.lr_scheduler = OneCycleLR( - optimizer, - max_lr=self.args.learning_rate, - total_steps=num_training_steps, - pct_start=pct_start, - div_factor=6, - ) - - return self.lr_scheduler - - -class ReLoRATrainer(AxolotlTrainer): - """ - Trainer subclass that uses the OneCycleLR scheduler - """ - - tag_names = ["axolotl", "relora"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lr_scheduler = None - - def create_scheduler( - self, - num_training_steps: int, - optimizer: Optional[torch.optim.Optimizer] = None, - ): - optimizer = self.optimizer if optimizer is None else optimizer - lr_scheduler = super().create_scheduler(num_training_steps, optimizer) - - if self.args.relora_steps: - warmup_steps = ( - self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 - ) - self.lr_scheduler = ReLoRAScheduler( - optimizer, - lr_scheduler, - self.args.relora_steps, - warmup_steps, - ) - else: - self.lr_scheduler = lr_scheduler - - return self.lr_scheduler - - -class TrainerBuilderBase(abc.ABC): - """ - Base class for trainer builder - """ - - _train_dataset = None - _eval_dataset = None - - def __init__(self, cfg, model, tokenizer): - self.cfg = cfg - self.model = model - self.tokenizer = tokenizer - - @property - def train_dataset(self): - return self._train_dataset - - @train_dataset.setter - def train_dataset(self, dataset): - self._train_dataset = dataset - - @property - def eval_dataset(self): - return self._eval_dataset - - @eval_dataset.setter - def eval_dataset(self, dataset): - self._eval_dataset = dataset - - @abstractmethod - def build(self, total_num_steps): - pass - - @abstractmethod - def get_callbacks(self): - pass - - @abstractmethod - def get_post_trainer_create_callbacks(self, trainer): - """ - Callbacks added after the trainer is created, usually b/c these need access to the trainer - """ - - -class HFCausalTrainerBuilder(TrainerBuilderBase): - """ - Build the HuggingFace training args/trainer for Causal models - """ - - def hook_pre_create_training_args(self, training_arguments_kwargs): - # TODO - return training_arguments_kwargs - - def hook_post_create_training_args(self, training_arguments): - # TODO - return training_arguments - - def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls): - # TODO - return trainer_kwargs, trainer_cls - - def hook_post_create_trainer(self, trainer): - # TODO - return trainer - - def get_callbacks(self): - callbacks = [] - callbacks.append(GPUStatsCallback(self.cfg)) - callbacks.append(EvalFirstStepCallback) - - if self.cfg.relora_steps: - callbacks.append(ReLoRACallback(self.cfg)) - - if ( - hasattr(self.model, "use_bettertransformer") - and self.model.use_bettertransformer is True - ): - callbacks.append(SaveBetterTransformerModelCallback) - - if self.cfg.use_wandb: - callbacks.append( - SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) - ) - - if self.cfg.loss_watchdog_threshold is not None: - callbacks.append(LossWatchDogCallback(self.cfg)) - - return callbacks - - def get_post_trainer_create_callbacks(self, trainer): - callbacks = [] - if self.cfg.use_wandb and self.cfg.eval_table_size > 0: - LogPredictionCallback = log_prediction_callback_factory( - trainer, self.tokenizer - ) - callbacks.append(LogPredictionCallback(self.cfg)) - - if self.cfg.do_bench_eval: - callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) - - if self.cfg.early_stopping_patience: - early_stop_cb = EarlyStoppingCallback( - self.cfg.early_stopping_patience, - ) - callbacks.append(early_stop_cb) - - return callbacks - - def _get_trainer_cls(self): - if self.cfg.lr_scheduler == "one_cycle" and ( - self.cfg.fsdp or self.cfg.adapter == "qlora" - ): - return OneCycleLRSchedulerTrainer - if self.cfg.relora_steps: - return ReLoRATrainer - if self.cfg.model_config_type == "mamba": - return AxolotlMambaTrainer - return AxolotlTrainer - - def build(self, total_num_steps): - warmup_steps = None - if self.cfg.warmup_steps is not None: - warmup_steps = self.cfg.warmup_steps - elif self.cfg.warmup_ratio is not None: - warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0) - else: - warmup_steps = min(int(0.03 * total_num_steps), 100) - - logging_steps = ( - self.cfg.logging_steps - if self.cfg.logging_steps is not None - else max(min(int(0.005 * total_num_steps), 10), 1) - ) - - training_arguments_kwargs = {} - if self.cfg.bf16 == "full": - training_arguments_kwargs["bf16_full_eval"] = True - else: - training_arguments_kwargs["bf16"] = self.cfg.bf16 - training_arguments_kwargs["fp16"] = ( - self.cfg.fp16 and not self.cfg.bf16 - ) or False - training_arguments_kwargs["tf32"] = self.cfg.tf32 - training_arguments_kwargs["warmup_steps"] = warmup_steps - training_arguments_kwargs["logging_steps"] = logging_steps - - if self.cfg.seed: - training_arguments_kwargs["seed"] = self.cfg.seed - - if self.cfg.gradient_checkpointing: - training_arguments_kwargs[ - "gradient_checkpointing" - ] = self.cfg.gradient_checkpointing - if self.cfg.fsdp: - training_arguments_kwargs["fsdp"] = self.cfg.fsdp - if self.cfg.fsdp_config: - training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config) - - # deepspeed - if self.cfg.deepspeed: - training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed - - if self.cfg.lr_quadratic_warmup is not None: - training_arguments_kwargs[ - "lr_quadratic_warmup" - ] = self.cfg.lr_quadratic_warmup - - if self.cfg.adam_beta1: - training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 - if self.cfg.adam_beta2: - training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2 - if self.cfg.adam_epsilon: - training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon - if self.cfg.max_grad_norm: - training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm - - if self.cfg.hub_model_id: - training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id - training_arguments_kwargs["push_to_hub"] = True - training_arguments_kwargs["hub_private_repo"] = True - - if self.cfg.hub_strategy: - training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy - - if self.cfg.save_safetensors is not None: - training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors - - if self.cfg.sample_packing_eff_est: - training_arguments_kwargs[ - "sample_packing_efficiency" - ] = self.cfg.sample_packing_eff_est - - if self.cfg.dataloader_pin_memory is not None: - training_arguments_kwargs[ - "dataloader_pin_memory" - ] = self.cfg.dataloader_pin_memory - if self.cfg.dataloader_num_workers is not None: - training_arguments_kwargs[ - "dataloader_num_workers" - ] = self.cfg.dataloader_num_workers - if self.cfg.dataloader_prefetch_factor is not None: - training_arguments_kwargs[ - "dataloader_prefetch_factor" - ] = self.cfg.dataloader_prefetch_factor - - if self.cfg.val_set_size == 0: - # no eval set, so don't eval - training_arguments_kwargs["evaluation_strategy"] = "no" - elif self.cfg.eval_steps: - training_arguments_kwargs["evaluation_strategy"] = "steps" - training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps - elif self.cfg.evaluation_strategy: - training_arguments_kwargs[ - "evaluation_strategy" - ] = self.cfg.evaluation_strategy - else: - # we have an eval set, but no steps defined, default to use epoch - training_arguments_kwargs["evaluation_strategy"] = "epoch" - - if self.cfg.save_steps: - training_arguments_kwargs["save_strategy"] = "steps" - training_arguments_kwargs["save_steps"] = self.cfg.save_steps - elif self.cfg.save_strategy: - training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy - else: - # default to saving each epoch if not defined - training_arguments_kwargs["save_strategy"] = "epoch" - - if self.cfg.do_bench_eval: - training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval - if self.cfg.bench_dataset: - training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset - if self.cfg.metric_for_best_model: - training_arguments_kwargs[ - "metric_for_best_model" - ] = self.cfg.metric_for_best_model - if self.cfg.greater_is_better: - training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better - - if self.cfg.torch_compile: - if torch.__version__ < "2.1.0": # pylint: disable=protected-access - LOG.warning("torch>=2.1.0 required for torch_compile to work properly") - elif torch._dynamo: # pylint: disable=protected-access - torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access - True - ) - training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile - if self.cfg.torch_compile_backend: - training_arguments_kwargs[ - "torch_compile_backend" - ] = self.cfg.torch_compile_backend - - # DDP Config - if self.cfg.ddp_timeout: - training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout - # see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html - if self.cfg.ddp_bucket_cap_mb: - training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb - if self.cfg.ddp_broadcast_buffers is not None: - training_arguments_kwargs[ - "ddp_broadcast_buffers" - ] = self.cfg.ddp_broadcast_buffers - - # these are all the "standard" kwargs that are def used - training_arguments_kwargs["max_steps"] = ( - total_num_steps if self.cfg.max_steps else -1 - ) - training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len - training_arguments_kwargs[ - "per_device_train_batch_size" - ] = self.cfg.micro_batch_size - training_arguments_kwargs[ - "per_device_eval_batch_size" - ] = self.cfg.eval_batch_size - training_arguments_kwargs[ - "gradient_accumulation_steps" - ] = self.cfg.gradient_accumulation_steps - training_arguments_kwargs[ - "eval_accumulation_steps" - ] = self.cfg.gradient_accumulation_steps - training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs - training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate - training_arguments_kwargs["output_dir"] = self.cfg.output_dir - training_arguments_kwargs["save_total_limit"] = ( - self.cfg.save_total_limit if self.cfg.save_total_limit else 4 - ) - training_arguments_kwargs["load_best_model_at_end"] = ( - ( - self.cfg.load_best_model_at_end is not False - or self.cfg.early_stopping_patience - ) - and self.cfg.val_set_size > 0 - and self.cfg.save_steps - and self.cfg.eval_steps - and self.cfg.save_steps % self.cfg.eval_steps == 0 - ) or False - training_arguments_kwargs["ddp_find_unused_parameters"] = ( - False if self.cfg.ddp else None - ) - training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length - training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None - training_arguments_kwargs["run_name"] = ( - self.cfg.wandb_name if self.cfg.use_wandb else None - ) - training_arguments_kwargs["optim"] = ( - self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" - ) - training_arguments_kwargs["lr_scheduler_type"] = ( - self.cfg.lr_scheduler - if self.cfg.lr_scheduler - and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep") - else "cosine" - ) - training_arguments_kwargs["lr_scheduler_kwargs"] = ( - self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} - ) - training_arguments_kwargs["weight_decay"] = ( - self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 - ) - training_arguments_kwargs["sample_packing"] = ( - self.cfg.sample_packing if self.cfg.sample_packing else False - ) - training_arguments_kwargs["eval_sample_packing"] = ( - self.cfg.sample_packing - if self.cfg.eval_sample_packing is not False - else False - ) - training_arguments_kwargs[ - "sample_packing_seq_len_multiplier" - ] = self.cfg.micro_batch_size - training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps - training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps - training_arguments_kwargs = self.hook_pre_create_training_args( - training_arguments_kwargs - ) - training_arguments_kwargs["model_type"] = self.cfg.model_config_type - - if self.cfg.neftune_noise_alpha is not None: - training_arguments_kwargs[ - "neftune_noise_alpha" - ] = self.cfg.neftune_noise_alpha - - training_args = ( - AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg - **training_arguments_kwargs, - ) - ) - training_args = self.hook_post_create_training_args(training_args) - trainer_kwargs = {} - - if self.cfg.optimizer == "adamw_anyprecision": - if Path(self.cfg.torchdistx_path).exists(): - sys.path.append(self.cfg.torchdistx_path) - importlib.import_module("torchdistx") - - data_collator_kwargs = { - "padding": True, # True/"longest" is the default - } - if self.cfg.pad_to_sequence_len: - data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil( - self.cfg.sequence_len / 64 - ) - else: - # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check - # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html - data_collator_kwargs["pad_to_multiple_of"] = 64 - - trainer_cls = self._get_trainer_cls() - trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( - trainer_kwargs, trainer_cls - ) - trainer = trainer_cls( - model=self.model, - train_dataset=self.train_dataset, - eval_dataset=self.eval_dataset, - args=training_args, - data_collator=self.build_collator(**data_collator_kwargs), - bench_data_collator=transformers.DataCollatorForSeq2Seq( - self.tokenizer, - return_tensors="pt", - **data_collator_kwargs, - ), - callbacks=self.get_callbacks(), - num_epochs=self.cfg.num_epochs, - **trainer_kwargs, - ) - trainer = self.hook_post_create_trainer(trainer) - for callback in self.get_post_trainer_create_callbacks(trainer): - trainer.add_callback(callback) - - if self.cfg.deepspeed and self.cfg.sample_packing: - trainer.accelerator.state.deepspeed_plugin.deepspeed_config[ - "train_micro_batch_size_per_gpu" - ] = self.cfg.micro_batch_size - - return trainer - - def build_collator(self, **kwargs): - if self.cfg.model_config_type == "mamba": - return MambaDataCollator(tokenizer=self.tokenizer) - - return BatchSamplerDataCollatorForSeq2Seq( - self.tokenizer, - return_tensors="pt", - **kwargs, - ) diff --git a/axolotl/src/axolotl/datasets.py b/axolotl/src/axolotl/datasets.py deleted file mode 100644 index d5362643f61a61926f214c65e19df9f417bb3788..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/datasets.py +++ /dev/null @@ -1,182 +0,0 @@ -"""Module containing Dataset functionality""" - -import logging -import os -from typing import List, Optional - -import torch -from datasets import Dataset, IterableDataset - -from .prompt_tokenizers import PromptTokenizingStrategy - -# We want this to be a wrapper for an existing dataset that we have loaded -# lets use the concept of middlewares to wrap each dataset, for example -# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])) -# let's check to ensure we don't truncate an item in the middle, we'll use -# the collators later on to pad the datasets - -LOG = logging.getLogger("axolotl") - - -class TokenizedPromptDataset(Dataset): - """ - Dataset that returns tokenized prompts from a stream of text files. - Args: - prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data. - dataset (dataset.Dataset): Dataset with text files. - """ - - def __init__( # pylint: disable=super-init-not-called - self, - prompt_tokenizer: PromptTokenizingStrategy, - dataset: IterableDataset, - process_count: Optional[int] = None, - **kwargs, - ): - self.prompt_tokenizer = prompt_tokenizer - self.process_count = process_count - super().__init__(self.process(dataset).data, **kwargs) - - def process(self, dataset): - features = dataset.features.keys() - num_proc = ( - min(64, self.process_count) - if self.process_count - else min(64, os.cpu_count()) - ) - map_kwargs = {} - if self.prompt_tokenizer.supports_batched: - map_kwargs["batched"] = True - map_kwargs["batch_size"] = 100 - return dataset.map( - self.prompt_tokenizer.tokenize_prompt, - num_proc=num_proc, - remove_columns=features, - **map_kwargs, - ) - - -# TODO this isn't the best since it can't interleave datasets -class ConstantLengthDataset(IterableDataset): - """ - Iterable dataset that returns constant length chunks of tokens from stream of text files. - Args: - tokenizer (Tokenizer): The processor used for processing the data. - dataset (dataset.Dataset): Dataset with text files. - seq_length (int): Length of token sequences to return. - """ - - def __init__( # pylint: disable=super-init-not-called - self, - tokenizer, - datasets, - seq_length=2048, - ): - self.tokenizer = tokenizer - self.concat_token_id = tokenizer.eos_token_id - self.datasets: List[IterableDataset] = datasets - self.seq_length = seq_length - - vocab_size = len(tokenizer.get_vocab()) - - if vocab_size <= torch.iinfo(torch.int16).max: - self.tokens_dtype = torch.int16 - elif vocab_size <= torch.iinfo(torch.int32).max: - self.tokens_dtype = torch.int32 - else: - self.tokens_dtype = torch.int64 - - def __iter__(self): - buffer = { - "input_ids": [], - "attention_mask": [], - "labels": [], - "position_ids": [], - } - buffer_len = 0 - for dataset in self.datasets: - idx = 0 - iterator = iter(dataset) - more_examples = True - while more_examples: - try: - example = next(iterator) - idx += 1 - except StopIteration: - more_examples = False - example = None - - add_concat_token = False - if example: - example_len = len(example["input_ids"]) - add_concat_token = example["input_ids"][-1] != self.concat_token_id - else: - example_len = 0 - - if not example_len or ( - buffer_len + int(add_concat_token) + example_len > self.seq_length - ): - if buffer["input_ids"]: - input_ids = torch.cat(buffer["input_ids"], dim=-1)[ - : self.seq_length - ] - attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ - : self.seq_length - ] - position_ids = torch.cat(buffer["position_ids"], dim=-1)[ - : self.seq_length - ] - labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] - if labels.size() == input_ids.size() and ( - attention_mask.size() == input_ids.size() - ): - yield { - "input_ids": input_ids, - "labels": labels, - "attention_mask": attention_mask, - "position_ids": position_ids, - } - else: - LOG.warning( - f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" - ) - buffer = { - "input_ids": [], - "attention_mask": [], - "labels": [], - "position_ids": [], - } - buffer_len = 0 - idx = 1 - - if example: - # FIXME - # just going to drop data points that are too long - if len(example["input_ids"]) <= self.seq_length: - input_ids = example["input_ids"] - attention_mask = example["attention_mask"] - labels = example["labels"] - - if add_concat_token: - input_ids.append(self.concat_token_id) - attention_mask.append(1) - labels.append(self.concat_token_id) - - input_ids_with_concat = torch.tensor( - input_ids, dtype=self.tokens_dtype - ) - attention_mask_with_concat = torch.tensor( - [idx * m for m in attention_mask], dtype=torch.int16 - ) - labels_with_concat = torch.tensor( - labels, dtype=self.tokens_dtype - ) - position_ids = torch.arange( - len(input_ids), dtype=self.tokens_dtype - ) - - buffer["input_ids"].append(input_ids_with_concat) - buffer["attention_mask"].append(attention_mask_with_concat) - buffer["labels"].append(labels_with_concat) - buffer["position_ids"].append(position_ids) - buffer_len += len(input_ids) diff --git a/axolotl/src/axolotl/logging_config.py b/axolotl/src/axolotl/logging_config.py deleted file mode 100644 index 8f473aa24085943caafb99b5b9dcf3551359a049..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/logging_config.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Common logging module for axolotl -""" - -import os -import sys -from logging import Formatter -from logging.config import dictConfig -from typing import Any, Dict - -from colorama import Fore, Style, init - - -class ColorfulFormatter(Formatter): - """ - Formatter to add coloring to log messages by log type - """ - - COLORS = { - "WARNING": Fore.YELLOW, - "ERROR": Fore.RED, - "CRITICAL": Fore.RED + Style.BRIGHT, - } - - def format(self, record): - record.rank = int(os.getenv("LOCAL_RANK", "0")) - log_message = super().format(record) - return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET - - -DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { - "version": 1, - "formatters": { - "simple": { - "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s", - }, - "colorful": { - "()": ColorfulFormatter, - "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s", - }, - }, - "filters": {}, - "handlers": { - "console": { - "class": "logging.StreamHandler", - "formatter": "simple", - "filters": [], - "stream": sys.stdout, - }, - "color_console": { - "class": "logging.StreamHandler", - "formatter": "colorful", - "filters": [], - "stream": sys.stdout, - }, - }, - "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")}, - "loggers": { - "axolotl": { - "handlers": ["color_console"], - "level": "DEBUG", - "propagate": False, - }, - }, -} - - -def configure_logging(): - """Configure with default logging""" - init() # Initialize colorama - dictConfig(DEFAULT_LOGGING_CONFIG) diff --git a/axolotl/src/axolotl/models/__init__.py b/axolotl/src/axolotl/models/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/axolotl/src/axolotl/models/mamba/__init__.py b/axolotl/src/axolotl/models/mamba/__init__.py deleted file mode 100644 index 247c1d184b6fcbb3601908268512f33bf7427982..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/models/mamba/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Modeling module for Mamba models -""" - - -def fix_mamba_attn_for_loss(): - from mamba_ssm.models import mixer_seq_simple - - from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed - - mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed - return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name diff --git a/axolotl/src/axolotl/models/mamba/configuration_mamba.py b/axolotl/src/axolotl/models/mamba/configuration_mamba.py deleted file mode 100644 index 5160ee8d7e06a6d7c2dd3d0d9ae401ba3a288ef8..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/models/mamba/configuration_mamba.py +++ /dev/null @@ -1,42 +0,0 @@ -""" -HF Transformers MambaConfig -""" -from transformers import PretrainedConfig - - -class MambaConfig(PretrainedConfig): - """ - modeling configuration for state space model/mamba - """ - - model_type = "mamba" - - def __init__( - self, - vocab_size=50280, - d_model=2560, - n_layer=64, - rms_norm=True, - residual_in_fp32=True, - fused_add_norm=True, - pad_vocab_size_multiple=8, - pad_token_id=50277, - bos_token_id=0, - eos_token_id=0, - tie_word_embeddings=False, - **kwargs, - ): - self.vocab_size = vocab_size - self.d_model = d_model - self.n_layer = n_layer - self.rms_norm = rms_norm - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.pad_vocab_size_multiple = pad_vocab_size_multiple - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/axolotl/src/axolotl/models/mamba/modeling_mamba.py b/axolotl/src/axolotl/models/mamba/modeling_mamba.py deleted file mode 100644 index 70e9c88c882f595b849a63e2c274f6853ccc5ec2..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/models/mamba/modeling_mamba.py +++ /dev/null @@ -1,128 +0,0 @@ -# pylint: skip-file -import os -from collections import namedtuple -from functools import partial -from typing import Optional, Union - -import torch -from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights -from mamba_ssm.utils.generation import GenerationMixin -from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf -from torch import nn -from torch.nn import CrossEntropyLoss - -from axolotl.models.mamba.configuration_mamba import MambaConfig - - -class MambaLMHeadModel(nn.Module, GenerationMixin): - def __init__( - self, - d_model: int, - n_layer: int, - vocab_size: int, - initializer_cfg=None, - pad_vocab_size_multiple: int = 1, - device=None, - dtype=None, - **backbone_kwargs, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - if vocab_size % pad_vocab_size_multiple != 0: - vocab_size += pad_vocab_size_multiple - ( - vocab_size % pad_vocab_size_multiple - ) - self.config = MambaConfig( - vocab_size=vocab_size, - d_model=d_model, - n_layer=n_layer, - pad_vocab_size_multiple=pad_vocab_size_multiple, - ) - self.backbone = MixerModel( - d_model=d_model, - n_layer=n_layer, - vocab_size=vocab_size, - initializer_cfg=initializer_cfg, - **backbone_kwargs, - **factory_kwargs, - ) - self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) - - # Initialize weights and apply final processing - self.apply( - partial( - _init_weights, - n_layer=n_layer, - **(initializer_cfg if initializer_cfg is not None else {}), - ) - ) - self.tie_weights() - - def tie_weights(self): - self.lm_head.weight = self.backbone.embedding.weight - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.backbone.allocate_inference_cache( - batch_size, max_seqlen, dtype=dtype, **kwargs - ) - - def forward( - self, - input_ids, - position_ids=None, - inference_params=None, - num_last_tokens=0, - labels=None, - **kwargs, - ): - """ - "position_ids" is just to be compatible with Transformer generation. We don't use it. - num_last_tokens: if > 0, only return the logits for the last n tokens - """ - hidden_states = self.backbone(input_ids, inference_params=inference_params) - if num_last_tokens > 0: - hidden_states = hidden_states[:, -num_last_tokens:] - lm_logits = self.lm_head(hidden_states) - - CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) - return CausalLMOutput(logits=lm_logits) - - loss = None - if labels is not None: - logits = lm_logits - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"]) - print(loss) - return CausalLMOutput(logits=lm_logits, loss=loss) - - else: - CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) - return CausalLMOutput(logits=lm_logits) - - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - state_dict: Optional[dict] = None, - safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument - ): - if state_dict is None: - state_dict = self.state_dict() - torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin")) - - @classmethod - def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): - config = load_config_hf(pretrained_model_name) - model = cls(**config, device=device, dtype=dtype, **kwargs) - model.load_state_dict( - load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype) - ) - return model diff --git a/axolotl/src/axolotl/models/phi/__init__.py b/axolotl/src/axolotl/models/phi/__init__.py deleted file mode 100644 index 76d6a0e10b84b30c37be29c6b28cde5fc31f88b9..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/models/phi/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -MixFormers model architecture used for phi models -""" - -from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa -from .configuration_phi import PhiConfig # noqa -from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa -from .modeling_phi import PhiForCausalLM # noqa diff --git a/axolotl/src/axolotl/models/phi/configuration_mixformer_sequential.py b/axolotl/src/axolotl/models/phi/configuration_mixformer_sequential.py deleted file mode 100644 index ceba62093acf457ffc7644afac2aaac1b2d74559..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/models/phi/configuration_mixformer_sequential.py +++ /dev/null @@ -1,63 +0,0 @@ -# pylint: skip-file - -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import math -from typing import Any, Dict, List, Optional, Union - -from transformers import PretrainedConfig - - -class MixFormerSequentialConfig(PretrainedConfig): - """MixFormer (sequential for DeepSpeed) configuration.""" - - model_type = "mixformer-sequential" - - attribute_map = { - "max_position_embeddings": "n_positions", - "hidden_size": "n_embd", - "num_attention_heads": "n_head", - "num_hidden_layers": "n_layer", - "input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility - "blocks": "architecture", # `blocks` key is for backward compatibility - } - - def __init__( - self, - vocab_size: Optional[int] = 50304, - n_positions: Optional[int] = 2048, - n_embd: Optional[int] = 1024, - n_layer: Optional[int] = 20, - n_inner: Optional[int] = None, - n_head: Optional[int] = 16, - rotary_dim: Optional[int] = 32, - activation_function: Optional[str] = "gelu_new", - embd_layer: Optional[str] = "default", - architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None, - embd_pdrop: Optional[float] = 0.0, - resid_pdrop: Optional[float] = 0.0, - layer_norm_epsilon: Optional[float] = 1e-5, - initializer_range: Optional[float] = 0.02, - tie_word_embeddings: Optional[bool] = False, - pad_vocab_size_multiple: Optional[int] = 64, - **kwargs - ) -> None: - self.vocab_size = int( - math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - ) - self.n_positions = n_positions - self.n_embd = n_embd - self.n_layer = n_layer - self.n_inner = n_inner - self.n_head = n_head - self.rotary_dim = min(rotary_dim, n_embd // n_head) - self.activation_function = activation_function - self.embd_layer = embd_layer - self.architecture = architecture - self.embd_pdrop = embd_pdrop - self.resid_pdrop = resid_pdrop - self.layer_norm_epsilon = layer_norm_epsilon - self.initializer_range = initializer_range - - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/axolotl/src/axolotl/models/phi/configuration_phi.py b/axolotl/src/axolotl/models/phi/configuration_phi.py deleted file mode 100644 index e941bf7980467aa381edd02ad8db90023c7746ff..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/models/phi/configuration_phi.py +++ /dev/null @@ -1,65 +0,0 @@ -# pylint: skip-file -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import math -from typing import Optional - -from transformers import PretrainedConfig - - -class PhiConfig(PretrainedConfig): - """Phi configuration.""" - - model_type = "phi" - attribute_map = { - "max_position_embeddings": "n_positions", - "hidden_size": "n_embd", - "num_attention_heads": "n_head", - "num_hidden_layers": "n_layer", - } - - def __init__( - self, - vocab_size: int = 50304, - n_positions: int = 2048, - n_embd: int = 1024, - n_layer: int = 20, - n_inner: Optional[int] = None, - n_head: int = 16, - n_head_kv: Optional[int] = None, - rotary_dim: Optional[int] = 32, - activation_function: Optional[str] = "gelu_new", - flash_attn: bool = False, - flash_rotary: bool = False, - fused_dense: bool = False, - attn_pdrop: float = 0.0, - embd_pdrop: float = 0.0, - resid_pdrop: float = 0.0, - layer_norm_epsilon: float = 1e-5, - initializer_range: float = 0.02, - tie_word_embeddings: bool = False, - pad_vocab_size_multiple: int = 64, - **kwargs - ) -> None: - self.vocab_size = int( - math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - ) - self.n_positions = n_positions - self.n_embd = n_embd - self.n_layer = n_layer - self.n_inner = n_inner - self.n_head = n_head - self.n_head_kv = n_head_kv - self.rotary_dim = min(rotary_dim, n_embd // n_head) - self.activation_function = activation_function - self.flash_attn = flash_attn - self.flash_rotary = flash_rotary - self.fused_dense = fused_dense - self.attn_pdrop = attn_pdrop - self.embd_pdrop = embd_pdrop - self.resid_pdrop = resid_pdrop - self.layer_norm_epsilon = layer_norm_epsilon - self.initializer_range = initializer_range - - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/axolotl/src/axolotl/models/phi/modeling_mixformer_sequential.py b/axolotl/src/axolotl/models/phi/modeling_mixformer_sequential.py deleted file mode 100644 index fd2ec054c5e2e7e8efced983a355937abdb37d5e..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/models/phi/modeling_mixformer_sequential.py +++ /dev/null @@ -1,930 +0,0 @@ -# pylint: skip-file - -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -# BSD 3-Clause License -# -# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import annotations - -import copy -import inspect -from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Tuple - -import torch -import torch.nn as nn -from einops import rearrange -from flash_attn.flash_attn_interface import ( - flash_attn_kvpacked_func, - flash_attn_qkvpacked_func, - flash_attn_varlen_qkvpacked_func, -) -from transformers import PretrainedConfig, PreTrainedModel -from transformers.activations import ACT2FN -from transformers.modeling_outputs import CausalLMOutputWithPast - -from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids -from .configuration_mixformer_sequential import MixFormerSequentialConfig - - -@dataclass -class InferenceParams: - """Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference. - Adapted from https://github.com/Dao-AILab/flash-attention.""" - - max_sequence_len: int - max_batch_size: int - sequence_len_offset: int = 0 - batch_size_offset: int = 0 - key_value_memory_dict: dict = field(default_factory=dict) - fused_ft_kernel: bool = False - lengths_per_sample: Optional[torch.Tensor] = None - - -class Embedding(nn.Module): - """Token embedding with dropout.""" - - def __init__(self, config: PretrainedConfig) -> None: - super().__init__() - - self.wte = nn.Embedding(config.vocab_size, config.n_embd) - self.drop = nn.Dropout(config.embd_pdrop) - - def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - hidden_states = self.wte(input_ids) - hidden_states = self.drop(hidden_states) - - return hidden_states - - -class RotaryEmbedding(nn.Module): - """PyTorch implementation of `flash-attn` RotaryEmbedding layer. - Adapted from https://github.com/Dao-AILab/flash-attention.""" - - def __init__( - self, - dim: int, - base: Optional[int] = 10000, - scale_base: Optional[float] = None, - device: Optional[str] = None, - **kwargs, - ) -> None: - super().__init__() - - if scale_base is not None: - raise NotImplementedError - - # Generate and save the inverse frequency buffer (non-trainable) - self.dim = dim - self.base = base - self.scale_base = scale_base - self.device = device - - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) - ) - self.register_buffer("inv_freq", inv_freq) - - scale = ( - (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) - / (1.4 * dim) - if scale_base is not None - else None - ) - self.register_buffer("scale", scale) - - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None - - def _update_cos_sin_cache( - self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0 - ) -> None: - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - seqlen = x.shape[1] + seqlen_offset - - # Re-generate the inverse frequency buffer if it's not fp32 - # (for instance if model.half() was called) - if self.inv_freq.dtype != "torch.float32": - self.inv_freq = 1.0 / ( - self.base - ** ( - torch.arange( - 0, self.dim, 2, device=self.device, dtype=torch.float32 - ) - / self.dim - ) - ) - - if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != x.device - or self._cos_cached.dtype != x.dtype - ): - self._seq_len_cached = seqlen - t = torch.arange(seqlen, device=x.device, dtype=torch.float32) - - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer( - t, self.inv_freq.to(device=t.device, dtype=torch.float32) - ) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(x.dtype) - self._sin_cached = torch.sin(freqs).to(x.dtype) - else: - power = ( - torch.arange( - seqlen, dtype=self.scale.dtype, device=self.scale.device - ) - - seqlen // 2 - ) / self.scale_base - scale = self.scale.to(device=power.device) ** rearrange( - power, "s -> s 1" - ) - - # We want the multiplication by scale to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) - - def apply_rotary_emb_qkv( - self, - qkv: torch.FloatTensor, - sin: torch.FloatTensor, - cos: torch.FloatTensor, - sin_k: Optional[torch.FloatTensor] = None, - cos_k: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - _, seqlen, three, _, headdim = qkv.shape - assert three == 3 - - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert ( - sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - ) - - q_rot = qkv[:, :, 0, :, :rotary_dim] - q_pass = qkv[:, :, 0, :, rotary_dim:] - - k_rot = qkv[:, :, 1, :, :rotary_dim] - k_pass = qkv[:, :, 1, :, rotary_dim:] - - # Splits the queries and keys in half - q1, q2 = q_rot.chunk(2, dim=-1) - k1, k2 = k_rot.chunk(2, dim=-1) - c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange( - sin[:seqlen], "s d -> s 1 d" - ) - - # Casts to fp32 are necessary to prevent fp16 overflow issues - q1, q2, k1, k2, c, s = [ - t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s] - ] - - # Computes the new keys and queries, recasting to original dtype - q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype) - - k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype) - - return torch.cat( - [ - torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2), - torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2), - qkv[:, :, 2:3, :, :], - ], - axis=2, - ) - - def forward( - self, qkv: torch.Tensor, seqlen_offset: int = 0 - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Perform the forward pass. - - Args: - qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim). - seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch. - - Returns: - New `qkv` and the cached sinusoids. - - """ - - self._update_cos_sin_cache(qkv, seqlen_offset) - - return self.apply_rotary_emb_qkv( - qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:] - ) - - -def _update_kv_cache(kv, inference_params, layer_idx): - """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) - Adapted from https://github.com/Dao-AILab/flash-attention.""" - # Pre-allocate memory for key-values for inference. - num_heads, head_dim = kv.shape[-2:] - if layer_idx not in inference_params.key_value_memory_dict: - kv_cache = torch.empty( - inference_params.max_batch_size, - inference_params.max_sequence_len, - 2, - num_heads, - head_dim, - dtype=kv.dtype, - device=kv.device, - ) - inference_params.key_value_memory_dict[layer_idx] = kv_cache - else: - kv_cache = inference_params.key_value_memory_dict[layer_idx] - - # Adjust key and value for inference - batch_start = inference_params.batch_size_offset - batch_end = batch_start + kv.shape[0] - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + kv.shape[1] - assert batch_end <= ( - kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] # noqa - ) - assert sequence_end <= ( - kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] # noqa - ) - - assert kv_cache is not None - kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - kv = kv_cache[batch_start:batch_end, :sequence_end, ...] - return kv - - -class MLP(nn.Module): - """Multi-Layer Perceptron. - - Reference: - Attention Is All You Need. - https://arxiv.org/pdf/1706.03762.pdf. - - """ - - def __init__( - self, - config: PretrainedConfig, - n_inner: Optional[int] = None, - act_fn: Optional[str] = None, - ) -> None: - super().__init__() - - act_fn = config.activation_function if act_fn is None else act_fn - assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}." - - n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner - n_inner = n_inner if n_inner is not None else 4 * config.n_embd - - self.fc1 = nn.Linear(config.n_embd, n_inner) - self.fc2 = nn.Linear(n_inner, config.n_embd) - self.act = ACT2FN[act_fn] - - def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - old_keys = [ - prefix + "fc_in.weight", - prefix + "fc_out.weight", - prefix + "fc_in.bias", - prefix + "fc_out.bias", - ] - new_keys = [ - prefix + "fc1.weight", - prefix + "fc2.weight", - prefix + "fc1.bias", - prefix + "fc2.bias", - ] - - if all(k in state_dict for k in old_keys) and not all( - k in state_dict for k in new_keys - ): - # Older version of `MLP` saved with different key names. - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - - return super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.fc2(hidden_states) - - return hidden_states - - -class FusedMLP(nn.Module): - """Fused Multi-Layer Perceptron from `flash-attn`. - - Reference: - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py. - - """ - - def __init__( - self, - config: PretrainedConfig, - n_inner: Optional[int] = None, - act_fn: Optional[str] = None, - raise_on_missing: bool = False, - ) -> None: - super().__init__() - - act_fn = config.activation_function if act_fn is None else act_fn - assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}." - - n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner - n_inner = n_inner if n_inner is not None else 4 * config.n_embd - - gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] # noqa - activation = "gelu_approx" if act_fn in gelu_activations else "relu" # noqa - - self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - return self.mlp(hidden_states) - - -class SelfAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Adapted from https://github.com/Dao-AILab/flash-attention. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - def forward( - self, qkv, causal=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None - ): - """Implements the multihead softmax attention. - Arguments - --------- - qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) - causal: if passed, will override self.causal - key_padding_mask: boolean mask to apply to the attention weights. True means to keep, - False means to mask out. (B, S) - """ - causal = self.causal if causal is None else causal - if cu_seqlens is not None: - return flash_attn_varlen_qkvpacked_func( - qkv.squeeze(0), - cu_seqlens, - max_seqlen, - dropout_p=self.drop.p, - softmax_scale=self.softmax_scale, - causal=causal, - ) - else: - return flash_attn_qkvpacked_func( - qkv, - dropout_p=self.drop.p, - softmax_scale=self.softmax_scale, - causal=causal, - ) - - -class CrossAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Adapted from https://github.com/Dao-AILab/flash-attention. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - def forward(self, q, kv, causal=None, key_padding_mask=None): - """Implements the multihead softmax attention. - Arguments - --------- - q: The tensor containing the query. (B, Sq, H, D) - kv: The tensor containing the key and value. (B, Sk, 2, H, D) - causal: if passed, will override self.causal - key_padding_mask: boolean mask to apply to the attention weights. True means to keep, - False means to mask out. (B, Sk) - """ - causal = self.causal if causal is None else causal - return flash_attn_kvpacked_func( - q, - kv, - dropout_p=self.drop.p, - softmax_scale=self.softmax_scale, - causal=causal, - ) - - -def find_mha_dims( - config: PretrainedConfig, - n_head: Optional[int] = None, - head_dim: Optional[int] = None, -) -> Tuple[int, int]: - """Validate and return the number of heads and head dimension for multi-head attention. - - Args: - config: Model configuration. - n_head: Number of heads. - head_dim: Head dimension. - - Returns: - Number of heads and head dimension. - - """ - - assert all( - hasattr(config, attr) for attr in ["n_embd", "n_head"] - ), "`config` must have `n_embd` and `n_head` attributes." - - if head_dim is None: - assert ( - config.n_embd % config.n_head == 0 - ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})." - - if n_head is None and head_dim is None: - head_dim = config.n_embd // config.n_head - n_head = config.n_head - elif n_head is None or head_dim is None: - raise ValueError("`n_head` and `head_dim` must be both specified or `None`.") - - return n_head, head_dim - - -class MHA(nn.Module): - """Multi-head attention layer. - Adapted from https://github.com/Dao-AILab/flash-attention.""" - - def __init__( - self, - config: PretrainedConfig, - rotary_dim: Optional[int] = None, - n_head: Optional[int] = None, - head_dim: Optional[int] = None, - bias: Optional[bool] = True, - dropout: Optional[float] = 0.0, - softmax_scale: Optional[float] = None, - causal: Optional[bool] = True, - layer_idx: Optional[int] = None, - rotary_emb_scale_base: Optional[float] = None, - return_residual: Optional[bool] = False, - checkpointing: Optional[bool] = False, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - fused_dense: Optional[bool] = True, - flash_attn: Optional[bool] = True, - cutlass_attn: Optional[bool] = False, - flash_rotary: Optional[bool] = True, - raise_on_missing: Optional[bool] = False, - ) -> None: - super().__init__() - - factory_kwargs = {"device": device, "dtype": dtype} - n_head, head_dim = find_mha_dims(config, n_head, head_dim) - - self.hidden_size = config.n_embd - self.n_head = n_head - self.head_dim = head_dim - self.op_size = n_head * head_dim - - self.causal = causal - self.layer_idx = layer_idx - self.rotary_emb_dim = ( - rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0) - ) - self.fused_dense = fused_dense - self.flash_attn = flash_attn - self.cutlass_attn = cutlass_attn - self.flash_rotary = flash_rotary - self.return_residual = return_residual - self.checkpointing = checkpointing - - if self.rotary_emb_dim > 0: - rotary_kwargs = {"device": device} - if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0: - rotary_kwargs["scale_base"] = rotary_emb_scale_base - - self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs) - else: - pass - - self.Wqkv = nn.Linear( - self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs - ) - self.out_proj = nn.Linear( - self.op_size, self.hidden_size, bias=bias, **factory_kwargs - ) - - self.inner_attn = SelfAttention( - causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout - ) - self.inner_cross_attn = CrossAttention( - causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout - ) - - def _update_kv_cache( - self, kv: torch.FloatTensor, inference_params: InferenceParams - ) -> None: - """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) - Adapted from https://github.com/Dao-AILab/flash-attention.""" - - assert ( - self.layer_idx is not None - ), "Generation requires layer_idx in the constructor" - - return _update_kv_cache(kv, inference_params, self.layer_idx) - - def forward( - self, - x: torch.FloatTensor, - x_kv: Optional[torch.FloatTensor] = None, - key_padding_mask: Optional[torch.BoolTensor] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - max_seqlen: Optional[int] = None, - mixer_subset: Optional[torch.LongTensor] = None, - past_cache: Optional[InferenceParams] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - """Perform the forward pass. - - Args: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if - cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total - is the is the sum of the sequence lengths in the batch. - x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. - key_padding_mask: boolean mask, True means to keep, False means to mask out. - (batch, seqlen). Only applicable when not using FlashAttention. - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into x. Only applicable when using - FlashAttention. - max_seqlen: int. Maximum sequence length in the batch. - mixer_subset: for cross-attention only. If not None, will take a subset of x - before applying the query projection. Useful for e.g., ViT where we only care - about the CLS token in the last layer. - past_cache: For generation only. - - Returns: - (batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None, - else (total, hidden_dim) where total is the is the sum of the sequence lengths - in the batch. - - """ - - if cu_seqlens is not None: - assert max_seqlen is not None - assert key_padding_mask is None - assert self.flash_attn - # assert self.rotary_emb_dim == 0 - - if key_padding_mask is not None: - assert cu_seqlens is None - assert max_seqlen is None - assert not self.flash_attn - - if past_cache is not None: - assert key_padding_mask is None - assert cu_seqlens is None and max_seqlen is None - - attn_kwargs = {"key_padding_mask": key_padding_mask} - - assert x_kv is None and mixer_subset is None - - qkv = self.Wqkv(x) - qkv = rearrange( - qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim - ) - - if past_cache is None: - if self.rotary_emb_dim > 0: - qkv = self.rotary_emb(qkv) - context = self.inner_attn( - qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, **attn_kwargs - ) - - else: - if self.rotary_emb_dim > 0: - qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset) - q = qkv[:, :, 0] - kv = self._update_kv_cache(qkv[:, :, 1:], past_cache) - # If we're processing the prompt, causal=None (use self.causal). - # If we're decoding, then causal=False. - causal = None if past_cache.sequence_len_offset == 0 else False - context = self.inner_cross_attn(q, kv, causal=causal) - - out = rearrange(context, "... h d -> ... (h d)") - out = self.out_proj(out) - - return out if not self.return_residual else (out, x) - - -class ParallelBlock(nn.Module): - """Parallel block. - - This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen). - - """ - - def __init__( - self, - config: PretrainedConfig, - mixer: Optional[Dict[str, Any]] = None, - mlp: Optional[Dict[str, Any]] = None, - block_idx: Optional[int] = None, - ) -> None: - super().__init__() - - self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.resid_dropout = nn.Dropout(config.resid_pdrop) - self.block_idx = block_idx - - self.mixer = MHA(config, layer_idx=block_idx) - self.mlp = MLP(config) - - def forward( - self, - hidden_states: torch.FloatTensor, - past_cache: Optional[torch.FloatTensor] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - max_seqlen: Optional[int] = None, - ) -> torch.FloatTensor: - residual = hidden_states - hidden_states = self.ln(hidden_states) - - attn_outputs = self.mixer( - hidden_states, - past_cache=past_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - if isinstance(attn_outputs, tuple): - attn_outputs = attn_outputs[0] - - attn_outputs = self.resid_dropout(attn_outputs) - feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) - - hidden_states = attn_outputs + feed_forward_hidden_states + residual - - return hidden_states - - -class CausalLMHead(nn.Module): - """Causal Language Modeling head. - - Reference: - Improving Language Understanding by Generative Pre-Training. - https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. - - """ - - def __init__(self, config: PretrainedConfig) -> None: - super().__init__() - - self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.linear = nn.Linear(config.n_embd, config.vocab_size) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - hidden_states = self.ln(hidden_states) - logits = self.linear(hidden_states).to(torch.float32) - - return logits - - -class CausalLMLoss(nn.Module): - """Causal Language Modeling loss. - - Reference: - Improving Language Understanding by Generative Pre-Training. - https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. - - """ - - def __init__(self, shift_labels: Optional[bool] = True) -> None: - super().__init__() - - self.shift_labels = shift_labels - self.loss_fct = nn.CrossEntropyLoss() - - def forward( - self, logits: torch.FloatTensor, labels: torch.LongTensor - ) -> torch.FloatTensor: - if self.shift_labels: - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - - loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) - - return loss - - -class MixFormerSequentialPreTrainedModel(PreTrainedModel): - """MixFormer (sequential for DeepSpeed) pre-trained model.""" - - config_class = MixFormerSequentialConfig - base_model_prefix = "transformer" - supports_gradient_checkpointing = True - - def __init__(self, *inputs, **kwargs) -> None: - super().__init__(*inputs, **kwargs) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, **kwargs - ) -> Dict[str, Any]: - if "use_cache" in kwargs and not kwargs["use_cache"]: - return {"input_ids": input_ids} - - if past_key_values is None or not ( - isinstance(past_key_values, InferenceParams) - ): - past_key_values = InferenceParams( - max_batch_size=input_ids.shape[0], - max_sequence_len=self.config.n_positions, - sequence_len_offset=0, - batch_size_offset=0, - fused_ft_kernel=False, - key_value_memory_dict={}, - ) - else: - # assume past_key_values has cached all but last token in input_ids - past_key_values.sequence_len_offset = len(input_ids[0]) - 1 - input_ids = input_ids[:, -1].unsqueeze(-1) - - return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs} - - -class PackedSequential(nn.Sequential): - def forward( - self, - input, - cu_seqlens: Optional[torch.LongTensor] = None, - max_seqlen: Optional[int] = None, - ): - for module in self: - sig = inspect.signature(module.forward) - if "cu_seqlens" in sig.parameters: - input = module(input, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - else: - input = module(input) - return input - - -class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel): - """MixFormer (sequential for DeepSpeed) for Causal Language Modeling.""" - - _keys_to_ignore_on_load_missing = [""] - _keys_to_ignore_on_load_unexpected = [ - r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)" - ] - _no_split_modules = ["ParallelBlock"] - - def __init__(self, config: MixFormerSequentialConfig) -> None: - super().__init__(config) - - modules = [Embedding(config)] - block_config = config.architecture - - if not isinstance(block_config, list): - block_config = [block_config for _ in range(config.n_layer)] - - if config.n_layer != len(block_config): - config.n_layer = len(block_config) - - for block_idx, block in enumerate(block_config): - # `block_cls` with `legacy` value is for backward compatibility - # `path` key is for backward compatibility - block = copy.deepcopy(block) or {"block_cls": "parallel"} - block.pop("path", None) or block.pop("block_cls", None) - - block["block_idx"] = block_idx - modules.append(ParallelBlock(config, **block)) - - modules.append(CausalLMHead(config)) - - self.layers = PackedSequential(*modules) - self.loss = CausalLMLoss() - - self.post_init() - - def get_input_embeddings(self) -> nn.Embedding: - return self.layers[0].wte - - def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: - self.layers[0].wte = new_embeddings - - def get_output_embeddings(self) -> nn.Linear: - return self.layers[-1].linear - - def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: - self.layers[-1].linear = new_embeddings - - def forward( - self, - input_ids: torch.LongTensor, - labels: Optional[torch.LongTensor] = None, - past_key_values: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - **kwargs, - ) -> CausalLMOutputWithPast: - cu_seqlens: Optional[torch.LongTensor] = None - max_seqlen: Optional[int] = None - if position_ids is not None: - batch_size, seq_length = input_ids.shape - position_ids = position_ids.view(-1, seq_length).long() - cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze() - - if not past_key_values: - lm_logits = self.layers( - input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - else: - hidden_layer = self.layers[0](input_ids) - for module in self.layers[1:-1]: - hidden_layer = module( - hidden_layer, - past_cache=past_key_values, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - lm_logits = self.layers[-1](hidden_layer) - - loss = None - if labels is not None: - loss = self.loss(lm_logits, labels) - - return CausalLMOutputWithPast( - loss=loss, logits=lm_logits, past_key_values=past_key_values - ) diff --git a/axolotl/src/axolotl/models/phi/modeling_phi.py b/axolotl/src/axolotl/models/phi/modeling_phi.py deleted file mode 100644 index 5b5c3ef6dc90855bdc62f3a9ce5753d2cb259df7..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/models/phi/modeling_phi.py +++ /dev/null @@ -1,1063 +0,0 @@ -# pylint: skip-file -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. -# -# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. -# Licensed under the BSD 3-Clause License. - -from __future__ import annotations - -import math -from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Tuple, Union - -import torch -import torch.nn as nn -from einops import rearrange, repeat -from transformers import PretrainedConfig, PreTrainedModel -from transformers.activations import ACT2FN -from transformers.modeling_outputs import CausalLMOutputWithPast - -from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids -from .configuration_phi import PhiConfig - -try: - from flash_attn.bert_padding import pad_input, unpad_input - from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding - from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention - from flash_attn.ops.fused_dense import FusedDense -except: # noqa: E722 - pad_input, unpad_input = None, None - FlashRotaryEmbedding = None - FlashSelfAttention, FlashCrossAttention = None, None - FusedDense = None - - -@dataclass -class InferenceParams: - """Inference parameters passed to model to efficiently calculate - and store context during inference. - - Reference: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py. - - Args: - max_seqlen: Maximum sequence length. - max_batch_size: Maximum batch size. - seqlen_offset: Sequence length offset. - batch_size_offset: Batch size offset. - key_value_memory_dict: Key value memory dictionary. - lengths_per_sample: Lengths per sample. - - """ - - max_seqlen: int = field(metadata={"help": "Maximum sequence length."}) - - max_batch_size: int = field(metadata={"help": "Maximum batch size."}) - - seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."}) - - batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."}) - - key_value_memory_dict: Dict[str, Any] = field( - default_factory=dict, metadata={"help": "Key value memory dictionary."} - ) - - lengths_per_sample: torch.Tensor = field( - default=None, metadata={"help": "Lengths per sample."} - ) - - -class Embedding(nn.Module): - """Token embedding with dropout.""" - - def __init__(self, config: PretrainedConfig) -> None: - super().__init__() - - self.wte = nn.Embedding(config.vocab_size, config.n_embd) - self.drop = nn.Dropout(config.embd_pdrop) - - def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - hidden_states = self.wte(input_ids) - hidden_states = self.drop(hidden_states) - - return hidden_states - - -def _apply_rotary_emb( - x: torch.FloatTensor, - cos: torch.FloatTensor, - sin: torch.FloatTensor, -) -> torch.FloatTensor: - _, seqlen, _, _ = x.shape - _, rotary_dim = cos.shape - rotary_dim *= 2 - - x_rot = x[:, :, :, :rotary_dim] - x_pass = x[:, :, :, rotary_dim:] - - x1, x2 = x_rot.chunk(2, dim=-1) - c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange( - sin[:seqlen], "s d -> s 1 d" - ) - x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] - - x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype) - - return torch.cat([x_rot, x_pass], axis=-1) - - -def _apply_rotary_emb_kv( - kv: torch.FloatTensor, - cos: torch.FloatTensor, - sin: torch.FloatTensor, - cos_k: Optional[torch.FloatTensor] = None, - sin_k: Optional[torch.FloatTensor] = None, -) -> torch.FloatTensor: - _, seqlen, _, _, _ = kv.shape - _, rotary_dim = cos.shape - rotary_dim *= 2 - - k_rot = kv[:, :, 0, :, :rotary_dim] - k_pass = kv[:, :, 0, :, rotary_dim:] - - k1, k2 = k_rot.chunk(2, dim=-1) - c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange( - sin[:seqlen], "s d -> s 1 d" - ) - k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]] - - k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype) - - return torch.cat( - [ - torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2), - kv[:, :, 1:2, :, :], - ], - axis=2, - ) - - -def _apply_rotary_emb_qkv( - qkv: torch.FloatTensor, - cos: torch.FloatTensor, - sin: torch.FloatTensor, - cos_k: Optional[torch.FloatTensor] = None, - sin_k: Optional[torch.FloatTensor] = None, -) -> torch.FloatTensor: - _, seqlen, _, _, _ = qkv.shape - _, rotary_dim = cos.shape - rotary_dim *= 2 - - q_rot = qkv[:, :, 0, :, :rotary_dim] - q_pass = qkv[:, :, 0, :, rotary_dim:] - - k_rot = qkv[:, :, 1, :, :rotary_dim] - k_pass = qkv[:, :, 1, :, rotary_dim:] - - q1, q2 = q_rot.chunk(2, dim=-1) - k1, k2 = k_rot.chunk(2, dim=-1) - c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange( - sin[:seqlen], "s d -> s 1 d" - ) - q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]] - - q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype) - k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype) - - return torch.cat( - [ - torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2), - torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2), - qkv[:, :, 2:3, :, :], - ], - axis=2, - ) - - -class RotaryEmbedding(nn.Module): - """Rotary positional embedding (RoPE). - - Reference: - RoFormer: Enhanced Transformer with Rotary Position Embedding. - https://arxiv.org/pdf/2104.09864.pdf. - - """ - - def __init__( - self, - dim: int, - base: int = 10000, - scale_base: Optional[float] = None, - pos_idx_in_fp32: bool = True, - max_position_embeddings: int = 2048, - device: Optional[str] = None, - **kwargs, - ) -> None: - super().__init__() - - if scale_base is not None: - raise NotImplementedError - - self.dim = dim - self.base = float(base) - self.scale_base = scale_base - self.pos_idx_in_fp32 = pos_idx_in_fp32 - self.max_position_embeddings = max_position_embeddings - self.device = device - - # Generate and save the inverse frequency buffer (non-trainable) - inv_freq = self._compute_inv_freq(device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Generate and save the scale buffer (non-trainable) - scale = ( - (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) - / (1.4 * dim) - if scale_base is not None - else None - ) - self.register_buffer("scale", scale, persistent=False) - - # Initialize cached attributes since ONNX can't rely on dynamic initialization - self._update_cos_sin_cache( - max_position_embeddings, device=device, dtype=torch.float32 - ) - - def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor: - return 1.0 / ( - self.base - ** ( - torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) - / self.dim - ) - ) - - def _update_cos_sin_cache( - self, - seqlen: int, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ) -> None: - self._seq_len_cached = seqlen - - # fp32 is preferred since the output of `torch.arange` can be quite large - # and bf16 would lose a lot of precision - if self.pos_idx_in_fp32: - t = torch.arange(seqlen, device=device, dtype=torch.float32) - if self.inv_freq.dtype != torch.float32: - inv_freq = self._compute_inv_freq(device=device) - else: - inv_freq = self.inv_freq - else: - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - inv_freq = self.inv_freq - - # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP - freqs = torch.outer(t, inv_freq) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - else: - power = ( - torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - - seqlen // 2 - ) / self.scale_base - scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") - - # Force the scale multiplication to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) - - def forward( - self, - qkv: torch.Tensor, - kv: Optional[torch.Tensor] = None, - seqlen_offset: int = 0, - **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - seq_start = seqlen_offset - seq_end = seq_start + qkv.shape[1] - - if ( - self._cos_cached.device != qkv.device - or self._cos_cached.dtype != qkv.dtype - or (self.training and self._cos_cached.is_inference()) - ): - self._update_cos_sin_cache( - self.max_position_embeddings, device=qkv.device, dtype=qkv.dtype - ) - - if kv is None: - return _apply_rotary_emb_qkv( - qkv, - self._cos_cached[seq_start:seq_end], - self._sin_cached[seq_start:seq_end], - ) - else: - q = _apply_rotary_emb( - qkv, - self._cos_cached[seq_start:seq_end], - self._sin_cached[seq_start:seq_end], - ) - kv = _apply_rotary_emb_kv( - kv, - self._cos_cached[seq_start:seq_end], - self._sin_cached[seq_start:seq_end], - ) - - return q, kv - - -class MLP(nn.Module): - """Multi-Layer Perceptron. - - Reference: - Attention Is All You Need. - https://arxiv.org/pdf/1706.03762.pdf. - - """ - - def __init__( - self, - config: PretrainedConfig, - n_inner: Optional[int] = None, - act_fn: Optional[str] = None, - ) -> None: - super().__init__() - - act_fn = config.activation_function if act_fn is None else act_fn - - n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner - n_inner = n_inner if n_inner is not None else 4 * config.n_embd - - self.fc1 = nn.Linear(config.n_embd, n_inner) - self.fc2 = nn.Linear(n_inner, config.n_embd) - self.act = ACT2FN[act_fn] - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.fc2(hidden_states) - - return hidden_states - - -class SelfAttention(nn.Module): - """Self-attention layer (compatible with PyTorch). - - Reference: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. - - """ - - def __init__( - self, - causal: bool = True, - softmax_scale: Optional[float] = None, - attention_dropout: float = 0.0, - ) -> None: - super().__init__() - - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - @torch.autocast("cpu", enabled=False) - @torch.autocast("cuda", enabled=False) - def forward( - self, - qkv: torch.FloatTensor, - causal: bool = None, - key_padding_mask: Optional[torch.BoolTensor] = None, - **kwargs, - ) -> torch.FloatTensor: - batch_size, seqlen = qkv.shape[0], qkv.shape[1] - q, k, v = qkv.unbind(dim=2) - - q = q.to(torch.float32) - k = k.to(torch.float32) - - causal = self.causal if causal is None else causal - softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - - # Autocast is manually disabled to avoid `torch.einsum` performing the operation - # using float16, which might lead to overflow - scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) - - if key_padding_mask is not None: - padding_mask = torch.full( - (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device - ) - padding_mask.masked_fill_(key_padding_mask, 0.0) - - scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") - - if causal: - causal_mask = torch.triu( - torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1 - ) - scores = scores + causal_mask.to(dtype=scores.dtype) - - attention = torch.softmax(scores, dim=-1).to(v.dtype) - attention = self.drop(attention) - - output = torch.einsum("bhts,bshd->bthd", attention, v) - - return output - - -class CrossAttention(nn.Module): - """Cross-attention layer (compatible with PyTorch). - - Reference: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. - - """ - - def __init__( - self, - causal: bool = True, - softmax_scale: Optional[float] = None, - attention_dropout: float = 0.0, - ) -> None: - super().__init__() - - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - @torch.autocast("cpu", enabled=False) - @torch.autocast("cuda", enabled=False) - def forward( - self, - q: torch.FloatTensor, - kv: torch.FloatTensor, - causal: bool = None, - key_padding_mask: Optional[torch.BoolTensor] = None, - **kwargs, - ) -> torch.FloatTensor: - batch_size, seqlen_q = q.shape[0], q.shape[1] - seqlen_k = kv.shape[1] - - if kv.shape[3] != q.shape[2]: - kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) - k, v = kv.unbind(dim=2) - - q = q.to(torch.float32) - k = k.to(torch.float32) - - causal = self.causal if causal is None else causal - softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - - # Autocast is manually disabled to avoid `torch.einsum` performing the operation - # using float16, which might lead to overflow - scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) - - if key_padding_mask is not None: - padding_mask = torch.full( - (batch_size, seqlen_k), - -10000.0, - dtype=scores.dtype, - device=scores.device, - ) - padding_mask.masked_fill_(key_padding_mask, 0.0) - - scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") - - if causal: - rows = rearrange( - torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1" - ) - cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long) - causal_mask = cols > rows + seqlen_k - seqlen_q - - scores = scores.masked_fill(causal_mask, -10000.0) - - attention = torch.softmax(scores, dim=-1).to(v.dtype) - attention = self.drop(attention) - - output = torch.einsum("bhts,bshd->bthd", attention, v) - - return output - - -def _find_mha_dims( - config: PretrainedConfig, - n_head: Optional[int] = None, - n_head_kv: Optional[int] = None, - head_dim: Optional[int] = None, -) -> Tuple[int, int]: - if n_head is None and head_dim is None: - head_dim = config.n_embd // config.n_head - n_head = config.n_head - elif n_head is None or head_dim is None: - raise ValueError("`n_head` and `head_dim` must be both specified or `None`.") - - if n_head_kv is None: - n_head_kv = getattr(config, "n_head_kv", None) or n_head - - return n_head, n_head_kv, head_dim - - -def _update_kv_cache( - kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int -) -> torch.FloatTensor: - num_heads, head_dim = kv.shape[-2:] - - if layer_idx not in inference_params.key_value_memory_dict: - kv_cache = torch.empty( - inference_params.max_batch_size, - inference_params.max_seqlen, - 2, - num_heads, - head_dim, - dtype=kv.dtype, - device=kv.device, - ) - inference_params.key_value_memory_dict[layer_idx] = kv_cache - else: - kv_cache = inference_params.key_value_memory_dict[layer_idx] - - batch_start = inference_params.batch_size_offset - batch_end = batch_start + kv.shape[0] - - sequence_start = inference_params.seqlen_offset - sequence_end = sequence_start + kv.shape[1] - - kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - kv = kv_cache[batch_start:batch_end, :sequence_end, ...] - - return kv - - -class MHA(nn.Module): - """Multi-head attention layer.""" - - def __init__( - self, - config: PretrainedConfig, - dtype: Optional[torch.dtype] = None, - device: Optional[str] = None, - rotary_dim: Optional[int] = None, - rotary_base: float = 10000.0, - rotary_scale_base: Optional[float] = None, - n_head: Optional[int] = None, - n_head_kv: Optional[int] = None, - head_dim: Optional[int] = None, - bias: bool = True, - causal: bool = True, - softmax_scale: Optional[float] = None, - layer_idx: Optional[int] = None, - return_residual: bool = False, - checkpointing: bool = False, - ) -> None: - super().__init__() - - # Rotary embedding - self.rotary_dim = ( - rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0) - ) - if self.rotary_dim > 0: - rotary_cls = ( - FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding - ) - if rotary_cls is None: - rotary_cls = RotaryEmbedding - - rotary_kwargs = {} - if rotary_cls is RotaryEmbedding: - rotary_kwargs["max_position_embeddings"] = config.n_positions - - self.rotary_emb = rotary_cls( - self.rotary_dim, - base=rotary_base, - scale_base=rotary_scale_base, - device=device, - **rotary_kwargs, - ) - - # MLP - self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims( - config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim - ) - op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv) - hidden_size = config.n_embd - - linear_cls = FusedDense if config.fused_dense else nn.Linear - if linear_cls is None: - linear_cls = nn.Linear - - self.Wqkv = linear_cls( - hidden_size, op_size, bias=bias, device=device, dtype=dtype - ) - self.out_proj = linear_cls( - hidden_size, hidden_size, bias=bias, device=device, dtype=dtype - ) - - # Attention - attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention - if attn_cls is None: - attn_cls = SelfAttention - - cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention - if cross_attn_cls is None: - cross_attn_cls = CrossAttention - - self.inner_attn = attn_cls( - causal=causal, - softmax_scale=softmax_scale, - attention_dropout=config.attn_pdrop, - ) - self.inner_cross_attn = cross_attn_cls( - causal=causal, - softmax_scale=softmax_scale, - attention_dropout=config.attn_pdrop, - ) - - self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention - self.layer_idx = layer_idx - self.return_residual = return_residual - self.checkpointing = checkpointing - - def _forward_self_attn( - self, - x: torch.FloatTensor, - key_padding_mask: Optional[torch.BoolTensor], - cu_seqlens: Optional[torch.LongTensor] = None, - max_seqlen: Optional[int] = None, - ) -> torch.FloatTensor: - qkv = self.Wqkv(x) - qkv = rearrange( - qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim - ) - - if self.rotary_dim > 0: - qkv = self.rotary_emb(qkv) - - if self.flash_attn: - batch_size, seqlen = qkv.shape[0], qkv.shape[1] - - if ( - key_padding_mask is not None - and cu_seqlens is None - and max_seqlen is None - ): - # If `key_padding_mask` is supplied, we need to unpad the input and retrieve - # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn` - qkv, indices, cu_seqlens, max_seqlen = unpad_input( - qkv, key_padding_mask - ) - - if self.checkpointing: - attn_output = torch.utils.checkpoint.checkpoint( - self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - else: - attn_output = self.inner_attn( - qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ).to(qkv.device) - - # If `key_padding_mask` is supplied, we need to pad the output back to the original shape - return ( - pad_input(attn_output, indices, batch_size, seqlen) - if key_padding_mask is not None - else attn_output - ) - - if self.checkpointing: - return torch.utils.checkpoint.checkpoint( - self.inner_attn, qkv, key_padding_mask=key_padding_mask - ) - - return self.inner_attn(qkv, key_padding_mask=key_padding_mask) - - def _forward_cross_attn( - self, - x: torch.FloatTensor, - past_key_values: Optional[InferenceParams], - key_padding_mask: Optional[torch.BoolTensor], - ) -> torch.FloatTensor: - batch_size = x.shape[0] - - qkv = self.Wqkv(x) - - q = qkv[..., : self.n_head * self.head_dim] - q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) - - kv = qkv[..., self.n_head * self.head_dim :] - kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) - - seqlen_offset = ( - past_key_values.seqlen_offset if past_key_values is not None else 0 - ) - causal = None if seqlen_offset == 0 else False - if self.rotary_dim > 0: - q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset) - - if past_key_values is not None: - kv = _update_kv_cache(kv, past_key_values, self.layer_idx) - - if self.flash_attn: - batch_size, seqlen_q = q.shape[0], q.shape[1] - seqlen_k = kv.shape[1] - - cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = ( - None, - None, - None, - None, - ) - if key_padding_mask is not None: - kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask) - - if seqlen_q == 1: - key_padding_mask = torch.ones(batch_size, 1, device=q.device) - elif seqlen_q != seqlen_k: - key_padding_mask = key_padding_mask[:, -seqlen_q:] - - q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( - q, key_padding_mask - ) - - if self.checkpointing: - attn_output = torch.utils.checkpoint.checkpoint( - self.inner_cross_attn, - q, - kv, - causal=causal, - cu_seqlens=cu_seqlens_q, - max_seqlen=max_seqlen_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_k=max_seqlen_k, - ) - else: - attn_output = self.inner_cross_attn( - q, - kv, - causal=causal, - cu_seqlens=cu_seqlens_q, - max_seqlen=max_seqlen_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_k=max_seqlen_k, - ) - - return ( - pad_input(attn_output, indices_q, batch_size, max_seqlen_q) - if key_padding_mask is not None - else attn_output - ) - - if self.checkpointing: - return torch.utils.checkpoint.checkpoint( - self.inner_cross_attn, - q, - kv, - key_padding_mask=key_padding_mask, - causal=causal, - ) - - return self.inner_cross_attn( - q, kv, key_padding_mask=key_padding_mask, causal=causal - ) - - def forward( - self, - x: torch.FloatTensor, - past_key_values: Optional[InferenceParams] = None, - attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - max_seqlen: Optional[int] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - # TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool()) - if attention_mask is not None: - attention_mask = attention_mask.bool() - else: - attention_mask = None - - # MHA - if self.n_head == self.n_head_kv: - if past_key_values is None: - # If `past_key_values` are not supplied, we run self-attention - attn_output = self._forward_self_attn( - x, attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - else: - # If `past_key_values` are supplied, it means that we might have cached values and - # could take advantage of cross-attention - attn_output = self._forward_cross_attn( - x, - past_key_values, - attention_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - # MQA / GQA - else: - # Regardless of `past_key_values` being supplied or not, it always use cross-attention - # because `q` and `kv` lengths might be different - attn_output = self._forward_cross_attn(x, past_key_values, attention_mask) - - output = rearrange(attn_output, "... h d -> ... (h d)") - output = self.out_proj(output) - - return output if not self.return_residual else (output, x) - - -class ParallelBlock(nn.Module): - """Parallel block. - - This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen). - - """ - - def __init__( - self, - config: PretrainedConfig, - block_idx: Optional[int] = None, - ) -> None: - super().__init__() - - self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.resid_dropout = nn.Dropout(config.resid_pdrop) - self.block_idx = block_idx - - self.mixer = MHA(config, layer_idx=block_idx) - self.mlp = MLP(config) - - def forward( - self, - hidden_states: torch.FloatTensor, - past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, - attention_mask: Optional[torch.BoolTensor] = None, - **kwargs, - ) -> torch.FloatTensor: - residual = hidden_states - hidden_states = self.ln(hidden_states) - - attn_outputs = self.mixer( - hidden_states, - past_key_values=past_key_values, - attention_mask=attention_mask, - ) - if isinstance(attn_outputs, tuple): - attn_outputs = attn_outputs[0] - - attn_outputs = self.resid_dropout(attn_outputs) - feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) - - hidden_states = attn_outputs + feed_forward_hidden_states + residual - - return hidden_states - - -class CausalLMHead(nn.Module): - """Causal Language Modeling head. - - Reference: - Improving Language Understanding by Generative Pre-Training. - https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. - - """ - - def __init__(self, config: PretrainedConfig) -> None: - super().__init__() - - self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.linear = nn.Linear(config.n_embd, config.vocab_size) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - hidden_states = self.ln(hidden_states) - logits = self.linear(hidden_states).to(torch.float32) - - return logits - - -class CausalLMLoss(nn.Module): - """Causal Language Modeling loss. - - Reference: - Improving Language Understanding by Generative Pre-Training. - https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. - - """ - - def __init__(self, shift_labels: bool = True) -> None: - super().__init__() - - self.shift_labels = shift_labels - self.loss_fct = nn.CrossEntropyLoss() - - def forward( - self, logits: torch.FloatTensor, labels: torch.LongTensor - ) -> torch.FloatTensor: - if self.shift_labels: - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - - loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) - - return loss - - -class PhiPreTrainedModel(PreTrainedModel): - """Phi pre-trained model.""" - - config_class = PhiConfig - base_model_prefix = "transformer" - supports_gradient_checkpointing = False - _no_split_modules = ["ParallelBlock"] - - def __init__(self, *inputs, **kwargs) -> None: - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module: nn.Module) -> None: - if isinstance(module, (nn.Linear,)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - if module.bias is not None: - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, - attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, - **kwargs, - ) -> Dict[str, Any]: - if past_key_values is None or not ( - isinstance(past_key_values, InferenceParams) - ): - past_key_values = InferenceParams( - max_seqlen=self.config.n_positions, - max_batch_size=input_ids.shape[0], - seqlen_offset=0, - batch_size_offset=0, - key_value_memory_dict={}, - lengths_per_sample=None, - ) - else: - # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids` - past_key_values.seqlen_offset = len(input_ids[0]) - 1 - input_ids = input_ids[:, -1].unsqueeze(-1) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "attention_mask": attention_mask, - } - - -class PhiModel(PhiPreTrainedModel): - """Phi model.""" - - _keys_to_ignore_on_load_missing = [""] - _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"] - - def __init__(self, config: PhiConfig) -> None: - super().__init__(config) - - self.embd = Embedding(config) - self.h = nn.ModuleList( - [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)] - ) - self.gradient_checkpointing = False - self.post_init() - - def get_input_embeddings(self) -> nn.Embedding: - return self.embd.wte - - def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: - self.embd.wte = new_embeddings - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, - attention_mask: Optional[torch.BoolTensor] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - max_seqlen: Optional[int] = None, - ) -> torch.FloatTensor: - hidden_states = self.embd(input_ids) - - for layer in self.h: - hidden_states = layer( - hidden_states, - past_key_values=past_key_values, - attention_mask=attention_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - return hidden_states - - -class PhiForCausalLM(PhiPreTrainedModel): - """Phi for Causal Language Modeling.""" - - _keys_to_ignore_on_load_missing = [""] - _keys_to_ignore_on_load_unexpected = [ - r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)" - ] - - def __init__(self, config: PhiConfig) -> None: - super().__init__(config) - - self.transformer = PhiModel(config) - self.lm_head = CausalLMHead(config) - self.loss = CausalLMLoss() - - self.post_init() - - def get_output_embeddings(self) -> nn.Linear: - return self.lm_head.linear - - def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: - self.lm_head.linear = new_embeddings - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, - attention_mask: Optional[torch.BoolTensor] = None, - labels: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - **kwargs, - ) -> CausalLMOutputWithPast: - cu_seqlens: Optional[torch.LongTensor] = None - max_seqlen: Optional[int] = None - if position_ids is not None: - batch_size, seq_length = input_ids.shape - position_ids = position_ids.view(-1, seq_length).long() - cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze() - - hidden_states = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - loss = self.loss(lm_logits, labels) - - return CausalLMOutputWithPast( - loss=loss, logits=lm_logits, past_key_values=past_key_values - ) diff --git a/axolotl/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py b/axolotl/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py deleted file mode 100644 index 1275906804b0f6908ef94be9b94d4a83960bf148..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Flash attention monkey patch for cerebras btlm model -""" - -import importlib -import logging -from typing import Optional, Tuple - -import torch -from accelerate import init_empty_weights -from flash_attn.flash_attn_interface import flash_attn_func -from transformers import AutoConfig, AutoModelForCausalLM - -LOG = logging.getLogger("axolotl") - - -def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"): - # this is a wonky hack to get the remotely loaded module - model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - # we need to load the model here in order for modeling_btlm to be available - with init_empty_weights(): - AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) - module_name = model_config.__class__.__module__.replace( - ".configuration_btlm", ".modeling_btlm" - ) - modeling_btlm = importlib.import_module(module_name) - modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access - flashattn_attn - ) - - -def flashattn_attn( - self, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - value: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument - head_mask: Optional[torch.Tensor] = None, - position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - softmax_scale = ( - 1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None - ) - - query = query.permute(0, 2, 1, 3) - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) - - # Perform Flash attention - attn_output = flash_attn_func( - query, - key, - value, - dropout_p=0.0, # Assuming you have this attribute - softmax_scale=softmax_scale, # Set this if you have specific scaling in mind - causal=not self.is_cross_attention, # Assuming you have this attribute - return_attn_probs=False, # Set this based on your needs - ) - - # Optional: Apply head mask if it's not None - if head_mask is not None: - attn_output *= head_mask - - attn_output = attn_output.permute(0, 2, 1, 3) - - return attn_output, None # We don't have explicit attn_weights in Flash attention diff --git a/axolotl/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/axolotl/src/axolotl/monkeypatch/fastchat_conversation_turns.py deleted file mode 100644 index 068261da3677539192849016e57b67c97e5c8b48..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ /dev/null @@ -1,203 +0,0 @@ -""" -monkeypatch to add a get_turns method -""" - -import logging -from typing import Generator, Tuple - -from fastchat.conversation import SeparatorStyle - -LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns") - - -def get_prompt(self) -> str: - ret = "" - for role, msg in self.get_turns(): - ret += role + msg - return ret - - -def get_turns( # pylint: disable=too-many-return-statements - self, -) -> Generator[Tuple[str, str], None, None]: - """Get the prompt for generation.""" - system_prompt = self.system_template.format(system_message=self.system_message) - if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: - yield "", system_prompt + self.sep - for role, message in self.messages: - if message: - yield role + ": ", message + self.sep - else: - yield role + ":", "" - return - if self.sep_style == SeparatorStyle.ADD_COLON_TWO: - seps = [self.sep, self.sep2] - yield "", system_prompt + seps[0] - for i, (role, message) in enumerate(self.messages): - if message: - yield role + ": ", message + seps[i % 2] - else: - yield role + ":", "" - return - if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: - yield "", system_prompt + self.sep - for role, message in self.messages: - if message: - yield role + ": ", message + self.sep - else: - yield role + ": ", "" # must be end with a space - return - if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: - yield "", "" if system_prompt == "" else system_prompt + self.sep - for role, message in self.messages: - if message: - yield role + "\n", message + self.sep - else: - yield role + "\n", "" - return - if self.sep_style == SeparatorStyle.NO_COLON_SINGLE: - yield "", system_prompt - for role, message in self.messages: - if message: - yield role, message + self.sep - else: - yield role, "" - return - if self.sep_style == SeparatorStyle.NO_COLON_TWO: - seps = [self.sep, self.sep2] - yield "", system_prompt - for i, (role, message) in enumerate(self.messages): - if message: - yield role, message + seps[i % 2] - else: - yield role, "" - return - if self.sep_style == SeparatorStyle.RWKV: - yield "", system_prompt - for i, (role, message) in enumerate(self.messages): - if message: - yield role + ": ", message.replace("\r\n", "\n").replace( - "\n\n", "\n" - ) + "\n\n" - else: - yield role + ":", "" - return - if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral": - if self.system_message: - if self.messages: - # For llama, the system message is incorporated into the first human instruction - first_role, first_msg = self.messages[0] - if first_role == self.roles[0]: - system_prompt += first_msg - self.messages.pop(0) - yield "", system_prompt - for i, (role, message) in enumerate(self.messages): - if message: - if (i % 2 == 0 and not self.system_message) or ( - i % 2 != 0 and self.system_message - ): - role = " " + role - yield role + " ", message - else: - yield role, "" - return - if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral": - contains_sys_msg = False - if self.system_message: - contains_sys_msg = True - if self.messages: - # There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline - first_role, first_msg = self.messages[0] - if first_role == self.roles[0]: - system_prompt = self.system_template.format( - system_message=" " + self.system_message - ) - system_prompt += first_msg - self.messages.pop(0) - yield "", system_prompt - for i, (role, message) in enumerate(self.messages): - if message and i == 0 and not contains_sys_msg: - yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a ` [INST]` at the beginning of the first instruction. - elif message: - yield role + " ", message - else: - yield role, "" - return - if self.sep_style == SeparatorStyle.CHATGLM: - # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 - # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 - round_add_n = 1 if self.name == "chatglm2" else 0 - if system_prompt: - yield "", system_prompt + self.sep - - for i, (role, message) in enumerate(self.messages): - if i % 2 == 0: - yield "", f"[Round {i//2 + round_add_n}]{self.sep}" - - if message: - yield f"{role}:", f"{message}{self.sep}" - else: - yield f"{role}:", "" - return - if self.sep_style == SeparatorStyle.CHATML: - yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n" - for role, message in self.messages: - if message: - yield role + "\n", message + self.sep + "\n" - else: - yield role + "\n", "" - return - if self.sep_style == SeparatorStyle.CHATINTERN: - # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 - seps = [self.sep, self.sep2] - yield "", system_prompt - for i, (role, message) in enumerate(self.messages): - prefix = "" if i % 2 == 0 else "" - if message: - yield prefix + role + ":", message + seps[i % 2] + "\n" - else: - yield role + ":", "" - return - if self.sep_style == SeparatorStyle.DOLLY: - seps = [self.sep, self.sep2] - yield "", system_prompt - for i, (role, message) in enumerate(self.messages): - if message: - suffix = "\n\n" if i % 2 == 1 else "" - yield role + ":\n", message + seps[i % 2] + suffix - else: - yield role + ":\n", "" - return - if self.sep_style == SeparatorStyle.PHOENIX: - yield "", system_prompt - for role, message in self.messages: - if message: - yield role + ": ", "" + message + "" - else: - yield role + ": " + "", "" - return - if self.sep_style == SeparatorStyle.ROBIN: - yield "", system_prompt + self.sep - for role, message in self.messages: - if message: - yield role + ":\n", message + self.sep - else: - yield role + ":\n", "" - return - if self.sep_style == SeparatorStyle.FALCON_CHAT: - if self.system_message: - yield "", system_prompt + self.sep - for role, message in self.messages: - if message: - yield role + ": ", message + self.sep - else: - yield role + ":", "" - else: - raise ValueError(f"Invalid style: {self.sep_style}") - - -def add_get_turns_to_conversation(): - import fastchat.conversation - - fastchat.conversation.Conversation.get_turns = get_turns - fastchat.conversation.Conversation.get_prompt = get_prompt diff --git a/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py deleted file mode 100644 index f380c3f2ae6bdc9e3238de3f276dd66227197f0f..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ /dev/null @@ -1,782 +0,0 @@ -"""Flash attention monkey patch for llama model""" - -# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py - -import logging -import warnings -from functools import partial -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import transformers -from einops import rearrange -from flash_attn.bert_padding import pad_input, unpad_input -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import LlamaAttention -from transformers.models.llama.modeling_llama import ( - LlamaDecoderLayer as OriginalLlamaDecoderLayer, -) -from transformers.models.llama.modeling_llama import ( - LlamaMLP, - apply_rotary_pos_emb, - repeat_kv, -) -from xformers.ops import SwiGLU - -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name - -try: - from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports - flash_attn_kvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, - ) -except ImportError: - from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func, - ) - from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func, - ) - - -LOG = logging.getLogger("axolotl") - - -def replace_llama_mlp_with_swiglu(model): - for name, module in model.named_modules(): - if isinstance(module, LlamaMLP): - mlp = FusedMLP( - module.config, module.gate_proj, module.up_proj, module.down_proj - ) - set_module_name(model, name, mlp) - - -def replace_llama_qkv_with_fused(model): - for name, module in model.named_modules(): - if isinstance(module, LlamaAttention): - qkv = FusedAttention( - module.config, - module.q_proj, - module.k_proj, - module.v_proj, - module.o_proj, - ) - set_module_name(model, name, qkv) - - -def replace_llama_attn_with_flash_attn( - packed: Optional[bool] = False, - cross_entropy: Optional[bool] = False, - rms_norm: Optional[bool] = False, -): - transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access - _prepare_decoder_attention_mask - ) - transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward - if packed: - transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer - transformers.models.llama.modeling_llama.LlamaModel.forward = ( - llama_model_forward - ) - - # skip only if explicitly disabled - if cross_entropy: - try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss - - LOG.info("patching with flash_attn.losses.cross_entropy") - transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( - CrossEntropyLoss, inplace_backward=True - ) - except ImportError: - LOG.info( - "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)" - ) - - # skip only if explicitly disabled - if rms_norm: - try: - from flash_attn.ops.rms_norm import RMSNorm - - class LlamaRMSNorm(RMSNorm): - """Patched LLamaRMSNorm""" - - def __init__(self, hidden_size, eps=1e-6): - super().__init__(hidden_size, eps=eps) - - LOG.info("patching with flash_attn.ops.rms_norm") - transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm - except ImportError: - LOG.info( - "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" - ) - - -class FusedAttention(LlamaAttention): - """ - Fused QKV Attention layer for incrementally improved training efficiency - """ - - def __init__( - self, - config, - q: torch.nn.Linear, # pylint: disable=invalid-name - k: torch.nn.Linear, # pylint: disable=invalid-name - v: torch.nn.Linear, # pylint: disable=invalid-name - o: torch.nn.Linear, # pylint: disable=invalid-name - ): - super().__init__(config) - self.config = config - self.init_device = next(iter(q.state_dict().values())).device - - # define equivalent fused qkv projection - self.out_features: List[int] = [q.out_features, k.out_features, v.out_features] - self.qkv_proj = torch.nn.Linear( - q.in_features, sum(self.out_features), device=self.init_device, bias=False - ) - self.o_proj = o - - # overwrite initialized weights with pretrained weights - self.qkv_proj.weight.data = torch.cat( - (q.weight.data, k.weight.data, v.weight.data), dim=0 - ) - - def _post_training(self, model, name): - q_proj, k_proj, v_proj = torch.split( - self.qkv_proj.weight.data, self.out_features, dim=0 - ) - - new_attn = LlamaAttention(self.config) - new_attn.q_proj.weight.data = q_proj - new_attn.k_proj.weight.data = k_proj - new_attn.v_proj.weight.data = v_proj - new_attn.o_proj.weight.data = self.o_proj.weight.data - - set_module_name(model, name, new_attn) - - -class FusedMLP(torch.nn.Module): - """ - Fused MLP layer for incrementally improved training efficiency - """ - - def __init__( - self, - config, - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, - ): - super().__init__() - self.config = config - self.swiglu = SwiGLU( - in_features=config.hidden_size, - hidden_features=config.intermediate_size, - bias=False, - _pack_weights=True, - ) - # overwrite initialized weights with pretrained weights - self.swiglu.w12.weight.data = torch.cat( - (gate_proj.weight.data, up_proj.weight.data), dim=0 - ) - self.swiglu.w3.weight.data = down_proj.weight.data - - def _post_training(self, model, name): - w1, w2 = torch.split( # pylint: disable=invalid-name - self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0 - ) - - # Assign the split weights back to the original layers - new_mlp = LlamaMLP(self.config) - new_mlp.gate_proj.weight.data = w1 - new_mlp.up_proj.weight.data = w2 - new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data - - set_module_name(model, name, new_mlp) - - def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return self.swiglu(x) - - -# Disable the transformation of the attention mask in LlamaModel as the flash attention -# requires the attention mask to be the same as the key_padding_mask -def _prepare_decoder_attention_mask( - self, - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, -): # pylint: disable=unused-argument - # [bsz, seq_len] - return attention_mask - - -def flashattn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel - - attention_mask: [bsz, q_len] - """ - # pylint: disable=duplicate-code - bsz, q_len, _ = hidden_states.size() - - if not hasattr(self, "pretraining_tp"): - self.pretraining_tp = 1 - - if self.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * self.head_dim - ) // self.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - if isinstance(self, FusedAttention): - query_states, key_states, value_states = self.qkv_proj(hidden_states).split( - self.out_features, dim=-1 - ) - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - # [bsz, q_len, nh, hd] - # [bsz, nh, q_len, hd] - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if output_attentions: - warnings.warn( - "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." - ) - - # - # flash-attn v2 start - # - - if self.training: - # during training q,k,v always have same seqlen - assert key_states.shape == query_states.shape - is_causal = True - else: - # turn off FA causal mask after first inference autoregressive iteration - # only on first autoregressive step q,k,v have same seqlen - is_causal = key_states.shape == query_states.shape - - dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0) - - if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: - # special handling using sample packing - qkv = torch.stack( - [query_states, key_states, value_states], dim=2 - ) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] - qkv = rearrange(qkv, "b s ... -> (b s) ...") - - output = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=dropout_rate, - softmax_scale=None, - causal=True, - ) - output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - elif query_states.shape == key_states.shape: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( - query_states, - key_states, - value_states, - qkvpacked=True, - # We have disabled _prepare_decoder_attention_mask in LlamaModel - # the attention_mask should be the same as the key_padding_mask - key_padding_mask=attention_mask, - query_padding_mask=attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None, - ) - output_unpad = flash_attn_varlen_qkvpacked_func( - qkv_unpad, - cu_seqlens_q, - max_seqlen_q, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - ) - output = output_pad_fn(output_unpad) - else: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - if attention_mask is None or attention_mask.all().item(): - output = flash_attn_kvpacked_func( - query_states, - torch.stack([key_states, value_states], 2), - dropout_p=dropout_rate, - causal=is_causal, - ) - else: - ( # pylint: disable=unbalanced-tuple-unpacking - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - _, - _, - output_pad_fn, - ) = generate_qkv( - query_states, - key_states, - value_states, - kvpacked=True, - key_padding_mask=attention_mask, - query_padding_mask=attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None, - ) - if q_unpad.dtype != kv_unpad.dtype: - kv_unpad = kv_unpad.to(q_unpad.dtype) - output_unpad = flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - ) - output = output_pad_fn(output_unpad) - - attn_output = output - if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = rearrange(attn_output, "b s h d -> b s (h d)") - - # - # flash-attn v2 end - # - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.pretraining_tp, dim=1 - ) - attn_output = sum( - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.pretraining_tp) - ) - else: - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38 -def generate_qkv( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - kvpacked=False, - qkvpacked=False, -): # pylint: disable=invalid-name,unnecessary-lambda-assignment - """ - Arguments: - q: (batch_size, seqlen_q, nheads, d) - k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) - query_padding_mask: (batch_size, seqlen), bool - key_padding_mask: (batch_size, seqlen), bool - """ - assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) - - if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( - q, query_padding_mask - ) - - output_pad_fn = lambda output_unpad: pad_input( # noqa: E731 - output_unpad, indices_q, batch_size, seqlen_q - ) - - else: - q_unpad = rearrange(q, "b s h d -> (b s) h d") - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seqlen_q, - step=seqlen_q, - dtype=torch.int32, - device=q_unpad.device, - ) - max_seqlen_q = seqlen_q - - output_pad_fn = lambda output_unpad: rearrange( # noqa: E731 - output_unpad, "(b s) h d -> b s h d", b=batch_size - ) - - if key_padding_mask is not None: - k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) - else: - k_unpad = rearrange(k, "b s h d -> (b s) h d") - v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * seqlen_k, - step=seqlen_k, - dtype=torch.int32, - device=k_unpad.device, - ) - max_seqlen_k = seqlen_k - - if qkvpacked: - assert nheads == nheads_k - qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) - qkv = torch.stack([q, k, v], dim=2) - return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn) - - if kvpacked: - kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) - kv = torch.stack([k, v], dim=2) - return ( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - kv, - output_pad_fn, - ) - - return ( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - output_pad_fn, - ) - - -def llama_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, -) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - if input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - cu_seqlens = None - max_seqlen = None - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - padding_mask = None - else: - if 0 in attention_mask: - padding_mask = attention_mask - else: - padding_mask = None - - attention_mask = ( - self._prepare_decoder_attention_mask( # pylint: disable=protected-access - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - transformers.logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module( - *inputs, - ) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - None, - padding_mask, - cu_seqlens, - max_seqlen, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaDecoderLayer(OriginalLlamaDecoderLayer): - """ - patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - padding_mask: Optional[torch.LongTensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs diff --git a/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py b/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py deleted file mode 100644 index cfed8cb1744b3d1c1f432fcc963846e715aab73f..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py +++ /dev/null @@ -1,142 +0,0 @@ -""" -Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention -""" - -import warnings -from typing import Optional, Tuple - -import torch -import torch.nn.functional as F -import transformers.models.llama.modeling_llama -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv - - -def hijack_llama_sdp_attention(): - transformers.models.llama.modeling_llama.LlamaAttention.forward = ( - sdp_attention_forward - ) - - -def sdp_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # pylint: disable=duplicate-code - bsz, q_len, _ = hidden_states.size() - - if not hasattr(self, "pretraining_tp"): - self.pretraining_tp = 1 - - if self.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * self.head_dim - ) // self.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - # [bsz, q_len, nh, hd] - # [bsz, nh, q_len, hd] - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if output_attentions: - warnings.warn( - "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." - ) - - # - # sdp-attn start - # - - with torch.backends.cuda.sdp_kernel(): - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - is_causal=False, - ) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - # - # sdp-attn end - # - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.pretraining_tp, dim=1 - ) - attn_output = sum( - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.pretraining_tp) - ) - else: - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value diff --git a/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py deleted file mode 100644 index 8143750f0050184609ea61711c35bdf33dcbe59a..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments -""" - -import logging -import warnings -from typing import Optional, Tuple - -import torch -import torch.nn.functional as F -import transformers.models.llama.modeling_llama -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv - -try: - import xformers.ops -except ImportError: - logging.error("xformers not found! Please install it before trying to use it.") - - -def hijack_llama_attention(): - transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward - - -def xformers_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # pylint: disable=duplicate-code - bsz, q_len, _ = hidden_states.size() - - if not hasattr(self, "pretraining_tp"): - self.pretraining_tp = 1 - - if self.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * self.head_dim - ) // self.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - # [bsz, q_len, nh, hd] - # [bsz, nh, q_len, hd] - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if output_attentions: - warnings.warn( - "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." - ) - - # - # xformers-attn start - # - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. - # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. - if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: - # input and output should be of form (bsz, q_len, num_heads, head_dim) - attn_output = xformers.ops.memory_efficient_attention( - query_states, key_states, value_states, attn_bias=None - ) - else: - # input and output should be of form (bsz, q_len, num_heads, head_dim) - attn_output = xformers.ops.memory_efficient_attention( - query_states, - key_states, - value_states, - # attn_bias=attention_mask, - attn_bias=xformers.ops.LowerTriangularMask(), - ) - - if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - # - # xformers-attn end - # - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.pretraining_tp, dim=1 - ) - attn_output = sum( - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.pretraining_tp) - ) - else: - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value diff --git a/axolotl/src/axolotl/monkeypatch/llama_expand_mask.py b/axolotl/src/axolotl/monkeypatch/llama_expand_mask.py deleted file mode 100644 index d69433baa85be39287e2b02a36da973e367fcd41..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/monkeypatch/llama_expand_mask.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf -""" -from typing import Optional - -import torch - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - This expansion handles packed sequences so that sequences share the same attention mask integer value - when they attend to each other within that sequence. - This expansion transforms the mask to lower triangular form to prevent future peeking. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - mask = mask.unsqueeze(1).unsqueeze(2) - mask = mask.expand(bsz, 1, tgt_len, src_len) - - # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one - binary_mask = torch.where( - mask != 0, - torch.tensor(1).to(dtype), - torch.tensor(0).to(dtype), - ) - - # Create a block-diagonal mask. - # we multiply by the binary mask so that 0's in the original mask are correctly excluded - zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask - - # Now let's create a lower triangular mask of ones that will zero out the upper triangular part - lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to( - mask.device - ) - - # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask - masked_zero_one_mask = zero_one_mask * lower_triangular_ones - inverted_mask = 1.0 - masked_zero_one_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) - - -def hijack_expand_mask(): - import transformers - - transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access - _expand_mask - ) diff --git a/axolotl/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/axolotl/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py deleted file mode 100644 index e31864b8387c13913f9d435f1c0335c1bb8b65d2..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ /dev/null @@ -1,643 +0,0 @@ -"""Flash attention monkey patch for mistral model""" -# pylint: disable=duplicate-code - -import logging -from typing import List, Optional, Tuple, Union - -import torch -import transformers -from einops import rearrange -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports - flash_attn_kvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, -) -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.mistral.modeling_mistral import ( - MistralAttention as OriginalMistralAttention, -) -from transformers.models.mistral.modeling_mistral import ( - MistralDecoderLayer as OriginalMistralDecoderLayer, -) -from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv - -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids - -LOG = logging.getLogger("axolotl.monkeypatch.mistral") - - -def replace_mistral_attn_with_flash_attn( - packed: Optional[bool] = False, -): - transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access - _prepare_decoder_attention_mask - ) - transformers.models.mistral.modeling_mistral.MistralAttention.forward = ( - flashattn_forward - ) - if packed: - transformers.models.mistral.modeling_mistral.MistralDecoderLayer = ( - MistralDecoderLayer - ) - transformers.models.mistral.modeling_mistral.MistralModel.forward = ( - mistral_model_forward - ) - - -@torch.jit.script -def _make_sliding_window_causal_mask( - bsz: int, - tgt_len: int, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: int = 4096, -): - """ - Make causal mask used for sliding window attention - """ - tensor = torch.full( - (tgt_len, tgt_len), - fill_value=1, - device=device, - ) - mask = torch.tril(tensor, diagonal=0) - # make the mask banded to account for sliding window - # NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1 - mask = torch.triu(mask, diagonal=-sliding_window + 1) - mask = torch.log(mask).to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros( - tgt_len, past_key_values_length, dtype=dtype, device=device - ), - mask, - ], - dim=-1, - ) - return mask[None, None, :, :].expand( - bsz, 1, tgt_len, tgt_len + past_key_values_length - ) - - -# Disable the transformation of the attention mask in LlamaModel as the flash attention -# requires the attention mask to be the same as the key_padding_mask -def _prepare_decoder_attention_mask( - self, - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - sliding_window, -): # pylint: disable=unused-argument - # [bsz, seq_len] - if attention_mask is None: - return attention_mask - - # NOTE: attention mask and sliding masks are only broadcastable in certain scenarios. - # Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled. - if input_shape[-1] > 1 and attention_mask.shape[0] == 1: - sliding_window_mask = _make_sliding_window_causal_mask( - bsz=input_shape[0], - tgt_len=input_shape[1], - dtype=inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - sliding_window=sliding_window, - ) - attention_mask = attention_mask + sliding_window_mask - else: - LOG.info("skipping sliding window mask, not broadcastable with attention mask") - - return attention_mask - - -def flashattn_forward( - self: OriginalMistralAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - - use_sliding_windows = ( - hasattr(self.config, "sliding_window") is not None - and kv_seq_len > self.config.sliding_window - ) - - if use_sliding_windows: - window_size = (self.config.sliding_window, self.config.sliding_window) - else: - window_size = (-1, -1) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - if ( - hasattr(self.config, "sliding_window") - and kv_seq_len > self.config.sliding_window - ): - slicing_tokens = kv_seq_len - self.config.sliding_window - - past_key = past_key_value[0] - past_value = past_key_value[1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - past_key_value = (past_key, past_value) if use_cache else None - - if past_key_value is not None: - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if self.training: - # during training q,k,v always have same seqlen - assert key_states.shape == query_states.shape - is_causal = True - else: - # turn off FA causal mask after first inference autoregressive iteration - # only on first autoregressive step q,k,v have same seqlen - is_causal = key_states.shape == query_states.shape - - dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0) - - if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: - # special handling using sample packing - qkv = torch.stack( - [query_states, key_states, value_states], dim=2 - ) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] - qkv = rearrange(qkv, "b s ... -> (b s) ...") - - output = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=dropout_rate, - softmax_scale=None, - causal=True, - window_size=window_size, - ) - output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - elif query_states.shape == key_states.shape: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( - query_states, - key_states, - value_states, - qkvpacked=True, - # We have disabled _prepare_decoder_attention_mask in LlamaModel - # the attention_mask should be the same as the key_padding_mask - key_padding_mask=attention_mask, - query_padding_mask=attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None, - ) - output_unpad = flash_attn_varlen_qkvpacked_func( - qkv_unpad, - cu_seqlens_q, - max_seqlen_q, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - window_size=window_size, - ) - output = output_pad_fn(output_unpad) - else: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - if attention_mask is None or attention_mask.all().item(): - output = flash_attn_kvpacked_func( - query_states, - torch.stack([key_states, value_states], 2), - dropout_p=dropout_rate, - causal=is_causal, - window_size=window_size, - ) - else: - ( # pylint: disable=unbalanced-tuple-unpacking - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - _, - _, - output_pad_fn, - ) = generate_qkv( - query_states, - key_states, - value_states, - kvpacked=True, - key_padding_mask=attention_mask, - query_padding_mask=attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None, - ) - if q_unpad.dtype != kv_unpad.dtype: - kv_unpad = kv_unpad.to(q_unpad.dtype) - output_unpad = flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - window_size=window_size, - ) - output = output_pad_fn(output_unpad) - - attn_output = output - if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = rearrange(attn_output, "b s h d -> b s (h d)") - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38 -def generate_qkv( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - kvpacked=False, - qkvpacked=False, -): # pylint: disable=invalid-name,unnecessary-lambda-assignment - """ - Arguments: - q: (batch_size, seqlen_q, nheads, d) - k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) - query_padding_mask: (batch_size, seqlen), bool - key_padding_mask: (batch_size, seqlen), bool - """ - assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) - - if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( - q, query_padding_mask - ) - - output_pad_fn = lambda output_unpad: pad_input( # noqa: E731 - output_unpad, indices_q, batch_size, seqlen_q - ) - - else: - q_unpad = rearrange(q, "b s h d -> (b s) h d") - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seqlen_q, - step=seqlen_q, - dtype=torch.int32, - device=q_unpad.device, - ) - max_seqlen_q = seqlen_q - - output_pad_fn = lambda output_unpad: rearrange( # noqa: E731 - output_unpad, "(b s) h d -> b s h d", b=batch_size - ) - - if key_padding_mask is not None: - k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) - else: - k_unpad = rearrange(k, "b s h d -> (b s) h d") - v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * seqlen_k, - step=seqlen_k, - dtype=torch.int32, - device=k_unpad.device, - ) - max_seqlen_k = seqlen_k - - if qkvpacked: - assert nheads == nheads_k - qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) - qkv = torch.stack([q, k, v], dim=2) - return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn) - - if kvpacked: - kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) - kv = torch.stack([k, v], dim=2) - return ( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - kv, - output_pad_fn, - ) - - return ( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - output_pad_fn, - ) - - -def mistral_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, -) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - if input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - cu_seqlens = None - max_seqlen = None - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - attention_mask = ( - self._prepare_decoder_attention_mask( # pylint: disable=protected-access - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - transformers.logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - None, - cu_seqlens, - max_seqlen, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class MistralDecoderLayer(OriginalMistralDecoderLayer): - """ - patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs diff --git a/axolotl/src/axolotl/monkeypatch/mixtral/__init__.py b/axolotl/src/axolotl/monkeypatch/mixtral/__init__.py deleted file mode 100644 index 418814689267d97f8f6b11955f19494b7e81e126..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/monkeypatch/mixtral/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Patches to support multipack for mixtral -""" -import transformers - - -def replace_mixtral_attn_with_multipack_flash_attn(): - from .modeling_mixtral import ( - MixtralMultipackFlashAttention2, - mixtral_decoder_layer_forward, - mixtral_model_forward, - ) - - transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = ( - mixtral_decoder_layer_forward - ) - transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = ( - mixtral_model_forward - ) - transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[ - "flash_attention_2" - ] = MixtralMultipackFlashAttention2 diff --git a/axolotl/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py b/axolotl/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py deleted file mode 100644 index 34f35015f9f9b06907b18a3cb9ec36b7b758baca..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py +++ /dev/null @@ -1,379 +0,0 @@ -""" -Mixtral modeling for multipack -""" -# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code -import logging -import warnings -from typing import List, Optional, Tuple, Union - -import torch -from einops import rearrange -from flash_attn import flash_attn_varlen_qkvpacked_func -from transformers import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import MoeModelOutputWithPast -from transformers.models.mixtral.modeling_mixtral import ( - MixtralFlashAttention2, - apply_rotary_pos_emb, - repeat_kv, -) - -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids - -LOG = logging.getLogger("axolotl.monkeypatch.mixtral") - - -class MixtralMultipackFlashAttention2(MixtralFlashAttention2): - """ - Custom multipack implementation w flash attention 2 - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._flash_attn_uses_top_left_mask = True - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: - # special handling using sample packing - qkv = torch.stack( - [query_states, key_states, value_states], dim=2 - ) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] - qkv = rearrange(qkv, "b s ... -> (b s) ...") - - attn_output = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=self.attention_dropout, - softmax_scale=None, - causal=True, - ) - attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def mixtral_decoder_layer_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - output_router_logits: Optional[bool] = False, - use_cache: Optional[bool] = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, - **kwargs, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_router_logits (`bool`, *optional*): - Whether or not to return the logits of all the routers. They are useful for computing the router loss, and - should not be returned during inference. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states, router_logits = self.block_sparse_moe(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - if output_router_logits: - outputs += (router_logits,) - - return outputs - - -def mixtral_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, -) -> Union[Tuple, MoeModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_router_logits = ( - output_router_logits - if output_router_logits is not None - else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - if input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - cu_seqlens = None - max_seqlen = None - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self._use_flash_attention_2 and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = ( - attention_mask - if (attention_mask is not None and 0 in attention_mask) - else None - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - LOG.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cu_seqlens, - max_seqlen, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if output_router_logits: - all_router_logits += (layer_outputs[-1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if use_legacy_cache - else next_decoder_cache - ) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_cache, - all_hidden_states, - all_self_attns, - all_router_logits, - ] - if v is not None - ) - - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, - ) diff --git a/axolotl/src/axolotl/monkeypatch/relora.py b/axolotl/src/axolotl/monkeypatch/relora.py deleted file mode 100644 index 9dac77e18a60eb38486cf5b669b826aeb616f59f..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/monkeypatch/relora.py +++ /dev/null @@ -1,393 +0,0 @@ -"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune.""" -import glob -import json -import logging -import os.path -import shutil -from pathlib import Path -from typing import Dict, List, Sequence - -import bitsandbytes as bnb -import peft -import safetensors.torch as st -import torch -from huggingface_hub import snapshot_download -from torch.optim.lr_scheduler import LRScheduler -from torch.optim.optimizer import Optimizer -from transformers import ( - TrainerCallback, - TrainerControl, - TrainerState, - TrainingArguments, -) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR - -from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process - -LOG = logging.getLogger("axolotl.relora") - - -def reset_optimizer(optimizer: torch.optim.Optimizer): - for group in optimizer.param_groups: - for param in group["params"]: - param_state = optimizer.state[param] - for key in param_state: - if "qmap" in key: - continue - - if key == "step" and isinstance(param_state[key], int): - param_state[key] = 0 - else: - param_state[key] = torch.zeros_like(param_state[key]) - - -class ReLoRACallback(TrainerCallback): - """Callback to merge LoRA weights into the base model and save full-weight checkpoints""" - - def __init__(self, cfg: DictDefault): - self.relora_steps = cfg.relora_steps - self.cpu_offload = cfg.relora_cpu_offload - self.quantized = cfg.load_in_4bit or cfg.load_in_8bit - self.last_full_model = cfg.base_model - self.resume_from_checkpoint = cfg.resume_from_checkpoint - - if not os.path.exists(self.last_full_model): - self.last_full_model = str(Path(snapshot_download(cfg.base_model))) - - assert os.path.exists( - self.last_full_model - ), "for ReLORA base_model must be a local path" - - self.num_lora_restarts = 0 - self.need_full_save = False - - def on_train_begin( - self, - _args: TrainingArguments, - _state: TrainerState, - control: TrainerControl, - model: peft.LoraModel, - **_kwargs, - ): - if self.resume_from_checkpoint: - weight_path = os.path.join(self.resume_from_checkpoint, "relora") - if not os.path.exists(weight_path): - LOG.warning( - "Resuming ReLoRA from checkpoint, but no full-weight save found" - ) - else: - LOG.info(f"Loading adjusted base weights from {weight_path}") - load_weight_checkpoint(model, weight_path) - return control - - def on_step_begin( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - model: peft.LoraModel, - optimizer: torch.optim.Optimizer, - **_kwargs, - ): - if state.global_step > 0 and state.global_step % self.relora_steps == 0: - checkpoint_folder = os.path.join( - args.output_dir, - f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", - "relora", - ) - - with torch.no_grad(): - merge_and_save( - model, - self.last_full_model, - checkpoint_folder, - reinit=True, - quantized=self.quantized, - actually_save=is_main_process(), - cpu_offload=self.cpu_offload, - ) - reset_optimizer(optimizer) - - if self.quantized: - self.last_full_model = checkpoint_folder - self.num_lora_restarts += 1 - - return control - - def on_save( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - model: peft.LoraModel, - **_kwargs, - ): - checkpoint_folder = os.path.join( - args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora" - ) - if ( - state.global_step >= self.relora_steps - and state.global_step % self.relora_steps != 0 - ): - if self.quantized: - if is_main_process() and self.last_full_model != checkpoint_folder: - # ensure the latest full parameter save is in the latest checkpoint - # folder, so that automatic pruning of checkpoints does not remove it - LOG.info(f"moving last full parameter save to {checkpoint_folder}") - os.makedirs(checkpoint_folder, exist_ok=True) - chunks = glob.glob( - f"{self.last_full_model}/model*.safetensors" - ) + glob.glob(f"{self.last_full_model}/model*.index.json") - for path in chunks: - new_path = os.path.abspath(shutil.move(path, checkpoint_folder)) - try: - os.symlink(new_path, path) - except OSError: - # probably on windows without permission to symlink - pass - - self.last_full_model = checkpoint_folder - else: - model.model.save_pretrained(checkpoint_folder, safe_serialization=True) - - return control - - def on_log( - self, - _args: TrainingArguments, - _state: TrainerState, - control: TrainerControl, - logs: Dict[str, float], - **_kwargs, - ): - logs["num_lora_restarts"] = self.num_lora_restarts - return control - - def on_train_end( - self, - args: TrainingArguments, - _state: TrainerState, - control: TrainerControl, - model: peft.LoraModel, - **_kwargs, - ): - if self.quantized: - # perform final merge and save - with torch.no_grad(): - merge_and_save( - model, - self.last_full_model, - args.output_dir, - reinit=False, - quantized=self.quantized, - actually_save=is_main_process(), - cpu_offload=self.cpu_offload, - ) - # no need to save if unquantized, as finetune.py will call merge_and_unload() - return control - - -class ReLoRAScheduler(LRScheduler): - """Wraps another scheduler to apply per-lora-restart learning rate warmups.""" - - def __init__( - self, - optimizer: Optimizer, - inner_schedule: LRScheduler, - relora_steps: int, - warmup_steps: int, - min_lr_scale: float = 0.001, - ) -> None: - self.inner_schedule = inner_schedule - self.relora_steps = relora_steps - self.warmup_steps = warmup_steps - self.min_lr_scale = min_lr_scale - super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose) - - def get_lr(self) -> float: - self.inner_schedule.last_epoch = self.last_epoch - - original = self.inner_schedule.get_lr() - step = self.last_epoch - if step < self.relora_steps: - scale = 1 - else: - cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps) - scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale - - if isinstance(original, Sequence): - return [lr * scale for lr in original] - return original * scale - - -def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]: - model_name = "model.safetensors" - if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists( - str(Path(path) / f"{model_name}.index.json") - ): - model_name = "pytorch_model.bin" - - index_path = str(Path(path) / f"{model_name}.index.json") - if os.path.exists(index_path): - with open(index_path, "r", encoding="utf-8") as file: - data = json.load(file) - return data["weight_map"] - return {(module_name + ".weight"): model_name for module_name in module_names} - - -def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor: - if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)): - adapter = layer.active_adapter - return ( - peft.utils.transpose( - layer.lora_B[adapter].weight.detach().to(device) - @ layer.lora_A[adapter].weight.detach().to(device), - getattr(layer, "fan_in_fan_out", False), - ) - * layer.scaling[adapter] - ) - - return layer.get_delta_weight().to(device) - - -def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]: - modules: Dict[str, peft.tuners.lora.LoraLayer] = {} - - key_list = [key for key, _ in model.model.named_modules() if "lora" not in key] - for key in key_list: - try: - # pylint: disable=protected-access - _parent, target, _target_name = peft.utils._get_submodules(model.model, key) - except AttributeError: - continue - - if isinstance(target, peft.tuners.lora.LoraLayer): - modules[key] = target - - return modules - - -def update_weights( - target: peft.tuners.lora.LoraLayer, new_weight: torch.Tensor, reinit: bool, device -): - if reinit: - for adapter_name in target.lora_A: - target.reset_lora_parameters(adapter_name) - for adapter_name in target.lora_embedding_A: - target.reset_lora_parameters(adapter_name) - - if isinstance(target, peft.tuners.lora.Linear4bit): - # This could be faster, but the quantization of Linear4bit weights occurs - # when the module is moved from cpu to gpu. Without meddling *too* deeply in - # PEFT's innards or maintaining a duplicate of that codepath, this is good - # enough for now. - target.weight.quant_state = None - target.weight.data = new_weight.cpu() - target.to(device) - elif isinstance(target, peft.tuners.lora.Linear8bitLt): - target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device) - else: - target.weight.data = new_weight.to(device) - - -def merge_and_save( - model: peft.LoraModel, - model_src: str, - model_dst: str, - reinit: bool = False, - quantized: bool = False, - cpu_offload: bool = False, - actually_save: bool = True, -): - modules = find_lora_modules(model) - - if not quantized: - for module_name, target in modules.items(): - update = target.get_delta_weight(target.active_adapter).detach() - target.weight.data += update - - if reinit: - for adapter_name in target.lora_A: - target.reset_lora_parameters(adapter_name) - for adapter_name in target.lora_embedding_A: - target.reset_lora_parameters(adapter_name) - return - - os.makedirs(model_dst, exist_ok=True) - shard_paths = sharded_paths(model_src, modules.keys()) - out_shard_paths = {} - - unique_shards = list(set(shard_paths.values())) - for shard_path in unique_shards: - out_tensors = {} - if shard_path.endswith(".safetensors"): - in_tensors = st.load_file(str(Path(model_src) / shard_path)) - else: - in_tensors = torch.load(Path(model_src) / shard_path) - if "state_dict" in in_tensors: - in_tensors = in_tensors["state_dict"] - - for module_name, target in modules.items(): - key = module_name + ".weight" - if key not in shard_paths or shard_paths[key] != shard_path: - continue - - orig_weight = in_tensors[key] - old_dev = target.weight.device - math_dev = "cpu" if cpu_offload else old_dev - - delta_weight = lora_delta_weight(target, math_dev) - new_weight = orig_weight.to(math_dev) + delta_weight - del delta_weight - - if actually_save: - out_tensors[key] = new_weight.half().cpu() - - update_weights(target, new_weight, reinit=reinit, device=old_dev) - - if actually_save: - out_shard_name = shard_path - if out_shard_name.startswith("pytorch_model"): - out_shard_name = ( - out_shard_name.replace("pytorch_model", "model").rstrip(".bin") - + ".safetensors" - ) - - for module_name in in_tensors: - if module_name not in out_tensors: - out_tensors[module_name] = in_tensors[module_name].half() - out_shard_paths[module_name] = out_shard_name - - shard_fn = str(Path(model_dst) / out_shard_name) - LOG.info(f"saving tensors to {shard_fn}") - st.save_file(out_tensors, shard_fn, metadata={"format": "pt"}) - - del in_tensors - del out_tensors - torch.cuda.empty_cache() - - if actually_save and len(unique_shards) > 1: - with open( - str(Path(model_dst, "model.safetensors.index.json")), "w", encoding="utf-8" - ) as file: - json.dump({"metadata": {}, "weight_map": out_shard_paths}, file) - - -def load_weight_checkpoint(model: peft.LoraModel, checkpoint_path: str): - modules = find_lora_modules(model) - shard_paths = sharded_paths(checkpoint_path, modules.keys()) - unique_shards = list(set(shard_paths.values())) - - for shard_path in unique_shards: - tensors = st.load_file(os.path.join(checkpoint_path, shard_path)) - - for module_name, target in modules.items(): - key = module_name + ".weight" - if key not in shard_paths or shard_paths[key] != shard_path: - continue - - new_weight = tensors[key] - update_weights( - target, new_weight, reinit=False, device=target.weight.device - ) diff --git a/axolotl/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py b/axolotl/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py deleted file mode 100644 index 0269f90157b12c46a460545842d69da65afc394a..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py +++ /dev/null @@ -1,415 +0,0 @@ -# coding=utf-8 -# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This code is based off the following work: -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py -""" PyTorch StableLM Epoch model. """ -import importlib -import math -from typing import Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from accelerate import init_empty_weights -from einops import rearrange -from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports - flash_attn_varlen_qkvpacked_func, -) -from torch import nn -from transformers import AutoConfig, AutoModelForCausalLM -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.utils import logging - -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids - -logger = logging.get_logger(__name__) - - -def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"): - # this is a wonky hack to get the remotely loaded module - model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - # we need to load the model here in order for modeling_stablelm_epoch to be available - with init_empty_weights(): - AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) - module_name = model_config.__class__.__module__.replace( - ".configuration_stablelm_epoch", ".modeling_stablelm_epoch" - ) - modeling_stablelm = importlib.import_module(module_name) - modeling_stablelm.Attention.forward = ( # pylint: disable=protected-access - flashattn_attn - ) - modeling_stablelm.StableLMEpochModel.forward = ( # pylint: disable=protected-access - stablelm_model_forward - ) - modeling_stablelm.DecoderLayer.forward = ( # pylint: disable=protected-access - decoder_layer_forward - ) - - -def rotate_half(x: torch.Tensor): - """Rotates half the hidden dims of the input.""" - # pylint: disable=invalid-name - x1, x2 = torch.chunk(x, 2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - # pylint: disable=invalid-name - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def flashattn_attn( - self, - hidden_states: torch.FloatTensor, - attention_mask: torch.FloatTensor, - position_ids: torch.LongTensor, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, # pylint: disable=unused-argument - use_cache: Optional[bool] = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - query_rot = query_states[..., : self.rotary_ndims] - query_pass = query_states[..., self.rotary_ndims :] - key_rot = key_states[..., : self.rotary_ndims] - key_pass = key_states[..., self.rotary_ndims :] - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_rot, key_rot, cos, sin, position_ids - ) - - # [batch_size, num_heads, seq_len, head_dim] - query_states = torch.cat((query_states, query_pass), dim=-1) - key_states = torch.cat((key_states, key_pass), dim=-1) - - if past_key_value is not None: - # Reuse k, v, self_attention - key_states = torch.cat((past_key_value[0], key_states), dim=2) - value_states = torch.cat((past_key_value[1], value_states), dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # Repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: - # special handling using sample packing - qkv = torch.stack( - [query_states, key_states, value_states], dim=2 - ) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] - qkv = rearrange(qkv, "b s ... -> (b s) ...") - softmax_scale = None - - output = flash_attn_varlen_qkvpacked_func( - qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True - ) - - attn_output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - attn_output = rearrange(attn_output, "b s h d -> b s (h d)") - else: - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # Upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - # Merge heads - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - # Final linear projection - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -def decoder_layer_forward( - self, - hidden_states: Optional[torch.FloatTensor], - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, -) -> Union[ - Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]] -]: - # pylint: disable=duplicate-code - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -def stablelm_model_forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, -) -> Union[Tuple, BaseModelOutputWithPast]: - # pylint: disable=duplicate-code - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # Retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - if input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - cu_seqlens = None - max_seqlen = None - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # Embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - attention_mask = ( - self._prepare_decoder_attention_mask( # pylint: disable=protected-access - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # Decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - None, - cu_seqlens, - max_seqlen, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # Add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) diff --git a/axolotl/src/axolotl/monkeypatch/utils.py b/axolotl/src/axolotl/monkeypatch/utils.py deleted file mode 100644 index b352cc55ece3fd2776e58fc4e9415189add93116..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/monkeypatch/utils.py +++ /dev/null @@ -1,116 +0,0 @@ -""" -Shared utils for the monkeypatches -""" -import torch - - -def get_cu_seqlens(attn_mask): - """generate a cumulative sequence length mask for flash attention using attn mask""" - if len(attn_mask.shape) == 1: - attn_mask = attn_mask.unsqueeze(0) - - device = attn_mask.device - results = [] - max_seq_lens = [] - - for row in attn_mask: - # Exclude zeros to avoid adding their positions to the mask - t_non_zeros = row[row != 0] - # Find where the sequence number changes (including the first position) - seq_change = torch.cat( - [ - torch.tensor([1], dtype=torch.int32, device=device), - t_non_zeros[1:] != t_non_zeros[:-1], - ] - ) - # Get the indices where the sequence changes - change_indices = torch.cat( - [ - (seq_change == 1).nonzero(as_tuple=True)[0], - torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device), - ] - ) - # Calculate the sequence lengths - seq_lengths = change_indices[1:] - change_indices[:-1] - # Calculate the length of the final sequence or padding - final_seq_length = len(row) - change_indices[-1] - # Append the length of the final sequence or padding to seq_lengths - if final_seq_length.item(): - seq_lengths = torch.cat( - [ - seq_lengths, - torch.tensor( - [final_seq_length.item()], dtype=torch.int32, device=device - ), - ] - ) - # Calculate the cumulative sequence lengths - cu_seqlens = torch.cat( - [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] - ) - max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - results.append(cu_seqlens) - max_seq_lens.append(max_seq_len) - - return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) - - -def get_cu_seqlens_from_pos_ids(position_ids): - """generate a cumulative sequence length mask for flash attention using pos ids""" - if len(position_ids.shape) == 1: - position_ids = position_ids.unsqueeze(0) - - device = position_ids.device - results = [] - max_seq_lens = [] - - for row in position_ids: - # Count the number of consecutive zeros from the right side - padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item() - - # Adjust the row to exclude padding - adjusted_row = row[:-padding_length] if padding_length else row.clone() - - # Find where the position resets to 0 (indicating a new sequence) - seq_starts = torch.cat( - [ - torch.tensor([True], dtype=torch.bool, device=device), - adjusted_row[1:] == 0, - ] - ) - # Get the indices where the sequence starts - start_indices = torch.cat( - [ - (seq_starts).nonzero(as_tuple=True)[0], - torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device), - ] - ) - # Calculate the sequence lengths - seq_lengths = start_indices[1:] - start_indices[:-1] - # Calculate the cumulative sequence lengths - cu_seqlens = torch.cat( - [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] - ) - # Append the padding length to the cumulative sequence lengths - if padding_length: - cu_seqlens = torch.cat( - [cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)] - ) - max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - results.append(cu_seqlens) - max_seq_lens.append(max_seq_len) - - return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) - - -def set_module_name(model, name, value): - if "." in name: - parent_name = name.rsplit(".", 1)[0] - child_name = name[len(parent_name) + 1 :] - parent = model.get_submodule(parent_name) - else: - parent_name = "" - parent = model - child_name = name - - setattr(parent, child_name, value) diff --git a/axolotl/src/axolotl/prompt_strategies/__init__.py b/axolotl/src/axolotl/prompt_strategies/__init__.py deleted file mode 100644 index e62a5c20ce4dca7e25a47d73c2803f96f86b563e..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Module to load prompt strategies.""" - -import importlib -import inspect - -from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig - - -def load(strategy, tokenizer, cfg, ds_cfg): - try: - load_fn = "load" - if strategy.split(".")[-1].startswith("load_"): - load_fn = strategy.split(".")[-1] - strategy = ".".join(strategy.split(".")[:-1]) - mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies") - func = getattr(mod, load_fn) - load_kwargs = {} - if strategy == "user_defined": - load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg) - else: - sig = inspect.signature(func) - if "ds_cfg" in sig.parameters: - load_kwargs["ds_cfg"] = ds_cfg - return func(tokenizer, cfg, **load_kwargs) - except Exception: # pylint: disable=broad-exception-caught - return None diff --git a/axolotl/src/axolotl/prompt_strategies/alpaca_chat.py b/axolotl/src/axolotl/prompt_strategies/alpaca_chat.py deleted file mode 100644 index 975fee889e1a2168508f64611781ab9bb537b114..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/alpaca_chat.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Module for Alpaca prompt strategy classes""" - -from typing import Any, Dict, Optional, Tuple - -from axolotl.prompt_tokenizers import ( - AlpacaPromptTokenizingStrategy, - InstructionPromptTokenizingStrategy, -) -from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter - - -def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - prompt_style = PromptStyle.CHAT.value - if ds_cfg and "conversation" in ds_cfg: - prompt_style = ds_cfg["conversation"] - - return AlpacaPromptTokenizingStrategy( - AlpacaPrompter(prompt_style), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -class AlpacaConcisePrompter(AlpacaPrompter): - """ - Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers - """ - - system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n" - system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n" - - -class AlpacaChatPrompter(AlpacaPrompter): - """ - Alpaca Chat Prompter extending the system prompt to for chat-instruct answers - """ - - system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n" - system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n" - - def __init__(self): # pylint: disable=super-init-not-called - self.prompt_style = PromptStyle.CHAT.value - self.match_prompt_style() - - -class NoSystemPrompter(AlpacaPrompter): - """ - Null Prompter with no system prompts - """ - - system_prompt = "" - system_no_input_prompt = "" - turn_format = "{instruction} {input} " - turn_no_input_format = "{instruction} " - - def __init__(self): # pylint: disable=super-init-not-called - pass - - -class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for AlpacaQA - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - prompt["question"], - "", - prompt["answer"], - ) - - -class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for CamelAI datasets - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - prompt["message_1"], - "", - prompt["message_2"], - ) - - -def load_concise(tokenizer, cfg): - return AlpacaPromptTokenizingStrategy( - AlpacaConcisePrompter(PromptStyle.CHAT.value), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load_qa(tokenizer, cfg): - return AlpacaQAPromptTokenizingStrategy( - AlpacaChatPrompter(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load_camel_ai(tokenizer, cfg): - return CamelAIPromptTokenizingStrategy( - AlpacaChatPrompter(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load_no_prompt(tokenizer, cfg): - return AlpacaPromptTokenizingStrategy( - UnpromptedPrompter(PromptStyle.CHAT.value), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) diff --git a/axolotl/src/axolotl/prompt_strategies/alpaca_instruct.py b/axolotl/src/axolotl/prompt_strategies/alpaca_instruct.py deleted file mode 100644 index 143f070f2a9f826f6ed83e64305760229f535173..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/alpaca_instruct.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Module loading the AlpacaInstructPromptTokenizingStrategy class""" - -from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy -from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter - - -def load(tokenizer, cfg): - return AlpacaPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.INSTRUCT.value), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load_no_prompt(tokenizer, cfg): - return AlpacaPromptTokenizingStrategy( - UnpromptedPrompter(PromptStyle.INSTRUCT.value), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) diff --git a/axolotl/src/axolotl/prompt_strategies/alpaca_w_system.py b/axolotl/src/axolotl/prompt_strategies/alpaca_w_system.py deleted file mode 100644 index 8c8cc07435f9e65e5401588aded9c3791b1c6de9..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/alpaca_w_system.py +++ /dev/null @@ -1,163 +0,0 @@ -""" -Prompt strategies loader for alpaca instruction datasets with system prompts -""" -from typing import Generator, Tuple, Union - -from axolotl.prompt_tokenizers import PromptTokenizingStrategy -from axolotl.prompters import AlpacaPrompter, PromptStyle - - -class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy): - """ - Tokenizing strategy for instruction-based prompts. - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]: - return ( - prompt["instruction"], - prompt["input"] if "input" in prompt else "", - prompt["output"], - prompt["system"], - ) - - def tokenize_prompt(self, prompt): - # pylint: disable=duplicate-code - ( - instruction, - input, # pylint: disable=redefined-builtin - response, - system, - ) = self.parse_instruction_fields(prompt) - user_prompt = next( - iter( - self.prompter.build_prompt_w_system( - system, - instruction, - input, - ) - ) - ) - tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False) - if not self.train_on_inputs: - user_prompt_len = len(tokenized_prompt["input_ids"]) - # TODO this could be sped up using numpy array slicing - tokenized_prompt["labels"] = [-100] * user_prompt_len - tokenized_res_prompt = self._tokenize( - response, strip_bos_token=True, add_eos_token=True - ) - tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"] - tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] - tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] - - return tokenized_prompt - - -class SystemDataPrompter(AlpacaPrompter): - """ - Alpaca Style Prompter that uses system prompts from the dataset - """ - - system_format: str = "### System:\n{system}\n\n" - - def build_prompt_w_system( - self, - system: str, - instruction: str, - input: Union[None, str] = None, # pylint: disable=redefined-builtin - output: Union[None, str] = None, - ) -> Generator[str, None, None]: - # returns the full prompt from instruction and optional input - # if a label (=response, =output) is provided, it's also appended. - formatted_sys_prompt = ( - self.system_format.format(system=system) - if system and self.system_format - else "" - ) - if input: - res = formatted_sys_prompt + self.turn_format.format( - instruction=instruction, input=input - ) - else: - res = formatted_sys_prompt + self.turn_no_input_format.format( - instruction=instruction - ) - if output: - res = f"{res}{output}" - yield res - - -class OpenOrcaSystemDataPrompter(SystemDataPrompter): - """ - Alpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts - """ - - def match_prompt_style(self): - # pylint: disable=duplicate-code - if self.prompt_style == PromptStyle.INSTRUCT.value: - self.turn_format = "### Human:\n{instruction}\n### Additional Context:\n{input}\n### Assistant:\n" - self.turn_no_input_format = "### Human:\n{instruction}\n### Assistant:\n" - self.system_format = "### System:\n{system}\n" - if self.prompt_style == PromptStyle.CHAT.value: - self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" - self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" - self.system_format = "SYSTEM: {system}\n" - if self.prompt_style == PromptStyle.CHATML.value: - self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n" - self.turn_no_input_format = ( - "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n" - ) - self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" - - -class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy): - """ - Tokenizing strategy for OpenOrca datasets - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]: - return ( - prompt["question"], - "", - prompt["response"], - prompt["system_prompt"], - ) - - -def load(tokenizer, cfg): - return load_chat(tokenizer, cfg) - - -def load_instruct(tokenizer, cfg): - return InstructionWSystemPromptTokenizingStrategy( - SystemDataPrompter(PromptStyle.INSTRUCT.value), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load_chat(tokenizer, cfg): - return InstructionWSystemPromptTokenizingStrategy( - SystemDataPrompter(PromptStyle.CHAT.value), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load_open_orca(tokenizer, cfg): - return OpenOrcaPromptTokenizingStrategy( - OpenOrcaSystemDataPrompter(PromptStyle.INSTRUCT.value), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load_open_orca_chatml(tokenizer, cfg): - return OpenOrcaPromptTokenizingStrategy( - OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) diff --git a/axolotl/src/axolotl/prompt_strategies/completion.py b/axolotl/src/axolotl/prompt_strategies/completion.py deleted file mode 100644 index 3285e667cbc36b03e596e548c0321352b1dc25a6..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/completion.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Basic completion text -""" -from collections import defaultdict -from typing import Any, Dict, Generator, Optional, Tuple - -from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy - - -class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for Completion prompts. - """ - - _field: str = "text" - - def __init__(self, *args, max_length=None, **kwargs): - super().__init__(*args, **kwargs) - if max_length is not None: - self.max_length = max_length - - @property - def supports_batched(self): - return True - - @property - def field(self) -> str: - return self._field - - @field.setter - def field(self, new_field: str): - self._field = new_field - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - prompt[self.field], - "", - "", - ) - - def tokenize_prompt(self, prompt): - res = defaultdict(lambda: []) - feature_names = list(prompt.keys()) - for row in zip(*prompt.values()): - prompt_row = dict(zip(feature_names, row)) - ( - instruction, - _, - _, - ) = self.parse_instruction_fields(prompt_row) - - full_prompt = self._build_full_prompt(instruction, None, None) - tokenized_full_prompt = self._tokenize(full_prompt) - - for key, val in tokenized_full_prompt.items(): - for i in range(0, len(val), self.sequence_len): - res[key].append(val[i : i + self.sequence_len]) - - return dict(res) - - def _build_full_prompt( - self, instruction, input, response - ): # pylint: disable=redefined-builtin - return next(iter(self.prompter.build_prompt(instruction, input, response))) - - -class CompletionPrompter: - """ - Prompter for completion - """ - - def build_prompt( - self, - instruction: str, - input=None, # pylint: disable=redefined-builtin, unused-argument - output=None, # pylint: disable=unused-argument - ) -> Generator[str, None, None]: - yield instruction - - -def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - strat = CompletionPromptTokenizingStrategy( - CompletionPrompter(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - max_length=cfg.sequence_len * 64, - ) - if ds_cfg and "field" in ds_cfg: - strat.field = ds_cfg["field"] - - return strat diff --git a/axolotl/src/axolotl/prompt_strategies/context_qa.py b/axolotl/src/axolotl/prompt_strategies/context_qa.py deleted file mode 100644 index f87dd8b5cdd0311510f348302c16eb3a0a26c13a..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/context_qa.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Module containing the classes for Context QA Prompt Tokenization Strategies""" -from typing import Tuple - -from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy -from axolotl.prompters import AlpacaPrompter, PromptStyle - - -# article, unanswerable_question, question, answer -def load_404(tokenizer, cfg): - return AlpacaMissingInfoContextPromptTokenizingStrategy( - AlpacaContextPrompter(PromptStyle.CHAT.value), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load(tokenizer, cfg): - return AlpacaContextPromptTokenizingStrategy( - AlpacaContextPrompter(PromptStyle.CHAT.value), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load_v2(tokenizer, cfg): - return ContextQaV2PromptTokenizingStrategy( - ContextV2Prompter(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -class AlpacaContextPrompter(AlpacaPrompter): - """ - Customized system prompted for concise QA - """ - - system_prompt = ( - "Use the following contextual information to concisely answer the question.\n" - ) - system_no_input_prompt = ( - "Use the following contextual information to concisely answer the question.\n" - ) - - -class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenization Strategy to combine in-context article with a question and answer - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - prompt["article"] + "\n===\n" + prompt["question"], - "", - prompt["answer"], - ) - - -class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenization Strategy to combine in-context article with a question and answer - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - "Context: " - + prompt["context"] - + "\nQuestion: " - + prompt["question"] - + "\n", - "", - "Answer: " + prompt["answer"], - ) - - -class ContextV2Prompter(AlpacaPrompter): - """ - Customized system prompted for concise QA - """ - - system_prompt = "" - system_no_input_prompt = "" - - def match_prompt_style(self): - # pylint: disable=duplicate-code - self.turn_format = "{instruction}\n{input}" - self.turn_no_input_format = "{instruction}" - self.system_format = "{system}" - - -class AlpacaMissingInfoContextPromptTokenizingStrategy( - InstructionPromptTokenizingStrategy -): - """ - Tokenization Strategy to combine in-context article with a question that can't be answered - from the context and a default response to that effect - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - prompt["article"] + "\n===\n" + prompt["unanswerable_question"], - "", - "The context provided does not contain any information about your inquiry. " - "Therefore, I'm unable to answer your question based on the given context.", - ) diff --git a/axolotl/src/axolotl/prompt_strategies/creative_acr.py b/axolotl/src/axolotl/prompt_strategies/creative_acr.py deleted file mode 100644 index ea67034b3b60cbef4227d12df1b4b32aecc6228a..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/creative_acr.py +++ /dev/null @@ -1,199 +0,0 @@ -"""Module loading the CreativePromptTokenizingStrategy and similar classes""" - -from typing import Generator, Tuple, Union - -import yaml - -from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy - - -class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for Creative Answering - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - question = prompt["instruction"] - answer = prompt[ - "revision" - ] # don't use prompt[answer], that's data we don't want in the dataset - return ( - question, - "", - answer, - ) - - -class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for Creative Critique - """ - - user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria: -refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question. -prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias. -creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity -comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth -Then for each criteria, give a detailed explanation of why the response doesn't meet the criteria (don't reference the score), what it did right, and how it could be improved for that criteria. - -Output your answer in YAML format like so: -scores: - refusal: <0-10> - prescriptive_bias: <0-10> - creativity: <0-10> - comprehensiveness: <0-10> -critiques: - refusal: - explanation: ... - improvements: ... - prescriptive_bias: - explanation: ... - improvements: ... - creativity: - explanation: ... - improvements: ... - comprehensiveness: - explanation: ... - improvements: ... - -Question: {question} -Answer: {answer} -""" - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - scores = yaml.dump( - prompt["scores"], - default_flow_style=False, - Dumper=yaml.Dumper, - ) - critiques = yaml.dump( - prompt["critiques"], - default_flow_style=False, - Dumper=yaml.Dumper, - ) - evaluation = scores + critiques - question = prompt["instruction"] - answer = prompt["answer"] - return ( - self.user_prompt.format(question=question, answer=answer), - "", - evaluation, - ) - - -class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for Creative Revise - """ - - user_prompt = """Definitions: -refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question. -prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias. -creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity -comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth - -Given the following Question, Response, and Evaluation, revise the Response based on the Evaluation and recommendations for improvements. Reply only with the revised response. - -Question: {question} -Answer: {answer} -Evaluation: -{evaluation} -""" - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - scores = yaml.dump( - prompt["scores"], - default_flow_style=False, - Dumper=yaml.Dumper, - ) - critiques = yaml.dump( - prompt["critiques"], - default_flow_style=False, - Dumper=yaml.Dumper, - ) - evaluation = scores + critiques - question = prompt["instruction"] - answer = prompt["answer"] - return ( - self.user_prompt.format( - question=question, answer=answer, evaluation=evaluation - ), - "", - prompt["revision"], - ) - - -class CreativePrompterBase: - """ - Base class for Creative Prompters - """ - - system_prompt = "" - prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:" - - def build_prompt( - self, - instruction: str, - input: Union[ # pylint: disable=redefined-builtin, unused-argument - None, str - ] = None, - output: Union[None, str] = None, - ) -> Generator[str, None, None]: - if self.system_prompt: - res = f"{self.system_prompt}\nUSER: {instruction}\nASSISTANT:" - else: - res = f"USER: {instruction}\nASSISTANT:" - if output: - res = f"{res}{output}" - yield res - - -class CreativeAnswerPrompter(CreativePrompterBase): - """ - Prompter for Creative Answering - """ - - system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity." - - -class CreativeCritiquePrompter(CreativePrompterBase): - """ - Prompter for Creative Critique - """ - - system_prompt = "" - - -class CreativeRevisePrompter(CreativePrompterBase): - """ - Prompter for Creative Revise - """ - - system_prompt = "" - - -def load_answer(tokenizer, cfg): - return CreativeAnsweringPromptTokenizingStrategy( - CreativeAnswerPrompter(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load_critique(tokenizer, cfg): - return CreativeCritiquePromptTokenizingStrategy( - CreativeCritiquePrompter(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load_revise(tokenizer, cfg): - return CreativeRevisePromptTokenizingStrategy( - CreativeRevisePrompter(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) diff --git a/axolotl/src/axolotl/prompt_strategies/llama2_chat.py b/axolotl/src/axolotl/prompt_strategies/llama2_chat.py deleted file mode 100644 index a1f5ffefff3f941694bdb2ba7a9f3bbd9118b748..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/llama2_chat.py +++ /dev/null @@ -1,206 +0,0 @@ -""" -Prompt Strategy for finetuning Llama2 chat models -see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation. - -This implementation is based on the Vicuna PR and the fastchat repo, see also: -https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847 - -Use dataset type: "llama2_chat" in conig.yml to use this prompt style. - -E.g. in the config.yml: -``` -datasets: - - path: llama_finetune_train.jsonl - type: llama2_chat -``` - -The dataset itself should look like this: -``` -{'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]} -``` -in a jsonl file. The first message should be from the human, the second from gpt. -For a custom system message, the first "from" can be "system" (followed by alternating "human" and "gpt" turns). - -Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing! -""" - -import logging -from dataclasses import dataclass, field -from typing import Generator, List, Sequence - -from axolotl.prompt_tokenizers import PromptTokenizingStrategy -from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE - - -@dataclass -class Llama2ChatConversation: - """A class that manages prompt templates and keeps all conversation history. - copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py""" - - name: str = "llama2" - # The system prompt - system: str = ( - "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " - "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " - "Please ensure that your responses are socially unbiased and positive in nature.\n\n" - "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " - "If you don't know the answer to a question, please don't share false information.\n<>\n\n" - ) - roles: Sequence[str] = ("[INST]", "[/INST]") - messages: List[List[str]] = field(default_factory=list) - offset: int = 0 - sep = " " - sep2 = " " - stop_token_ids = [2] - - def get_prompt(self) -> str: - """Get the prompt for generation.""" - seps = [self.sep, self.sep2] - ret = "" - for i, (role, message) in enumerate(self.messages): - if (i == len(self.messages) - 1) and (role == self.roles[0]): - # last message is from user (due to length), - # return prompt without it for training - return ret - if i == 0: - ret += self.system + message.strip() - else: - ret += role + " " + message.strip() + seps[i % 2] - return ret - - def append_message(self, role: str, message: str): - """Append a new message.""" - self.messages.append([role, message]) - - -class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy): - """ - Tokenizing strategy for ShareGPT prompts. - adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.tokenizer.add_special_tokens( - {"pad_token": getattr(self.tokenizer, "pad_token", "")} - ) - # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json - - def tokenize_prompt(self, prompt): - conv = next(self.prompter.build_prompt(prompt)) - conversation_str = conv.get_prompt() - - # Tokenize conversations - input_ids = self.tokenizer( - conversation_str, - return_tensors="pt", - padding="max_length", - max_length=self.sequence_len, - truncation=True, - ).input_ids[0] - target = input_ids.clone() - - # Mask targets. Only compute loss on the assistant outputs. - sep = conv.roles[1] - - total_len = int(target.ne(self.tokenizer.pad_token_id).sum()) - - turns = conversation_str.split(conv.sep2) - cur_len = 1 - target[:cur_len] = IGNORE_TOKEN_ID - for turn in turns: - if turn == "": - break - turn_len = len(self.tokenizer(turn).input_ids) - - parts = turn.split(sep) - if len(parts) != 2: - break - parts[0] += sep - # "-1" is hardcoded for the LLaMA tokenizer to make the offset correct. - instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1 - - # Ignore the user instructions - target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID - cur_len += turn_len + 2 # due to length of role token - - target[cur_len:] = IGNORE_TOKEN_ID - - if cur_len < self.sequence_len: - if cur_len != total_len: - target[:] = IGNORE_TOKEN_ID - logging.warning( - f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." - f" (ignored)" - ) - - attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist() - input_ids = input_ids.tolist() - target = target.tolist() - # this is a fix for the tokenizer which tokenizes [ differently with eos tokens and - # follows the original llama implementation - for i in range(2, total_len - 2): - if input_ids[i] == 29961: - input_ids[i] = 518 - if target[i] == 29961: - target[i] = 518 - return { - "input_ids": input_ids, - "labels": target, - "attention_mask": attention_mask, - } - - -class Llama2ChatPrompter: # pylint: disable=too-few-public-methods - """ - A prompter that generates prompts for Llama2 models. - """ - - system_prompt = ( - "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " - "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " - "Please ensure that your responses are socially unbiased and positive in nature.\n\n" - "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " - "If you don't know the answer to a question, please don't share false information.\n<>\n\n" - ) - - def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]: - # see https://github.com/lm-sys/FastChat/blob/da0641e567cf93756b0978ab5a6b092e96f06240/fastchat/train/train.py#L78 - source = source["conversations"] # fix data structure for datasets - - # if system prompt provided, use it - if source[0]["from"] == "system": - system = f"[INST] <>\n{source[0]['value']}\n<>\n\n" - source = source[1:] - else: - system = self.system_prompt - - conv = Llama2ChatConversation(system=system) - - if len(source) < 2: - # If there isn't a back and forth conversation, ignore it - # also happens on the data splitting leaving empty conversations - raise IndexError - - roles = {"human": conv.roles[0], "gpt": conv.roles[1]} - - if roles[source[0]["from"]] != conv.roles[0]: - # Skip the first one if it is not from human - source = source[1:] - - conv.messages = [] # pylint: disable=R0801 - for j, sentence in enumerate(source): - role = roles[sentence["from"]] - assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE - if sentence["value"]: - conv.append_message(role, sentence["value"]) - yield conv - - -def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy: - return LLama2ChatTokenizingStrategy( - Llama2ChatPrompter(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) diff --git a/axolotl/src/axolotl/prompt_strategies/metharme.py b/axolotl/src/axolotl/prompt_strategies/metharme.py deleted file mode 100644 index 52d77c00cf9333a7af09c018d0126bd8fbf8a1df..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/metharme.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class""" - -import logging -from typing import Tuple - -from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy -from axolotl.prompters import AlpacaPrompter - -LOG = logging.getLogger("axolotl") - -IGNORE_TOKEN_ID = -100 - -# pylint: disable=duplicate-code - - -class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for the Metharme models - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return (prompt["prompt"], "", prompt["generation"]) - - def _tokenize( - self, - prompt: str, - add_eos_token: bool = True, - strip_bos_token: bool = False, - num_eos_tokens: int = 3, - ): - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.sequence_len, - padding=False, - return_tensors=None, - ) - if len(result["input_ids"]) == 0: - LOG.warning("Tokenizer result is empty. You may want to audit your dataset") - # If there's already an EOS token there, subtract from the number added - if result["input_ids"][-1] == self.tokenizer.eos_token_id: - num_eos_tokens -= 1 - - if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0: - for _ in range(num_eos_tokens): - if len(result["input_ids"]) < self.sequence_len: - result["input_ids"].append(self.tokenizer.eos_token_id) - result["attention_mask"].append(1) - - if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: - result["input_ids"] = result["input_ids"][1:] - result["attention_mask"] = result["attention_mask"][1:] - - result["labels"] = result["input_ids"].copy() - return result - - -class MetharmePrompter(AlpacaPrompter): - """ - Prompter for the Metharme models. - """ - - system_prompt = "" - system_no_input_prompt = "" - system_format = "" - turn_format = "{instruction}" - turn_no_input_format = "{instruction}" - - def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called - pass - - -def load(tokenizer, cfg): - return MetharmePromptTokenizingStrategy( - MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len - ) diff --git a/axolotl/src/axolotl/prompt_strategies/orcamini.py b/axolotl/src/axolotl/prompt_strategies/orcamini.py deleted file mode 100644 index 04ce5767dd75a000007fa01b15744f4bc459595c..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/orcamini.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Prompt Strategy for finetuning Orca Mini (v2) models -see also https://huggingface.co/psmathur/orca_mini_v2_7b for more information - -Use dataset type: orcamini in conig.yml to use this prompt style. - -Compared to the alpaca_w_system.open_orca dataset type, -this one specifies the system prompt with "### System:". - -Not suited/tested for multiple-turn conversations without further adjustments. -""" -from typing import Generator, Union - -from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy -from axolotl.prompters import AlpacaPrompter - - -class OrcaMiniPrompter(AlpacaPrompter): - """Adjusted Prompter for Orca Mini (v2) datasets""" - - def match_prompt_style(self): - self.turn_no_input_format = ( - "### System:\n{system}\n\n### User:\n{instruction}\n\n### Response:\n" - ) - - def build_prompt_w_system( - self, - system: str, - instruction: str, - output: Union[None, str] = None, - ) -> Generator[str, None, None]: - # returns the full prompt from instruction and optional input - # if a label (=response, =output) is provided, it's also appended. - res = self.turn_no_input_format.format(system=system, instruction=instruction) - if output: - res = f"{res}{output}" - yield res - - -def load(tokenizer, cfg): - return OpenOrcaPromptTokenizingStrategy( - OrcaMiniPrompter(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) diff --git a/axolotl/src/axolotl/prompt_strategies/pygmalion.py b/axolotl/src/axolotl/prompt_strategies/pygmalion.py deleted file mode 100644 index 88208f6ec4329eb550344af9048d8d61d0d4d7e9..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/pygmalion.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class""" - -import copy -import logging -from collections import defaultdict -from typing import Generator, List, Tuple - -from axolotl.prompt_tokenizers import ( - PromptTokenizingStrategy, - parse_tokenized_to_result, - tokenize_prompt_default, -) - -LOG = logging.getLogger("axolotl") - -IGNORE_TOKEN_ID = -100 - - -class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): - """ - Tokenizing strategy for Pygmalion. - """ - - bot_prefix_token_ids: List[int] = [] - - def __init__(self, prompter, tokenizer, *args, **kwargs): - super().__init__(prompter, tokenizer, *args, **kwargs) - res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True) - self.bot_prefix_token_ids = res["input_ids"] - - def tokenize_prompt(self, prompt): - result, current_len = tokenize_prompt_default() - for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): - role, message = part - if role == "system": - prefix = "<|system|>" - # this should include a bos token, no eos token, strip trailing "\n" - if message.endswith("\n"): - message = message[:-8] - res = self._tokenize( - prefix + "Persona: " + message.strip(), - add_eos_token=False, - strip_bos_token=False, - ) - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif role == "human": - prefix = "<|user|>" - res = self._tokenize( - prefix + " " + message.strip(), - add_eos_token=False, - strip_bos_token=True, - ) - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif role == "bot": - prefix = "<|model|>" - res = self._tokenize( - prefix + " " + message.strip(), - add_eos_token=True, - strip_bos_token=True, - ) - # mask out the prefix token, rest is not masked out from labels - # make sure we create the labels first, otherwise we get incorrect lengths - labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [ - *copy.deepcopy(res["input_ids"]) - ][len(self.bot_prefix_token_ids) :] - else: - LOG.warning(f"unknown role in conversation: {role}") - res = defaultdict(lambda: []) - - # pylint: disable=duplicate-code - result, current_len = parse_tokenized_to_result( - result, - current_len, - res, - labels, - pad_token_id=self.tokenizer.pad_token_id, - ) - return result - - -class PygmalionPrompter: - """ - Prompter for Pygmalion. - """ - - def __init__(self, *args, **kwargs): - pass - - def build_prompt( - self, source, *args, **kwargs # pylint: disable=unused-argument - ) -> Generator[Tuple[str, str], None, None]: - for msg in source: - yield msg["role"], msg["value"] - - -def load(tokenizer, cfg): - return PygmalionPromptTokenizingStrategy( - PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len - ) diff --git a/axolotl/src/axolotl/prompt_strategies/sharegpt.py b/axolotl/src/axolotl/prompt_strategies/sharegpt.py deleted file mode 100644 index fbb44ccfae62f9426f7f19a28326154b81f1ce15..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/sharegpt.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" -from typing import Any, Dict, Optional - -from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template - -from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy -from axolotl.prompters import ShareGPTPrompterV2 - -register_conv_template( - Conversation( - name="chatml", - system_template="<|im_start|>system\n{system_message}", - system_message="You are a helpful assistant.", - roles=["<|im_start|>user", "<|im_start|>assistant"], - sep_style=SeparatorStyle.CHATML, - sep="<|im_end|>", - ) -) - - -def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - conversation = ( - ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None - ) - field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None - field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None - strategy = SimpleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation=conversation, - role_key_model=field_model, - role_key_human=field_human, - ), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - if ds_cfg and "strict" in ds_cfg: - strategy.strict = ds_cfg["strict"] - return strategy - - -def load_role(tokenizer, cfg): - return SimpleRoleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load_guanaco(tokenizer, cfg): - return GuanacoShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): - """ - basic sharegpt strategy to grab conversations from the sample row - """ - - _strict = True - - @property - def strict(self): - return self._strict - - @strict.setter - def strict(self, strict): - self._strict = strict - - def get_conversation_thread(self, prompt): - conversations = prompt["conversations"] - if self.strict: - return conversations - # remap roles - allow for assistant turn - role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"} - turns = [ - {"from": role_map[t["from"]], "value": t["value"]} for t in conversations - ] - return turns - - -class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): - """ - basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from - """ - - def get_conversation_thread(self, prompt): - conversations = prompt["conversations"] - # remap role: prompter/assistant, text: ... => from: human/gpt, value: ... - turns = [{"from": t["role"], "value": t["value"]} for t in conversations] - return turns - - -class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): - """ - sharegpt strategy that remaps oasst data to sharegpt format - """ - - def get_conversation_thread(self, prompt): - conversations = prompt["conversations"] - # remap role: prompter/assistant, text: ... => from: human/gpt, value: ... - role_map = {"prompter": "human", "assistant": "gpt"} - turns = [ - {"from": role_map[t["role"]], "value": t["text"]} for t in conversations - ] - return turns diff --git a/axolotl/src/axolotl/prompt_strategies/sharegpt_jokes.py b/axolotl/src/axolotl/prompt_strategies/sharegpt_jokes.py deleted file mode 100644 index 404302c81ea643aa841b6d1639ce6f95090ebfc9..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/sharegpt_jokes.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Module for Jokes prompts using sharegpt style """ -from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy -from axolotl.prompters import ShareGPTPrompterV2 - - -def load(tokenizer, cfg): - return SimpleJokesShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): - """ - Tokenization strategy for asking bot to tell a joke and then explain why its funny - """ - - # title, text, explanation - def get_conversation_thread(self, prompt): - title = "" if not prompt["title"] else prompt["title"] + " " - return [ - {"from": "human", "value": "Tell me a joke."}, - {"from": "gpt", "value": title + prompt["text"]}, - {"from": "human", "value": "Why is that joke funny?"}, - {"from": "gpt", "value": prompt["explanation"]}, - ] diff --git a/axolotl/src/axolotl/prompt_strategies/user_defined.py b/axolotl/src/axolotl/prompt_strategies/user_defined.py deleted file mode 100644 index e20e80c3a451fda218686128e3a32fc2a8ade369..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_strategies/user_defined.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -User Defined prompts with configuration from the YML config -""" - -from dataclasses import dataclass -from functools import partial -from typing import Optional, Tuple - -from axolotl.prompt_strategies.alpaca_w_system import ( - InstructionWSystemPromptTokenizingStrategy, - SystemDataPrompter, -) - - -@dataclass -class UserDefinedDatasetConfig: - """ - dataclass configuration representing a userdefined dataset type - """ - - system_prompt: str = "" - field_system: str = "system" - field_instruction: str = "instruction" - field_input: str = "input" - field_output: str = "output" - format: str = "{instruction} {input} " - no_input_format: str = "{instruction} " - system_format: str = "{system}" - - def __getitem__(self, item): - return getattr(self, item) - - -class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy): - """ - Prompt Tokenization Strategy for user defined prompts - """ - - -def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None): - if not ds_cfg: - raise ValueError("Missing dataset prompt configuration") - - system_prompt = "" - if ds_cfg.system_prompt: - system_prompt = ds_cfg.system_prompt - - def parse_instruction_fields( - field_instruction, - field_input, - field_output, - field_system, - system_prompt, - prompt, - ) -> Tuple[str, str, str, str]: - return ( - prompt[field_instruction], - prompt[field_input] if field_input in prompt else "", - prompt[field_output] if field_output in prompt else "", - prompt[field_system] if field_system in prompt else system_prompt, - ) - - turn_format = ds_cfg.format - turn_no_input_format = ds_cfg.no_input_format - system_format = ds_cfg.system_format - - class UserDefinedPrompter(SystemDataPrompter): - """ - Prompter for user defined prompts - """ - - def match_prompt_style(self): - self.turn_format = turn_format - self.turn_no_input_format = turn_no_input_format - self.system_format = system_format - - prompter = UserDefinedPrompter() - - strat = UserDefinedPromptTokenizationStrategy( - prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - setattr( - strat, - "parse_instruction_fields", - partial( - parse_instruction_fields, - ds_cfg.field_instruction, - ds_cfg.field_input, - ds_cfg.field_output, - ds_cfg.field_system, - system_prompt, - ), - ) - return strat diff --git a/axolotl/src/axolotl/prompt_tokenizers.py b/axolotl/src/axolotl/prompt_tokenizers.py deleted file mode 100644 index fe4f3b62f79aa5046cb5e6f128bb26e2c98fbd40..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompt_tokenizers.py +++ /dev/null @@ -1,468 +0,0 @@ -"""Module containing PromptTokenizingStrategy and Prompter classes""" - -import abc -import copy -import logging -from typing import Dict, List, Tuple, Union - -from fastchat.conversation import Conversation -from transformers import BatchEncoding, PreTrainedTokenizer - -from axolotl.monkeypatch.fastchat_conversation_turns import ( - add_get_turns_to_conversation, -) -from axolotl.prompters import IGNORE_TOKEN_ID - -LOG = logging.getLogger("axolotl") - -IGNORE_INDEX = -100 -LLAMA_DEFAULT_PAD_TOKEN = "" # nosec -LLAMA_DEFAULT_EOS_TOKEN = "" # nosec -LLAMA_DEFAULT_BOS_TOKEN = "" # nosec -LLAMA_DEFAULT_UNK_TOKEN = "" # nosec - -add_get_turns_to_conversation() - - -class InvalidDataException(Exception): - """ - Exception raised when the data is invalid - """ - - -class PromptTokenizingStrategy(abc.ABC): - """ - Abstract class for tokenizing strategies - """ - - def __init__( - self, - prompter, - tokenizer, - train_on_inputs: bool = False, - sequence_len: int = 2048, - ): - self.prompter = prompter - self.tokenizer: PreTrainedTokenizer = tokenizer - self.train_on_inputs = train_on_inputs - # sequence_len and max_length can be different for CompletionPromptTokenizingStrategy. - # TODO: Document how they are different. - self.sequence_len = sequence_len - self.max_length = sequence_len - - @abc.abstractmethod - def tokenize_prompt(self, prompt): - pass - - @property - def supports_batched(self): - return False - - def _tokenize( - self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False - ) -> BatchEncoding: - empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) - if not prompt: - LOG.warning("Empty text requested for tokenization.") - return empty - - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.max_length, - padding=False, - return_tensors=None, - ) - if len(result["input_ids"]) == 0: - LOG.warning("Tokenizer result is empty. You may want to audit your dataset") - return empty - - if ( - result["input_ids"][-1] != self.tokenizer.eos_token_id - and len(result["input_ids"]) < self.max_length - and add_eos_token - ): - result["input_ids"].append(self.tokenizer.eos_token_id) - result["attention_mask"].append(1) - - if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: - result["input_ids"] = result["input_ids"][1:] - result["attention_mask"] = result["attention_mask"][1:] - - result["labels"] = result["input_ids"].copy() - return result - - -class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): - """ - Tokenizing strategy for instruction-based prompts. - """ - - def parse_instruction_fields( - self, prompt - ) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]: - raise NotImplementedError - - def tokenize_prompt(self, prompt): - ( - instruction, - input, # pylint: disable=redefined-builtin - response, - ) = self.parse_instruction_fields(prompt) - user_prompt = next( - iter( - self.prompter.build_prompt( - instruction, - input, - ) - ) - ) - tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False) - if not self.train_on_inputs: - user_prompt_len = len(tokenized_prompt["input_ids"]) - # TODO this could be sped up using numpy array slicing - tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len - tokenized_res_prompt = self._tokenize( - response, strip_bos_token=True, add_eos_token=True - ) - tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"] - tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] - tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] - - return tokenized_prompt - - def _build_full_prompt( - self, instruction, input, response # pylint: disable=redefined-builtin - ): - return next( - iter( - self.prompter.build_prompt( - instruction, - input, - response, - ) - ) - ) - - -class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for Alpaca prompts. - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - prompt["instruction"], - prompt["input"] if "input" in prompt else "", - prompt["output"], - ) - - -class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for Alpaca Multiple Choice prompts. - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - prompt["question"], - "\n".join(f'- "{choice}"' for choice in prompt["choices"]), - prompt["solution"] if "solution" in prompt else prompt["explanation"], - ) - - -class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for Jeopardy prompts. - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - prompt["question"], - prompt["category"], - "what is " + prompt["answer"], - ) - - -class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for OpenAssistant prompts. - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - prompt["INSTRUCTION"], - "", - prompt["RESPONSE"], - ) - - -class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for SummarizeTLDR prompts. - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - prompt["article"], - "", - prompt["summary"], - ) - - -class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for GPTeacher prompts. - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - prompt["instruction"], - prompt["input"] if "input" in prompt else "", - prompt["response"], - ) - - -class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - """ - Tokenizing strategy for NomicGPT4All prompts. - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: - return ( - prompt["prompt"], - "", - prompt["response"], - ) - - -class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): - """ - Tokenizing strategy for Reflection prompts. - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]: - raise NotImplementedError - - def tokenize_prompt(self, prompt): - # pylint: disable=duplicate-code - ( - instruction, - input, # pylint: disable=redefined-builtin - output, - reflection, - corrected, - ) = self.parse_instruction_fields(prompt) - full_prompt = self._build_full_prompt( - instruction, input, output, reflection, corrected - ) - tokenized_full_prompt = self._tokenize(full_prompt) - if not self.train_on_inputs: - user_prompt = next( - iter( - self.prompter.build_prompt( - instruction, - input, - ) - ) - ) - tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) - user_prompt_len = len(tokenized_user_prompt["input_ids"]) - # TODO this could be sped up using numpy array slicing - tokenized_full_prompt["labels"] = [ - IGNORE_INDEX - ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] - - return tokenized_full_prompt - - def _build_full_prompt( - self, instruction, input, output, reflection, corrected - ): # pylint: disable=redefined-builtin - return next( - iter( - self.prompter.build_prompt( - instruction, - input, - output, - reflection, - corrected, - ) - ) - ) - - def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.sequence_len, - padding=False, - return_tensors=None, - ) - if ( - result["input_ids"][-1] != self.tokenizer.eos_token_id - and len(result["input_ids"]) < self.sequence_len - and add_eos_token - ): - result["input_ids"].append(self.tokenizer.eos_token_id) - result["attention_mask"].append(1) - - result["labels"] = result["input_ids"].copy() - return result - - -class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy): - """ - Tokenizing strategy for Alpaca Reflection prompts. - """ - - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]: - return ( - prompt["instruction"], - prompt["input"] if "input" in prompt else "", - prompt["output"], - prompt["reflection"], - prompt["corrected"], - ) - - -class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): - """ - Tokenizing strategy for ShareGPT prompts. - """ - - def get_conversation_thread(self, prompt): - return prompt["conversations"] - - def tokenize_prompt(self, prompt): - # Initial values. We will append to these as we go through the conversation. - result, current_len = tokenize_prompt_default() - conversation: Conversation = ( - self.prompter._conversation.copy() # pylint: disable=protected-access - ) - - # support for custom roles from the dataset, only useful for vicuna style prompts/roles - role_remap = [] - if ( - conversation.name == "vicuna_v1.1" - and "roles" in prompt - and len(prompt["roles"]) >= 2 - ): - role_remap = [ - {"from": conversation.roles[0], "to": prompt["roles"][0]}, - {"from": conversation.roles[1], "to": prompt["roles"][1]}, - ] - - try: - for _, part in enumerate( - self.prompter.build_prompt(self.get_conversation_thread(prompt)) - ): - if not isinstance(part, tuple): - LOG.warning(f"expected tuple, got {part}") - continue - - user, assistant = conversation.roles - role, content = part - - # Uses "in" because role contains extra characters - if user in role: - role = ( - role.replace(role_remap[0]["from"], role_remap[0]["to"]) - if role_remap - else role - ) - turn = role + content - # this is still the user query, we should - if not content.strip(): - LOG.warning(f"user turn has empty text: {prompt}") - res = self._tokenize( - turn, - add_eos_token=False, - strip_bos_token=True, - ) - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif assistant in role: - # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID - role = ( - role.replace(role_remap[1]["from"], role_remap[1]["to"]) - if role_remap - else role - ) - turn = role + content - # this should be the assistant response, should end with an eos token - if not content.strip(): - LOG.warning(f"assistant turn has empty text: {prompt}") - res = self._tokenize( - turn, - add_eos_token=True, - strip_bos_token=True, - ) - role_res = self._tokenize( - role.rstrip(), - add_eos_token=False, - strip_bos_token=True, - ) - # not masked out from labels - labels = copy.deepcopy(res["input_ids"]) - len_role = len(role_res["input_ids"]) - labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels)) - elif role == "": - turn = content - # this is only ever the first part, should include the bos token and the user query - res = self._tokenize( - turn, add_eos_token=False, strip_bos_token=False - ) - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - else: - LOG.warning(f"unhandled role: {role}") - continue - - # pylint: disable=duplicate-code - result, current_len = parse_tokenized_to_result( - result, - current_len, - res, - labels, - pad_token_id=self.tokenizer.pad_token_id, - ) - return result - except (KeyError, AssertionError, IndexError) as err: - raise InvalidDataException(str(err)) from err - - -def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: - """ - Returns the default values for the tokenize prompt function - """ - - result: Dict[str, List[int]] = { - "input_ids": [], - "attention_mask": [], - "labels": [], - } - current_len = 0 - return result, current_len - - -def parse_tokenized_to_result( - result: Dict[str, List[int]], - current_len: int, - res: Dict[str, List[int]], - labels: List[int], - pad_token_id: Union[int, None] = None, -) -> Tuple[Dict[str, List[int]], int]: - """ - Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result - """ - - input_ids = res["input_ids"] - input_len = len(input_ids) - result["input_ids"][current_len : current_len + input_len] = input_ids - result["attention_mask"][current_len : current_len + input_len] = [ - 1 if x != pad_token_id else 0 for x in input_ids - ] - result["labels"][current_len : current_len + input_len] = labels - current_len += input_len - - return result, current_len diff --git a/axolotl/src/axolotl/prompters.py b/axolotl/src/axolotl/prompters.py deleted file mode 100644 index 73966def3f3656eab27753aae013907d67583a79..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/prompters.py +++ /dev/null @@ -1,367 +0,0 @@ -"""Module containing prompters""" - -import logging -from enum import Enum -from typing import Generator, Optional, Union - -from colorama import Fore -from fastchat.conversation import Conversation, get_conv_template - -LOG = logging.getLogger("axolotl") -IGNORE_TOKEN_ID = -100 -REPR_TEMPLATE = "\n\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n\n" - - -class PromptStyle(Enum): - """ - Enum for prompt styles - """ - - INSTRUCT = "instruct" - CHAT = "chat" - CHATML = "chatml" - - -class Prompter: - """ - Base prompter class for all prompters - """ - - -class AlpacaPrompter(Prompter): - """ - Base class for alpaca prompters - """ - - system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." - system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." - system_format: str = "{system}" - turn_format: str - turn_no_input_format: str - prompt_style: Optional[PromptStyle] = None - - def __init__(self, prompt_style=PromptStyle.INSTRUCT.value): - self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value - self.match_prompt_style() - - def match_prompt_style(self): - # pylint: disable=duplicate-code - if self.prompt_style == PromptStyle.INSTRUCT.value: - self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" - self.turn_no_input_format = ( - "### Instruction:\n{instruction}\n\n### Response:\n" - ) - self.system_format = "### System:\n{system}\n\n" - if self.prompt_style == PromptStyle.CHAT.value: - self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" - self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" - self.system_format = "SYSTEM: {system}\n" - if self.prompt_style == PromptStyle.CHATML.value: - self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n" - self.turn_no_input_format = ( - "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n" - ) - self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" - - def _build_result(self, instruction, input_text, output): - # returns the full prompt from instruction and optional input - # if a label (=response, =output) is provided, it's also appended. - if input_text: - res = ( - self.system_format.format(system=self.system_prompt) - if self.system_prompt - else "" - ) + self.turn_format.format(instruction=instruction, input=input_text) - else: - res = ( - self.system_format.format(system=self.system_no_input_prompt) - if self.system_no_input_prompt - else "" - ) + self.turn_no_input_format.format(instruction=instruction) - if output: - res = f"{res}{output}" - - return res - - def build_prompt( - self, - instruction: str, - input: Union[None, str] = None, # pylint: disable=redefined-builtin - output: Union[None, str] = None, - ) -> Generator[str, None, None]: - yield self._build_result(instruction, input, output) - - def __repr__(self) -> str: - return REPR_TEMPLATE.format( - full_prompt=self._build_result("{instruction}", "{input}", "{output}") - ) - - -class UnpromptedPrompter(AlpacaPrompter): - """ - Prompter for alpaca no system prompt - """ - - system_prompt = "" - system_no_input_prompt = "" - - -class JeopardyPrompter(AlpacaPrompter): - """ - Prompter for Jeopardy - """ - - prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" - - -class MultipleChoiceExplainPrompter(AlpacaPrompter): - """ - Prompter for multiple choice explain - """ - - system_prompt = ( - "Choose the answer that best answers the question. Explain your reasoning.\n" - ) - system_no_input_prompt = ( - "Choose the answer that best answers the question. Explain your reasoning.\n" - ) - - -class MultipleChoiceConcisePrompter(AlpacaPrompter): - """ - Prompter for multiple choice concise - """ - - system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n" - system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n" - - def match_prompt_style(self): - self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" - self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" - - -class SummarizeTLDRPrompter(AlpacaPrompter): - """ - Prompter for summarize TLDR - """ - - system_prompt = "" - system_no_input_prompt = "" - - def match_prompt_style(self): - self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:" - self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" - - -class GPTeacherPrompter(AlpacaPrompter): - """ - Prompter for GPTeacher - """ - - -class NomicGPT4AllPrompter(AlpacaPrompter): - """ - Prompter for NomicGPT4All - """ - - -class ReflectAlpacaPrompter(Prompter): - """ - Prompter for ReflectAlpaca - """ - - system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n" - system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n" - - prompt_input = ( - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" - ) - prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n" - agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}" - response_split = "### Response:" - - def __init__(self, prompt_style="instruct"): - self.prompt_style = prompt_style - self.match_prompt_style() - - def match_prompt_style(self): - if self.prompt_style == PromptStyle.INSTRUCT.value: - self.prompt_input = ( - self.system_prompt - + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" - ) - self.prompt_no_input = ( - self.system_no_input_prompt - + "### Instruction:\n{instruction}\n\n### Response:\n" - ) - self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}" - self.response_split = "### Final Response:" - if self.prompt_style == PromptStyle.CHAT.value: - self.prompt_input = ( - self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" - ) - self.prompt_no_input = ( - self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:" - ) - self.agent_label = ( - "\nTHOUGHT: {output}\nASSISTANT REFLECTION: {reflection}\nASSISTANT:" - ) - self.response_split = "ASSISTANT:" - - def _build_result( - self, - instruction: str, - input: Union[None, str] = None, # pylint: disable=redefined-builtin - output: Union[None, str] = None, - reflection: Union[None, str] = None, - corrected: Union[None, str] = None, - ): - # returns the full prompt from instruction and optional input - # if a label (=response, =output) is provided, it's also appended. - if input: - res = self.prompt_input.format(instruction=instruction, input=input) - else: - res = self.prompt_no_input.format(instruction=instruction) - if output and reflection and corrected: - label = self.agent_label.format( - output=output, - reflection=reflection, - corrected=corrected, - ) - res = f"{res}{label}" - - return res - - def build_prompt( - self, - instruction: str, - input: Union[None, str] = None, # pylint: disable=redefined-builtin - output: Union[None, str] = None, - reflection: Union[None, str] = None, - corrected: Union[None, str] = None, - ) -> Generator[str, None, None]: - # pylint: disable=duplicate-code - yield self._build_result( - instruction, - input, - output, - reflection, - corrected, - ) - - def __repr__(self) -> str: - return REPR_TEMPLATE.format( - full_prompt=self._build_result("{instruction}", "{input}", "{output}") - ) - - -SHAREGPT_ASSERTION_FAILED_ROLE = ( - "Role did not alternate between turns (gpt and human). Please check your data." -) - - -class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods - """ - A prompter that generates prompts for the ShareGPT - """ - - role_key_human = "human" - role_key_model = "gpt" - - def __init__( - self, - prompt_style=None, # pylint: disable=unused-argument - conversation: Optional[Union[str, Conversation]] = None, - role_key_human: Optional[str] = None, - role_key_model: Optional[str] = None, - ): - if conversation: - if isinstance(conversation, Conversation): - self._conversation = conversation - else: - self._conversation = get_conv_template(conversation) - else: - self._conversation = get_conv_template("vicuna_v1.1") - if role_key_human: - self.role_key_human = role_key_human - if role_key_model: - self.role_key_model = role_key_model - - def _build_result(self, source): - if len(source) < 2: - # If there isn't a back and forth conversation, ignore it - # also happens on the data splitting leaving empty conversations - raise IndexError( - f"A conversation entry has less than 2 messages :\n{source}" - ) - - conv = self._conversation.copy() - - # Add the conversation system prompt if provided, otherwise use the default one - if source[0]["from"] == "system": - conv.set_system_message(source[0]["value"]) - source.pop(0) - - roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]} - - try: - # Apply prompt templates - if source[0]["from"] not in roles: - # Skip the first one if it is not from human - source = source[1:] - except IndexError as err: - # sometimes there is a bing or system chat - raise err - - conv.messages = [] - for _, sentence in enumerate(source): - role = roles[sentence["from"]] - if len(conv.messages) > 0 and ( - (role == conv.messages[-1][0]) or (role not in conv.roles) - ): - LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") - conv.append_message(role, sentence["value"]) - - return conv.get_turns() - - def build_prompt(self, source) -> Generator[str, None, None]: - turns = self._build_result(source) - - for part in turns: - if part[0] and not part[1]: - LOG.warning(f"role with empty message: {part[0]}") - yield part - - def __repr__(self) -> str: - turns = self._build_result([{"from": "{from}", "value": "{value}"}]) - return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns]) - - -class ShareGPTPrompterV2(ShareGPTPrompter): - """ - A V2 prompter that generates prompts for the ShareGPT - """ - - def __init__( - self, - conversation: Optional[Union[str, Conversation]] = None, - role_key_human: Optional[str] = None, - role_key_model: Optional[str] = None, - ): - super().__init__( - conversation=conversation, - role_key_human=role_key_human, - role_key_model=role_key_model, - ) - - -class UnsupportedPrompter(Prompter): - """ - A dummy class for custom prompters - """ - - def __init__(self) -> None: - pass - - def __repr__(self): - return "Pre-tokenized or custom dataset types are unsupported for logging" diff --git a/axolotl/src/axolotl/train.py b/axolotl/src/axolotl/train.py deleted file mode 100644 index 4e5241e4c84cf521eb3622457b65216365a83761..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/train.py +++ /dev/null @@ -1,204 +0,0 @@ -"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" - -import os -import signal -import sys -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - -import torch -import transformers.modelcard -from accelerate.logging import get_logger -from datasets import Dataset -from optimum.bettertransformer import BetterTransformer -from pkg_resources import get_distribution # type: ignore -from transformers.deepspeed import is_deepspeed_zero3_enabled - -from axolotl.common.cli import TrainerCliArgs -from axolotl.logging_config import configure_logging -from axolotl.utils.dict import DictDefault -from axolotl.utils.freeze import freeze_parameters_except -from axolotl.utils.models import load_model, load_tokenizer -from axolotl.utils.trainer import setup_trainer - -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -src_dir = os.path.join(project_root, "src") -sys.path.insert(0, src_dir) - -configure_logging() -LOG = get_logger("axolotl.train") - - -@dataclass -class TrainDatasetMeta: - """ - dataclass to capture the dataset specific options for training - """ - - train_dataset: Dataset - eval_dataset: Optional[Dataset] = None - total_num_steps: Optional[int] = None - - -def train( - *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta -): - # load the tokenizer first - LOG.debug( - f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", - main_process_only=True, - ) - tokenizer = load_tokenizer(cfg) - - train_dataset = dataset_meta.train_dataset - eval_dataset = dataset_meta.eval_dataset - total_num_steps = dataset_meta.total_num_steps - - # Load the model and tokenizer - msg = "loading model" - if cfg.adapter: - msg += " and peft_config..." - LOG.debug(msg) - model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) - - safe_serialization = cfg.save_safetensors is True - - if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: - possible_checkpoints = [ - str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") - ] - if len(possible_checkpoints) > 0: - sorted_paths = sorted( - possible_checkpoints, - key=lambda path: int(path.split("-")[-1]), - ) - cfg.resume_from_checkpoint = sorted_paths[-1] - LOG.info( - f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" - ) - resume_from_checkpoint = cfg.resume_from_checkpoint - - if cfg.unfrozen_parameters: - freeze_parameters_except(model, cfg.unfrozen_parameters) - - trainer = setup_trainer( - cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps - ) - - if hasattr(model, "config"): - model.config.use_cache = False - - # go ahead and presave, so we have the adapter config available to inspect - if peft_config: - LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") - peft_config.save_pretrained(cfg.output_dir) - # additionally presave the tokenizer and model configs - if not Path(cfg.output_dir).is_dir(): - os.makedirs(cfg.output_dir, exist_ok=True) - tokenizer.save_pretrained(str(Path(cfg.output_dir))) - if hasattr(model, "config"): - model.config.save_pretrained(str(Path(cfg.output_dir))) - - # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model - if cfg.local_rank == 0: - - def terminate_handler(_, __, model): - if cfg.flash_optimum: - model = BetterTransformer.reverse(model) - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) - sys.exit(0) - - signal.signal( - signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) - ) - - badge_markdown = """[Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl)""" - transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" - - if getattr(cfg, "axolotl_config_path"): - raw_axolotl_cfg = Path(cfg.axolotl_config_path) - version = get_distribution("axolotl").version - if raw_axolotl_cfg.is_file(): - transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n
See axolotl config\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n

\n" - - LOG.info("Starting trainer...") - if cfg.group_by_length: - LOG.info("hang tight... sorting dataset for group_by_length") - - pretrain_hooks(cfg, trainer) - if cfg.flash_optimum: - with torch.backends.cuda.sdp_kernel( - enable_flash=True, enable_math=True, enable_mem_efficient=True - ): - trainer.train(resume_from_checkpoint=resume_from_checkpoint) - else: - trainer.train(resume_from_checkpoint=resume_from_checkpoint) - post_train_hooks(cfg, trainer) - - LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") - - # post training - for name, module in model.named_modules(): - if hasattr(module, "_post_training"): - module._post_training(model, name) # pylint: disable=protected-access - - if trainer.is_fsdp_enabled: - trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") - LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.") - - if cfg.relora_steps: - if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): - model = model.merge_and_unload() - else: - # final model weights have already been saved by `ReLoRACallback.on_train_end` - return model, tokenizer - - # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading - # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file - if cfg.fsdp: - trainer.save_model(cfg.output_dir) - elif cfg.deepspeed and is_deepspeed_zero3_enabled(): - # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading - trainer.accelerator.wait_for_everyone() - unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped) - - # Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if - # `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or - # `zero3_save_16bit_model` is True in DeepSpeed Plugin. - # For Zero Stages 1 and 2, models are saved as usual in the output directory. - # The model name saved is `pytorch_model.bin` - unwrapped_model.save_pretrained( - cfg.output_dir, - is_main_process=trainer.accelerator.is_main_process, - save_function=trainer.accelerator.save, - state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped), - ) - elif cfg.local_rank == 0: - if cfg.flash_optimum: - model = BetterTransformer.reverse(model) - - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) - - if not cfg.hub_model_id: - trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) - - return model, tokenizer - - -def pretrain_hooks(_cfg, _trainer): - """ - Run hooks right before kicking off the training - :param cfg: - :param trainer: - :return: - """ - - -def post_train_hooks(_cfg, _trainer): - """ - Run hooks right after training completes - :param cfg: - :param trainer: - :return: - """ diff --git a/axolotl/src/axolotl/utils/__init__.py b/axolotl/src/axolotl/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/axolotl/src/axolotl/utils/bench.py b/axolotl/src/axolotl/utils/bench.py deleted file mode 100644 index 40be0d9ac8665b99ded9d7b6c643ee2d1dc53e03..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/bench.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Benchmarking and measurement utilities""" -import functools - -import pynvml -import torch -from pynvml.nvml import NVMLError - - -def check_cuda_device(default_value): - """ - wraps a function and returns the default value instead of running the - wrapped function if cuda isn't available or the device is auto - :param default_value: - :return: - """ - - def deco(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - device = kwargs.get("device", args[0] if args else None) - - if ( - not torch.cuda.is_available() - or device == "auto" - or torch.device(device).type == "cpu" - ): - return default_value - - return func(*args, **kwargs) - - return wrapper - - return deco - - -@check_cuda_device(0.0) -def gpu_memory_usage(device=0): - return torch.cuda.memory_allocated(device) / 1024.0**3 - - -@check_cuda_device((0.0, 0.0, 0.0)) -def gpu_memory_usage_all(device=0): - usage = torch.cuda.memory_allocated(device) / 1024.0**3 - reserved = torch.cuda.memory_reserved(device) / 1024.0**3 - smi = gpu_memory_usage_smi(device) - return usage, reserved - usage, max(0, smi - reserved) - - -@check_cuda_device(0.0) -def gpu_memory_usage_smi(device=0): - if isinstance(device, torch.device): - device = device.index - if isinstance(device, str) and device.startswith("cuda:"): - device = int(device[5:]) - try: - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(device) - info = pynvml.nvmlDeviceGetMemoryInfo(handle) - return info.used / 1024.0**3 - except NVMLError: - return 0.0 - - -def log_gpu_memory_usage(log, msg, device): - usage, cache, misc = gpu_memory_usage_all(device) - extras = [] - if cache > 0: - extras.append(f"+{cache:.03f}GB cache") - if misc > 0: - extras.append(f"+{misc:.03f}GB misc") - log.info( - f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2 - ) - return usage, cache, misc diff --git a/axolotl/src/axolotl/utils/callbacks.py b/axolotl/src/axolotl/utils/callbacks.py deleted file mode 100644 index 8599c0df0fd32ad0a5e65e96d7dcf59e955c5911..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/callbacks.py +++ /dev/null @@ -1,570 +0,0 @@ -"""Callbacks for Trainer class""" - -from __future__ import annotations - -import logging -import os -from typing import TYPE_CHECKING, Dict, List - -import evaluate -import numpy as np -import pandas as pd -import torch -import torch.distributed as dist -import wandb -from datasets import load_dataset -from optimum.bettertransformer import BetterTransformer -from tqdm import tqdm -from transformers import ( - GenerationConfig, - Trainer, - TrainerCallback, - TrainerControl, - TrainerState, - TrainingArguments, -) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy - -from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.distributed import ( - barrier, - broadcast_dict, - gather_scalar_from_all_ranks, - get_world_size, - is_distributed, - is_main_process, - zero_first, -) - -if TYPE_CHECKING: - from axolotl.core.trainer_builder import AxolotlTrainingArguments - -LOG = logging.getLogger("axolotl.callbacks") -IGNORE_INDEX = -100 - - -class EvalFirstStepCallback( - TrainerCallback -): # pylint: disable=too-few-public-methods disable=unused-argument - """ - Callback to trigger evals on the first step - """ - - def on_step_end( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ): - if ( - args.evaluation_strategy == IntervalStrategy.STEPS - and args.eval_steps < 1.0 - and state.global_step == 1 - ): - control.should_evaluate = True - return control - - -class SaveBetterTransformerModelCallback( - TrainerCallback -): # pylint: disable=too-few-public-methods - """Callback to save the BetterTransformer wrapped model""" - - def on_step_end( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ): - # Save - if ( - args.save_strategy == IntervalStrategy.STEPS - and args.save_steps > 0 - and state.global_step % args.save_steps == 0 - ): - control.should_save = True - - if control.should_save: - checkpoint_folder = os.path.join( - args.output_dir, - f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", - ) - - model = BetterTransformer.reverse(kwargs["model"]) - model.save_pretrained(checkpoint_folder) - # FIXME - need to cleanup old checkpoints - - # since we're saving here, we don't need the trainer loop to attempt to save too b/c - # the trainer will raise an exception since it can't save a BetterTransformer wrapped model - control.should_save = False - return control - - -class GPUStatsCallback( - TrainerCallback -): # pylint: disable=too-few-public-methods disable=unused-argument - """Callback to track GPU utilization""" - - def __init__(self, cfg): - self.cfg = cfg - self.logged = False - - def on_step_end( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ): - if not self.logged and state.global_step > 1: - log_gpu_memory_usage(LOG, "while training", self.cfg.device) - self.logged = True - return control - - -class LossWatchDogCallback(TrainerCallback): - """Callback to track loss and stop training if loss is too high""" - - def __init__(self, cfg): - self.cfg = cfg - self.logged = False - self.violations = 0 - self.threshold = cfg.loss_watchdog_threshold - self.patience = cfg.loss_watchdog_patience or 3 - - def on_step_end( - self, - _args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **_kwargs, - ): - if len(state.log_history) > 0 and "loss" in state.log_history[-1]: - if state.log_history[-1]["loss"] > self.threshold: - self.violations += 1 - if self.violations >= self.patience: - LOG.warning( - "Loss is too high, stopping training (loss_watchdog_threshold)" - ) - control.should_training_stop = True - else: - self.violations = 0 - return control - - -def bench_eval_callback_factory(trainer, tokenizer): - accuracy = evaluate.load("accuracy") - abcd_idx = [ - tokenizer("A", add_special_tokens=False).input_ids[0], - tokenizer("B", add_special_tokens=False).input_ids[0], - tokenizer("C", add_special_tokens=False).input_ids[0], - tokenizer("D", add_special_tokens=False).input_ids[0], - tokenizer("E", add_special_tokens=False).input_ids[0], - tokenizer("F", add_special_tokens=False).input_ids[0], - tokenizer("G", add_special_tokens=False).input_ids[0], - ] - bench_split = "eval" - - def transform_bench_subject(example): - # Split on ':' and trim whitespace - parts = example["subject"].split(":") - first_part = ( - parts[0].strip().lower().replace("-", "_") - ) # Lowercase the first part - second_part = ( - parts[1].strip().replace("-", "_") if len(parts) > 1 else "all" - ) # Replace hyphens with underscores - - # Return the transformed values - return {"name": first_part, "subject": second_part} - - if trainer.args.bench_dataset == "mmlu-zs": - bench_dataset = load_dataset( - "openaccess-ai-collective/mmlu-evals", - data_files={ - "eval": "zero_shot_mmlu_val.json", - "test": "zero_shot_mmlu_test.json", - }, - ) - # bench_dataset = bench_dataset.remove_columns("subject") - # MMLU Five-shot (Eval/Test only) - elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]: - bench_dataset = load_dataset( - "openaccess-ai-collective/mmlu-evals", - data_files={ - "eval": "five_shot_mmlu_val.json", - "test": "five_shot_mmlu_test.json", - }, - ) - # bench_dataset = bench_dataset.remove_columns('subject') - elif "/" in trainer.args.bench_dataset: - bench_ds = trainer.args.bench_dataset - bench_ds_name = "/".join(bench_ds.split("/", 2)[:2]) - bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:]) - bench_dataset = load_dataset( - bench_ds_name, - data_files={ - "eval": bench_ds_data_file, - }, - ) - bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject) - else: - raise ValueError( - f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args" - ) - bench_dataset = bench_dataset[trainer.args.bench_split] - if trainer.args.max_bench_samples is not None: - bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples)) - - def tokenize_evals(example): - source = f"{tokenizer.bos_token}{example['input']}" - target = f"{example['output']}{tokenizer.eos_token}" - - tokenized_source = tokenizer( - source, - max_length=2048, - truncation=True, - add_special_tokens=False, - ) - tokenized_target = tokenizer( - target, - max_length=2048, - truncation=True, - add_special_tokens=False, - ) - input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"] - labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[ - "input_ids" - ] - - return { - "input_ids": input_ids, - "labels": labels, - "subject": example["subject"], - } - - with zero_first(is_main_process()): - bench_dataset = bench_dataset.map(tokenize_evals) - bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx) - - class BenchEvalCallback(TrainerCallback): - """ - TrainerCallback that runs the MMLU evals - """ - - def on_evaluate( - self, - args: AxolotlTrainingArguments, - state: TrainerState, # pylint: disable=unused-argument - control: TrainerControl, # pylint: disable=unused-argument - metrics: Dict[str, float], # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument - ): - data_loader = trainer.get_bench_dataloader( - bench_dataset.remove_columns(["input", "subject", "output", "name"]) - ) - trainer.model.eval() - preds, refs = [], [] - loss_bench = 0 - for batch in tqdm(data_loader, total=len(data_loader)): - (loss, logits, labels) = trainer.prediction_step( - trainer.model, - batch, - prediction_loss_only=False, - ) - # There are two tokens, the output, and eos token. - for i, logit in enumerate(logits): - label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[ - 0 - ][0] - logit_abcd = logit[label_non_zero_id - 1][abcd_idx] - preds.append(torch.argmax(logit_abcd).item()) - labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0] - refs += [ - abcd_idx.index(label) if label in abcd_idx else -1 - for label in labels.tolist() - ] - loss_bench += loss.item() - # Extract results by subject. - bench_name = bench_dataset["name"] - bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)} - for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name - bench_names[s]["preds"].append(p) - bench_names[s]["refs"].append(r) - barrier() - local_bench_names = bench_names - gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())] - # Gather results from all GPUs to GPU 0 - - loss_bench_ranks = gather_scalar_from_all_ranks( - lambda: loss_bench, get_world_size() - ) - len_data_loader_ranks = gather_scalar_from_all_ranks( - lambda: len(data_loader), get_world_size() - ) - - results = {} - if is_distributed() and not is_main_process(): - dist.gather_object(local_bench_names, dst=0) - else: - if is_distributed(): - dist.gather_object(local_bench_names, gathered_bench_names, dst=0) - else: - gathered_bench_names = [local_bench_names] - bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks) - results = {f"{bench_split}_bench_loss": bench_loss} - - # Combine results from all GPUs - combined_bench_names: Dict[str, Dict[str, List]] = {} - for bench_name in gathered_bench_names: - for name, data in bench_name.items(): - if name not in combined_bench_names: - combined_bench_names[name] = {"refs": [], "preds": []} - combined_bench_names[name]["refs"].extend(data["refs"]) - combined_bench_names[name]["preds"].extend(data["preds"]) - - bench_scores = [] - bench_refs = [] - bench_preds = [] - for ( - bench_name - ) in combined_bench_names: # pylint: disable=consider-using-dict-items - bench_score = accuracy.compute( - references=combined_bench_names[bench_name]["refs"], - predictions=combined_bench_names[bench_name]["preds"], - )["accuracy"] - bench_refs.extend(combined_bench_names[bench_name]["refs"]) - bench_preds.extend(combined_bench_names[bench_name]["preds"]) - if not pd.isna(bench_score): - results[ - f"{bench_split}_bench_accuracy_{bench_name}" - ] = bench_score - bench_scores.append(bench_score) - else: - results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0 - bench_scores.append(0.0) - results[f"{bench_split}_bench_average_accuracy"] = np.mean(bench_scores) - results[f"{bench_split}_bench_total_accuracy"] = accuracy.compute( - references=bench_refs, predictions=bench_preds - )["accuracy"] - trainer.log(results) - - results = broadcast_dict(results) - for key, val in results.items(): - metrics[key] = val - - return BenchEvalCallback - - -def log_prediction_callback_factory(trainer: Trainer, tokenizer): - class LogPredictionCallback(TrainerCallback): - """Callback to log prediction values during each evaluation""" - - def __init__(self, cfg): - self.cfg = cfg - self.logged = False - - def on_evaluate( - self, - args: AxolotlTrainingArguments, # pylint: disable=unused-argument - state: TrainerState, - control: TrainerControl, - train_dataloader, # pylint: disable=unused-argument - eval_dataloader, - **kwargs, # pylint: disable=unused-argument - ): - eval_table_size = self.cfg.eval_table_size - - if eval_table_size <= 0: - return control - - trainer.model.eval() - device = torch.device(self.cfg.device) - - # pylint: disable=duplicate-code - generation_config = GenerationConfig( - max_new_tokens=self.cfg.eval_table_max_new_tokens, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=False, - use_cache=True, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - - def logits_to_tokens(logits) -> torch.Tensor: - probabilities = torch.softmax(logits, dim=-1) - # Get the predicted token ids (the ones with the highest probability) - predicted_token_ids = torch.argmax(probabilities, dim=-1) - return predicted_token_ids - - def find_ranges(lst): - ranges = [] - start = 0 - for i in range(1, len(lst)): - if lst[i] == 0: - ranges.append((start, i - 1)) - start = i - end = len(lst) - 1 - ranges.append((start, end)) - return ranges - - def log_table_from_dataloader(name: str, table_dataloader): - table = wandb.Table( # type: ignore[attr-defined] - columns=[ - "id", - "Prompt", - "Correct Completion", - "Predicted Completion (model.generate)", - "Predicted Completion (trainer.prediction_step)", - ] - ) - row_index = 0 - - for batch in tqdm(table_dataloader): - if row_index > eval_table_size: - break - - batch_labels = batch["labels"].to(device) - batch_input_ids = batch["input_ids"].to(device) - - if "position_ids" in batch: - batch_pos_ids = batch["position_ids"].tolist() - else: - batch_pos_ids = [None] * len(batch["input_ids"]) - - (_, batch_logits, _) = trainer.prediction_step( - trainer.model, - batch, - prediction_loss_only=False, - ) - - prompt_token_ids_list = [] - pred_step_token_ids_list = [] - completion_token_ids_list = [] - - for input_ids_all, labels_all, pos_ids, logits in zip( - batch_input_ids, - batch_labels, - batch_pos_ids, - batch_logits, - ): - if pos_ids is None: - pos_ranges = [(0, len(input_ids_all) - 1)] - else: - pos_ranges = find_ranges(pos_ids) - - for pos_range in pos_ranges: - start, end = pos_range - if start == end: - continue - - input_ids = input_ids_all[start : end + 1] - labels = labels_all[start : end + 1] - - tokens_without_loss = labels == IGNORE_INDEX - tokens_with_loss = labels != IGNORE_INDEX - tokens_exclude_padding = input_ids != tokenizer.pad_token_id - prompt_token_includes = ( - tokens_without_loss & tokens_exclude_padding - ) - - prompt_token_ids = input_ids[prompt_token_includes] - prompt_token_ids_list.append(prompt_token_ids) - - completion_token_ids = input_ids[tokens_with_loss] - completion_token_ids_list.append(completion_token_ids) - - pred_step_token_ids = logits_to_tokens( - logits[start : end + 1] - )[tokens_with_loss] - pred_step_token_ids_list.append(pred_step_token_ids) - - prompt_texts = tokenizer.batch_decode( - prompt_token_ids_list, skip_special_tokens=True - ) - completion_texts = tokenizer.batch_decode( - completion_token_ids_list, skip_special_tokens=True - ) - pred_step_texts = tokenizer.batch_decode( - pred_step_token_ids_list, skip_special_tokens=True - ) - - with torch.no_grad(): - prompt_encoding = tokenizer( - prompt_texts, padding=True, return_tensors="pt" - ).to(self.cfg.device) - predictions = trainer.model.generate( - **prompt_encoding, generation_config=generation_config - ) - - prediction_all_tokens = predictions["sequences"].cpu().tolist() - prediction_without_prompt_tokens_list = [] - for prompt_token_ids, prediction_tokens in zip( - prompt_token_ids_list, prediction_all_tokens - ): - prediction_without_prompt_tokens = prediction_tokens[ - len(prompt_token_ids) : - ] - prediction_without_prompt_tokens_list.append( - prediction_without_prompt_tokens - ) - - predicted_texts = tokenizer.batch_decode( - prediction_without_prompt_tokens_list, skip_special_tokens=True - ) - - for ( - prompt_text, - completion_text, - prediction_text, - pred_step_text, - ) in zip( - prompt_texts, completion_texts, predicted_texts, pred_step_texts - ): - table.add_data( - row_index, - prompt_text, - completion_text, - prediction_text, - pred_step_text, - ) - row_index += 1 - - wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined] - - if is_main_process(): - log_table_from_dataloader("Eval", eval_dataloader) - - return control - - return LogPredictionCallback - - -class SaveAxolotlConfigtoWandBCallback(TrainerCallback): - """Callback to save axolotl config to wandb""" - - def __init__(self, axolotl_config_path): - self.axolotl_config_path = axolotl_config_path - - def on_train_begin( - self, - args: AxolotlTrainingArguments, # pylint: disable=unused-argument - state: TrainerState, # pylint: disable=unused-argument - control: TrainerControl, - **kwargs, # pylint: disable=unused-argument - ): - if is_main_process(): - try: - artifact = wandb.Artifact(name="axolotl-config", type="config") - artifact.add_file(local_path=self.axolotl_config_path) - wandb.run.log_artifact(artifact) - LOG.info("Axolotl config has been saved to WandB as an artifact.") - except (FileNotFoundError, ConnectionError) as err: - LOG.warning(f"Error while saving Axolotl config to WandB: {err}") - return control diff --git a/axolotl/src/axolotl/utils/collators.py b/axolotl/src/axolotl/utils/collators.py deleted file mode 100644 index 0f0eb5a95a75cd8d86a5ade6ca78b0bbaac51c1a..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/collators.py +++ /dev/null @@ -1,180 +0,0 @@ -""" -DataCollator for axolotl to pad labels and position_ids for packed sequences -""" -from dataclasses import dataclass -from typing import Any, Dict, Optional, Sequence, Union - -import numpy as np -import torch -import transformers -from transformers import PreTrainedTokenizerBase -from transformers.utils import PaddingStrategy - -IGNORE_INDEX = -100 - - -@dataclass -class DataCollatorForSeq2Seq: - """ - Data collator that will dynamically pad the inputs received, as well as the labels and position_ids - - Args: - tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): - The tokenizer used for encoding the data. - model ([`PreTrainedModel`]): - The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to - prepare the *decoder_input_ids* - - This is useful when using *label_smoothing* to avoid calculating loss twice. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding index) - among: - - - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single - sequence is provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - pad_to_multiple_of (`int`, *optional*): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= - 7.5 (Volta). - label_pad_token_id (`int`, *optional*, defaults to -100): - The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). - return_tensors (`str`): - The type of Tensor to return. Allowable values are "np", "pt" and "tf". - """ - - tokenizer: PreTrainedTokenizerBase - model: Optional[Any] = None - padding: Union[bool, str, PaddingStrategy] = True - max_length: Optional[int] = None - pad_to_multiple_of: Optional[int] = None - label_pad_token_id: int = -100 - position_pad_token_id: int = 0 - return_tensors: str = "pt" - - def __call__(self, features, return_tensors=None): - labels = None - if return_tensors is None: - return_tensors = self.return_tensors - - for feature_name, pad_token_id in [ - ("labels", self.label_pad_token_id), - ("position_ids", self.position_pad_token_id), - ]: - feat = ( - [feature[feature_name] for feature in features] - if feature_name in features[0].keys() - else None - ) - labels = feat if feat and feature_name == "labels" else labels - # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the - # same length to return tensors. - if feat is not None: - max_feature_length = max(len(l) for l in feat) # noqa: E741 - if self.pad_to_multiple_of is not None: - max_feature_length = ( - (max_feature_length + self.pad_to_multiple_of - 1) - // self.pad_to_multiple_of - * self.pad_to_multiple_of - ) - - padding_side = self.tokenizer.padding_side - for feature in features: - remainder = [pad_token_id] * ( - max_feature_length - len(feature[feature_name]) - ) - if isinstance(feature[feature_name], list): - feature[feature_name] = ( - feature[feature_name] + remainder - if padding_side == "right" - else remainder + feature[feature_name] - ) - elif padding_side == "right": - feature[feature_name] = np.concatenate( - [feature[feature_name], remainder] - ).astype(np.int64) - else: - feature[feature_name] = np.concatenate( - [remainder, feature[feature_name]] - ).astype(np.int64) - - features = self.tokenizer.pad( - features, - padding=self.padding, - max_length=self.max_length, - pad_to_multiple_of=self.pad_to_multiple_of, - return_tensors=return_tensors, - ) - - # prepare decoder_input_ids - if ( - labels is not None - and self.model is not None - and hasattr(self.model, "prepare_decoder_input_ids_from_labels") - ): - decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels( - labels=features["labels"] - ) - features["decoder_input_ids"] = decoder_input_ids - - return features - - -@dataclass -class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): - """ - Collator for multipack specific to the using the BatchSampler - """ - - def __call__(self, features, return_tensors=None): - chunked_data = {} - for feature in features[0].keys(): - if feature == "length": - continue - if feature == "attention_mask": - arrays = [ - (1) * np.array(item[feature]) - for item in features - if feature in item - ] - chunked_data[feature] = np.concatenate(arrays) - else: - arrays = [ - np.array(item[feature]) for item in features if feature in item - ] - chunked_data[feature] = np.concatenate(arrays) - features = [chunked_data] - return super().__call__(features, return_tensors=return_tensors) - - -@dataclass -class MambaDataCollator: - """ - Collator for State Space Models (Mamba) - """ - - tokenizer: transformers.PreTrainedTokenizer - - def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: - input_ids, labels = tuple( - [torch.LongTensor(instance[key]) for instance in instances] - for key in ("input_ids", "labels") - ) - input_ids = torch.nn.utils.rnn.pad_sequence( - input_ids, - batch_first=True, - padding_value=self.tokenizer.pad_token_id, - ) - labels = torch.nn.utils.rnn.pad_sequence( - labels, batch_first=True, padding_value=IGNORE_INDEX - ) - - return { - "input_ids": input_ids, - "labels": labels, - } diff --git a/axolotl/src/axolotl/utils/config.py b/axolotl/src/axolotl/utils/config.py deleted file mode 100644 index d9e56b95a6724bdd8cfed70c0368f0ad76e78ebd..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/config.py +++ /dev/null @@ -1,475 +0,0 @@ -"""Module for working with config dicts""" - -import logging -import os - -import torch -from transformers.utils import is_torch_bf16_gpu_available - -from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.models import load_model_config - -LOG = logging.getLogger("axolotl") - - -def choose_device(cfg): - def get_device(): - try: - if torch.cuda.is_available(): - return f"cuda:{cfg.local_rank}" - - if torch.backends.mps.is_available(): - return "mps" - - raise SystemError("No CUDA/mps device found") - except Exception: # pylint: disable=broad-exception-caught - return "cpu" - - cfg.device = get_device() - if cfg.world_size == 1: - cfg.device_map = cfg.device_map or "auto" - else: - if cfg.device.startswith("cuda"): - cfg.device_map = {"": torch.cuda.current_device()} - else: - cfg.device_map = {"": cfg.device} - - # in `accelerate launch`, we need to not pass through any device map and let - # accelerate figure out which parts of the model to put on which gpu - accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")] - if accelerate_vars: - cfg.device_map = None - - -def normalize_config(cfg): - # setup some derived config / hyperparams - cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or ( - cfg.batch_size // cfg.micro_batch_size - ) - cfg.batch_size = ( - cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps - ) - if cfg.eval_batch_size is None: - cfg.eval_batch_size = cfg.micro_batch_size - cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) - cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) - cfg.eval_table_size = cfg.eval_table_size or 0 - cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128 - choose_device(cfg) - cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 - if cfg.ddp: - cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} - cfg.batch_size = cfg.batch_size * cfg.world_size - - if cfg.device == "mps": - cfg.load_in_8bit = False - cfg.tf32 = False - if cfg.bf16: - cfg.fp16 = True - cfg.bf16 = False - else: - torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False - - if cfg.bf16 or cfg.bfloat16: - cfg.torch_dtype = torch.bfloat16 - elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: - cfg.torch_dtype = torch.float16 - else: - cfg.torch_dtype = torch.float32 - - if cfg.saves_per_epoch: - save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs) - if save_steps < 1.0: # prevent saves on every step - cfg.save_steps = save_steps - if cfg.evals_per_epoch: - eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs) - if eval_steps < 1.0: # prevent evals on every step - cfg.eval_steps = eval_steps - - cfg.dataset_processes = cfg.dataset_processes or os.cpu_count() - - if not cfg.base_model_config: - cfg.base_model_config = cfg.base_model - - model_config = load_model_config(cfg) - cfg.model_config_type = model_config.model_type - - # figure out if the model is llama - cfg.is_llama_derived_model = ( - (hasattr(model_config, "model_type") and model_config.model_type == "llama") - or cfg.is_llama_derived_model - or "llama" in cfg.base_model.lower() - or (cfg.model_type and "llama" in cfg.model_type.lower()) - ) - - # figure out if the model is falcon - cfg.is_falcon_derived_model = ( - ( - hasattr(model_config, "model_type") - and model_config.model_type - in [ - "falcon", - "RefinedWebModel", - "RefinedWeb", - ] - ) - or cfg.is_falcon_derived_model - or "falcon" in cfg.base_model.lower() - or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower()) - ) - - cfg.is_mistral_derived_model = ( - ( - hasattr(model_config, "model_type") - and model_config.model_type - in [ - "mistral", - ] - ) - or cfg.is_mistral_derived_model - or "mistral" in cfg.base_model.lower() - or (cfg.model_type and "mistral" in cfg.model_type.lower()) - ) - - cfg.is_qwen_derived_model = ( - ( - hasattr(model_config, "model_type") - and model_config.model_type - in [ - "qwen", - ] - ) - or cfg.is_qwen_derived_model - or "qwen" in cfg.base_model.lower() - or (cfg.model_type and "qwen" in cfg.model_type.lower()) - ) - - if isinstance(cfg.learning_rate, str): - cfg.learning_rate = float(cfg.learning_rate) - - log_gpu_memory_usage(LOG, "baseline", cfg.device) - - -def validate_config(cfg): - if is_torch_bf16_gpu_available(): - if not cfg.bf16 and not cfg.bfloat16: - LOG.info("bf16 support detected, but not enabled for this configuration.") - else: - if not cfg.merge_lora and (cfg.bf16 or cfg.bfloat16): - raise ValueError( - "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." - ) - if cfg.max_packed_sequence_len and cfg.sample_packing: - raise ValueError( - "please set only one of max_packed_sequence_len (deprecated soon) or sample_packing" - ) - if cfg.max_packed_sequence_len: - LOG.warning( - str( - PendingDeprecationWarning( - "max_packed_sequence_len will be deprecated in favor of sample_packing" - ) - ) - ) - - if cfg.sample_packing and not cfg.pad_to_sequence_len: - LOG.warning( - "`pad_to_sequence_len: true` is recommended when using sample_packing" - ) - - if cfg.gradient_accumulation_steps and cfg.batch_size: - raise ValueError( - "please set only one of gradient_accumulation_steps or batch_size" - ) - if cfg.batch_size: - LOG.warning( - "%s\n%s", - "batch_size is not recommended. Please use gradient_accumulation_steps instead.", - "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", - ) - if ( - cfg.eval_batch_size - and cfg.micro_batch_size - and cfg.eval_batch_size != cfg.micro_batch_size - ): - LOG.warning( - "eval_batch_size != micro_batch_size. This can lead to VRAM instability." - ) - - if cfg.load_4bit: - raise ValueError("cfg.load_4bit parameter has been deprecated") - - if cfg.adapter == "qlora": - if cfg.merge_lora: - # can't merge qlora if loaded in 8bit or 4bit - if cfg.load_in_8bit: - raise ValueError("Can't merge qlora if loaded in 8bit") - - if cfg.gptq: - raise ValueError("Can't merge qlora if gptq") - - if cfg.load_in_4bit: - raise ValueError("Can't merge qlora if loaded in 4bit") - - else: - if cfg.load_in_8bit: - raise ValueError("Can't load qlora in 8bit") - - if cfg.gptq: - raise ValueError("Can't load qlora if gptq") - - if not cfg.load_in_4bit: - raise ValueError("Require cfg.load_in_4bit to be True for qlora") - - if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp: - raise ValueError("Fused modules are not supported with QLoRA") - - if not cfg.load_in_8bit and cfg.adapter == "lora": - LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") - - if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp): - raise ValueError("Fused modules are not supported with LoRA") - - if cfg.relora_steps: - if cfg.adapter not in ("lora", "qlora"): - raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") - - if cfg.fsdp: - raise ValueError("fsdp not supported with ReLoRA") - - if cfg.deepspeed: - raise ValueError("deepspeed not supported with ReLoRA") - - if cfg.lr_scheduler == "one_cycle": - raise ValueError("ReLoRA is not compatible with the one_cycle scheduler") - - if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp: - raise ValueError("Fused modules are not supported with ReLoRA") - - if cfg.trust_remote_code: - LOG.warning( - "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." - ) - - if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True: - raise ValueError( - "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" - ) - - if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp: - raise ValueError("FSDP is not supported for falcon models") - - if ( - cfg.base_model and "mpt" in cfg.base_model.lower() - ) and cfg.gradient_checkpointing: - raise ValueError("gradient_checkpointing is not supported for MPT models") - - if cfg.flash_optimum is True: - if cfg.adapter: - LOG.warning("BetterTransformers probably doesn't work with PEFT adapters") - if cfg.fp16 or cfg.bf16: - raise ValueError("AMP is not supported with BetterTransformer") - if cfg.float16 is not True and cfg.bloat16 is not True: - LOG.warning( - "You should probably set bfloat16 or float16 to true to " - "load the model in float16 for BetterTransformers" - ) - if int(torch.__version__.split(".", maxsplit=1)[0]) < 2: - LOG.warning("torch>=2.0.0 required") - raise ValueError( - f"flash_optimum for BetterTransformers may not be used with {torch.__version__}" - ) - - if cfg.pretraining_dataset and cfg.group_by_length: - LOG.warning( - "You probably want to disable group_by_length as it will force a streamed dataset to download completely." - ) - if cfg.pretraining_dataset and not cfg.max_steps: - raise ValueError( - "max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!" - ) - - if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and ( - not cfg.optimizer or "adamw" not in cfg.optimizer - ): - LOG.warning("adamw hyperparameters found, but no adamw optimizer set") - - if cfg.push_to_hub_model_id: - raise ValueError( - "push_to_hub_model_id is deprecated. Please use hub_model_id instead." - ) - - if cfg.gptq and cfg.model_revision: - raise ValueError( - "model_revision is not supported for GPTQ models. " - + "Please download the model from HuggingFace Hub manually for correct branch, " - + "point to its path, and remove model_revision from the config." - ) - - if cfg.sample_packing and cfg.sdp_attention: - # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2 - raise ValueError( - "sample_packing not compatible with sdp_attention. Use flash_attention" - ) - - if cfg.sample_packing and cfg.xformers_attention: - raise ValueError( - "sample_packing not compatible with xformers_attention. Use flash_attention" - ) - - if cfg.early_stopping_patience: - if not cfg.save_steps or not cfg.eval_steps: - raise ValueError( - "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps." - ) - if cfg.save_steps % cfg.eval_steps != 0: - raise ValueError( - "`early_stopping_patience` requires that eval_steps should evenly divide save_steps." - ) - - if cfg.model_type == "MixFormerSequentialForCausalLM" and cfg.adapter is not None: - LOG.warning("Use AutoModelForCausalLM for phi/MixFormer models with qLoRA") - - if cfg.model_config_type == "mixformer-sequential": - if cfg.sample_packing: - if cfg.adapter is not None: - LOG.warning( - "phi/MixFormer models are not currently compatible with LoRA and sample_packing" - ) - if cfg.model_type == "AutoModelForCausalLM": - raise ValueError( - "`model_type: MixFormerSequentialForCausalLM` required for sample_packing" - ) - - if cfg.datasets: - for idx, ds_cfg in enumerate(cfg.datasets): - if not ds_cfg.type: - continue - if ds_cfg.type == "sharegpt:chat": - LOG.warning( - PendingDeprecationWarning( - "`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead." - ) - ) - cfg.datasets[idx].type = "sharegpt" - if "sharegpt_simple" in ds_cfg.type: - LOG.warning( - PendingDeprecationWarning( - "`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead." - ) - ) - cfg.datasets[idx].type = cfg.datasets[idx].type.replace( - "sharegpt_simple", "sharegpt" - ) - - if cfg.saves_per_epoch and cfg.save_steps: - raise ValueError( - "save_steps and saves_per_epoch are mutually exclusive and cannot be used together." - ) - if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps": - raise ValueError( - "save_strategy must be empty or set to `steps` when used with saves_per_epoch." - ) - if cfg.evals_per_epoch and cfg.eval_steps: - raise ValueError( - "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." - ) - if ( - cfg.evals_per_epoch - and cfg.evaluation_strategy - and cfg.evaluation_strategy != "steps" - ): - raise ValueError( - "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch." - ) - if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps": - raise ValueError( - "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." - ) - - if ( - cfg.evaluation_strategy - and cfg.eval_steps - and cfg.evaluation_strategy != "steps" - ): - raise ValueError( - "evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps." - ) - - if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy): - raise ValueError( - "eval_steps and evaluation_strategy are not supported with val_set_size == 0" - ) - - if ( - cfg.sample_packing - and cfg.eval_table_size - and cfg.eval_sample_packing is not False - ): - raise ValueError( - "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false." - ) - - if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit): - raise ValueError( - "load_in_8bit and load_in_4bit are not supported without setting an adapter." - "If you want to full finetune, please turn off load_in_8bit and load_in_4bit." - ) - - if cfg.rope_scaling: - LOG.warning("`rope_scaling` should now be be a key under `model_config`") - - if cfg.warmup_steps and cfg.warmup_ratio: - raise ValueError("warmup_steps and warmup_ratio are mutually exclusive") - - if cfg.is_qwen_derived_model and cfg.gradient_checkpointing: - LOG.warning( - "Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch." - ) - - if cfg.wandb_run_id and not cfg.wandb_name: - cfg.wandb_name = cfg.wandb_run_id - - LOG.warning( - "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." - ) - - if cfg.noisy_embedding_alpha is not None: - # Deprecated, use neftune_noise_alpha - LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha") - if cfg.neftune_noise_alpha is None: - cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha - else: - # User is providing both; bail and have them sort out their settings - raise ValueError( - "noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting" - ) - - if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0: - raise ValueError("neftune_noise_alpha must be > 0.0") - - if ( - cfg.adapter - and cfg.tokens - and ( - not cfg.lora_modules_to_save - or not all( - x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"] - ) - ) - ): - raise ValueError( - "lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`." - ) - - # TODO - # MPT 7b - # https://github.com/facebookresearch/bitsandbytes/issues/25 - # no 8bit adaAmw w bf16 - - # GPT-NeoX - # evals broken when extending context len - # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product - # attention_mask = causal_mask + attention_mask - # RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3 diff --git a/axolotl/src/axolotl/utils/data.py b/axolotl/src/axolotl/utils/data.py deleted file mode 100644 index 5c41d16fe4e14d9f29c1455e3b8c278fa798b2a1..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/data.py +++ /dev/null @@ -1,821 +0,0 @@ -"""Module containing data utilities""" -import functools -import hashlib -import logging -from pathlib import Path -from typing import Dict, List, Tuple, Union - -import torch -from datasets import ( - Dataset, - DatasetDict, - concatenate_datasets, - load_dataset, - load_from_disk, -) -from huggingface_hub import hf_hub_download -from transformers import PreTrainedTokenizerBase - -from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH -from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset -from axolotl.prompt_strategies import load -from axolotl.prompt_tokenizers import ( - AlpacaMultipleChoicePromptTokenizingStrategy, - AlpacaPromptTokenizingStrategy, - AlpacaReflectionPTStrategy, - GPTeacherPromptTokenizingStrategy, - JeopardyPromptTokenizingStrategy, - OpenAssistantPromptTokenizingStrategy, - SummarizeTLDRPromptTokenizingStrategy, -) -from axolotl.prompters import ( - AlpacaPrompter, - GPTeacherPrompter, - JeopardyPrompter, - MultipleChoiceConcisePrompter, - MultipleChoiceExplainPrompter, - Prompter, - ReflectAlpacaPrompter, - SummarizeTLDRPrompter, - UnsupportedPrompter, -) -from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process, zero_first -from axolotl.utils.trainer import ( - calculate_total_num_steps, - process_datasets_for_packing, -) - -LOG = logging.getLogger("axolotl") - - -def md5(to_hash: str, encoding: str = "utf-8") -> str: - try: - return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() - except TypeError: - return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec - - -def prepare_dataset(cfg, tokenizer): - prompters = [] - if not cfg.pretraining_dataset: - with zero_first(is_main_process()): - train_dataset, eval_dataset, prompters = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH - ) - else: - train_dataset = load_pretraining_dataset( - cfg.pretraining_dataset, - tokenizer, - max_tokens=cfg.sequence_len, - seed=cfg.seed or 42, - ) - # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 - train_dataset = train_dataset.with_format("torch") - eval_dataset = None - return train_dataset, eval_dataset, cfg.max_steps, prompters - - with zero_first(is_main_process()): - train_dataset, eval_dataset = process_datasets_for_packing( - cfg, train_dataset, eval_dataset, tokenizer - ) - - if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: - total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) - if total_eval_steps == 0: - raise ValueError( - "eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. " - ) - - if cfg.max_steps: - total_num_steps = min( - calculate_total_num_steps(cfg, train_dataset), cfg.max_steps - ) - LOG.info(f"Maximum number of steps set at {total_num_steps}") - else: - total_num_steps = calculate_total_num_steps(cfg, train_dataset) - return train_dataset, eval_dataset, total_num_steps, prompters - - -def load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path -) -> Tuple[DatasetDict, List[Prompter]]: - tokenizer_name = tokenizer.__class__.__name__ - ds_hash = str( - md5( - ( - str(cfg.sequence_len) - + "@" - + "|".join( - sorted( - [ - f"{d.path}:{d.type}:{d.shards}:{d.conversation}" - for d in cfg.datasets - ] - ) - ) - + "|" - + tokenizer_name - ) - ) - ) - prepared_ds_path = ( - Path(cfg.dataset_prepared_path) / ds_hash - if cfg.dataset_prepared_path - else Path(default_dataset_prepared_path) / ds_hash - ) - dataset = None - prompters = [] - use_auth_token = cfg.hf_use_auth_token - try: - if cfg.push_dataset_to_hub: - dataset = load_dataset( - f"{cfg.push_dataset_to_hub}/{ds_hash}", - token=use_auth_token, - ) - dataset = dataset["train"] - except Exception: # pylint: disable=broad-except # nosec - pass - - if dataset: - ... - elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")): - LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") - dataset = load_from_disk(str(prepared_ds_path)) - LOG.info("Prepared dataset loaded from disk...") - else: - LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") - LOG.info("Loading raw datasets...") - - if cfg.seed: - seed = cfg.seed - else: - LOG.info("No seed provided, using default seed of 42") - seed = 42 - - datasets = [] - - def for_d_in_datasets(dataset_configs): - for dataset in dataset_configs: - if dataset.name and isinstance(dataset.name, list): - for name in dataset.name: - yield DictDefault({**dataset, "name": name}) - else: - yield dataset - - # pylint: disable=invalid-name - for config_dataset in for_d_in_datasets(cfg.datasets): - ds: Union[Dataset, DatasetDict] = None - ds_from_hub = False - try: - load_dataset( - config_dataset.path, - name=config_dataset.name, - streaming=True, - token=use_auth_token, - ) - ds_from_hub = True - except (FileNotFoundError, ConnectionError): - pass - - ds_from_cloud = False - storage_options = {} - remote_file_system = None - if config_dataset.path.startswith("s3://"): - try: - import aiobotocore.session # type: ignore - import s3fs # type: ignore - except ImportError as exc: - raise ImportError( - "s3:// paths require aiobotocore and s3fs to be installed" - ) from exc - - # Takes credentials from ~/.aws/credentials for default profile - s3_session = aiobotocore.session.AioSession(profile="default") - storage_options = {"session": s3_session} - remote_file_system = s3fs.S3FileSystem(**storage_options) - elif config_dataset.path.startswith( - "gs://" - ) or config_dataset.path.startswith("gcs://"): - try: - import gcsfs # type: ignore - except ImportError as exc: - raise ImportError( - "gs:// or gcs:// paths require gcsfs to be installed" - ) from exc - - # gcsfs will use default credentials from the environment else anon - # https://gcsfs.readthedocs.io/en/latest/#credentials - storage_options = {"token": None} - remote_file_system = gcsfs.GCSFileSystem(**storage_options) - # TODO: Figure out how to get auth creds passed - # elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"): - # try: - # import adlfs - # except ImportError as exc: - # raise ImportError( - # "adl:// or abfs:// paths require adlfs to be installed" - # ) from exc - - # # Gen 1 - # storage_options = { - # "tenant_id": TENANT_ID, - # "client_id": CLIENT_ID, - # "client_secret": CLIENT_SECRET, - # } - # # Gen 2 - # storage_options = { - # "account_name": ACCOUNT_NAME, - # "account_key": ACCOUNT_KEY, - # } - - # remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) - try: - if remote_file_system and remote_file_system.exists( - config_dataset.path - ): - ds_from_cloud = True - except (FileNotFoundError, ConnectionError): - pass - - # prefer local dataset, even if hub exists - local_path = Path(config_dataset.path) - if local_path.exists(): - if local_path.is_dir(): - # TODO dirs with arrow or parquet files could be loaded with `load_from_disk` - ds = load_dataset( - config_dataset.path, - name=config_dataset.name, - data_files=config_dataset.data_files, - streaming=False, - split=None, - ) - elif local_path.is_file(): - ds_type = get_ds_type(config_dataset) - - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=False, - split=None, - ) - else: - raise ValueError( - "unhandled dataset load: local path exists, but is neither a directory or a file" - ) - elif ds_from_hub: - ds = load_dataset( - config_dataset.path, - name=config_dataset.name, - streaming=False, - data_files=config_dataset.data_files, - token=use_auth_token, - ) - elif ds_from_cloud and remote_file_system: - if remote_file_system.isdir(config_dataset.path): - ds = load_from_disk( - config_dataset.path, - storage_options=storage_options, - ) - elif remote_file_system.isfile(config_dataset.path): - ds_type = get_ds_type(config_dataset) - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=False, - split=None, - storage_options=storage_options, - ) - else: - if isinstance(config_dataset.data_files, str): - fp = hf_hub_download( - repo_id=config_dataset.path, - repo_type="dataset", - filename=config_dataset.data_files, - ) - elif isinstance(config_dataset.data_files, list): - fp = [] - for file in config_dataset.data_files: - fp.append( - hf_hub_download( - repo_id=config_dataset.path, - repo_type="dataset", - filename=file, - ) - ) - else: - raise ValueError( - "data_files must be either a string or list of strings" - ) - ds = load_dataset( - "json", - name=config_dataset.name, - data_files=fp, - streaming=False, - split=None, - ) - if not ds: - raise ValueError("unhandled dataset load") - # support for using a subset of the data - if config_dataset.shards: - if "train" in ds: - ds = ds.shuffle(seed=seed)["train"].shard( - num_shards=config_dataset.shards, index=0 - ) - else: - ds = ds.shuffle(seed=seed).shard( - num_shards=config_dataset.shards, index=0 - ) - - d_base_type = d_prompt_style = None - d_type = config_dataset.type - if isinstance(d_type, str): - d_type_split = d_type.split(":") - d_base_type = d_type_split[0] - d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None - if "train" in ds: - ds = ds["train"] - elif ( - isinstance(ds, DatasetDict) - and config_dataset.train_on_split - and config_dataset.train_on_split in ds - ): - ds = ds[config_dataset.train_on_split] - elif isinstance(ds, DatasetDict): - raise ValueError( - f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `" - ) - - dataset_wrapper, dataset_prompter = get_dataset_wrapper( - config_dataset=config_dataset, - dataset=ds, - tokenizer=tokenizer, - cfg=cfg, - d_base_type=d_base_type, - d_prompt_style=d_prompt_style, - ) - datasets.append(dataset_wrapper) - prompters.append(dataset_prompter) - - LOG.info("merging datasets") - dataset = concatenate_datasets(datasets) - - if len(datasets) > 1: - LOG.info("shuffle merged datasets") - dataset = dataset.shuffle(seed=seed) - if cfg.local_rank == 0: - LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") - dataset.save_to_disk(prepared_ds_path) - if cfg.push_dataset_to_hub: - LOG.info( - f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" - ) - dataset.push_to_hub( - f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True - ) - - return dataset, prompters - - -def get_ds_type(config_dataset: DictDefault): - """ - Get the dataset type from the path if it's not specified - """ - ds_type = "json" - if config_dataset.ds_type: - ds_type = config_dataset.ds_type - elif ".parquet" in config_dataset.path: - ds_type = "parquet" - elif ".arrow" in config_dataset.path: - ds_type = "arrow" - elif ".csv" in config_dataset.path: - ds_type = "csv" - elif ".txt" in config_dataset.path: - ds_type = "text" - return ds_type - - -def load_prepare_datasets( - tokenizer: PreTrainedTokenizerBase, - cfg, - default_dataset_prepared_path, -) -> Tuple[Dataset, Dataset, List[Prompter]]: - max_packed_sequence_len = ( - cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len - ) - max_packed_sequence_len = min( - max_packed_sequence_len, cfg.sequence_len - ) # make sure we don't accidentally set it larger than sequence_len - - tokenizer_name = tokenizer.__class__.__name__ - prompters: List[Prompter] = [] - if cfg.max_packed_sequence_len is not None: - # see if we can go ahead and load the stacked dataset - seed = f"@{str(cfg.seed)}" if cfg.seed else "" - ds_hash = str( - md5( - ( - str(cfg.sequence_len) - + "@" - + str(max_packed_sequence_len) - + seed - + "|".join( - sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) - ) - + "|" - + tokenizer_name - ) - ) - ) - prepared_ds_path = ( - Path(cfg.dataset_prepared_path) / ds_hash - if cfg.dataset_prepared_path - else Path(default_dataset_prepared_path) / ds_hash - ) - - dataset = None - use_auth_token = cfg.hf_use_auth_token - try: - if cfg.push_dataset_to_hub: - LOG.info( - f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}" - ) - dataset = load_dataset( - f"{cfg.push_dataset_to_hub}/{ds_hash}", - token=use_auth_token, - ) - dataset = dataset["train"] - except Exception: # pylint: disable=broad-except # nosec - pass - - if dataset: - ... - elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")): - LOG.info( - f"Loading prepared packed dataset from disk at {prepared_ds_path}..." - ) - dataset = load_from_disk(str(prepared_ds_path)) - LOG.info("Prepared packed dataset loaded from disk...") - if cfg.push_dataset_to_hub: - LOG.info( - f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" - ) - dataset.push_to_hub( - f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True - ) - else: - dataset, prompters = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path - ) - - if cfg.seed: - dataset = dataset.shuffle(seed=cfg.seed) - - constant_len_dataset = ConstantLengthDataset( - tokenizer, - [dataset], - seq_length=max_packed_sequence_len, - ) - LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}") - dataset = Dataset.from_list(list(constant_len_dataset)) - - # filter out bad data - # TODO convert to dataset.filter(...) - dataset = Dataset.from_list( - [ - d - for d in dataset - if len(d["input_ids"]) <= cfg.sequence_len - and len(d["input_ids"]) > 0 - and len(d["input_ids"]) == len(d["attention_mask"]) - and len(d["input_ids"]) == len(d["labels"]) - ] - ) - - if cfg.local_rank == 0: - LOG.info( - f"Saving packed prepared dataset to disk... {prepared_ds_path}" - ) - dataset.save_to_disk(prepared_ds_path) - if cfg.push_dataset_to_hub: - LOG.info( - f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" - ) - dataset.push_to_hub( - f"{cfg.push_dataset_to_hub}/{ds_hash}", - private=True, - ) - else: - dataset, prompters = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path - ) - - if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: - LOG.info( - f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" - ) - dataset = dataset.shard( - num_shards=cfg.dataset_shard_num, - index=cfg.dataset_shard_idx, - ) - - if cfg.val_set_size: - # ensure we end up with the same fingerprint by doing rank0 first and being able to cache - to_hash_train = ( - dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(cfg.val_set_size) - + "|" - + "train" - + "|" - + str(cfg.seed or 42) - ) - to_hash_test = ( - dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(cfg.val_set_size) - + "|" - + "test" - + "|" - + str(cfg.seed or 42) - ) - train_fingerprint = md5(to_hash_train) - test_fingerprint = md5(to_hash_test) - - dataset = dataset.train_test_split( - test_size=cfg.val_set_size, - shuffle=False, - seed=cfg.seed or 42, - train_new_fingerprint=train_fingerprint, - test_new_fingerprint=test_fingerprint, - ) - - train_dataset = dataset["train"] - eval_dataset = dataset["test"] - else: - train_dataset = dataset - eval_dataset = None - - return train_dataset, eval_dataset, prompters - - -def get_dataset_wrapper( - config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style -): - dataset_wrapper = None - dataset_prompter = None - - if ( - "input_ids" in dataset.features - and "attention_mask" in dataset.features - and "labels" in dataset.features - ): - # dataset is already tokenized, just drop it straight in - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = dataset - elif isinstance(config_dataset.type, DictDefault): - ds_strategy = load( - "user_defined", tokenizer, cfg, config_dataset.type.to_dict() - ) - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes - ) - elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes - ) - elif d_base_type == "alpaca": - dataset_prompter = AlpacaPrompter(d_prompt_style) - ds_strategy = AlpacaPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "explainchoice": - dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style) - ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "concisechoice": - dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style) - ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "summarizetldr": - dataset_prompter = SummarizeTLDRPrompter(d_prompt_style) - ds_strategy = SummarizeTLDRPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "jeopardy": - dataset_prompter = JeopardyPrompter(d_prompt_style) - ds_strategy = JeopardyPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "oasst": - dataset_prompter = AlpacaPrompter(d_prompt_style) - ds_strategy = OpenAssistantPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "gpteacher": - dataset_prompter = GPTeacherPrompter(d_prompt_style) - ds_strategy = GPTeacherPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "reflection": - dataset_prompter = ReflectAlpacaPrompter(d_prompt_style) - ds_strategy = AlpacaReflectionPTStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes - ) - dataset_wrapper = ds_wrapper - else: - suffix = "" - if ":load_" in config_dataset.type: - suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?" - LOG.error( - f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}" - ) - raise ValueError( - f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}" - ) - - return dataset_wrapper, dataset_prompter - - -def encode_pretraining( - tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] -) -> Dict[str, List]: - res = tokenizer( - examples, - truncation=True, - max_length=max_tokens - 2, - add_special_tokens=True, - ) - # Convert to PyTorch tensors - input_ids = [torch.tensor(seq) for seq in res["input_ids"]] - attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] - new_input_ids = [] - new_attention_mask = [] - # Append EOS and PAD tokens to input_ids, and correct attention_mask - for i, _ in enumerate(input_ids): - input_ids[i] = torch.cat( - ( - input_ids[i], - torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]), - ), - dim=0, - ) - attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) - - # Concatenate tokens so that their lengths are less than max_tokens - buffer_input_ids = torch.tensor([], dtype=torch.long) - buffer_attention_mask = torch.tensor([], dtype=torch.long) - - for ids, mask in zip(input_ids, attention_mask): - if buffer_input_ids.numel() == max_tokens: - new_input_ids.append(buffer_input_ids) - new_attention_mask.append(buffer_attention_mask) - buffer_input_ids = torch.tensor([], dtype=torch.long) - buffer_attention_mask = torch.tensor([], dtype=torch.long) - buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) - buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) - elif buffer_input_ids.numel() + ids.numel() <= max_tokens: - buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) - buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) - else: - buffer_input_ids = torch.cat( - ( - buffer_input_ids, - torch.full( - (max_tokens - buffer_input_ids.numel(),), - tokenizer.pad_token_id, - dtype=torch.long, - ), - ), - dim=0, - ) - buffer_attention_mask = torch.cat( - ( - buffer_attention_mask, - torch.full( - (max_tokens - buffer_attention_mask.numel(),), - 0, - dtype=torch.long, - ), - ), - dim=0, - ) - new_input_ids.append(buffer_input_ids) - new_attention_mask.append(buffer_attention_mask) - buffer_input_ids = torch.tensor([], dtype=torch.long) - buffer_attention_mask = torch.tensor([], dtype=torch.long) - - buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) - buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) - - if buffer_input_ids.numel() > 0: # for any leftover tokens - while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size - buffer_input_ids = torch.cat( - ( - buffer_input_ids, - torch.full( - (max_tokens - buffer_input_ids.numel(),), - tokenizer.pad_token_id, - dtype=torch.long, - ), - ), - dim=0, - ) - buffer_attention_mask = torch.cat( - ( - buffer_attention_mask, - torch.full( - (max_tokens - buffer_attention_mask.numel(),), - 0, - dtype=torch.long, - ), - ), - dim=0, - ) - new_input_ids.append(buffer_input_ids) - new_attention_mask.append(buffer_attention_mask) - - ret = { - "input_ids": [seq.tolist() for seq in new_input_ids], - "labels": [seq.tolist() for seq in new_input_ids], - "attention_mask": [seq.tolist() for seq in new_attention_mask], - } - - LOG.debug(len(ret["input_ids"])) - return ret - - -def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): - encode = functools.partial(encode_pretraining, tokenizer, max_tokens) - dataset = load_dataset(path, streaming=True, split="train") - dataset = dataset.shuffle(seed=seed, buffer_size=10_000) - dataset = dataset.map( - encode, - batched=True, - input_columns="text", - # remove all the existing columns after mapping since they end up having - # a different length than the encoded/tokenized column - remove_columns=dataset.features.keys(), - ) - return dataset diff --git a/axolotl/src/axolotl/utils/dict.py b/axolotl/src/axolotl/utils/dict.py deleted file mode 100644 index 69567c60421922d55bb6a53820946b0ebeff9d95..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/dict.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Module containing the DictDefault class""" - -from addict import Dict - - -class DictDefault(Dict): - """ - A Dict that returns None instead of returning empty Dict for missing keys. - """ - - def __missing__(self, key): - return None - - def __or__(self, other): - return DictDefault(super().__or__(other)) diff --git a/axolotl/src/axolotl/utils/distributed.py b/axolotl/src/axolotl/utils/distributed.py deleted file mode 100644 index 313dd24e8c18b0a6a3797677653f139902e8a3e4..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/distributed.py +++ /dev/null @@ -1,230 +0,0 @@ -""" -utility helpers for distributed checks -""" -import os -import pickle # nosec -from contextlib import contextmanager - -import torch -import torch.distributed as dist -from accelerate import Accelerator - -accelerate = None # pylint: disable=invalid-name - - -def load_accelerate(): - global accelerate # pylint: disable=global-statement - accelerate = Accelerator() - - -def is_distributed(): - """ - Check if distributed training is initialized. - """ - global accelerate # pylint: disable=global-statement - if not accelerate: - accelerate = Accelerator() - return dist.is_available() and dist.is_initialized() - - -def barrier(): - """ - Acts as a barrier to wait for all processes. This ensures that all processes - reach the barrier before proceeding further. - """ - if is_distributed(): - dist.barrier() - - -def is_main_process(): - """ - Check if the current process is the main process. - If not in distributed mode, always return True. - """ - if not is_distributed(): - return True - return dist.get_rank() == 0 - - -def get_world_size(): - return int(os.getenv("WORLD_SIZE", "1")) - - -@contextmanager -def zero_only(): - """ - Context manager that only runs the enclosed block on the main rank. - """ - if is_main_process(): - yield - else: - yield None - - -@contextmanager -def zero_first(is_main): - """ - runs the wrapped context so that rank 0 runs first before other ranks - """ - if not is_main: # other ranks wait first - barrier() - yield - if is_main: # then rank 0 waits after it has run the context - barrier() - - -def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name - """ - Run a callable 'fn' on all ranks and gather the results on the specified rank. - - Args: - - fn (callable): A function that computes the value. This should not have any side effects. - - rank (int, optional): The rank that gathers the values. Default is 0. - - world_size (int, optional): Total number of processes in the current distributed setup. - - Returns: - - A list of computed values from all ranks if on the gathering rank, otherwise None. - """ - value_scalar = fn() - if not is_distributed(): - return [value_scalar] - value_tensor = torch.tensor( - value_scalar, device=torch.cuda.current_device() - ).float() - - if not is_main_process(): - dist.gather(value_tensor, dst=0) - else: - gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] - dist.gather(value_tensor, gather_list=gathered_tensors, dst=0) - - # Convert tensors back to their original type (int or float) - gathered_values = [] - for tensor in gathered_tensors: - if tensor == tensor.int(): - gathered_values.append(int(tensor.item())) - else: - gathered_values.append(float(tensor.item())) - return gathered_values - return None - - -def broadcast_dict(vals: dict): - if not is_distributed(): - return vals - - if is_main_process(): - data_byte = pickle.dumps(vals) - data_tensor = torch.ByteTensor(list(data_byte)).to("cuda") - data_size = torch.IntTensor([len(data_byte)]).to("cuda") - else: - data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda") - data_size = torch.IntTensor([0]).to("cuda") - - dist.broadcast(data_size, 0) - if not is_main_process(): - # resize - data_tensor = data_tensor.new_empty([data_size.item()]) - - dist.broadcast(data_tensor, 0) - - if not is_main_process(): - data_list = data_tensor.cpu().tolist() - data_byte = bytes(data_list[: data_size.item()]) - vals = pickle.loads(data_byte) # nosec - - return vals - - -def compute_and_broadcast(fn): # pylint: disable=invalid-name - """ - Compute a value using the function 'fn' only on the specified rank (default is 0). - The value is then broadcasted to all other ranks. - - Args: - - fn (callable): A function that computes the value. This should not have any side effects. - - rank (int, optional): The rank that computes the value. Default is 0. - - Returns: - - The computed value (int or float). - """ - if is_main_process(): - value_scalar = fn() - value_tensor = torch.tensor( - value_scalar, device=torch.cuda.current_device() - ).float() - else: - value_tensor = torch.tensor( - 0.0, device=torch.cuda.current_device() - ) # Placeholder tensor - - # Broadcast the tensor to all processes. - barrier() - dist.broadcast(value_tensor, src=0) - - # Convert the tensor back to its original type (int or float) - if value_tensor == value_tensor.int(): - return int(value_tensor.item()) - return float(value_tensor.item()) - - -def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name - """ - Run a callable 'fn' on all ranks and gather the results on the specified rank. - - Args: - - fn (callable): A function that computes the value. This should not have any side effects. - - rank (int, optional): The rank that gathers the values. Default is 0. - - world_size (int, optional): Total number of processes in the current distributed setup. - - Returns: - - A list of computed values from all ranks if on the gathering rank, otherwise None. - """ - value_scalar = fn() - value_tensor = torch.tensor( - value_scalar, device=torch.cuda.current_device() - ).float() - - # Placeholder tensor for gathering results - if is_main_process(): - gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] - else: - gathered_tensors = None - - dist.gather(value_tensor, gather_list=gathered_tensors, dst=0) - - if is_main_process(): - # Convert tensors back to their original type (int or float) - gathered_values = [] - for tensor in gathered_tensors: - if tensor == tensor.int(): - gathered_values.append(int(tensor.item())) - else: - gathered_values.append(float(tensor.item())) - return gathered_values - return None - - -def reduce_and_broadcast(fn1, fn2): - """ - Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2', - and then broadcast the reduced result to all ranks. - - Args: - - fn1 (callable): A function that computes the value on each rank. - - fn2 (callable): A reduction function that takes a list of values and returns a single value. - - world_size (int, optional): Total number of processes in the current distributed setup. - - Returns: - - The reduced and broadcasted value. - """ - - # Gather values from all ranks using fn1 - if not is_distributed(): - return fn2([fn1()]) - - gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size()) - - # Use compute_and_broadcast to compute the reduced value on the main process - # and then broadcast it to all ranks - return compute_and_broadcast(lambda: fn2(gathered_values)) diff --git a/axolotl/src/axolotl/utils/freeze.py b/axolotl/src/axolotl/utils/freeze.py deleted file mode 100644 index 05beda1caa348f6f6e9755d291819c740bd05096..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/freeze.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -module to freeze/unfreeze parameters by name -""" -import logging -import re - -from axolotl.utils.distributed import is_main_process - -LOG = logging.getLogger("axolotl.utils.freeze") - - -def freeze_parameters_except(model, regex_patterns): - """ - Freezes all layers of the given model except for the layers that match given regex patterns. - Periods in the patterns are treated as literal periods, not as wildcard characters. - - Parameters: - - model (nn.Module): The PyTorch model to be modified. - - regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen. - - Returns: - None; the model is modified in place. - """ - # Escape periods and compile the regex patterns - compiled_patterns = [ - re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns - ] - - # First, freeze all parameters in the model - for param in model.parameters(): - param.requires_grad = False - - # Unfreeze layers that match the regex patterns - for name, param in model.named_parameters(): - if any(pattern.match(name) for pattern in compiled_patterns): - if is_main_process(): - LOG.debug(f"unfreezing {name}") - param.requires_grad = True diff --git a/axolotl/src/axolotl/utils/models.py b/axolotl/src/axolotl/utils/models.py deleted file mode 100644 index 872d530abd4750d3f7595fb56fe2de64e0169405..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/models.py +++ /dev/null @@ -1,667 +0,0 @@ -"""Module for models and model loading""" -import logging -import math -import os -from typing import Optional, Tuple # noqa: F401 - -import addict -import bitsandbytes as bnb -import torch -import transformers -from optimum.bettertransformer import BetterTransformer -from peft import PeftConfig, prepare_model_for_kbit_training -from peft.tuners.lora import QuantLinear -from transformers import ( # noqa: F401 - AddedToken, - AutoConfig, - AutoModelForCausalLM, - AutoTokenizer, - BitsAndBytesConfig, - GPTQConfig, - PreTrainedModel, - PreTrainedTokenizerBase, -) -from transformers.deepspeed import is_deepspeed_zero3_enabled - -from axolotl.models.mamba import fix_mamba_attn_for_loss -from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN -from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.dict import DictDefault - -LOG = logging.getLogger("axolotl") - - -def check_model_config(cfg: DictDefault, model_config: AutoConfig): - quant_config_exists = hasattr(model_config, "quantization_config") - quant_config_method_is_gptq = ( - quant_config_exists - and "quant_method" in model_config.quantization_config - and model_config.quantization_config["quant_method"] == "gptq" - ) - - if cfg.gptq and not quant_config_method_is_gptq: - raise ValueError( - "model_config.quantization_config is not set or quant_method is not set to gptq. " - "Please make sure to point to a GPTQ model." - ) - - if not cfg.gptq and quant_config_exists: - raise ValueError( - "model_config.quantization_config is set but `gptq` flag is not. " - "Please use the `gptq` flag to train quantized model or point to a non-quantized model." - ) - - -def load_model_config(cfg): - model_config_name = cfg.base_model_config or cfg.base_model - trust_remote_code = cfg.trust_remote_code is True - - try: - model_config = AutoConfig.from_pretrained( - model_config_name, trust_remote_code=trust_remote_code - ) - except ValueError as err: - if "mamba" in model_config_name: - return addict.Dict( - { - "model_type": "mamba", - } - ) - raise err - - if cfg.model_config: - for key, val in cfg.model_config.items(): - setattr(model_config, key, val) - - check_model_config(cfg, model_config) - - return model_config - - -def load_tokenizer(cfg): - tokenizer_kwargs = {} - use_fast = True # this is the default - - if cfg.tokenizer_use_fast is not None: - use_fast = cfg.tokenizer_use_fast - if cfg.tokenizer_legacy is not None: - # True is the default w/ https://github.com/huggingface/transformers/pull/25224 - tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy - - tokenizer_cls = AutoTokenizer - if cfg.tokenizer_type: - tokenizer_cls = getattr(transformers, cfg.tokenizer_type) - - tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model - tokenizer = tokenizer_cls.from_pretrained( - tokenizer_config, - trust_remote_code=cfg.trust_remote_code or False, - use_fast=use_fast, - **tokenizer_kwargs, - ) - - if ( - tokenizer.__class__.__name__ - in [ - "LlamaTokenizer", - "LlamaTokenizerFast", - "CodeLlamaTokenizer", - "CodeLlamaTokenizerFast", - ] - and hasattr(tokenizer, "pad_token") - and not tokenizer.pad_token - ): - # set a pad_token, but use eos_token so we don't add a new token - tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN - - if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - # Mistral's official FA implementation requires left padding - if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: - tokenizer.padding_side = "left" - - # Qwen base only has single token, so we need to set the special tokens - if cfg.is_qwen_derived_model: - token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"] - for attr_name in token_ids: - if getattr(tokenizer, attr_name) is None: - setattr(tokenizer, attr_name, tokenizer.eod_id) - - token_names = ["bos_token", "eos_token", "pad_token", "unk_token"] - for attr_name in token_names: - if getattr(tokenizer, attr_name) is None: - setattr(tokenizer, attr_name, "<|endoftext|>") - - if cfg.special_tokens: - for k, val in cfg.special_tokens.items(): - # check if new special token is not already in tokenizer and - # is adapter training to make sure lora_modules_to_save is set - if ( - (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) - and cfg.adapter - and ( - not cfg.lora_modules_to_save - or not all( - x in cfg.lora_modules_to_save - for x in ["embed_tokens", "lm_head"] - ) - ) - ): - raise ValueError( - "Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens." - ) - - tokenizer.add_special_tokens( - {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} - ) - - # If we add bos_token and eos_token, we need to update the post processor to - # handle them correctly. - # https://github.com/huggingface/transformers/pull/24132 - bos_or_eos_in_special_tokens = ( - "bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens - ) - if ( - tokenizer.__class__.__name__ - in ( - "LlamaTokenizerFast", - "CodeLlamaTokenizerFast", - ) - and bos_or_eos_in_special_tokens - ): - tokenizer.update_post_processor() - - if cfg.tokens: - tokenizer.add_tokens( - [ - AddedToken(token, rstrip=False, lstrip=False, normalized=False) - for token in cfg.tokens - ] - ) - - LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") - LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") - LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") - LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") - - return tokenizer - - -def load_model( - cfg: DictDefault, - tokenizer: PreTrainedTokenizerBase, - inference: bool = False, -) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: - """ - Load a model for a given configuration and tokenizer. - """ - base_model = cfg.base_model - model_type = cfg.model_type - model_config = load_model_config(cfg) - - # TODO refactor as a kwarg - load_in_8bit = cfg.load_in_8bit - - if hasattr(model_config, "model_type") and model_config.model_type == "btlm": - if cfg.flash_attention: - from axolotl.monkeypatch.btlm_attn_hijack_flash import ( - replace_btlm_attn_with_flash_attn, - ) - - replace_btlm_attn_with_flash_attn(cfg.base_model) - - if ( - hasattr(model_config, "model_type") - and model_config.model_type == "stablelm_epoch" - ): - if cfg.flash_attention and cfg.sample_packing: - from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( - replace_stablelm_attn_with_flash_attn, - ) - - replace_stablelm_attn_with_flash_attn(cfg.base_model) - - if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing: - if cfg.device not in ["mps", "cpu"] and not inference: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - replace_llama_attn_with_flash_attn, - ) - - LOG.info("patching with flash attention for sample packing") - replace_llama_attn_with_flash_attn( - packed=cfg.sample_packing, - cross_entropy=cfg.flash_attn_cross_entropy, - rms_norm=cfg.flash_attn_rms_norm, - ) - elif cfg.is_llama_derived_model and cfg.xformers_attention: - from axolotl.monkeypatch.llama_attn_hijack_xformers import ( - hijack_llama_attention, - ) - - LOG.info("patching with xformers attention") - hijack_llama_attention() - elif cfg.is_llama_derived_model and cfg.sdp_attention: - from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention - - LOG.info("patching with sdp attention") - hijack_llama_sdp_attention() - - if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing: - from axolotl.monkeypatch.mistral_attn_hijack_flash import ( - replace_mistral_attn_with_flash_attn, - ) - - LOG.info("patching with flash attention") - replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) - - if ( - cfg.model_config_type == "mixtral" - and cfg.flash_attention - and cfg.sample_packing - ): - from axolotl.monkeypatch.mixtral import ( - replace_mixtral_attn_with_multipack_flash_attn, - ) - - LOG.info("patching with flash attention") - replace_mixtral_attn_with_multipack_flash_attn() - - if ( - cfg.is_llama_derived_model - and (cfg.max_packed_sequence_len or cfg.sample_packing) - and not inference - ): - from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask - - LOG.info("patching _expand_mask") - hijack_expand_mask() - - model_kwargs = {} - - model_kwargs["device_map"] = cfg.device_map - model_kwargs["max_memory"] = cfg.max_memory - model_kwargs["torch_dtype"] = cfg.torch_dtype - - if is_deepspeed_zero3_enabled(): - del model_kwargs["device_map"] - - if cfg.model_revision: - model_kwargs["revision"] = cfg.model_revision - if cfg.gptq: - if not hasattr(model_config, "quantization_config"): - LOG.warning("model config does not contain quantization_config information") - else: - if cfg.gptq_disable_exllama is not None: - model_config.quantization_config[ - "disable_exllama" - ] = cfg.gptq_disable_exllama - model_kwargs["quantization_config"] = GPTQConfig( - **model_config.quantization_config - ) - if cfg.adapter == "qlora" and cfg.load_in_4bit: - model_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - bnb_4bit_compute_dtype=cfg.torch_dtype, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - ) - # sample packing uses custom FA2 patch - if cfg.flash_attention: - if not cfg.sample_packing: - if ( - cfg.is_llama_derived_model - or cfg.is_falcon_derived_model - or cfg.is_mistral_derived_model - or model_config.model_type == "mixtral" - ): - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - else: - if model_config.model_type == "mixtral": - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - else: - model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" - ) - - try: - if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: - from transformers import LlamaForCausalLM - - model = LlamaForCausalLM.from_pretrained( - base_model, - config=model_config, - load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, - load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - **model_kwargs, - ) - - if cfg.flash_attention and not inference: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - replace_llama_mlp_with_swiglu, - replace_llama_qkv_with_fused, - ) - - if cfg.flash_attn_fuse_mlp: - LOG.info("patching with SwiGLU") - replace_llama_mlp_with_swiglu(model) - - if cfg.flash_attn_fuse_qkv: - LOG.info("patching with fused QKV") - replace_llama_qkv_with_fused(model) - # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: - # This is a WIP, still an issue with the backward pass - # RuntimeError: grad can be implicitly created only for scalar outputs - # TODO: try config.sequence_parallel = False - # # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12 - # # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components - # # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442 - # from flash_attn.utils.pretrained import state_dict_from_pretrained - # from flash_attn.models.gpt import GPTLMHeadModel - # from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config - # from transformers import GPTNeoXConfig - # config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model)) - # config.use_flash_attn = True - # config.fused_bias_fc = True - # config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast" - # config.activation_function = "gelu_fast" - # config.fused_dropout_add_ln = True - # # config.residual_in_fp32 = True - # - # model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained( - # base_model, - # config, - # dtype=torch_dtype, - # device=cfg.device, - # ) - # model.train() # sets to train instead of eval mode - elif model_type == "PhiForCausalLM": - from axolotl.models.phi import PhiForCausalLM - - model = PhiForCausalLM.from_pretrained( - base_model, - load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, - load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - **model_kwargs, - ) - elif model_type == "MambaLMHeadModel": - # FIXME this is janky at best and hacked together to make it work - MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name - - model_kwargs["dtype"] = model_kwargs["torch_dtype"] - model_kwargs["device"] = torch.cuda.current_device() - del model_kwargs["torch_dtype"] - del model_kwargs["device_map"] - del model_kwargs["max_memory"] - - model = MambaLMHeadModel.from_pretrained( - base_model, - **model_kwargs, - ) - elif model_type and not cfg.trust_remote_code: - if cfg.gptq: - model = AutoModelForCausalLM.from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, - ) - else: - model = getattr(transformers, model_type).from_pretrained( - base_model, - config=model_config, - load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, - load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, - ) - else: - # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this - # when training starts - if ( - hasattr(model_config, "max_seq_len") - and model_config.max_seq_len - and cfg.sequence_len > model_config.max_seq_len - ): - model_config.max_seq_len = cfg.sequence_len - LOG.warning(f"increasing context length to {cfg.sequence_len}") - elif ( - hasattr(model_config, "max_sequence_length") - and model_config.max_sequence_length - and cfg.sequence_len > model_config.max_sequence_length - ): - model_config.max_sequence_length = cfg.sequence_len - LOG.warning(f"increasing context length to {cfg.sequence_len}") - if cfg.gptq: - model = AutoModelForCausalLM.from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, - ) - else: - model = AutoModelForCausalLM.from_pretrained( - base_model, - config=model_config, - load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, - load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, - ) - except Exception as err: # pylint: disable=broad-exception-caught - LOG.exception(err) - raise err - - embeddings_len = ( - math.ceil(len(tokenizer) / 32) * 32 - if cfg.resize_token_embeddings_to_32x - else len(tokenizer) - ) - if ( - hasattr(model, "get_input_embeddings") - and model.get_input_embeddings().num_embeddings < embeddings_len - ): - model.resize_token_embeddings(embeddings_len) - else: - model.tie_weights() - - if ( - hasattr(model, "config") - and hasattr(model.config, "max_position_embeddings") - and model.config.max_position_embeddings - and cfg.sequence_len > model.config.max_position_embeddings - ): - LOG.warning( - f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}" - ) - model.config.max_position_embeddings = cfg.sequence_len - - if ( - hasattr(model, "config") - and hasattr(model.config, "bos_token_id") - and model.config.bos_token_id - and model.config.bos_token_id != tokenizer.bos_token_id - ): - model.config.bos_token_id = tokenizer.bos_token_id - - if ( - hasattr(model, "config") - and hasattr(model.config, "eos_token_id") - and model.config.eos_token_id - and model.config.eos_token_id != tokenizer.eos_token_id - ): - model.config.eos_token_id = tokenizer.eos_token_id - - if hasattr(model, "device") and model.device.type == "cuda": - log_gpu_memory_usage(LOG, "after model load", model.device) - - # make sure these are fp32 per Ramesh et al. (2021) - for name, module in model.named_modules(): - if "norm" in name: - module.to(torch.float32) - if model_config.model_type == "btlm": - # don't upcast lm_head for btlm - continue - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): - module.to(torch.float32) - - needs_fa2_dtype = cfg.adapter or cfg.fsdp - skip_prepare_model_for_kbit_training = False - - if cfg.model_config_type == "qwen" and cfg.adapter == "lora": - # Qwen doesn't play nicely with LoRA if this is enabled - skip_prepare_model_for_kbit_training = True - - if (cfg.adapter == "lora" and load_in_8bit) or ( - cfg.adapter == "qlora" and cfg.load_in_4bit - ): - LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") - if cfg.gradient_checkpointing: - model.gradient_checkpointing_enable() - if not skip_prepare_model_for_kbit_training: - model = prepare_model_for_kbit_training( - model, use_gradient_checkpointing=cfg.gradient_checkpointing - ) - needs_fa2_dtype = True - - # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to - # convert them back to fp16/bf16 for flash-attn compatibility. - if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model): - LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) - for name, module in model.named_modules(): - if "norm" in name: - module.to(cfg.torch_dtype) - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): - module.to(cfg.torch_dtype) - - model, lora_config = load_adapter(model, cfg, cfg.adapter) - - if cfg.ddp and not load_in_8bit: - model.to(f"cuda:{cfg.local_rank}") - - if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: - setattr(model, "is_parallelizable", True) - setattr(model, "model_parallel", True) - - requires_grad = [] - for name, param in model.named_parameters(recurse=True): - if param.requires_grad: - requires_grad.append(f"{name}: {param.requires_grad}") - if len(requires_grad) == 0: - LOG.warning("there are no parameters that require gradient updates") - if hasattr(model, "config"): - model.config.use_cache = False - - if cfg.flash_optimum: - model = BetterTransformer.transform(model) - - if cfg.adapter is not None: - log_gpu_memory_usage(LOG, "after adapters", model.device) - - # TODO resume_from_checkpoint handling - return model, lora_config - - -def load_adapter(model, cfg, adapter, inference=False): - # type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - - if adapter is None: - return model, None - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - if adapter in ["lora", "qlora"]: - return load_lora(model, cfg, inference=inference) - if adapter == "llama-adapter": - return load_llama_adapter(model, cfg) - - raise NotImplementedError(f"{adapter} peft adapter not available") - - -def load_llama_adapter(model, cfg): - # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - from peft import AdaptionPromptConfig, PeftModel, get_peft_model - - peft_config = AdaptionPromptConfig( - adapter_layers=cfg.peft_adapter.layers, # layers (L) - adapter_len=cfg.peft_adapter.len, # prompt length (K) - task_type="CAUSAL_LM", - ) - - if cfg.lora_model_dir: - LOG.debug("Loading pretained PEFT - llama_adapter") - model = PeftModel.from_pretrained( - model, - cfg.lora_model_dir, - torch_dtype=torch.float16, - ) - else: - model = get_peft_model(model, peft_config) - - model.print_trainable_parameters() - - return model, peft_config - - -def find_all_linear_names(model): - cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) - lora_module_names = set() - for name, module in model.named_modules(): - if ( - isinstance(module, cls) - or "Linear" in module.__class__.__name__ - and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) - ): - names = name.split(".") - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - - if "lm_head" in lora_module_names: # needed for 16-bit - lora_module_names.remove("lm_head") - - return list(lora_module_names) - - -def load_lora(model, cfg, inference=False): - # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - - from peft import LoraConfig, PeftModel, get_peft_model - - lora_target_modules = list(cfg.lora_target_modules or []) - - if cfg.lora_target_linear: - linear_names = find_all_linear_names(model) - LOG.info(f"found linear modules: {repr(linear_names)}") - lora_target_modules = list(set(lora_target_modules + linear_names)) - - lora_config = LoraConfig( - r=cfg.lora_r, - lora_alpha=cfg.lora_alpha, - target_modules=lora_target_modules, - lora_dropout=cfg.lora_dropout, - fan_in_fan_out=cfg.lora_fan_in_fan_out, - modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, - bias="none", - task_type="CAUSAL_LM", - ) - - if cfg.lora_model_dir: - LOG.debug("Loading pretained PEFT - LoRA") - model = PeftModel.from_pretrained( - model, - cfg.lora_model_dir, - is_trainable=(not inference), - ) - else: - model = get_peft_model(model, lora_config) - - model.print_trainable_parameters() - - return model, lora_config diff --git a/axolotl/src/axolotl/utils/samplers/__init__.py b/axolotl/src/axolotl/utils/samplers/__init__.py deleted file mode 100644 index 4c102826f8708ab9334092fe48db09748d21b3e2..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/samplers/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -axolotl samplers module -""" -from .multipack import MultipackBatchSampler # noqa: F401 diff --git a/axolotl/src/axolotl/utils/samplers/multipack.py b/axolotl/src/axolotl/utils/samplers/multipack.py deleted file mode 100644 index 629a1a44c3e62fd18d3b48d4402e085838dadc2d..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/samplers/multipack.py +++ /dev/null @@ -1,196 +0,0 @@ -# pylint: skip-file -""" -Multipack Batch Sampler -""" -import logging -import math -import os -from typing import Any, Iterable, List, Union - -import numba -import numpy as np -from torch.utils.data import BatchSampler, Sampler - -LOG = logging.getLogger("axolotl.utils.samplers.multipack") - - -@numba.njit -def ffd_check(a: np.ndarray, c: int, n: int): - # First-fit-decreasing bin packing - # Check if a[] could fit in n bins with capacity c - # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing - - a = np.sort(a)[::-1] - bins = np.full((n,), c, dtype=a.dtype) - for size in a: - not_found = True - for idx in range(n): - if bins[idx] >= size: - bins[idx] -= size - not_found = False - break - - if not_found: - return False - - return True - - -@numba.njit -def ffd_with_result(a: np.ndarray, c: int, start_index: int): - # First-fit-decreasing bin packing (with result return) - - indices = np.argsort(a)[::-1] - a = a[indices] - - bins: List[Any] = [] - bins_result: List[Any] = [] - for a_id, size in enumerate(a): - add_new = True - for idx in range(len(bins)): - if bins[idx] >= size: - bins[idx] -= size - bins_result[idx].append(indices[a_id] + start_index) - add_new = False - break - - if add_new: - bins.append(c - size) - bins_result.append([indices[a_id] + start_index]) - - return bins_result - - -@numba.njit -def allocate( - lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int -): - # Dynamic batch allocator, similar to Multifit - # https://en.wikipedia.org/wiki/Multifit_algorithm - # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) - - s = 0 - start_index = 0 - result = [] - - while True: - # binary search [l, r) - left = 1 - right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") - - while right - left > 1: - mid = (left + right) // 2 - if ffd_check(lengths[start_index : start_index + mid], c, n): - left = mid - else: - right = mid - - # use length l - batch = ffd_with_result( - lengths[start_index : start_index + left], c, start_index - ) - assert len(batch) <= n - if len(batch) < n: - break - - start_index += left - s = lengths_cumsum[start_index - 1] - - # add local rank - result.append(batch[rank]) - - return result, s, len(result) * c * n - - -class MultipackBatchSampler(BatchSampler): - """ - Batch Sampler class for multipack - """ - - def __init__( - self, - sampler: Union[Sampler[int], Iterable[int]], - batch_size: int, - drop_last: bool, - batch_max_len: int, - lengths: np.ndarray, - packing_efficiency_estimate: float = 1.0, - ): - super().__init__(sampler, batch_size, drop_last) - self.batch_size = None - self.batch_max_len = batch_max_len - self.lengths: np.ndarray = lengths - self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 - - assert isinstance(self.lengths, np.ndarray) - - self.epoch = 0 - - # statistics - self.eff_total_used = 0 - self.eff_total_slots = 0 - - def set_epoch(self, epoch: int): - self.epoch = epoch - - def generate_batches(self, set_stats=False): - indices = [idx for idx in self.sampler] - - lengths = self.lengths[indices] - lengths_cumsum = np.cumsum(lengths) - - batches, total_used, total_slots = allocate( - lengths=lengths, - lengths_cumsum=lengths_cumsum, - rank=0, - c=self.batch_max_len, - n=1, - ) - - batches = [[indices[b_idx] for b_idx in batch] for batch in batches] - - # statistics - if set_stats: - self.eff_total_used += total_used - self.eff_total_slots += total_slots - - return batches - - def __iter__(self): - batches = self.generate_batches(set_stats=True) - return iter(batches) - - def num_batches(self): - batches = self.generate_batches(set_stats=True) - return len(batches) - - def efficiency(self): - return self.eff_total_used / self.eff_total_slots - - def __len__(self): - self.num_batches() - return self._len_est() - - def _len_est(self): - world_size = int(os.getenv("WORLD_SIZE", "1")) - lengths_sum = np.sum(self.lengths) - lengths_sum_per_device = lengths_sum // world_size - LOG.info( - f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " - f"total_num_tokens per device: {lengths_sum_per_device}" - ) - - # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler - return max( - 0, - ( - world_size - * math.floor( - 0.99 - * lengths_sum_per_device - / self.packing_efficiency_estimate - // self.batch_max_len - ) - - 1 - ), - ) diff --git a/axolotl/src/axolotl/utils/schedulers.py b/axolotl/src/axolotl/utils/schedulers.py deleted file mode 100644 index 4c14a358a37e4fda512f8dd2946be24676a9717d..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/schedulers.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Module for custom LRScheduler class""" -import math -from functools import partial - -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR, LRScheduler - - -class InterpolatingLogScheduler(LRScheduler): - """ - A scheduler that interpolates learning rates in a logarithmic fashion - """ - - def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1): - """A scheduler that interpolates learning rates in a logarithmic fashion - - Args: - - optimizer: pytorch optimizer - - num_steps: int, the number of steps over which to increase from the min_lr to the max_lr - - min_lr: float, the minimum learning rate - - max_lr: float, the maximum learning rate - - Usage: - fc = nn.Linear(1,1) - optimizer = optim.Adam(fc.parameters()) - lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4) - """ - self.num_steps = num_steps - self.min_lr = min_lr - self.max_lr = max_lr - self.q = (max_lr / min_lr) ** ( # pylint: disable=invalid-name - 1 / (num_steps - 1) - ) - super().__init__(optimizer, last_epoch) - - def get_lr(self): - if self.last_epoch <= 0: - lrs = [self.min_lr for base_lr in self.base_lrs] - elif self.last_epoch < self.num_steps: - lrs = [ - self.min_lr * (self.q ** (self.last_epoch - 1)) - for base_lr in self.base_lrs - ] - else: - lrs = [self.max_lr for base_lr in self.base_lrs] - - return lrs - - -def _get_cosine_schedule_with_quadratic_warmup_lr_lambda( - current_step: int, - *, - num_warmup_steps: int, - num_training_steps: int, - num_cycles: float -): - if current_step < num_warmup_steps: - return (float(current_step) / float(max(1, num_warmup_steps))) ** 2 - progress = float(current_step - num_warmup_steps) / float( - max(1, num_training_steps - num_warmup_steps) - ) - return max( - 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) - ) - - -def get_cosine_schedule_with_quadratic_warmup( - optimizer: Optimizer, - num_warmup_steps: int, - num_training_steps: int, - num_cycles: float = 0.5, - last_epoch: int = -1, -): - """ - Create a schedule with a learning rate that decreases following the values of the cosine function between the - initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the - initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - num_cycles (`float`, *optional*, defaults to 0.5): - The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 - following a half-cosine). - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - lr_lambda = partial( - _get_cosine_schedule_with_quadratic_warmup_lr_lambda, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - num_cycles=num_cycles, - ) - return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/axolotl/src/axolotl/utils/tokenization.py b/axolotl/src/axolotl/utils/tokenization.py deleted file mode 100644 index 7f63a92feaa3441342fdb13cc8b8e71fc406c21b..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/tokenization.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Module for tokenization utilities""" - - -import logging - -from termcolor import colored - -LOG = logging.getLogger("axolotl") - - -def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False): - # the dataset is already shuffled, so let's just check the first 5 elements - for idx in range(num_examples): - check_example_labels(dataset[idx], tokenizer, text_only=text_only) - - -def check_example_labels(example, tokenizer, text_only=False): - # Get the input_ids, labels, and attention_mask from the dataset - input_ids = example["input_ids"] - labels = example["labels"] - - # You can compare the input_ids and labels element-wise - # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 - colored_tokens = [] - for _, (input_id, label_id) in enumerate(zip(input_ids, labels)): - decoded_input_token = tokenizer.decode(input_id) - # Choose the color based on whether the label has the ignore value or not - color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") - colored_token = colored(decoded_input_token, color) + ( - not text_only and colored(f"({label_id}, {input_id})", "white") or "" - ) - colored_tokens.append(colored_token) - - delimiter = "" if text_only else " " - LOG.info(delimiter.join(colored_tokens)) - LOG.info("\n\n\n") - - return " ".join(colored_tokens) diff --git a/axolotl/src/axolotl/utils/trainer.py b/axolotl/src/axolotl/utils/trainer.py deleted file mode 100644 index f046dd7be89a7a152ffb16a98392e781b672c241..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/trainer.py +++ /dev/null @@ -1,287 +0,0 @@ -"""Module containing the Trainer class and related functions""" -import math -import os -from contextlib import contextmanager -from functools import partial -from typing import List - -import numpy as np -import torch -import torch.cuda -from accelerate.logging import get_logger -from datasets import set_caching_enabled -from torch.utils.data import DataLoader, RandomSampler - -from axolotl.core.trainer_builder import HFCausalTrainerBuilder -from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first -from axolotl.utils.samplers import MultipackBatchSampler - -LOG = get_logger("axolotl") - - -@torch.jit.script -def weighted_cross_entropy( - logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor -): - # Flatten the logits, labels, and weights tensors - logits = logits.view( - -1, logits.size(-1) - ) # logits becomes of shape [batch_size*sequence_length, vocab_size] - labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length] - weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length] - - # Compute the unweighted cross entropy loss - losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none") - - # Apply the weights to the losses and compute their sum - return (weights * losses).sum() - - -@torch.jit.script -def create_weighted_mask(labels: torch.Tensor): - # Check if the tensor is 2D. If not, unsqueeze it to make it 2D - if len(labels.shape) == 1: - labels = labels.unsqueeze(0) - - weights = torch.zeros_like(labels).float() - for i in range(labels.shape[0]): - mask = labels[i] != -100 - - # Create a tensor to track group ids - group_ids = torch.zeros_like(labels[i]).int() - curr_group_id = 0 - - for j in range(1, len(labels[i])): - if mask[j] and not mask[j - 1]: # switch from masked to unmasked label - curr_group_id += 1 # start new group - group_ids[j] = ( - curr_group_id if mask[j] else 0 - ) # assign group id if unmasked label - - # Count only unmasked labels in each group - group_counts = torch.bincount(group_ids[mask]) - - mask_weights = torch.zeros_like(labels[i]).float() - mask_weights[mask] = 1.0 / group_counts[group_ids[mask]] - - weights[i] = mask_weights - - return weights.squeeze() # squeeze the output to match the input dimension - - -def trainer_weighted_loss(model_output, labels, shift_labels=True): - logits = ( - model_output["logits"] if isinstance(model_output, dict) else model_output[0] - ) - if shift_labels: - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - - weights = create_weighted_mask(labels) - return weighted_cross_entropy(logits, labels, weights) - - -def add_position_ids(sample): - sample_len = len(sample["input_ids"]) - sample["position_ids"] = torch.arange(len(sample["input_ids"])) - sample["length"] = sample_len - return sample - - -def add_length(sample): - sample["length"] = len(sample["input_ids"]) - return sample - - -def drop_long_seq(sample, sequence_len=2048): - return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0 - - -@contextmanager -def disable_datasets_caching(): - try: - set_caching_enabled(False) - yield - finally: - set_caching_enabled(True) - - -def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): - drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) - with zero_first(is_main_process()): - train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes) - if eval_dataset: - eval_dataset = eval_dataset.filter( - drop_long, num_proc=cfg.dataset_processes - ) - - if cfg.group_by_length: - train_dataset = train_dataset.map( - add_length, num_proc=cfg.dataset_processes - ) - - if cfg.sample_packing: - train_dataset = train_dataset.map( - add_position_ids, num_proc=cfg.dataset_processes - ) - if cfg.eval_sample_packing is not False: - if eval_dataset: - eval_dataset = eval_dataset.map( - add_position_ids, num_proc=cfg.dataset_processes - ) - - # Phi doesn't want the attention_mask feature when training - if ( - "CodeGenTokenizer" in tokenizer.__class__.__name__ - or (cfg.is_mistral_derived_model and cfg.flash_attention) - or cfg.model_config_type == "mamba" - ): - train_dataset = train_dataset.remove_columns("attention_mask") - if eval_dataset: - eval_dataset = eval_dataset.remove_columns("attention_mask") - - return train_dataset, eval_dataset - - -def calculate_total_num_steps(cfg, train_dataset, update=True): - if not cfg.total_num_tokens: - total_num_tokens = np.sum( - train_dataset.data.column("input_ids") - .to_pandas() - .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda - .values - ) - LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True) - if update: - cfg.total_num_tokens = total_num_tokens - - skip_estimates = cfg.model_config_type == "mamba" - - if not skip_estimates and not cfg.total_supervised_tokens: - total_supervised_tokens = ( - train_dataset.data.column("labels") - .to_pandas() - .apply(lambda x: np.sum(np.array(x) != -100)) - .sum() - ) - LOG.debug( - f"`total_supervised_tokens: {total_supervised_tokens}`", - main_process_only=True, - ) - if update: - cfg.total_supervised_tokens = total_supervised_tokens - - if not skip_estimates and cfg.sample_packing: - # we have to drop anything longer then sequence len otherwise - # flash attention with position ids fails - - if cfg.sample_packing_eff_est: - total_num_steps = ( - # match count to len est in dataloader - ( - math.floor( - 0.99 - * cfg.total_num_tokens - / cfg.sample_packing_eff_est - / cfg.sequence_len - // cfg.batch_size - // int(os.environ.get("WORLD_SIZE", 1)) - ) - - 1 - ) - * cfg.num_epochs - ) - LOG.debug( - f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}", - main_process_only=True, - ) - else: - sampler = MultipackBatchSampler( - sampler=RandomSampler(train_dataset), - batch_size=cfg.micro_batch_size, - drop_last=True, - batch_max_len=cfg.micro_batch_size - * (cfg.max_packed_sequence_len or cfg.sequence_len), - lengths=( - train_dataset.data.column("position_ids") - .to_pandas() - .apply(lambda x: x[-1] + 1) - .values - ), - ) - - data_loader = DataLoader( - train_dataset.remove_columns(["length"]), - batch_sampler=sampler, - ) - data_loader_len = len(data_loader) - actual_eff = sampler.efficiency() - LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) - # FIXME: is there a bug here somewhere? the total num steps depends - # on the agreed on value for sample_packing_eff_est - total_num_steps = int( - math.floor( - data_loader_len - * cfg.num_epochs - / int(os.environ.get("WORLD_SIZE", 1)) - ) - ) - - def calc_sample_packing_eff_est(estimates: List[float]): - LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") - return max(estimates) - - sample_packing_actual_eff_all = reduce_and_broadcast( - lambda: actual_eff, - calc_sample_packing_eff_est, - ) - sample_packing_eff_est = ( - math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0 - ) - if update: - cfg.sample_packing_eff_est = sample_packing_eff_est - LOG.debug( - f"sample_packing_eff_est: {cfg.sample_packing_eff_est}", - main_process_only=True, - ) - else: - total_num_steps = int( - math.ceil( - len(train_dataset) - * cfg.num_epochs - / int(os.environ.get("WORLD_SIZE", 1)) - / cfg.batch_size - ) - ) - LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True) - return total_num_steps - - -def setup_fsdp_envs(cfg): - os.environ["ACCELERATE_USE_FSDP"] = "true" - if cfg.fsdp_config.fsdp_offload_params: - os.environ["FSDP_OFFLOAD_PARAMS"] = "true" - if cfg.fsdp_config.fsdp_sync_module_states: - os.environ["FSDP_SYNC_MODULE_STATES"] = "true" - if cfg.fsdp_config.fsdp_state_dict_type: - os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type - if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap: - os.environ[ - "FSDP_TRANSFORMER_CLS_TO_WRAP" - ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap - - -def prepare_optim_env(cfg): - if cfg.fsdp: - setup_fsdp_envs(cfg) - elif cfg.deepspeed: - os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" - os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed - - -def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) - trainer_builder.train_dataset = train_dataset - trainer_builder.eval_dataset = eval_dataset - - return trainer_builder.build(total_num_steps) diff --git a/axolotl/src/axolotl/utils/wandb_.py b/axolotl/src/axolotl/utils/wandb_.py deleted file mode 100644 index 327dd9b6348159942de6d991cd7fd1730d2f5792..0000000000000000000000000000000000000000 --- a/axolotl/src/axolotl/utils/wandb_.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Module for wandb utilities""" - -import os - -from axolotl.utils.dict import DictDefault - - -def setup_wandb_env_vars(cfg: DictDefault): - for key in cfg.keys(): - if key.startswith("wandb_"): - value = cfg.get(key, "") - - if value and isinstance(value, str) and len(value) > 0: - os.environ[key.upper()] = value - - # Enable wandb if project name is present - if cfg.wandb_project and len(cfg.wandb_project) > 0: - cfg.use_wandb = True - os.environ.pop("WANDB_DISABLED", None) # Remove if present - else: - os.environ["WANDB_DISABLED"] = "true" diff --git a/axolotl/tests/e2e/.gitignore b/axolotl/tests/e2e/.gitignore deleted file mode 100644 index ad1727ec598318511c9b0620ec5a85e4a243243c..0000000000000000000000000000000000000000 --- a/axolotl/tests/e2e/.gitignore +++ /dev/null @@ -1 +0,0 @@ -last_run_prepared diff --git a/axolotl/tests/e2e/__init__.py b/axolotl/tests/e2e/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/axolotl/tests/e2e/test_fused_llama.py b/axolotl/tests/e2e/test_fused_llama.py deleted file mode 100644 index 513df69f9d5ffac7dd3b93a3383fd147a99090c8..0000000000000000000000000000000000000000 --- a/axolotl/tests/e2e/test_fused_llama.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -E2E tests for lora llama -""" - -import logging -import os -import unittest -from pathlib import Path - -from transformers.utils import is_torch_bf16_gpu_available - -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs -from axolotl.train import train -from axolotl.utils.config import normalize_config -from axolotl.utils.dict import DictDefault - -from .utils import with_temp_dir - -LOG = logging.getLogger("axolotl.tests.e2e") -os.environ["WANDB_DISABLED"] = "true" - - -class TestFusedLlama(unittest.TestCase): - """ - Test case for Llama models using Fused layers - """ - - @with_temp_dir - def test_fft_packing(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "JackFram/llama-68m", - "flash_attention": True, - "flash_attn_fuse_qkv": True, - "flash_attn_fuse_mlp": True, - "sample_packing": True, - "sequence_len": 1024, - "val_set_size": 0.1, - "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "num_epochs": 2, - "micro_batch_size": 2, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch", - "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "eval_steps": 10, - } - ) - if is_torch_bf16_gpu_available(): - cfg.bf16 = True - else: - cfg.fp16 = True - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/axolotl/tests/e2e/test_lora_llama.py b/axolotl/tests/e2e/test_lora_llama.py deleted file mode 100644 index 9d795601a42a5758db875edcfcddc270d6055fa9..0000000000000000000000000000000000000000 --- a/axolotl/tests/e2e/test_lora_llama.py +++ /dev/null @@ -1,160 +0,0 @@ -""" -E2E tests for lora llama -""" - -import logging -import os -import unittest -from pathlib import Path - -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs -from axolotl.train import train -from axolotl.utils.config import normalize_config -from axolotl.utils.dict import DictDefault - -from .utils import with_temp_dir - -LOG = logging.getLogger("axolotl.tests.e2e") -os.environ["WANDB_DISABLED"] = "true" - - -class TestLoraLlama(unittest.TestCase): - """ - Test case for Llama models using LoRA - """ - - @with_temp_dir - def test_lora(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "JackFram/llama-68m", - "tokenizer_type": "LlamaTokenizer", - "sequence_len": 1024, - "load_in_8bit": True, - "adapter": "lora", - "lora_r": 32, - "lora_alpha": 64, - "lora_dropout": 0.05, - "lora_target_linear": True, - "val_set_size": 0.1, - "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "num_epochs": 2, - "micro_batch_size": 8, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch", - "lr_scheduler": "cosine", - } - ) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() - - @with_temp_dir - def test_lora_packing(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "JackFram/llama-68m", - "tokenizer_type": "LlamaTokenizer", - "sequence_len": 1024, - "sample_packing": True, - "flash_attention": True, - "load_in_8bit": True, - "adapter": "lora", - "lora_r": 32, - "lora_alpha": 64, - "lora_dropout": 0.05, - "lora_target_linear": True, - "val_set_size": 0.1, - "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "num_epochs": 2, - "micro_batch_size": 8, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch", - "lr_scheduler": "cosine", - "bf16": True, - } - ) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() - - @with_temp_dir - def test_lora_gptq(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ", - "model_type": "AutoModelForCausalLM", - "tokenizer_type": "LlamaTokenizer", - "sequence_len": 1024, - "sample_packing": True, - "flash_attention": True, - "load_in_8bit": True, - "adapter": "lora", - "gptq": True, - "gptq_disable_exllama": True, - "lora_r": 32, - "lora_alpha": 64, - "lora_dropout": 0.05, - "lora_target_linear": True, - "val_set_size": 0.1, - "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "num_epochs": 2, - "save_steps": 0.5, - "micro_batch_size": 8, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch", - "lr_scheduler": "cosine", - } - ) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() diff --git a/axolotl/tests/e2e/test_mamba.py b/axolotl/tests/e2e/test_mamba.py deleted file mode 100644 index 463b0ddac09be65d91e5b0919600d8b19895585c..0000000000000000000000000000000000000000 --- a/axolotl/tests/e2e/test_mamba.py +++ /dev/null @@ -1,65 +0,0 @@ -""" -E2E tests for lora llama -""" - -import logging -import os -import unittest -from pathlib import Path - -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs -from axolotl.train import train -from axolotl.utils.config import normalize_config -from axolotl.utils.dict import DictDefault - -from .utils import with_temp_dir - -LOG = logging.getLogger("axolotl.tests.e2e") -os.environ["WANDB_DISABLED"] = "true" - - -class TestMistral(unittest.TestCase): - """ - Test case for Llama models using LoRA - """ - - @with_temp_dir - def test_fft(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "state-spaces/mamba-130m", - "model_type": "MambaLMHeadModel", - "tokenizer_type": "AutoTokenizer", - "tokenizer_config": "EleutherAI/gpt-neox-20b", - "flash_attention": False, - "sequence_len": 1024, - "load_in_8bit": False, - "val_set_size": 0.0, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "gradient_checkpointing": False, - "num_epochs": 2, - "micro_batch_size": 2, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch", - "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "eval_steps": None, - "save_safetensors": False, - } - ) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/axolotl/tests/e2e/test_mistral.py b/axolotl/tests/e2e/test_mistral.py deleted file mode 100644 index 57d85e51eb9137f3032757f7da8972e9b69f779e..0000000000000000000000000000000000000000 --- a/axolotl/tests/e2e/test_mistral.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -E2E tests for lora llama -""" - -import logging -import os -import unittest -from pathlib import Path - -from transformers.utils import is_torch_bf16_gpu_available - -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs -from axolotl.train import train -from axolotl.utils.config import normalize_config -from axolotl.utils.dict import DictDefault - -from .utils import with_temp_dir - -LOG = logging.getLogger("axolotl.tests.e2e") -os.environ["WANDB_DISABLED"] = "true" - - -class TestMistral(unittest.TestCase): - """ - Test case for Llama models using LoRA - """ - - @with_temp_dir - def test_lora(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "openaccess-ai-collective/tiny-mistral", - "flash_attention": True, - "sequence_len": 1024, - "load_in_8bit": True, - "adapter": "lora", - "lora_r": 32, - "lora_alpha": 64, - "lora_dropout": 0.05, - "lora_target_linear": True, - "val_set_size": 0.1, - "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "num_epochs": 2, - "micro_batch_size": 2, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch", - "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "eval_steps": 10, - } - ) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() - - @with_temp_dir - def test_ft(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "openaccess-ai-collective/tiny-mistral", - "flash_attention": True, - "sequence_len": 1024, - "val_set_size": 0.1, - "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "num_epochs": 2, - "micro_batch_size": 2, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch", - "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "eval_steps": 10, - } - ) - if is_torch_bf16_gpu_available(): - cfg.bf16 = True - else: - cfg.fp16 = True - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/axolotl/tests/e2e/test_mistral_samplepack.py b/axolotl/tests/e2e/test_mistral_samplepack.py deleted file mode 100644 index cefbd7dc0605c83097e9d24a5f64954bd46e19b6..0000000000000000000000000000000000000000 --- a/axolotl/tests/e2e/test_mistral_samplepack.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -E2E tests for lora llama -""" - -import logging -import os -import unittest -from pathlib import Path - -from transformers.utils import is_torch_bf16_gpu_available - -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs -from axolotl.train import train -from axolotl.utils.config import normalize_config -from axolotl.utils.dict import DictDefault - -from .utils import with_temp_dir - -LOG = logging.getLogger("axolotl.tests.e2e") -os.environ["WANDB_DISABLED"] = "true" - - -class TestMistral(unittest.TestCase): - """ - Test case for Llama models using LoRA - """ - - @with_temp_dir - def test_lora_packing(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "openaccess-ai-collective/tiny-mistral", - "flash_attention": True, - "sample_packing": True, - "sequence_len": 1024, - "load_in_8bit": True, - "adapter": "lora", - "lora_r": 32, - "lora_alpha": 64, - "lora_dropout": 0.05, - "lora_target_linear": True, - "val_set_size": 0.1, - "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "num_epochs": 2, - "micro_batch_size": 2, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch", - "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "eval_steps": 10, - } - ) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() - - @with_temp_dir - def test_ft_packing(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "openaccess-ai-collective/tiny-mistral", - "flash_attention": True, - "sample_packing": True, - "sequence_len": 1024, - "val_set_size": 0.1, - "special_tokens": { - "unk_token": "", - "bos_token": "", - "eos_token": "", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "num_epochs": 2, - "micro_batch_size": 2, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch", - "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "eval_steps": 10, - } - ) - if is_torch_bf16_gpu_available(): - cfg.bf16 = True - else: - cfg.fp16 = True - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/axolotl/tests/e2e/test_phi.py b/axolotl/tests/e2e/test_phi.py deleted file mode 100644 index b735236ebfe8d820476396b32ab4ce9531a1baa0..0000000000000000000000000000000000000000 --- a/axolotl/tests/e2e/test_phi.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -E2E tests for lora llama -""" - -import logging -import os -import unittest -from pathlib import Path - -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs -from axolotl.train import train -from axolotl.utils.config import normalize_config -from axolotl.utils.dict import DictDefault - -from .utils import with_temp_dir - -LOG = logging.getLogger("axolotl.tests.e2e") -os.environ["WANDB_DISABLED"] = "true" - - -class TestPhi(unittest.TestCase): - """ - Test case for Llama models using LoRA - """ - - @with_temp_dir - def test_ft(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "microsoft/phi-1_5", - "trust_remote_code": True, - "model_type": "PhiForCausalLM", - "tokenizer_type": "AutoTokenizer", - "sequence_len": 512, - "sample_packing": False, - "load_in_8bit": False, - "adapter": None, - "val_set_size": 0.1, - "special_tokens": { - "unk_token": "<|endoftext|>", - "bos_token": "<|endoftext|>", - "eos_token": "<|endoftext|>", - "pad_token": "<|endoftext|>", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "dataset_shard_num": 10, - "dataset_shard_idx": 0, - "num_epochs": 1, - "micro_batch_size": 1, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_bnb_8bit", - "lr_scheduler": "cosine", - "bf16": True, - } - ) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() - - @with_temp_dir - def test_ft_packed(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "microsoft/phi-1_5", - "trust_remote_code": True, - "model_type": "PhiForCausalLM", - "tokenizer_type": "AutoTokenizer", - "sequence_len": 512, - "sample_packing": True, - "load_in_8bit": False, - "adapter": None, - "val_set_size": 0.1, - "special_tokens": { - "unk_token": "<|endoftext|>", - "bos_token": "<|endoftext|>", - "eos_token": "<|endoftext|>", - "pad_token": "<|endoftext|>", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "dataset_shard_num": 10, - "dataset_shard_idx": 0, - "num_epochs": 1, - "micro_batch_size": 1, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_bnb_8bit", - "lr_scheduler": "cosine", - "bf16": True, - } - ) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/axolotl/tests/e2e/test_resume.py b/axolotl/tests/e2e/test_resume.py deleted file mode 100644 index 98ec3ac6bfec2700e2e5cb1ba81ced73ab5fe1f4..0000000000000000000000000000000000000000 --- a/axolotl/tests/e2e/test_resume.py +++ /dev/null @@ -1,95 +0,0 @@ -""" -E2E tests for resuming training -""" - -import logging -import os -import re -import subprocess -import unittest -from pathlib import Path - -from transformers.utils import is_torch_bf16_gpu_available - -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs -from axolotl.train import train -from axolotl.utils.config import normalize_config -from axolotl.utils.dict import DictDefault - -from .utils import most_recent_subdir, with_temp_dir - -LOG = logging.getLogger("axolotl.tests.e2e") -os.environ["WANDB_DISABLED"] = "true" - - -class TestResumeLlama(unittest.TestCase): - """ - Test case for resuming training of llama models - """ - - @with_temp_dir - def test_resume_qlora(self, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": "JackFram/llama-68m", - "tokenizer_type": "LlamaTokenizer", - "sequence_len": 1024, - "sample_packing": True, - "flash_attention": True, - "load_in_4bit": True, - "adapter": "qlora", - "lora_r": 32, - "lora_alpha": 64, - "lora_dropout": 0.05, - "lora_target_linear": True, - "val_set_size": 0.1, - "special_tokens": {}, - "datasets": [ - { - "path": "vicgalle/alpaca-gpt4", - "type": "alpaca", - }, - ], - "num_epochs": 2, - "micro_batch_size": 1, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch", - "lr_scheduler": "cosine", - "save_steps": 10, - "save_total_limit": 5, - "max_steps": 40, - } - ) - if is_torch_bf16_gpu_available(): - cfg.bf16 = True - else: - cfg.fp16 = True - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - - resume_cfg = cfg | DictDefault( - { - "resume_from_checkpoint": f"{temp_dir}/checkpoint-30/", - } - ) - normalize_config(resume_cfg) - cli_args = TrainerCliArgs() - - train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() - - tb_log_path_1 = most_recent_subdir(temp_dir + "/runs") - cmd = f"tensorboard --inspect --logdir {tb_log_path_1}" - res = subprocess.run( - cmd, shell=True, text=True, capture_output=True, check=True - ) - pattern = r"first_step\s+(\d+)" - first_steps = int(re.findall(pattern, res.stdout)[0]) - assert first_steps == 31 diff --git a/axolotl/tests/e2e/utils.py b/axolotl/tests/e2e/utils.py deleted file mode 100644 index 203824fc9d0dd7457ea5d78cb37b2b7e7b9ab9ac..0000000000000000000000000000000000000000 --- a/axolotl/tests/e2e/utils.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -helper utils for tests -""" -import os -import shutil -import tempfile -from functools import wraps -from pathlib import Path - - -def with_temp_dir(test_func): - @wraps(test_func) - def wrapper(*args, **kwargs): - # Create a temporary directory - temp_dir = tempfile.mkdtemp() - try: - # Pass the temporary directory to the test function - test_func(*args, temp_dir=temp_dir, **kwargs) - finally: - # Clean up the directory after the test - shutil.rmtree(temp_dir) - - return wrapper - - -def most_recent_subdir(path): - base_path = Path(path) - subdirectories = [d for d in base_path.iterdir() if d.is_dir()] - if not subdirectories: - return None - subdir = max(subdirectories, key=os.path.getctime) - - return subdir diff --git a/axolotl/tests/fixtures/alpaca/alpaca.json b/axolotl/tests/fixtures/alpaca/alpaca.json deleted file mode 100644 index 912643d17fa230421cc64f1a1326a12b03ff4f4e..0000000000000000000000000000000000000000 --- a/axolotl/tests/fixtures/alpaca/alpaca.json +++ /dev/null @@ -1,12 +0,0 @@ -[ - { - "instruction": "You will be given a series of words. Output these words in reverse order, with each word on its own line.", - "input": "Words: ['Hello', 'world'].", - "output": "['world', 'Hello']" - }, - { - "instruction": "In this task, you're given a short description of an event. Your job is to order the steps involved in the event from first to last. Note that there may be multiple correct answers for each event.", - "input": "Description: A man walks into a bar and orders a drink. He pays for his drink and leaves the bar.", - "output": "1. The man walks into the bar.\n2. He orders a drink.\n3. He pays for his drink.\n4. He leaves the bar." - } -] diff --git a/axolotl/tests/fixtures/conversation.json b/axolotl/tests/fixtures/conversation.json deleted file mode 100644 index bffb118a4196fcaa2e629c59c51e4c2c4e127c52..0000000000000000000000000000000000000000 --- a/axolotl/tests/fixtures/conversation.json +++ /dev/null @@ -1,52 +0,0 @@ -{ - "conversations": [ - { - "from": "human", - "value": "how could i implement a minesweeper algorithm that utilises algebraic topology to solve boards?" - }, - { - "from": "gpt", - "value": "It's possible to use algebraic topology to solve Minesweeper, but it would likely be quite complex. Algebraic topology is a branch of mathematics that studies the properties of spaces that are preserved under continuous deformations, such as stretching and bending. In the context of Minesweeper, this could be used to analyze the structure of the game board and determine the locations of mines based on certain topological invariants. However, implementing this would require a deep understanding of both algebraic topology and the game of Minesweeper. It's not something that can be explained in a simple answer." - }, - { - "from": "human", - "value": "could you give me the gist of how it could though?" - }, - { - "from": "gpt", - "value": "Sure. The basic idea would be to use algebraic topology to analyze the structure of the Minesweeper game board and determine the locations of mines based on certain topological invariants. This could involve representing the game board as a topological space and using topological invariants such as homotopy groups and Betti numbers to infer information about the location of mines on the board. However, as I mentioned before, implementing this would require a deep understanding of both algebraic topology and the game of Minesweeper, so it's not something that can be explained in a simple answer." - }, - { - "from": "human", - "value": "give me a specific example" - }, - { - "from": "gpt", - "value": "As a specific example, suppose we have a Minesweeper game board with the following configuration:\n```\n 1 1 1 1 1\n 1 ? ? ? 1\n 1 ? ? ? 1\n 1 ? ? ? 1\n 1 1 1 1 1\n```\nIn this case, the game board can be represented as a topological space, where each square on the board corresponds to a point in the space. The mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space. Using algebraic topology, we can analyze the structure of this space and determine the locations of the mines based on certain topological invariants.\n\nFor instance, one way to do this is to compute the homotopy groups of the space. Homotopy groups are a way of describing the holes in a space, and they can be used to infer information about the locations of the mines on the Minesweeper board. In the case of the board above, we can compute the first homotopy group of the space, which describes the number and \"shape\" of the holes in the space. In this case, the first homotopy group would tell us that there are four holes in the space, arranged in a square shape.\n\nBased on this information, we can infer that the mines in the game are located in the four corners of the board, as shown below:\n```\n 1 1 1 1 1\n 1 M ? ? 1\n 1 ? ? ? 1\n 1 ? ? M 1\n 1 1 1 1 1\n```\nOf course, this is just one example of how algebraic topology could be used to solve Minesweeper. There are many other possible approaches, and the specific details of how to implement them would depend on the specific game board and the topological invariants being used. As I mentioned before, implementing a solution like this would require a deep understanding of both algebraic topology and the game of Minesweeper." - }, - { - "from": "human", - "value": "assuming i was someone with a deep understanding of both, could you explain the same concept to me with that assumption?" - }, - { - "from": "gpt", - "value": "Sure. Assuming you have a deep understanding of both algebraic topology and the game of Minesweeper, here is how you might use algebraic topology to solve a Minesweeper game board:\n\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space.\n2. Use topological invariants to infer information about the locations of the mines on the board. There are many possible topological invariants that could be used for this, such as homotopy groups, Betti numbers, or persistent homology.\n3. Use the information inferred from the topological invariants to determine the locations of the mines on the board. For example, if the first homotopy group of the space indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\n\nOf course, this is just one possible approach to using algebraic topology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the topological invariants being used. However, this should give you a general idea of how algebraic topology can be used to solve Minesweeper." - }, - { - "from": "human", - "value": "how would i do it with persistent homology or any of the other topological invariants?" - }, - { - "from": "gpt", - "value": "To use persistent homology or any of the other topological invariants to solve a Minesweeper game board, you would follow the same basic steps as described above, but with some differences in the specific details.\n\nFor instance, here is how you might use persistent homology to solve a Minesweeper game board:\n\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space.\n2. Compute the persistent homology of the space. Persistent homology is a way of describing the holes in a space over time, and it can be used to infer information about the locations of the mines on the Minesweeper board.\n3. Use the information from the persistent homology to determine the locations of the mines on the board. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\n\nOf course, this is just one possible approach to using persistent homology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the specific details of the persistent homology computation. However, this should give you a general idea of how persistent homology can be used to solve Minesweeper." - }, - { - "from": "human", - "value": "how could i implement this in python?" - }, - { - "from": "gpt", - "value": "To implement a Minesweeper solver that uses algebraic topology in Python, you would need to do the following:\n\n1. Install a Python package for algebraic topology. There are several options available, such as Gudhi, Dionysus, or scikit-tda. These packages provide algorithms and data structures for computing topological invariants, such as homotopy groups, Betti numbers, or persistent homology.\n2. Represent the Minesweeper game board as a topological space. This could be done using a data structure provided by the algebraic topology package, such as a simplicial complex or a Vietoris-Rips complex. Each square on the board would correspond to a point in the space, and the mines on the board would be represented as \"holes\" in the space.\n3. Compute the topological invariants of the space. This could be done using algorithms provided by the algebraic topology package. For example, you could compute the persistent homology of the space using the persistent homology algorithm.\n4. Use the information from the topological invariants to determine the locations of the mines on the board. This could be done using a combination of mathematical reasoning and programming. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then you could use a piece of code to infer that the mines are located in the four corners of the board.\n5. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you could use a piece of code to automatically click on the safe squares and clear the board.\n\nOf course, this is just one possible approach to implementing a Minesweeper solver that uses algebraic topology in Python. There may be other ways to do it, depending on the specific details of the implementation. However, this should give you a general idea of how to get started with such a project." - } - ] -} diff --git a/axolotl/tests/fixtures/conversation.missingturns.json b/axolotl/tests/fixtures/conversation.missingturns.json deleted file mode 100644 index b5b962d859bc94804885d2fbd6aca0910413b4c3..0000000000000000000000000000000000000000 --- a/axolotl/tests/fixtures/conversation.missingturns.json +++ /dev/null @@ -1,52 +0,0 @@ -{ - "conversations": [ - { - "from": "human", - "value": "how could i implement a minesweeper algorithm that utilises algebraic topology to solve boards?" - }, - { - "from": "gpt", - "value": "" - }, - { - "from": "human", - "value": "could you give me the gist of how it could though?" - }, - { - "from": "gpt", - "value": "Sure. The basic idea would be to use algebraic topology to analyze the structure of the Minesweeper game board and determine the locations of mines based on certain topological invariants. This could involve representing the game board as a topological space and using topological invariants such as homotopy groups and Betti numbers to infer information about the location of mines on the board. However, as I mentioned before, implementing this would require a deep understanding of both algebraic topology and the game of Minesweeper, so it's not something that can be explained in a simple answer." - }, - { - "from": "human", - "value": "give me a specific example" - }, - { - "from": "gpt", - "value": "As a specific example, suppose we have a Minesweeper game board with the following configuration:\n```\n 1 1 1 1 1\n 1 ? ? ? 1\n 1 ? ? ? 1\n 1 ? ? ? 1\n 1 1 1 1 1\n```\nIn this case, the game board can be represented as a topological space, where each square on the board corresponds to a point in the space. The mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space. Using algebraic topology, we can analyze the structure of this space and determine the locations of the mines based on certain topological invariants.\n\nFor instance, one way to do this is to compute the homotopy groups of the space. Homotopy groups are a way of describing the holes in a space, and they can be used to infer information about the locations of the mines on the Minesweeper board. In the case of the board above, we can compute the first homotopy group of the space, which describes the number and \"shape\" of the holes in the space. In this case, the first homotopy group would tell us that there are four holes in the space, arranged in a square shape.\n\nBased on this information, we can infer that the mines in the game are located in the four corners of the board, as shown below:\n```\n 1 1 1 1 1\n 1 M ? ? 1\n 1 ? ? ? 1\n 1 ? ? M 1\n 1 1 1 1 1\n```\nOf course, this is just one example of how algebraic topology could be used to solve Minesweeper. There are many other possible approaches, and the specific details of how to implement them would depend on the specific game board and the topological invariants being used. As I mentioned before, implementing a solution like this would require a deep understanding of both algebraic topology and the game of Minesweeper." - }, - { - "from": "human", - "value": "assuming i was someone with a deep understanding of both, could you explain the same concept to me with that assumption?" - }, - { - "from": "gpt", - "value": "Sure. Assuming you have a deep understanding of both algebraic topology and the game of Minesweeper, here is how you might use algebraic topology to solve a Minesweeper game board:\n\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space.\n2. Use topological invariants to infer information about the locations of the mines on the board. There are many possible topological invariants that could be used for this, such as homotopy groups, Betti numbers, or persistent homology.\n3. Use the information inferred from the topological invariants to determine the locations of the mines on the board. For example, if the first homotopy group of the space indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\n\nOf course, this is just one possible approach to using algebraic topology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the topological invariants being used. However, this should give you a general idea of how algebraic topology can be used to solve Minesweeper." - }, - { - "from": "human", - "value": "how would i do it with persistent homology or any of the other topological invariants?" - }, - { - "from": "gpt", - "value": "To use persistent homology or any of the other topological invariants to solve a Minesweeper game board, you would follow the same basic steps as described above, but with some differences in the specific details.\n\nFor instance, here is how you might use persistent homology to solve a Minesweeper game board:\n\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space.\n2. Compute the persistent homology of the space. Persistent homology is a way of describing the holes in a space over time, and it can be used to infer information about the locations of the mines on the Minesweeper board.\n3. Use the information from the persistent homology to determine the locations of the mines on the board. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\n\nOf course, this is just one possible approach to using persistent homology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the specific details of the persistent homology computation. However, this should give you a general idea of how persistent homology can be used to solve Minesweeper." - }, - { - "from": "human", - "value": "how could i implement this in python?" - }, - { - "from": "gpt", - "value": "To implement a Minesweeper solver that uses algebraic topology in Python, you would need to do the following:\n\n1. Install a Python package for algebraic topology. There are several options available, such as Gudhi, Dionysus, or scikit-tda. These packages provide algorithms and data structures for computing topological invariants, such as homotopy groups, Betti numbers, or persistent homology.\n2. Represent the Minesweeper game board as a topological space. This could be done using a data structure provided by the algebraic topology package, such as a simplicial complex or a Vietoris-Rips complex. Each square on the board would correspond to a point in the space, and the mines on the board would be represented as \"holes\" in the space.\n3. Compute the topological invariants of the space. This could be done using algorithms provided by the algebraic topology package. For example, you could compute the persistent homology of the space using the persistent homology algorithm.\n4. Use the information from the topological invariants to determine the locations of the mines on the board. This could be done using a combination of mathematical reasoning and programming. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then you could use a piece of code to infer that the mines are located in the four corners of the board.\n5. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you could use a piece of code to automatically click on the safe squares and clear the board.\n\nOf course, this is just one possible approach to implementing a Minesweeper solver that uses algebraic topology in Python. There may be other ways to do it, depending on the specific details of the implementation. However, this should give you a general idea of how to get started with such a project." - } - ] -} diff --git a/axolotl/tests/fixtures/conversation.tokenized.json b/axolotl/tests/fixtures/conversation.tokenized.json deleted file mode 100644 index 4b7334cc154310025e28325e249f81ab0ac65652..0000000000000000000000000000000000000000 --- a/axolotl/tests/fixtures/conversation.tokenized.json +++ /dev/null @@ -1 +0,0 @@ -{"input_ids": [1, 319, 13563, 1546, 263, 12758, 1404, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 1404, 29915, 29879, 5155, 29889, 29871, 3148, 1001, 29901, 920, 1033, 474, 2334, 263, 29086, 705, 11356, 5687, 393, 3667, 4637, 21531, 20159, 304, 4505, 1045, 3163, 29973, 29871, 319, 1799, 9047, 13566, 29901, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 21106, 29879, 29958, 2, 3148, 1001, 29901, 1033, 366, 2367, 592, 278, 330, 391, 310, 920, 372, 1033, 2466, 29973, 29871, 319, 1799, 9047, 13566, 29901, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 21106, 29879, 29958, 2, 3148, 1001, 29901, 2367, 592, 263, 2702, 1342, 29871, 319, 1799, 9047, 13566, 29901, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, 3148, 1001, 29901, 10241, 474, 471, 4856, 411, 263, 6483, 8004, 310, 1716, 29892, 1033, 366, 5649, 278, 1021, 6964, 304, 592, 411, 393, 11833, 29973, 29871, 319, 1799, 9047, 13566, 29901, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, 3148, 1001, 29901, 920, 723, 474, 437, 372, 411, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 29973, 29871, 319, 1799, 9047, 13566, 29901, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, 3148, 1001, 29901, 920, 1033, 474, 2334, 445, 297, 3017, 29973, 29871, 319, 1799, 9047, 13566, 29901, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 21106, 29879, 29958, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 21106, 29879, 29958, 2]} diff --git a/axolotl/tests/fixtures/conversation.tokenized_llama2chat.json b/axolotl/tests/fixtures/conversation.tokenized_llama2chat.json deleted file mode 100644 index a4109fc3e206a455b02f9607c7f325afa71435cb..0000000000000000000000000000000000000000 --- a/axolotl/tests/fixtures/conversation.tokenized_llama2chat.json +++ /dev/null @@ -1 +0,0 @@ -{"input_ids": [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 29892, 3390, 1319, 322, 15993, 20255, 29889, 29849, 1234, 408, 1371, 3730, 408, 1950, 29892, 1550, 1641, 9109, 29889, 3575, 6089, 881, 451, 3160, 738, 10311, 1319, 29892, 443, 621, 936, 29892, 11021, 391, 29892, 7916, 391, 29892, 304, 27375, 29892, 18215, 29892, 470, 27302, 2793, 29889, 3529, 9801, 393, 596, 20890, 526, 5374, 635, 443, 5365, 1463, 322, 6374, 297, 5469, 29889, 13, 13, 3644, 263, 1139, 947, 451, 1207, 738, 4060, 29892, 470, 338, 451, 2114, 1474, 16165, 261, 296, 29892, 5649, 2020, 2012, 310, 22862, 1554, 451, 1959, 29889, 960, 366, 1016, 29915, 29873, 1073, 278, 1234, 304, 263, 1139, 29892, 3113, 1016, 29915, 29873, 6232, 2089, 2472, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 3525, 1033, 474, 2334, 263, 29086, 705, 11356, 5687, 393, 3667, 4637, 21531, 20159, 304, 4505, 1045, 3163, 29973, 518, 29914, 25580, 29962, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 29871, 2, 1, 518, 25580, 29962, 1033, 366, 2367, 592, 278, 330, 391, 310, 920, 372, 1033, 2466, 29973, 518, 29914, 25580, 29962, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 29871, 2, 1, 518, 25580, 29962, 2367, 592, 263, 2702, 1342, 518, 29914, 25580, 29962, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 29871, 2, 1, 518, 25580, 29962, 10241, 474, 471, 4856, 411, 263, 6483, 8004, 310, 1716, 29892, 1033, 366, 5649, 278, 1021, 6964, 304, 592, 411, 393, 11833, 29973, 518, 29914, 25580, 29962, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 29871, 2, 1, 518, 25580, 29962, 920, 723, 474, 437, 372, 411, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 29973, 518, 29914, 25580, 29962, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 29871, 2, 1, 518, 25580, 29962, 920, 1033, 474, 2334, 445, 297, 3017, 29973, 518, 29914, 25580, 29962, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 29871, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 29871, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 29871, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 29871, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 29871, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 29871, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 29871, 2, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], "attention_mask": [true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false]} diff --git a/axolotl/tests/monkeypatch/test_llama_attn_hijack_flash.py b/axolotl/tests/monkeypatch/test_llama_attn_hijack_flash.py deleted file mode 100644 index 289c01a8632bbf0cc55b3ae42f7ecd08bcbeeeca..0000000000000000000000000000000000000000 --- a/axolotl/tests/monkeypatch/test_llama_attn_hijack_flash.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -Unit tests for the monkeypatch utils -""" -import unittest - -import torch - -from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids - - -class TestMonkeyPatchUtils(unittest.TestCase): - """ - Unit test class for monkeypatch utils - """ - - def test_get_cu_seqlens_1d(self): - attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]]) - target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32) - self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res)) - - def test_get_cu_seqlens_from_pos_ids_1d(self): - position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]]) - target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32) - self.assertTrue( - torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res) - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/axolotl/tests/test_data.py b/axolotl/tests/test_data.py deleted file mode 100644 index 16af089a063880b75b8ad1451de1349b9ffdee63..0000000000000000000000000000000000000000 --- a/axolotl/tests/test_data.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -test module for the axolotl.utis.data module -""" -import unittest - -from transformers import LlamaTokenizer - -from axolotl.utils.data import encode_pretraining, md5 - - -class TestEncodePretraining(unittest.TestCase): - """ - test class for encode pretraining and md5 helper - """ - - def setUp(self): - self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b") - self.tokenizer.add_special_tokens( - { - "eos_token": "
", - "bos_token": "", - "unk_token": "", - "pad_token": "", - } - ) - self.max_tokens = 15 # set a small number for easy inspection - - def test_encode_pretraining(self): - examples = { - "text": [ - "Hello, world!", - "Nice to meet you.", - "lorem ipsum dolor sit amet.", - "Nice to meet you again!.", - "hello, hello", - ] - } - result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"]) - - self.assertEqual(len(result["input_ids"]), 3) - - # Assert the length of input_ids and attention_mask is correct - self.assertEqual(len(result["input_ids"][0]), self.max_tokens) - self.assertEqual(len(result["attention_mask"][0]), self.max_tokens) - - # Assert EOS and PAD tokens are correctly added - # hello world! is 4 tokens - self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id) - self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id) - self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id) - # second part, 5 tokens - self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id) - self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id) - self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id) - - def test_md5(self): - self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3") - self.assertEqual( - md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3" - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/axolotl/tests/test_dict.py b/axolotl/tests/test_dict.py deleted file mode 100644 index 8367e7c2a981f73157ef2174b7425fe696731acb..0000000000000000000000000000000000000000 --- a/axolotl/tests/test_dict.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Module for testing DictDefault class""" - - -import unittest - -import pytest - -from axolotl.utils.dict import DictDefault - - -class DictDefaultTest(unittest.TestCase): - """ - Test DictDefault class - """ - - def test_dict_default(self): - cfg = DictDefault( - { - "key_a": {"key_b": "value_a"}, - "key_c": "value_c", - "key_d": ["value_d", "value_e"], - } - ) - - assert ( - cfg.key_a.key_b == "value_a" - ), "DictDefault should return value for existing nested keys" - - assert ( - cfg.key_c == "value_c" - ), "DictDefault should return value for existing keys" - - assert ( - cfg.key_d[0] == "value_d" - ), "DictDefault should return value for existing keys in list" - - assert ( - "value_e" in cfg.key_d - ), "DictDefault should support in operator for existing keys in list" - - def test_dict_or_operator(self): - cfg = DictDefault( - { - "key_a": {"key_b": "value_a"}, - "key_c": "value_c", - "key_d": ["value_d", "value_e"], - "key_f": "value_f", - } - ) - - cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation - {"key_a": {"key_b": "value_b"}, "key_f": "value_g"} - ) - - assert ( - cfg.key_a.key_b == "value_b" - ), "DictDefault should support OR operator for existing nested keys" - - assert cfg.key_c == "value_c", "DictDefault should not delete existing key" - - assert cfg.key_d == [ - "value_d", - "value_e", - ], "DictDefault should not overwrite existing keys in list" - - assert ( - cfg.key_f == "value_g" - ), "DictDefault should support OR operator for existing key" - - def test_dict_missingkey(self): - cfg = DictDefault({}) - - assert cfg.random_key is None, "DictDefault should return None for missing keys" - - def test_dict_or(self): - cfg = DictDefault({}) | DictDefault({}) - - assert ( - cfg.random_key is None - ), "DictDefault should return None for missing keys after | operation" - - def test_dict_nested_missingparentkey(self): - """ - Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist. - """ - cfg = DictDefault({}) - - with pytest.raises( - AttributeError, - match=r"'NoneType' object has no attribute 'another_random_key'", - ): - cfg.random_key.another_random_key = "value" - - def test_dict_shorthand_assignment(self): - """ - Shorthand assignment is said to not be supported if subclassed. However, their example raises error instead of None. - This test ensures that it is supported for current implementation. - - Ref: https://github.com/mewwts/addict#default-values - """ - - cfg = DictDefault({"key_a": {"key_b": "value_a"}}) - - cfg.key_a.key_b = "value_b" - - assert cfg.key_a.key_b == "value_b", "Shorthand assignment should be supported" diff --git a/axolotl/tests/test_expand_mask.py b/axolotl/tests/test_expand_mask.py deleted file mode 100644 index 01241c2958b14460cb5cd3ba91671fc05759dda3..0000000000000000000000000000000000000000 --- a/axolotl/tests/test_expand_mask.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -Unit tests for the monkey patch for expand mask to handle packed sequences -""" -import unittest - -import torch - -from axolotl.monkeypatch.llama_expand_mask import _expand_mask - - -class TestExpandMask(unittest.TestCase): - """ - Test class for attention mask expansion for packed sequences - """ - - def test_output(self): - mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]]) - dtype = torch.float32 - expected_output = torch.tensor( - [ - [ - [ - [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38], - [0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38], - [0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38], - [-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00], - ] - ], - [ - [ - [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38], - [-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38], - [-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38], - [-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38], - ] - ], - ] - ) - # Check that the output matches the expected output - self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output)) - - -if __name__ == "__main__": - unittest.main() diff --git a/axolotl/tests/test_normalize_config.py b/axolotl/tests/test_normalize_config.py deleted file mode 100644 index 1397b23af752c4f63a3107f03f1bef2a72e8e20a..0000000000000000000000000000000000000000 --- a/axolotl/tests/test_normalize_config.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Test classes for checking functionality of the cfg normalization -""" -import unittest - -from axolotl.utils.config import normalize_config -from axolotl.utils.dict import DictDefault - - -class NormalizeConfigTestCase(unittest.TestCase): - """ - test class for normalize_config checks - """ - - def _get_base_cfg(self): - return DictDefault( - { - "base_model": "JackFram/llama-68m", - "base_model_config": "JackFram/llama-68m", - "tokenizer_type": "LlamaTokenizer", - "num_epochs": 1, - "micro_batch_size": 1, - "gradient_accumulation_steps": 1, - } - ) - - def test_lr_as_float(self): - cfg = ( - self._get_base_cfg() - | DictDefault( # pylint: disable=unsupported-binary-operation - { - "learning_rate": "5e-5", - } - ) - ) - - normalize_config(cfg) - - assert cfg.learning_rate == 0.00005 - - def test_base_model_config_set_when_empty(self): - cfg = self._get_base_cfg() - del cfg.base_model_config - normalize_config(cfg) - - assert cfg.base_model_config == cfg.base_model diff --git a/axolotl/tests/test_packed_dataset.py b/axolotl/tests/test_packed_dataset.py deleted file mode 100644 index da8fb7a937c37b77ac4acf034397cf49ef4bb0dc..0000000000000000000000000000000000000000 --- a/axolotl/tests/test_packed_dataset.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Module for testing dataset sequence packing""" - -import unittest -from pathlib import Path - -from datasets import Dataset, load_dataset -from transformers import AutoTokenizer - -from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset -from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy -from axolotl.prompters import AlpacaPrompter - - -class TestPacking(unittest.TestCase): - """ - Test class for packing dataset sequences - """ - - def setUp(self) -> None: - # pylint: disable=duplicate-code - self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") - self.tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - } - ) - - def test_increments_attention(self): - prompter = AlpacaPrompter("chat") - strat = AlpacaPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - dateset = load_dataset( - "json", - data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"), - )["train"] - dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset))) - - constant_len_dataset = ConstantLengthDataset( - self.tokenizer, - [dataset], - seq_length=2048, - ) - packed_dataset = Dataset.from_list(list(constant_len_dataset)) - example = packed_dataset[0] - next_bos_index = ( - example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1 - ) # add one since we sliced - - # first example doesn't have mask reset - assert example["input_ids"][0] == self.tokenizer.bos_token_id - assert example["attention_mask"][0] == 1 - assert example["position_ids"][0] == 0 - assert example["position_ids"][1] == 1 - - # but subsequent one does - assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id - assert example["attention_mask"][next_bos_index] == 2 - assert example["position_ids"][next_bos_index] == 0 - assert example["position_ids"][next_bos_index + 1] == 1 - - -if __name__ == "__main__": - unittest.main() diff --git a/axolotl/tests/test_prompt_tokenizers.py b/axolotl/tests/test_prompt_tokenizers.py deleted file mode 100644 index cea39d0adf9e6efe8267c79af410a3738e676c23..0000000000000000000000000000000000000000 --- a/axolotl/tests/test_prompt_tokenizers.py +++ /dev/null @@ -1,431 +0,0 @@ -"""Module for testing prompt tokenizers.""" -import json -import logging -import unittest -from copy import deepcopy -from pathlib import Path -from typing import Optional - -import pytest -from transformers import AutoTokenizer, LlamaTokenizer - -from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter -from axolotl.prompt_strategies.alpaca_w_system import ( - InstructionWSystemPromptTokenizingStrategy, - SystemDataPrompter, -) -from axolotl.prompt_strategies.llama2_chat import ( - Llama2ChatPrompter, - LLama2ChatTokenizingStrategy, -) -from axolotl.prompt_tokenizers import ( - AlpacaPromptTokenizingStrategy, - ShareGPTPromptTokenizingStrategy, -) -from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2 - -LOG = logging.getLogger("axolotl") - -test_data = { - "multi_turn_sys": { - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "123"}, - {"from": "gpt", "value": "sit"}, - ] - }, - "single_turn_sys": { - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - ] - }, - "single_turn_no_sys": { - "conversations": [ - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - ] - }, - "multi_turn_no_sys": { - "conversations": [ - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "123"}, - {"from": "gpt", "value": "sit"}, - ] - }, -} - - -def prompt_strat(conversation, tokenizer): - "Helper function to create a prompt strategy for testing." - prompter = ShareGPTPrompterV2(conversation=conversation) - return ShareGPTPromptTokenizingStrategy( - prompter, - tokenizer, - False, - 2048, - ) - - -class TestPromptTokenizationStrategies(unittest.TestCase): - """ - Test class for prompt tokenization strategies. - """ - - _caplog: Optional[pytest.LogCaptureFixture] = None - - @pytest.fixture(autouse=True) - def inject_fixtures(self, caplog): - self._caplog = caplog - - def setUp(self) -> None: - # pylint: disable=duplicate-code - self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") - self.tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - } - ) - - def test_sharegpt_integration(self): - with open( - Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" - ) as fin: - data = fin.read() - conversation = json.loads(data) - with open( - Path(__file__).parent / "fixtures/conversation.tokenized.json", - encoding="utf-8", - ) as fin: - data = fin.read() - tokenized_conversation = json.loads(data) - prompter = ShareGPTPrompterV2() - strat = ShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - example = strat.tokenize_prompt(conversation) - for fields in ["input_ids", "attention_mask", "labels"]: - self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) - self.assertEqual(example[fields], tokenized_conversation[fields]) - - def test_sharegpt_warnings_integration(self): - with open( - Path(__file__).parent / "fixtures/conversation.missingturns.json", - encoding="utf-8", - ) as fin: - data = fin.read() - conversation = json.loads(data) - prompter = ShareGPTPrompterV2() - strat = ShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - with self._caplog.at_level(logging.WARNING): - strat.tokenize_prompt(conversation) - assert "assistant turn has empty text" in self._caplog.records[1].message - - def test_sharegpt_warnings_turns(self): - conversation = { - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "dolor"}, - {"from": "human", "value": "dolor"}, - {"from": "gpt", "value": "sit"}, - ] - } - prompter = ShareGPTPrompterV2() - strat = ShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - with self._caplog.at_level(logging.WARNING): - strat.tokenize_prompt(conversation) - assert ( - "Role did not alternate between turns (gpt and human)" - in self._caplog.records[0].message - ) - - def test_sharegpt_llama(self): - "Make sure the sharegpt/llama is tokenized and formatted correctly." - strat = prompt_strat("llama-2", self.tokenizer) - - def tokenize(conv): - return strat.tokenize_prompt(deepcopy(conv))["input_ids"] - - def decode(ids): - return strat.tokenizer.decode(ids) - - # fmt: off - # System message, multi-turn conversations - mt_ids = tokenize(test_data['multi_turn_sys']) - assert decode(mt_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum [INST] 123 [/INST] sit' - assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] - - # System message, single-turn conversations - st_ids = tokenize(test_data['single_turn_sys']) - assert decode(st_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum' - assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] - - # No system message, single-turn - ns_ids = tokenize(test_data['single_turn_no_sys']) - assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' - assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] - - # No system message, multi-turn - ns_mt_ids = tokenize(test_data['multi_turn_no_sys']) - assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit' - assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] - # fmt: on - - def test_sharegpt_mistral(self): - "Make sure the sharegpt/mistral is tokenized and formatted correctly." - strat = prompt_strat("mistral", self.tokenizer) - - def tokenize(conv): - return strat.tokenize_prompt(deepcopy(conv))["input_ids"] - - def decode(ids): - return strat.tokenizer.decode(ids) - - # fmt: off - # System message, multi-turn conversations - mt_ids = tokenize(test_data['multi_turn_sys']) - assert decode(mt_ids) == ' [INST] lorem\nabc [/INST] ipsum [INST] 123 [/INST] sit' - assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] - - # System message, single-turn conversations - st_ids = tokenize(test_data['single_turn_sys']) - assert decode(st_ids) == ' [INST] lorem\nabc [/INST] ipsum' - assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] - - # No system message, single-turn - ns_ids = tokenize(test_data['single_turn_no_sys']) - assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' - assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] - - # No system message, multi-turn - ns_mt_ids = tokenize(test_data['multi_turn_no_sys']) - assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit
' - assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] - # fmt: on - - def test_sharegpt_changes_roles(self): - conversation = { - "roles": ["USER", "CHARACTER"], - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "dolor"}, - {"from": "gpt", "value": "sit"}, - ], - } - prompter = ShareGPTPrompterV2() - strat = ShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - with self._caplog.at_level(logging.WARNING): - res = strat.tokenize_prompt(conversation) - assert "CHARACTER" in self.tokenizer.decode(res["input_ids"]) - - def test_sharegpt_assistant_label_ignore(self): - conversation = { - "roles": ["user", "assistant"], - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "dolor"}, - {"from": "gpt", "value": "sit"}, - ], - } - prompter = ShareGPTPrompterV2() - strat = ShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - with self._caplog.at_level(logging.WARNING): - res = strat.tokenize_prompt(conversation) - idx = res["input_ids"].index(20255) # assistant token - assert res["labels"][idx] == -100 - - def test_no_sys_prompt(self): - """ - tests the interface between the user and assistant parts - """ - prompter = NoSystemPrompter() - # pylint: disable=duplicate-code - strat = AlpacaPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - sample = { - "instruction": "hello cruel. lorem ipsum dolor sit amet.", - "output": "world!", - } - example = strat.tokenize_prompt(sample) - world_idx = example["input_ids"].index(3186) - assert example["labels"][world_idx] == 3186 - assert example["labels"][world_idx - 1] == -100 - - def test_alpaca(self): - """ - tests the interface between the user and assistant parts - """ - # pylint: disable=duplicate-code - prompter = AlpacaPrompter() - strat = AlpacaPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - sample = {"instruction": "hello!", "output": "Hi! How can I help?"} - example = strat.tokenize_prompt(sample) - world_idx = example["input_ids"].index(6324) - assert example["labels"][world_idx] == 6324 - assert example["labels"][world_idx - 1] == -100 - - -class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase): - """ - Test class for prompt tokenization strategies with sys prompt from the dataset - """ - - def setUp(self) -> None: - # pylint: disable=duplicate-code - self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") - self.tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - } - ) - - def test_system_alpaca(self): - prompter = SystemDataPrompter(PromptStyle.CHAT.value) - strat = InstructionWSystemPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - sample = { - "system": "use cot", - "instruction": "hello!", - "output": "Hi! How can I help?", - } - example = strat.tokenize_prompt(sample) - assert example["input_ids"][0:5] == [ - 1, - 28962, - 1254, - 12665, - 29901, - ] # "SYSTEM:" - assert example["input_ids"][5:7] == [671, 20118] # " use cot" - assert example["input_ids"][8] == 11889 # USER - - -class Llama2ChatTokenizationTest(unittest.TestCase): - """ - Test class for prompt tokenization strategies with sys prompt from the dataset - """ - - def setUp(self) -> None: - # pylint: disable=duplicate-code - self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf") - # woraround because official Meta repos are not open - - def test_llama2_chat_integration(self): - with open( - Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" - ) as fin: - data = fin.read() - conversation = json.loads(data) - with open( - Path(__file__).parent / "fixtures/conversation.tokenized_llama2chat.json", - encoding="utf-8", - ) as fin: - data = fin.read() - tokenized_conversation = json.loads(data) - prompter = Llama2ChatPrompter() - strat = LLama2ChatTokenizingStrategy( - prompter, - self.tokenizer, - False, - 4096, - ) - example = strat.tokenize_prompt(conversation) - for fields in ["input_ids", "attention_mask", "labels"]: - self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) - self.assertEqual(example[fields], tokenized_conversation[fields]) - - def compare_with_transformers_integration(self): - # this needs transformers >= v4.31.0 - from transformers.models.llama.tokenization_llama import B_SYS, E_SYS - from transformers.pipelines.conversational import Conversation - - # from transformers.models.llama.tokenization_llama import DEFAULT_SYSTEM_PROMPT - # broken as of 23/7/20 - # see https://github.com/huggingface/transformers/pull/24935 - # pylint: disable=C0103 - DEFAULT_SYSTEM_PROMPT = """\ -You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. - -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" - with open( - Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" - ) as fin: - data = fin.read() - conversation = json.loads(data) - with open( - Path(__file__).parent / "fixtures/conversation.tokenized_llama2chat.json", - encoding="utf-8", - ) as fin: - data = fin.read() - tokenized_conversation = json.loads(data) - - user_input = [] - answers = [] - for msg in conversation["conversations"]: - if msg["from"] == "human": - user_input.append(msg["value"]) - else: - answers.append(msg["value"]) - hf_conf = Conversation( - text=user_input[-1], - past_user_inputs=[B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + user_input[0]] - + user_input[1:-1], - generated_responses=answers, - ) - # pylint: disable=W0212 - hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf) - - self.assertEqual( - hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)] - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/axolotl/tests/test_prompters.py b/axolotl/tests/test_prompters.py deleted file mode 100644 index 6c5b8f27c2e9c0582adbe152bfca1e4b581f35b3..0000000000000000000000000000000000000000 --- a/axolotl/tests/test_prompters.py +++ /dev/null @@ -1,124 +0,0 @@ -"""Module testing prompters""" - -import unittest - -from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter -from axolotl.prompters import ( - AlpacaPrompter, - MultipleChoiceExplainPrompter, - PromptStyle, - UnpromptedPrompter, -) - - -class AlpacaPrompterTest(unittest.TestCase): - """ - Test AlpacaPrompter - """ - - def test_prompt_style_w_none(self): - prompter = AlpacaPrompter(prompt_style=None) - res = next(prompter.build_prompt("tell me a joke")) - # just testing that it uses instruct style - assert "### Instruction:" in res - - def test_prompt_style_w_instruct(self): - prompter = AlpacaPrompter(prompt_style=PromptStyle.INSTRUCT.value) - res = next( - prompter.build_prompt("tell me a joke about the following", "alpacas") - ) - assert "Below is an instruction" in res - assert "### Instruction:" in res - assert "### Input:" in res - assert "alpacas" in res - assert "### Response:" in res - assert "USER:" not in res - assert "ASSISTANT:" not in res - res = next(prompter.build_prompt("tell me a joke about the following")) - assert "Below is an instruction" in res - assert "### Instruction:" in res - assert "### Input:" not in res - assert "### Response:" in res - assert "USER:" not in res - assert "ASSISTANT:" not in res - - def test_prompt_style_w_chat(self): - prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value) - res = next( - prompter.build_prompt("tell me a joke about the following", "alpacas") - ) - assert "Below is an instruction" in res - assert "### Instruction:" not in res - assert "### Input:" not in res - assert "alpacas" in res - assert "### Response:" not in res - assert "USER:" in res - assert "ASSISTANT:" in res - res = next(prompter.build_prompt("tell me a joke about the following")) - assert "Below is an instruction" in res - assert "### Instruction:" not in res - assert "### Input:" not in res - assert "### Response:" not in res - assert "USER:" in res - assert "ASSISTANT:" in res - - def test_system_prompt(self): - prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value) - res = next( - prompter.build_prompt_w_system( - "use cot", "tell me a joke about the following", "alpacas" - ) - ) - assert "use cot" in res - assert res.startswith("SYSTEM:") - assert "### Instruction:" not in res - assert "### Input:" not in res - assert "alpacas" in res - assert "### Response:" not in res - assert "USER:" in res - assert "ASSISTANT:" in res - - -class UnpromptedPrompterTest(unittest.TestCase): - """ - Test class for UnpromptedPrompter with no system prompts - """ - - def test_prompt_style_w_none(self): - prompter = UnpromptedPrompter(prompt_style=None) - res = next(prompter.build_prompt("tell me a joke")) - assert "### Instruction:" in res - assert "tell me a joke" in res - assert res.startswith("###") - - def test_prompt_style_w_instruct(self): - prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value) - res = next( - prompter.build_prompt("tell me a joke about the following", "alpacas") - ) - assert "### Instruction:" in res - assert "tell me a joke" in res - assert res.startswith("###") - - def test_prompt_style_w_chat(self): - prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value) - res = next( - prompter.build_prompt("tell me a joke about the following", "alpacas") - ) - assert "USER:" in res - assert "tell me a joke" in res - assert res.startswith("USER:") - - -class MultipleChoiceExplainPrompterTest(unittest.TestCase): - """ - Test class for MultipleChoiceExplainPrompter - """ - - def test_prompt_style_w_chat(self): - prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value) - res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C")) - assert "USER:" in res - assert "choose one" in res - assert "Choose the answer that best answers the question." in res - assert "- A\n- B\n- C" in res diff --git a/axolotl/tests/test_tokenizers.py b/axolotl/tests/test_tokenizers.py deleted file mode 100644 index bfe4f06af9d3bba3abeba0eccb769cba3874a740..0000000000000000000000000000000000000000 --- a/axolotl/tests/test_tokenizers.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -Test cases for the tokenizer loading -""" -import unittest - -import pytest - -from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_tokenizer - - -class TestTokenizers(unittest.TestCase): - """ - test class for the load_tokenizer fn - """ - - def test_default_use_fast(self): - cfg = DictDefault( - { - "tokenizer_config": "huggyllama/llama-7b", - } - ) - tokenizer = load_tokenizer(cfg) - assert "Fast" in tokenizer.__class__.__name__ - - def test_dont_use_fast(self): - cfg = DictDefault( - { - "tokenizer_config": "huggyllama/llama-7b", - "tokenizer_use_fast": False, - } - ) - tokenizer = load_tokenizer(cfg) - assert "Fast" not in tokenizer.__class__.__name__ - - def test_special_tokens_modules_to_save(self): - # setting special_tokens to new token - cfg = DictDefault( - { - "tokenizer_config": "huggyllama/llama-7b", - "adapter": "lora", - "special_tokens": {"bos_token": "[INST]"}, - } - ) - with pytest.raises( - ValueError, - match=r".*Please set lora_modules_to_save*", - ): - load_tokenizer(cfg) - - # setting special_tokens but not changing from default - cfg = DictDefault( - { - "tokenizer_config": "huggyllama/llama-7b", - "adapter": "lora", - "special_tokens": {"bos_token": ""}, - } - ) - load_tokenizer(cfg) - - # non-adapter setting special_tokens - cfg = DictDefault( - { - "tokenizer_config": "huggyllama/llama-7b", - "special_tokens": {"bos_token": "[INST]"}, - } - ) - load_tokenizer(cfg) - - -if __name__ == "__main__": - unittest.main() diff --git a/axolotl/tests/test_validation.py b/axolotl/tests/test_validation.py deleted file mode 100644 index 12997b023b5c83eacbbb5d62c97398e7ca5392a8..0000000000000000000000000000000000000000 --- a/axolotl/tests/test_validation.py +++ /dev/null @@ -1,800 +0,0 @@ -"""Module for testing the validation module""" - -import logging -import os -import unittest -from typing import Optional - -import pytest - -from axolotl.utils.config import validate_config -from axolotl.utils.dict import DictDefault -from axolotl.utils.wandb_ import setup_wandb_env_vars - - -class ValidationTest(unittest.TestCase): - """ - Test the validation module - """ - - _caplog: Optional[pytest.LogCaptureFixture] = None - - @pytest.fixture(autouse=True) - def inject_fixtures(self, caplog): - self._caplog = caplog - - def test_load_4bit_deprecate(self): - cfg = DictDefault( - { - "load_4bit": True, - } - ) - - with pytest.raises(ValueError): - validate_config(cfg) - - def test_batch_size_unused_warning(self): - cfg = DictDefault( - { - "batch_size": 32, - } - ) - - with self._caplog.at_level(logging.WARNING): - validate_config(cfg) - assert "batch_size is not recommended" in self._caplog.records[0].message - - def test_qlora(self): - base_cfg = DictDefault( - { - "adapter": "qlora", - } - ) - - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "load_in_8bit": True, - } - ) - - with pytest.raises(ValueError, match=r".*8bit.*"): - validate_config(cfg) - - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "gptq": True, - } - ) - - with pytest.raises(ValueError, match=r".*gptq.*"): - validate_config(cfg) - - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "load_in_4bit": False, - } - ) - - with pytest.raises(ValueError, match=r".*4bit.*"): - validate_config(cfg) - - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "load_in_4bit": True, - } - ) - - validate_config(cfg) - - def test_qlora_merge(self): - base_cfg = DictDefault( - { - "adapter": "qlora", - "merge_lora": True, - } - ) - - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "load_in_8bit": True, - } - ) - - with pytest.raises(ValueError, match=r".*8bit.*"): - validate_config(cfg) - - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "gptq": True, - } - ) - - with pytest.raises(ValueError, match=r".*gptq.*"): - validate_config(cfg) - - cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation - { - "load_in_4bit": True, - } - ) - - with pytest.raises(ValueError, match=r".*4bit.*"): - validate_config(cfg) - - def test_hf_use_auth_token(self): - cfg = DictDefault( - { - "push_dataset_to_hub": "namespace/repo", - } - ) - - with pytest.raises(ValueError, match=r".*hf_use_auth_token.*"): - validate_config(cfg) - - cfg = DictDefault( - { - "push_dataset_to_hub": "namespace/repo", - "hf_use_auth_token": True, - } - ) - validate_config(cfg) - - def test_gradient_accumulations_or_batch_size(self): - cfg = DictDefault( - { - "gradient_accumulation_steps": 1, - "batch_size": 1, - } - ) - - with pytest.raises( - ValueError, match=r".*gradient_accumulation_steps or batch_size.*" - ): - validate_config(cfg) - - cfg = DictDefault( - { - "batch_size": 1, - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "gradient_accumulation_steps": 1, - } - ) - - validate_config(cfg) - - def test_falcon_fsdp(self): - regex_exp = r".*FSDP is not supported for falcon models.*" - - # Check for lower-case - cfg = DictDefault( - { - "base_model": "tiiuae/falcon-7b", - "fsdp": ["full_shard", "auto_wrap"], - } - ) - - with pytest.raises(ValueError, match=regex_exp): - validate_config(cfg) - - # Check for upper-case - cfg = DictDefault( - { - "base_model": "Falcon-7b", - "fsdp": ["full_shard", "auto_wrap"], - } - ) - - with pytest.raises(ValueError, match=regex_exp): - validate_config(cfg) - - cfg = DictDefault( - { - "base_model": "tiiuae/falcon-7b", - } - ) - - validate_config(cfg) - - def test_mpt_gradient_checkpointing(self): - regex_exp = r".*gradient_checkpointing is not supported for MPT models*" - - # Check for lower-case - cfg = DictDefault( - { - "base_model": "mosaicml/mpt-7b", - "gradient_checkpointing": True, - } - ) - - with pytest.raises(ValueError, match=regex_exp): - validate_config(cfg) - - def test_flash_optimum(self): - cfg = DictDefault( - { - "flash_optimum": True, - "adapter": "lora", - } - ) - - with self._caplog.at_level(logging.WARNING): - validate_config(cfg) - assert any( - "BetterTransformers probably doesn't work with PEFT adapters" - in record.message - for record in self._caplog.records - ) - - cfg = DictDefault( - { - "flash_optimum": True, - } - ) - - with self._caplog.at_level(logging.WARNING): - validate_config(cfg) - assert any( - "probably set bfloat16 or float16" in record.message - for record in self._caplog.records - ) - - cfg = DictDefault( - { - "flash_optimum": True, - "fp16": True, - } - ) - regex_exp = r".*AMP is not supported.*" - - with pytest.raises(ValueError, match=regex_exp): - validate_config(cfg) - - cfg = DictDefault( - { - "flash_optimum": True, - "bf16": True, - } - ) - regex_exp = r".*AMP is not supported.*" - - with pytest.raises(ValueError, match=regex_exp): - validate_config(cfg) - - def test_adamw_hyperparams(self): - cfg = DictDefault( - { - "optimizer": None, - "adam_epsilon": 0.0001, - } - ) - - with self._caplog.at_level(logging.WARNING): - validate_config(cfg) - assert any( - "adamw hyperparameters found, but no adamw optimizer set" - in record.message - for record in self._caplog.records - ) - - cfg = DictDefault( - { - "optimizer": "adafactor", - "adam_beta1": 0.0001, - } - ) - - with self._caplog.at_level(logging.WARNING): - validate_config(cfg) - assert any( - "adamw hyperparameters found, but no adamw optimizer set" - in record.message - for record in self._caplog.records - ) - - cfg = DictDefault( - { - "optimizer": "adamw_bnb_8bit", - "adam_beta1": 0.9, - "adam_beta2": 0.99, - "adam_epsilon": 0.0001, - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "optimizer": "adafactor", - } - ) - - validate_config(cfg) - - def test_packing(self): - cfg = DictDefault( - { - "max_packed_sequence_len": 2048, - } - ) - with self._caplog.at_level(logging.WARNING): - validate_config(cfg) - assert any( - "max_packed_sequence_len will be deprecated in favor of sample_packing" - in record.message - for record in self._caplog.records - ) - - cfg = DictDefault( - { - "sample_packing": True, - "pad_to_sequence_len": None, - } - ) - with self._caplog.at_level(logging.WARNING): - validate_config(cfg) - assert any( - "`pad_to_sequence_len: true` is recommended when using sample_packing" - in record.message - for record in self._caplog.records - ) - - cfg = DictDefault( - { - "max_packed_sequence_len": 2048, - "sample_packing": True, - } - ) - regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*" - with pytest.raises(ValueError, match=regex_exp): - validate_config(cfg) - - def test_merge_lora_no_bf16_fail(self): - """ - This is assumed to be run on a CPU machine, so bf16 is not supported. - """ - - cfg = DictDefault( - { - "bf16": True, - } - ) - - with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"): - validate_config(cfg) - - cfg = DictDefault( - { - "bf16": True, - "merge_lora": True, - } - ) - - validate_config(cfg) - - def test_sharegpt_deprecation(self): - cfg = DictDefault( - {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]} - ) - with self._caplog.at_level(logging.WARNING): - validate_config(cfg) - assert any( - "`type: sharegpt:chat` will soon be deprecated." in record.message - for record in self._caplog.records - ) - assert cfg.datasets[0].type == "sharegpt" - - cfg = DictDefault( - {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]} - ) - with self._caplog.at_level(logging.WARNING): - validate_config(cfg) - assert any( - "`type: sharegpt_simple` will soon be deprecated." in record.message - for record in self._caplog.records - ) - assert cfg.datasets[0].type == "sharegpt:load_role" - - def test_no_conflict_save_strategy(self): - cfg = DictDefault( - { - "save_strategy": "epoch", - "save_steps": 10, - } - ) - - with pytest.raises( - ValueError, match=r".*save_strategy and save_steps mismatch.*" - ): - validate_config(cfg) - - cfg = DictDefault( - { - "save_strategy": "no", - "save_steps": 10, - } - ) - - with pytest.raises( - ValueError, match=r".*save_strategy and save_steps mismatch.*" - ): - validate_config(cfg) - - cfg = DictDefault( - { - "save_strategy": "steps", - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "save_strategy": "steps", - "save_steps": 10, - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "save_steps": 10, - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "save_strategy": "no", - } - ) - - validate_config(cfg) - - def test_no_conflict_eval_strategy(self): - cfg = DictDefault( - { - "evaluation_strategy": "epoch", - "eval_steps": 10, - } - ) - - with pytest.raises( - ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*" - ): - validate_config(cfg) - - cfg = DictDefault( - { - "evaluation_strategy": "no", - "eval_steps": 10, - } - ) - - with pytest.raises( - ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*" - ): - validate_config(cfg) - - cfg = DictDefault( - { - "evaluation_strategy": "steps", - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "evaluation_strategy": "steps", - "eval_steps": 10, - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "eval_steps": 10, - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "evaluation_strategy": "no", - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "evaluation_strategy": "epoch", - "val_set_size": 0, - } - ) - - with pytest.raises( - ValueError, - match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*", - ): - validate_config(cfg) - - cfg = DictDefault( - { - "eval_steps": 10, - "val_set_size": 0, - } - ) - - with pytest.raises( - ValueError, - match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*", - ): - validate_config(cfg) - - cfg = DictDefault( - { - "val_set_size": 0, - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "eval_steps": 10, - "val_set_size": 0.01, - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "evaluation_strategy": "epoch", - "val_set_size": 0.01, - } - ) - - validate_config(cfg) - - def test_eval_table_size_conflict_eval_packing(self): - cfg = DictDefault( - { - "sample_packing": True, - "eval_table_size": 100, - } - ) - - with pytest.raises( - ValueError, match=r".*Please set 'eval_sample_packing' to false.*" - ): - validate_config(cfg) - - cfg = DictDefault( - { - "sample_packing": True, - "eval_sample_packing": False, - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "sample_packing": False, - "eval_table_size": 100, - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "sample_packing": True, - "eval_table_size": 100, - "eval_sample_packing": False, - } - ) - - validate_config(cfg) - - def test_load_in_x_bit_without_adapter(self): - cfg = DictDefault( - { - "load_in_4bit": True, - } - ) - - with pytest.raises( - ValueError, - match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*", - ): - validate_config(cfg) - - cfg = DictDefault( - { - "load_in_8bit": True, - } - ) - - with pytest.raises( - ValueError, - match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*", - ): - validate_config(cfg) - - cfg = DictDefault( - { - "load_in_4bit": True, - "adapter": "qlora", - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "load_in_8bit": True, - "adapter": "lora", - } - ) - - validate_config(cfg) - - def test_warmup_step_no_conflict(self): - cfg = DictDefault( - { - "warmup_steps": 10, - "warmup_ratio": 0.1, - } - ) - - with pytest.raises( - ValueError, - match=r".*warmup_steps and warmup_ratio are mutually exclusive*", - ): - validate_config(cfg) - - cfg = DictDefault( - { - "warmup_steps": 10, - } - ) - - validate_config(cfg) - - cfg = DictDefault( - { - "warmup_ratio": 0.1, - } - ) - - validate_config(cfg) - - def test_add_tokens_adapter(self): - cfg = DictDefault( - {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} - ) - - with pytest.raises( - ValueError, - match=r".*lora_modules_to_save not properly set yet adding new tokens*", - ): - validate_config(cfg) - - cfg = DictDefault( - { - "adapter": "qlora", - "load_in_4bit": True, - "tokens": ["<|imstart|>"], - "lora_modules_to_save": ["embed_tokens"], - } - ) - - with pytest.raises( - ValueError, - match=r".*lora_modules_to_save not properly set yet adding new tokens*", - ): - validate_config(cfg) - - cfg = DictDefault( - { - "adapter": "qlora", - "load_in_4bit": True, - "tokens": ["<|imstart|>"], - "lora_modules_to_save": ["embed_tokens", "lm_head"], - } - ) - - validate_config(cfg) - - -class ValidationWandbTest(ValidationTest): - """ - Validation test for wandb - """ - - def test_wandb_set_run_id_to_name(self): - cfg = DictDefault( - { - "wandb_run_id": "foo", - } - ) - - with self._caplog.at_level(logging.WARNING): - validate_config(cfg) - assert any( - "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." - in record.message - for record in self._caplog.records - ) - - assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo" - - cfg = DictDefault( - { - "wandb_name": "foo", - } - ) - - validate_config(cfg) - - assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None - - def test_wandb_sets_env(self): - cfg = DictDefault( - { - "wandb_project": "foo", - "wandb_name": "bar", - "wandb_run_id": "bat", - "wandb_entity": "baz", - "wandb_mode": "online", - "wandb_watch": "false", - "wandb_log_model": "checkpoint", - } - ) - - validate_config(cfg) - - setup_wandb_env_vars(cfg) - - assert os.environ.get("WANDB_PROJECT", "") == "foo" - assert os.environ.get("WANDB_NAME", "") == "bar" - assert os.environ.get("WANDB_RUN_ID", "") == "bat" - assert os.environ.get("WANDB_ENTITY", "") == "baz" - assert os.environ.get("WANDB_MODE", "") == "online" - assert os.environ.get("WANDB_WATCH", "") == "false" - assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint" - assert os.environ.get("WANDB_DISABLED", "") != "true" - - def test_wandb_set_disabled(self): - cfg = DictDefault({}) - - validate_config(cfg) - - setup_wandb_env_vars(cfg) - - assert os.environ.get("WANDB_DISABLED", "") == "true" - - cfg = DictDefault( - { - "wandb_project": "foo", - } - ) - - validate_config(cfg) - - setup_wandb_env_vars(cfg) - - assert os.environ.get("WANDB_DISABLED", "") != "true"