diff --git a/Prism/Dream/Dream_Baseline/LICENSE b/Prism/Dream/Dream_Baseline/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Prism/Dream/Dream_Prism/LICENSE b/Prism/Dream/Dream_Prism/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/Prism/Dream/Dream_Prism/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Prism/Dream/Dream_Prism/eval_instruct/.gitignore b/Prism/Dream/Dream_Prism/eval_instruct/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..18b56efe2b0943aea0e6dc856f57f3a18f9eceb2 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/.gitignore @@ -0,0 +1,26 @@ +env +*.pyc +output/ +output5/ +data/ +lm_cache +.idea +build +dist +*.egg-info +venv +.venv/ +.vscode/ +temp +__pycache__ +.ipynb_checkpoints +temp +test_logs/ +# IPython +profile_default/ +ipython_config.py +# don't track (the default location of) the cached requests +lm_eval/caching/.cache +# don't track files created by wandb +wandb +examples/wandb diff --git a/Prism/Dream/Dream_Prism/eval_instruct/README.md b/Prism/Dream/Dream_Prism/eval_instruct/README.md new file mode 100644 index 0000000000000000000000000000000000000000..aebe1795bdb18fb0cbea231d2262fcb91fe9304c --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/README.md @@ -0,0 +1,16 @@ +# Dream-Instruct Evaluation Toolkit +This toolkit contains the code Dream-Instruct models make use of for evaluation. + +## Quickstart +To install the toolkit, run: +``` +pip install -e ".[ifeval,math]" +``` + +We provide a script to evaluate [Dream-org/Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B): +``` +bash eval.sh +``` + +## Acknowledgement +This is a fork of [EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main). diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__init__.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fece9162482d018c2969a9e67603096d0ad21713 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__init__.py @@ -0,0 +1,7 @@ +import logging +import os + +from .evaluator import evaluate, simple_evaluate + + +__version__ = "0.4.8" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__main__.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..dda41ebe43d39522c4961fc7bfdf2d730e2eec26 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/__main__.py @@ -0,0 +1,512 @@ +import argparse +import json +import logging +import os +import sys +from functools import partial +from typing import Union + +from lm_eval import evaluator, utils +from lm_eval.evaluator import request_caching_arg_to_dict +from lm_eval.loggers import EvaluationTracker, WandbLogger +from lm_eval.tasks import TaskManager +from lm_eval.utils import ( + handle_non_serializable, + make_table, + simple_parse_args_string, +) + + +def try_parse_json(value: str) -> Union[str, dict, None]: + if value is None: + return None + try: + return json.loads(value) + except json.JSONDecodeError: + if "{" in value: + raise argparse.ArgumentTypeError( + f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings." + ) + return value + + +def _int_or_none_list_arg_type( + min_len: int, max_len: int, defaults: str, value: str, split_char: str = "," +): + def parse_value(item): + item = item.strip().lower() + if item == "none": + return None + try: + return int(item) + except ValueError: + raise argparse.ArgumentTypeError(f"{item} is not an integer or None") + + items = [parse_value(v) for v in value.split(split_char)] + num_items = len(items) + + if num_items == 1: + # Makes downstream handling the same for single and multiple values + items = items * max_len + elif num_items < min_len or num_items > max_len: + raise argparse.ArgumentTypeError( + f"Argument requires {max_len} integers or None, separated by '{split_char}'" + ) + elif num_items != max_len: + logging.warning( + f"Argument requires {max_len} integers or None, separated by '{split_char}'. " + "Missing values will be filled with defaults." + ) + default_items = [parse_value(v) for v in defaults.split(split_char)] + items.extend( + default_items[num_items:] + ) # extend items list with missing defaults + + return items + + +def check_argument_types(parser: argparse.ArgumentParser): + """ + Check to make sure all CLI args are typed, raises error if not + """ + for action in parser._actions: + if action.dest != "help" and not action.const: + if action.type is None: + raise ValueError( + f"Argument '{action.dest}' doesn't have a type specified." + ) + else: + continue + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument( + "--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`" + ) + parser.add_argument( + "--tasks", + "-t", + default=None, + type=str, + metavar="task1,task2", + help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above", + ) + parser.add_argument( + "--model_args", + "-a", + default="", + type=try_parse_json, + help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""", + ) + parser.add_argument( + "--num_fewshot", + "-f", + type=int, + default=None, + metavar="N", + help="Number of examples in few-shot context", + ) + parser.add_argument( + "--batch_size", + "-b", + type=str, + default=1, + metavar="auto|auto:N|N", + help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.", + ) + parser.add_argument( + "--max_batch_size", + type=int, + default=None, + metavar="N", + help="Maximal batch size to try with --batch_size auto.", + ) + parser.add_argument( + "--device", + type=str, + default=None, + help="Device to use (e.g. cuda, cuda:0, cpu).", + ) + parser.add_argument( + "--output_path", + "-o", + default=None, + type=str, + metavar="DIR|DIR/file.json", + help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.", + ) + parser.add_argument( + "--limit", + "-L", + type=float, + default=None, + metavar="N|0 argparse.Namespace: + check_argument_types(parser) + return parser.parse_args() + + +def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: + if not args: + # we allow for args to be passed externally, else we parse them ourselves + parser = setup_parser() + args = parse_eval_args(parser) + + if args.wandb_args: + wandb_args_dict = simple_parse_args_string(args.wandb_args) + wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args) + wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict) + + utils.setup_logging(args.verbosity) + eval_logger = logging.getLogger(__name__) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # update the evaluation tracker args with the output path and the HF token + if args.output_path: + args.hf_hub_log_args += f",output_path={args.output_path}" + if os.environ.get("HF_TOKEN", None): + args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}" + evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args) + evaluation_tracker = EvaluationTracker(**evaluation_tracker_args) + + if args.predict_only: + args.log_samples = True + if (args.log_samples or args.predict_only) and not args.output_path: + raise ValueError( + "Specify --output_path if providing --log_samples or --predict_only" + ) + + if args.fewshot_as_multiturn and args.apply_chat_template is False: + raise ValueError( + "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)." + ) + + if args.include_path is not None: + eval_logger.info(f"Including path: {args.include_path}") + metadata = ( + simple_parse_args_string(args.model_args) + if isinstance(args.model_args, str) + else args.model_args + if isinstance(args.model_args, dict) + else {} + ) | ( + args.metadata + if isinstance(args.metadata, dict) + else simple_parse_args_string(args.metadata) + ) + + task_manager = TaskManager(include_path=args.include_path, metadata=metadata) + + if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples: + eval_logger.warning( + "Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub." + ) + + if args.limit: + eval_logger.warning( + " --limit SHOULD ONLY BE USED FOR TESTING." + "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." + ) + + if args.tasks is None: + eval_logger.error("Need to specify task to evaluate.") + sys.exit() + elif args.tasks == "list": + print(task_manager.list_all_tasks()) + sys.exit() + elif args.tasks == "list_groups": + print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False)) + sys.exit() + elif args.tasks == "list_tags": + print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False)) + sys.exit() + elif args.tasks == "list_subtasks": + print(task_manager.list_all_tasks(list_groups=False, list_tags=False)) + sys.exit() + else: + if os.path.isdir(args.tasks): + import glob + + task_names = [] + yaml_path = os.path.join(args.tasks, "*.yaml") + for yaml_file in glob.glob(yaml_path): + config = utils.load_yaml_config(yaml_file) + task_names.append(config) + else: + task_list = args.tasks.split(",") + task_names = task_manager.match_tasks(task_list) + for task in [task for task in task_list if task not in task_names]: + if os.path.isfile(task): + config = utils.load_yaml_config(task) + task_names.append(config) + task_missing = [ + task for task in task_list if task not in task_names and "*" not in task + ] # we don't want errors if a wildcard ("*") task name was used + + if task_missing: + missing = ", ".join(task_missing) + eval_logger.error( + f"Tasks were not found: {missing}\n" + f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks", + ) + raise ValueError( + f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues." + ) + + # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args + if args.trust_remote_code: + eval_logger.info( + "Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`" + ) + # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally, + # because it's already been determined based on the prior env var before launching our + # script--`datasets` gets imported by lm_eval internally before these lines can update the env. + import datasets + + datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True + + args.model_args = args.model_args + ",trust_remote_code=True" + eval_logger.info( + f"Selected Tasks: {task_names}" + ) if eval_logger.getEffectiveLevel() >= logging.INFO else print( + f"Selected Tasks: {task_names}" + ) + + request_caching_args = request_caching_arg_to_dict( + cache_requests=args.cache_requests + ) + + results = evaluator.simple_evaluate( + model=args.model, + model_args=args.model_args, + tasks=task_names, + num_fewshot=args.num_fewshot, + batch_size=args.batch_size, + max_batch_size=args.max_batch_size, + device=args.device, + use_cache=args.use_cache, + limit=args.limit, + check_integrity=args.check_integrity, + write_out=args.write_out, + log_samples=args.log_samples, + evaluation_tracker=evaluation_tracker, + system_instruction=args.system_instruction, + apply_chat_template=args.apply_chat_template, + fewshot_as_multiturn=args.fewshot_as_multiturn, + gen_kwargs=args.gen_kwargs, + task_manager=task_manager, + predict_only=args.predict_only, + random_seed=args.seed[0], + numpy_random_seed=args.seed[1], + torch_random_seed=args.seed[2], + fewshot_random_seed=args.seed[3], + confirm_run_unsafe_code=args.confirm_run_unsafe_code, + metadata=metadata, + **request_caching_args, + ) + + if results is not None: + if args.log_samples: + samples = results.pop("samples") + dumped = json.dumps( + results, indent=2, default=handle_non_serializable, ensure_ascii=False + ) + if args.show_config: + print(dumped) + + batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) + + # Add W&B logging + if args.wandb_args: + try: + wandb_logger.post_init(results) + wandb_logger.log_eval_result() + if args.log_samples: + wandb_logger.log_eval_samples(samples) + except Exception as e: + eval_logger.info(f"Logging to Weights and Biases failed due to {e}") + + evaluation_tracker.save_results_aggregated( + results=results, samples=samples if args.log_samples else None + ) + + if args.log_samples: + for task_name, config in results["configs"].items(): + evaluation_tracker.save_results_samples( + task_name=task_name, samples=samples[task_name] + ) + + if ( + evaluation_tracker.push_results_to_hub + or evaluation_tracker.push_samples_to_hub + ): + evaluation_tracker.recreate_metadata_card() + + print( + f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " + f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" + ) + print(make_table(results)) + if "groups" in results: + print(make_table(results, "groups")) + + if args.wandb_args: + # Tear down wandb run once all the logging is done. + wandb_logger.run.finish() + + +if __name__ == "__main__": + cli_evaluate() diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/filter.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..8d9db6821724c497c4a27116a1238e3b8d32ae29 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/filter.py @@ -0,0 +1,56 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Iterable, List, Union + +from lm_eval.api.instance import Instance + + +class Filter(ABC): + """ + Filter classes operate on a per-task level. + They take all model outputs (`instance.resps` for all `task.instances`) + across all instances of a task, and perform operations. + In a single run, one can configure any number of separate filters or lists of filters. + + """ + + def __init__(self, **kwargs) -> None: + """ + Can define custom behavior here, if an individual instantiation of a Filter class should have state. + """ + + @abstractmethod + def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable: + """ + Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. + Should return the list of (filtered) response lists *in the same order as they were input*, e.g. + if pass in [, ] should return + [, ] + """ + return resps + + +@dataclass +class FilterEnsemble: + """ + FilterEnsemble creates a pipeline applying multiple filters. + Its intended usage is to stack multiple post-processing steps in order. + `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each + pipeline separately. + """ + + name: str + filters: List[Callable[[], Filter]] + + def apply(self, instances: List[Instance]) -> None: + resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) + resps, docs = list(resps), list(docs) + + for f in self.filters: + # apply filters in sequence + resps = f().apply(resps, docs) + + # add the end results after filtering to filtered_requests of their respective source instances. + # has key `self.name`: each FilterEnsemble applied in a given run should use a different name. + for inst, resp in zip(instances, resps): + inst.filtered_resps[self.name] = resp diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/model.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1debc9b4da2ed307951e50bf2351af971937e59a --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/model.py @@ -0,0 +1,493 @@ +import abc +import hashlib +import json +import logging +import os +from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union + +import transformers +from sqlitedict import SqliteDict +from tqdm import tqdm + +from lm_eval import utils + + +eval_logger = logging.getLogger(__name__) + +T = TypeVar("T", bound="LM") + + +class LM(abc.ABC): + def __init__(self) -> None: + """Defines the interface that should be implemented by all LM subclasses. + LMs are assumed to take text (strings) as input and yield strings as output + (inputs/outputs should be tokenization-agnostic.) + + """ + # set rank and world size to a single process, by default. + self._rank = 0 + self._world_size = 1 + self.cache_hook = CacheHook(None) + + @abc.abstractmethod + def loglikelihood(self, requests) -> List[Tuple[float, bool]]: + """Compute log-likelihood of generating a continuation from a context. + Downstream tasks should attempt to use loglikelihood instead of other + LM calls whenever possible. + + :param requests: list[Instance] + A list of Instance objects, with property `args` which returns a tuple (context, continuation). + `context: str` + Context string. Implementations of LM must be able to handle an + empty context string. + `continuation: str` + The continuation over which log likelihood will be calculated. If + there is a word boundary, the space should be in the continuation. + For example, context="hello" continuation=" world" is correct. + + :return: list[tuple[float, bool]] + A list of pairs (logprob, isgreedy) + `logprob: float` + The log probability of `continuation`. + `isgreedy`: + Whether `continuation` would be generated by greedy sampling from `context`. + """ + pass + + @abc.abstractmethod + def loglikelihood_rolling(self, requests) -> List[float]: + """Compute full log-likelihood of a string, with no truncation, for perplexity computation + - We will use the full max context length of the model. + - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to + the max context length. + - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations + which may simply concatenate multiple documents together. + - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into + multiple chunks, the last input will still a full-sized context. + Example: + Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] + Prefix: BOS/EOS + Max context length: 4 + Resulting input/prediction pairs: + + INPUT: BOS 0 1 2 + PRED: 0 1 2 3 + + INPUT: 3 4 5 6 + PRED: 4 5 6 7 + + INPUT: 5 6 7 8 + PRED: 8 9 + + Observe that: + 1. Each token is predicted exactly once + 2. For the last pair, we provide the full context, but only score the last two tokens + + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context,). + string: str + String for which we are computing overall loglikelihood + :return: list[tuple[float]] + A list of tuples (logprob,) + logprob: float + The log probability of `context` conditioned on the BOS/EOS token. + Can also be overridden for custom cases by `prefix_token_id`. + """ + pass + + # TODO: Add an optional max length + @abc.abstractmethod + def generate_until(self, requests) -> List[str]: + """Generate greedily until a stopping sequence + + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs). + context: str + Context string + gen_kwargs: dict + A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc. + :return: list[str] + A list of model generated continuations. + continuation: str + The generated continuation. + """ + pass + + def apply_chat_template( + self, chat_history: List[Dict[str, str]], add_generation_prompt=True + ) -> str: + """ + Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM. + + :param chat_history: list[dict[str, str]] + A list of dictionaries with keys 'role' and 'content'. + Values are strings representing the role name and the content of the message, respectively. + :param add_generation_prompt: bool + Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message. + :return: str + A string representing the chat history in a format that can be used as input to the LM. + """ + raise NotImplementedError( + "To use this model with chat templates, please implement the 'apply_chat_template' method for your model type." + ) + + @classmethod + def create_from_arg_string( + cls: Type[T], arg_string: str, additional_config: Optional[dict] = None + ) -> T: + """ + Creates an instance of the LM class using the given argument string and additional config. + + Parameters: + - arg_string: A string containing arguments in the format key1=value1,key2=value2. + - additional_config: Optional dictionary containing additional configuration parameters. + + Returns: + - Instance of the LM class. + """ + additional_config = {} if additional_config is None else additional_config + args = utils.simple_parse_args_string(arg_string) + args2 = {k: v for k, v in additional_config.items() if v is not None} + return cls(**args, **args2) + + @classmethod + def create_from_arg_obj( + cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None + ) -> T: + """ + Creates an instance of the LM class using the given arg_obj + + Parameters: + - arg_obj: A dict containing arguments in the format key1=value1,key2=value2. + - additional_config: Optional dictionary containing additional configuration parameters. + + Returns: + - Instance of the LM class. + """ + + additional_config = {} if additional_config is None else additional_config + additional_config = { + k: v for k, v in additional_config.items() if v is not None + } + + return cls(**arg_dict, **additional_config) + + @property + def rank(self): + # used in the case of parallelism. Hardcoded to + # ensure no errors arise using API models which do + # not support multi-device parallelism nor expect it. + return self._rank + + @property + def world_size(self): + # used in the case of parallelism. Hardcoded to + # ensure no errors arise using API models which do + # not support multi-device parallelism nor expect it. + return self._world_size + + @property + def tokenizer_name(self) -> str: + """Must be defined for LM subclasses which implement Chat Templating. + Should return the name of the tokenizer or chat template used. + Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used. + """ + raise NotImplementedError( + "To use this model with chat templates, please implement the 'tokenizer_name' property." + ) + + def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: + """Returns the chat template structure for user/assistant messages if a template is provided. + This method is intended to be overridden in a subclass to define a specific chat template format. + For models that do not support chat templates, this method returns None by default. + """ + + return "" + + def set_cache_hook(self, cache_hook) -> None: + self.cache_hook = cache_hook + + +### SQLite-based caching of LM responses +def hash_args(attr, args): + dat = json.dumps([attr] + list(args)) + return hashlib.sha256(dat.encode("utf-8")).hexdigest() + + +class CacheHook: + def __init__(self, cachinglm) -> None: + if cachinglm is None: + self.dbdict = None + return + + self.dbdict = cachinglm.dbdict + + def add_partial(self, attr, req, res) -> None: + if self.dbdict is None: + return + hsh = hash_args(attr, req) + self.dbdict[hsh] = res + + +class CachingLM: + def __init__(self, lm, cache_db) -> None: + """LM wrapper that returns cached results if they exist, and uses the underlying LM if not. + + :param lm: LM + Underlying LM + :param cache_db: str + Path to cache db + """ + self.lm = lm + self.cache_db = cache_db + if os.path.dirname(cache_db): + os.makedirs(os.path.dirname(cache_db), exist_ok=True) + self.dbdict = SqliteDict(cache_db, autocommit=True) + + # add hook to lm + lm.set_cache_hook(self.get_cache_hook()) + + def __getattr__(self, attr: str): + lm_attr = getattr(self.lm, attr) + if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]: + eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM") + return lm_attr + + def fn(requests): + res = [] + remaining_reqs = [] + warned = False + # figure out which ones are cached and which ones are new + eval_logger.info( + f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..." + ) + for req in tqdm(requests, desc="Checking cached requests"): + hsh = hash_args(attr, req.args) + if attr == "generate_until" and req.args[1].get("do_sample", False): + # when we are doing non-greedy generation, don't use the cache + # (else every "randomly sampled" generation would be identical for repeats > 1). + if not warned: + eval_logger.warning( + f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests." + ) + warned = True + res.append(None) + remaining_reqs.append(req) + elif hsh in self.dbdict: + ob = self.dbdict[hsh] + + assert ob is not None + + res.append(ob) + else: + res.append(None) + remaining_reqs.append(req) + eval_logger.info( + f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}" + ) + if remaining_reqs: + # actually run the LM on the requests that do not have cached results + rem_res = getattr(self.lm, attr)(remaining_reqs) + else: + rem_res = [] + + # stick the new ones back into the list and also cache any of the new ones + resptr = 0 + for req, r in zip(remaining_reqs, rem_res): + while res[resptr] is not None: + resptr += 1 + + res[resptr] = r + + # caching + hsh = hash_args(attr, req.args) + self.dbdict[hsh] = r + self.dbdict.commit() + + return res + + return fn + + def get_cache_hook(self): + return CacheHook(self) + + +class TemplateLM(LM): + """ + A class acting as intermediary between the LM base class + and boilerplate often included in other LM subclasses. + """ + + tokenizer = None + + @property + @abc.abstractmethod + def eot_token_id(self): + pass + + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + return self.eot_token_id + + @abc.abstractmethod + def tok_encode(self, string: str, **kwargs) -> List[int]: + """ + Tokenize a string using the model's tokenizer and return a list of token IDs. + """ + pass + + @abc.abstractmethod + def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: + pass + + def _encode_pair( + self, context: str, continuation: str + ) -> Tuple[List[int], List[int]]: + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + + model_class = getattr(self, "AUTO_MODEL_CLASS", None) + + if model_class == transformers.AutoModelForSeq2SeqLM: + context_enc = self.tok_encode(context) + continuation_enc = self.tok_encode(continuation, add_special_tokens=False) + else: + whole_enc = self.tok_encode(context + continuation) + context_enc = self.tok_encode(context) + + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + + return context_enc, continuation_enc + + def loglikelihood( + self, requests, disable_tqdm: bool = False + ) -> List[Tuple[float, bool]]: + new_reqs = [] + for context, continuation in [req.args for req in requests]: + if context == "": + # BOS or EOS as context + context_enc, continuation_enc = ( + [self.prefix_token_id], + self.tok_encode(continuation), + ) + else: + context_enc, continuation_enc = self._encode_pair(context, continuation) + + new_reqs.append(((context, continuation), context_enc, continuation_enc)) + + return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm) + + @abc.abstractmethod + def loglikelihood_rolling( + self, requests, disable_tqdm: bool = False + ) -> List[float]: + pass + + @abc.abstractmethod + def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: + pass + + def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: + """ + Set and get the appropriate chat template for the model. + This method sets the tokenizer's chat_template and returns the template string for reproducibility. + + The template selection logic is adapted from the Transformers library's `apply_chat_template` + method in the Tokenizer class. The original implementation can be found at: + https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687 + + This method ensures that the right template is chosen based on the following: + 0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string. + 1. If the model's tokenizer has multiple templates: + a. Use the specified template if it exists in the dictionary. + b. Use the default template from the list if no specific template is provided. + c. Raise an error if no default template exists and no specific template is provided. + 2. If the model's tokenizer has a single template or no template: + a. Use the tokenizer's chat template if available. + b. Fall back to the default chat template if no tokenizer chat template exists. + + Args: + chat_template (Union[bool, str]): Specifies the chat template to use. + - If False or None, no template is applied. + - If True, the default or only available template is used. + - If a string, the template with the matching name is used. + + Returns: + Optional[str]: The selected chat template, or None if no template is applied. + """ + if self.tokenizer is None: + return "" + + if chat_template is False or chat_template is None: + eval_logger.warning( + "model.chat_template was called with the chat_template set to False or None. " + "Therefore no chat template will be applied. Make sure this is an intended behavior." + ) + return None + + # Convert boolean chat_template to None to ensure compatibility with the adapted logic + if isinstance(chat_template, bool): + chat_template = None + using_default_template = False + + # First, handle the cases when the model has a dict of multiple templates + try: + template = ( + self.tokenizer.chat_template or self.tokenizer.default_chat_template + ) + except AttributeError: + return None + + if isinstance(template, dict): + using_default_dict = self.tokenizer.chat_template is None + + if chat_template is not None: + if chat_template in template: + selected_template = template[chat_template] + if using_default_dict: + using_default_template = True + else: + raise ValueError( + f"The specified chat template '{chat_template}' is not available. " + f"Available template names are {sorted(template.keys())}." + ) + else: + # If user didn't pass a chat template, use the default template from the dict + if "default" in template: + selected_template = template["default"] + using_default_template = True + else: + raise ValueError( + "This model has multiple chat templates with no default specified! Please either pass a chat " + "template or the name of the template you wish to use to the `chat_template` argument. Available " + f"template names are {sorted(template.keys())}." + ) + + # Cases when the model has a single template or no template + else: + # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template + if isinstance(chat_template, str): + eval_logger.warning( + "Chat template name provided, but the tokenizer's chat template is not a dictionary. " + "Using the tokenizer's chat template or the default template instead." + ) + if self.tokenizer.chat_template is not None: + selected_template = self.tokenizer.chat_template + else: + selected_template = self.tokenizer.default_chat_template + using_default_template = True + + if using_default_template: + eval_logger.warning( + "No chat template is set for this tokenizer, falling back to a default class-level template. This is " + "very error-prone, because models are often trained with templates different from the class default! " + "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " + "point any code depending on them will stop working. We recommend setting a valid chat template before " + "then to ensure that this model continues working without issues." + ) + + return selected_template diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/samplers.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..5d1791bdb4f8ae06cf4168dcdfa4c6a5a9bbc823 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/samplers.py @@ -0,0 +1,232 @@ +import logging +import warnings +from functools import partial +from typing import TYPE_CHECKING, Iterable, Optional, Union + +import datasets + + +if TYPE_CHECKING: + from random import Random + + from lm_eval.api.task import ConfigurableTask, Task + +eval_logger = logging.getLogger("lm-eval") + + +class ContextSampler: + def __init__( + self, + docs: list[dict], + task: Union["Task", "ConfigurableTask"], + fewshot_indices: Optional[Iterable] = None, + rnd: Optional["Random"] = None, + ) -> None: + self.rnd = rnd + if not self.rnd: + raise ValueError( + "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!" + ) + + self.task = task + self.config = task._config + + self.target_delimiter = self.config.target_delimiter + self.fewshot_delimiter = self.config.fewshot_delimiter + + if ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("doc_to_text", None) is not None + ): + self.doc_to_text = partial( + self.task.doc_to_text, + doc_to_text=self.config.fewshot_config.get("doc_to_text", None), + ) + else: + self.doc_to_text = self.task.doc_to_text + + if ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("doc_to_target", None) is not None + ): + self.doc_to_target = partial( + self.task.doc_to_target, + doc_to_target=self.config.fewshot_config.get("doc_to_target", None), + ) + else: + self.doc_to_target = self.task.doc_to_target + + if ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("doc_to_choice", None) is not None + ): + self.doc_to_choice = partial( + self.task.doc_to_choice, + doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None), + ) + else: + self.doc_to_choice = self.task.doc_to_choice + + self.docs = docs # HF dataset split, provided by task._fewshot_docs() + if fewshot_indices: # subset few-shot docs from + if not isinstance(self.docs, datasets.Dataset): + raise ValueError( + "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously" + ) + self.docs = self.docs.select(fewshot_indices) + + def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None): + # draw an extra fewshot sample if using same split as evaluating on + prefix = gen_prefix + " " if gen_prefix else "" + n_samples = ( + num_fewshot + 1 + if self.config.fewshot_split == self.config.test_split + else num_fewshot + ) + + # draw `n_samples` docs from fewshot_docs + fewshotex = self.sample(n_samples) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + # TODO: should we just stop people from using fewshot from same split as evaluating? + selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] + + labeled_examples = "" + for doc in selected_docs: + doc_content = self.doc_to_text(doc) + doc_target = self.doc_to_target(doc) + if self.config.doc_to_choice is None or isinstance(doc_content, str): + labeled_examples += doc_content + else: + labeled_examples += self.doc_to_choice(doc)[doc_content] + + if doc_target != "": + if self.target_delimiter.isspace() and str(doc_target)[0].isspace(): + # TODO: add logger warn once here. + warnings.warn( + "Both target_delimiter and target start with a space. This may cause issues.", + Warning, + stacklevel=2, + ) + labeled_examples += self.target_delimiter + labeled_examples += prefix + labeled_examples += ( + str(doc_target[0]) + if isinstance(doc_target, list) + else doc_target + if self.config.doc_to_choice is None or isinstance(doc_target, str) + else str(self.doc_to_choice(doc)[doc_target]) + ) + labeled_examples += self.fewshot_delimiter + + return labeled_examples + + def get_chat_context( + self, + doc: dict, + num_fewshot: int, + fewshot_as_multiturn: bool = False, + gen_prefix: Optional[str] = None, + ): + # TODO: Do we need any other delimiter + prefix = gen_prefix + " " if gen_prefix else "" + chat_history = [] + # draw an extra fewshot sample if using same split as evaluating on + n_samples = ( + num_fewshot + 1 + if self.config.fewshot_split == self.config.test_split + else num_fewshot + ) + # draw `n_samples` docs from fewshot_docs + fewshotex = self.sample(n_samples) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + # TODO: should we just stop people from using fewshot from same split as evaluating? + selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] + + if fewshot_as_multiturn: + for doc in selected_docs: + doc_content = self.doc_to_text(doc) + doc_target = self.doc_to_target(doc) + chat_history.append( + { + "role": "user", + "content": doc_content + if self.config.doc_to_choice is None + or isinstance(doc_content, str) + else self.doc_to_choice(doc)[doc_content], + } + ) + chat_history.append( + { + "role": "assistant", + "content": prefix + str(doc_target[0]) + if isinstance(doc_target, list) + else prefix + doc_target + if self.config.doc_to_choice is None + or isinstance(doc_target, str) + else prefix + str(self.doc_to_choice(doc)[doc_target]), + } + ) + else: + # get fewshot context as one user turn + chat_history.append( + { + "role": "user", + "content": self.get_context( + doc, num_fewshot, gen_prefix=gen_prefix + ), + } + ) + + return chat_history + + def sample(self, n: int): + """ + Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. + """ + + return self.rnd.sample(self.docs, n) + + +class FirstNSampler(ContextSampler): + def sample(self, n: int) -> None: + """ + Draw the first `n` samples in order from the specified split. + Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. + """ + assert n <= len(self.docs), ( + f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available." + ) + return self.docs[:n] + + +class BalancedSampler(ContextSampler): + def sample(self, n: int) -> None: + """ + TODO: this should return approximately class-balanced samples from our fewshot examples. + TODO: what order should they be in? maybe random? + """ + + pass + + +class ManualSampler(ContextSampler): + def sample(self, n: int) -> None: + """ """ + pass + + +SAMPLER_REGISTRY = { + "default": ContextSampler, + "first_n": FirstNSampler, +} + + +def get_sampler(name: str): + try: + return SAMPLER_REGISTRY[name] + except KeyError: + raise ValueError( + f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}" + ) diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/task.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/task.py new file mode 100644 index 0000000000000000000000000000000000000000..cf865a167725fb1da7dcdd229aa1b75bc38362b1 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/api/task.py @@ -0,0 +1,1839 @@ +import abc +import ast +import logging +import random +import re +from collections.abc import Callable +from copy import deepcopy +from dataclasses import asdict, dataclass +from inspect import getsource +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Literal, + Mapping, + Optional, + Tuple, + Union, +) + +import datasets +import numpy as np +from tqdm import tqdm + +from lm_eval import utils +from lm_eval.api import samplers +from lm_eval.api.instance import Instance, OutputType +from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity +from lm_eval.api.registry import ( + AGGREGATION_REGISTRY, + DEFAULT_METRIC_REGISTRY, + get_aggregation, + get_metric, + get_metric_aggregation, + is_higher_better, +) +from lm_eval.caching.cache import load_from_cache, save_to_cache +from lm_eval.filters import build_filter_ensemble +from lm_eval.prompts import get_prompt + + +ALL_OUTPUT_TYPES = [ + "loglikelihood", + "multiple_choice", + "loglikelihood_rolling", + "generate_until", +] + +eval_logger = logging.getLogger(__name__) + + +@dataclass +class TaskConfig(dict): + # task naming/registry + task: Optional[str] = None + task_alias: Optional[str] = None + tag: Optional[Union[str, list]] = None + # HF dataset options. + # which dataset to use, + # and what splits for what purpose + custom_dataset: Optional[Callable] = None + dataset_path: Optional[str] = None + dataset_name: Optional[str] = None + dataset_kwargs: Optional[dict] = None + training_split: Optional[str] = None + validation_split: Optional[str] = None + test_split: Optional[str] = None + fewshot_split: Optional[str] = ( + None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?) + ) + # formatting / prompting options. + # see docs/advanced_task_guide.md for more info + process_docs: Optional[Callable] = None + doc_to_text: Optional[Union[Callable, str]] = None + doc_to_target: Optional[Union[Callable, str]] = None + doc_to_image: Union[Callable, str] = None + doc_to_audio: Union[Callable, str] = None + unsafe_code: bool = False + doc_to_choice: Optional[Union[Callable, str, dict, list]] = None + process_results: Optional[Union[Callable, str]] = None + use_prompt: Optional[str] = None + description: str = "" + target_delimiter: str = " " + fewshot_delimiter: str = "\n\n" + fewshot_config: Optional[dict] = None + # runtime configuration options + num_fewshot: Optional[int] = None + # scoring options + metric_list: Optional[list] = None + output_type: OutputType = "generate_until" + generation_kwargs: Optional[dict] = None + repeats: int = 1 + filter_list: Optional[Union[str, list]] = None + should_decontaminate: bool = False + doc_to_decontamination_query: Optional[str] = None + gen_prefix: Optional[str] = None + metadata: Optional[dict] = ( + None # by default, not used in the code. allows for users to pass arbitrary info to tasks + ) + + def __post_init__(self) -> None: + if self.generation_kwargs is not None: + if self.output_type != "generate_until": + eval_logger.warning( + f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!" + ) + + if "temperature" in self.generation_kwargs: + self.generation_kwargs["temperature"] = float( + self.generation_kwargs["temperature"] + ) + + if "until" not in self.generation_kwargs: + self.generation_kwargs["until"] = [self.fewshot_delimiter] + else: + if self.output_type == "generate_until": + # ensure that we greedily generate in absence of explicit arguments otherwise + self.generation_kwargs = { + "until": ( + None + if self.fewshot_delimiter is None + else [self.fewshot_delimiter] + ), + "do_sample": False, + } + + def __getitem__(self, item): + return getattr(self, item) + + def __setitem__(self, item, value): + return setattr(self, item, value) + + def to_dict(self, keep_callable: bool = False) -> dict: + """dumps the current config as a dictionary object, as a printable format. + null fields will not be printed. + Used for dumping results alongside full task configuration + + :return: dict + A printable dictionary version of the TaskConfig object. + + # TODO: should any default value in the TaskConfig not be printed? + """ + cfg_dict = asdict(self) + # remove values that are `None` + for k, v in list(cfg_dict.items()): + if v is None: + cfg_dict.pop(k) + elif k == "metric_list": + for metric_dict in v: + for metric_key, metric_value in metric_dict.items(): + if callable(metric_value): + metric_dict[metric_key] = self.serialize_function( + metric_value, keep_callable=keep_callable + ) + cfg_dict[k] = v + elif callable(v): + cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable) + return cfg_dict + + def serialize_function( + self, value: Union[Callable, str], keep_callable=False + ) -> Union[Callable, str]: + """Serializes a given function or string. + + If 'keep_callable' is True, the original callable is returned. + Otherwise, attempts to return the source code of the callable using 'getsource'. + """ + if keep_callable: + return value + else: + try: + return getsource(value) + except (TypeError, OSError): + return str(value) + + +class Task(abc.ABC): + """A task represents an entire benchmark including its dataset, problems, + answers, and evaluation methods. See BoolQ for a simple example implementation + + A `doc` can be any python object which represents one instance of evaluation. + This is usually a dictionary e.g. + {"question": ..., "answer": ...} or + {"question": ..., question, answer) + """ + + VERSION: Optional[Union[int, str]] = None + + # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub + # or a path to a custom `datasets` loading script. + DATASET_PATH: Optional[str] = None + + # The name of a subset within `DATASET_PATH`. + DATASET_NAME: Optional[str] = None + + OUTPUT_TYPE: Optional[OutputType] = None + + def __init__( + self, + data_dir: Optional[str] = None, + cache_dir: Optional[str] = None, + download_mode: Optional[datasets.DownloadMode] = None, + config: Optional[Mapping] = None, # Union[dict, TaskConfig] + ) -> None: + """ + :param data_dir: str + Stores the path to a local folder containing the `Task`'s data files. + Use this to specify the path to manually downloaded data (usually when + the dataset is not publicly accessible). + :param cache_dir: str + The directory to read/write the `Task` dataset. This follows the + HuggingFace `datasets` API with the default cache directory located at: + `~/.cache/huggingface/datasets` + NOTE: You can change the cache location globally for a given process + to another directory: + `export HF_DATASETS_CACHE="/path/to/another/directory"` + :param download_mode: datasets.DownloadMode + How to treat pre-existing `Task` downloads and data. + - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS` + Reuse download and reuse dataset. + - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS` + Reuse download with fresh dataset. + - `datasets.DownloadMode.FORCE_REDOWNLOAD` + Fresh download and fresh dataset. + """ + self.download(data_dir, cache_dir, download_mode) + self._training_docs: Optional[list] = None + self._fewshot_docs: Optional[list] = None + self._instances: Optional[List[Instance]] = None + + self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig() + + self._filters = [build_filter_ensemble("none", [["take_first", None]])] + self.fewshot_rnd: Optional[random.Random] = ( + None # purposely induce errors in case of improper usage + ) + + def download( + self, + data_dir: Optional[str] = None, + cache_dir: Optional[str] = None, + download_mode=None, + ) -> None: + """Downloads and returns the task dataset. + Override this method to download the dataset from a custom API. + + :param data_dir: str + Stores the path to a local folder containing the `Task`'s data files. + Use this to specify the path to manually downloaded data (usually when + the dataset is not publicly accessible). + :param cache_dir: str + The directory to read/write the `Task` dataset. This follows the + HuggingFace `datasets` API with the default cache directory located at: + `~/.cache/huggingface/datasets` + NOTE: You can change the cache location globally for a given process + by setting the shell environment variable, `HF_DATASETS_CACHE`, + to another directory: + `export HF_DATASETS_CACHE="/path/to/another/directory"` + :param download_mode: datasets.DownloadMode + How to treat pre-existing `Task` downloads and data. + - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS` + Reuse download and reuse dataset. + - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS` + Reuse download with fresh dataset. + - `datasets.DownloadMode.FORCE_REDOWNLOAD` + Fresh download and fresh dataset. + """ + self.dataset = datasets.load_dataset( + path=self.DATASET_PATH, + name=self.DATASET_NAME, + data_dir=data_dir, + cache_dir=cache_dir, + download_mode=download_mode, + ) + + @property + def config(self) -> TaskConfig: + """Returns the TaskConfig associated with this class.""" + return self._config + + @abc.abstractmethod + def has_training_docs(self): + """Whether the task has a training set""" + pass + + @abc.abstractmethod + def has_validation_docs(self): + """Whether the task has a validation set""" + pass + + @abc.abstractmethod + def has_test_docs(self): + """Whether the task has a test set""" + pass + + def training_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + return [] + + def validation_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + return [] + + def test_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + return [] + + def fewshot_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + if self.has_training_docs(): + return self.training_docs() + elif self.has_validation_docs(): + return self.validation_docs() + else: + if self.config.get("num_fewshot", 0) > 0: + eval_logger.warning( + f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False" + ", using test_docs as fewshot_docs but this is not recommended." + ) + return self.test_docs() + + def _process_doc(self, doc: dict) -> dict: + """ + Override this to process (detokenize, strip, replace, etc.) individual + documents. This can be used in a map over documents of a data split. + E.g. `map(self._process_doc, self.dataset["validation"])` + + :return: dict + The processed version of the specified `doc`. + """ + return doc + + @property + def instances(self) -> List[Instance]: + """After calling `task.build_all_requests()`, tasks + maintain a list of the dataset instances which will be evaluated. + """ + return self._instances + + def fewshot_examples(self, k, rnd): + if self._training_docs is None: + self._training_docs = list(self.training_docs()) + + return rnd.sample(self._training_docs, k) + + def doc_to_decontamination_query(self, doc): + raise NotImplementedError( + "Override doc_to_decontamination_query with document specific decontamination query." + ) + + @abc.abstractmethod + def doc_to_text(self, doc): + pass + + @abc.abstractmethod + def doc_to_target(self, doc): + pass + + # not an abstractmethod because not every language-only task has to implement this + def doc_to_image(self, doc): + raise NotImplementedError + + def doc_to_audio(self, doc): + raise NotImplementedError + + def doc_to_prefix(self, doc): + return "" + + def build_all_requests( + self, + *, + limit: Union[int, None] = None, + rank: int = 0, + world_size: int = 1, + cache_requests: bool = False, + rewrite_requests_cache: bool = False, + system_instruction: Optional[str] = None, + apply_chat_template: bool = False, + fewshot_as_multiturn: bool = False, + chat_template: Optional[Callable] = None, + tokenizer_name: str = "", + ) -> None: + """Build a set of Instances for a task, and store them in task.instances""" + + # used with caching + og_limit = limit + + cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}" + cache_key += "-chat_template" if apply_chat_template else "" + cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else "" + cache_key += ( + f"-system_prompt_hash{utils.hash_string(system_instruction)}" + if system_instruction is not None + else "" + ) + cache_key += f"-tokenizer{tokenizer_name}" + + cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests) + + if cache_requests and cached_instances and not rewrite_requests_cache: + cached_instances = cached_instances[:limit] + + flattened_instances = [ + instance + for instance_group in cached_instances + for instance in instance_group + ] + + self._instances = flattened_instances + return + + eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...") + + instances = [] + + # process all documents when caching is specified for simplicity + if ( + cache_requests + and (not cached_instances or rewrite_requests_cache) + and limit is not None + ): + limit = None + + doc_id_docs = list( + self.doc_iterator(rank=rank, limit=limit, world_size=world_size) + ) + + num_docs = len(doc_id_docs) + + for doc_id, doc in tqdm( + doc_id_docs, + total=num_docs, + ): + # sample fewshot context #TODO: need to offset doc_id by rank now! + fewshot_ctx = self.fewshot_context( + doc, + 0 if self.config.num_fewshot is None else self.config.num_fewshot, + system_instruction, + apply_chat_template, + fewshot_as_multiturn, + chat_template, + gen_prefix=self.doc_to_prefix(doc), + ) + + # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute + inst = self.construct_requests( + doc=doc, + ctx=fewshot_ctx, + metadata=(self.config["task"], doc_id, self.config.repeats), + apply_chat_template=apply_chat_template, + chat_template=chat_template, + ) + + if not isinstance(inst, list): + inst = [inst] + + instances.append(inst) + + # now flatten, this is to allow slicing to work with pickles + + sliced_instances = instances[:og_limit] + + flattened_instances = [ + instance + for instance_group in sliced_instances + for instance in instance_group + ] + + self._instances = flattened_instances + + if len(self._instances) == 0: + raise ValueError("task.build_requests() did not find any docs!") + + if cache_requests and (not cached_instances or rewrite_requests_cache): + save_to_cache(file_name=cache_key, obj=instances) + + @abc.abstractmethod + def construct_requests(self, doc, ctx, **kwargs): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + :param doc_idx: int + The index of a document within `self.test_docs()` or `self.validation_docs()`, + whichever is the main split used. + :param repeats: int + TODO: update this docstring + The number of times each instance in a dataset is inferred on. Defaults to 1, + can be increased for techniques like majority voting. + """ + pass + + @abc.abstractmethod + def process_results(self, doc, results): + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of + the metric for that one document + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param results: + The results of the requests created in construct_requests. + """ + pass + + @abc.abstractmethod + def aggregation(self): + """ + :returns: {str: [metric_score] -> float} + A dictionary where keys are the names of submetrics and values are + functions that aggregate a list of metric scores + """ + pass + + @abc.abstractmethod + def higher_is_better(self): + """ + :returns: {str: bool} + A dictionary where keys are the names of submetrics and values are + whether a higher value of the submetric is better + """ + pass + + def get_config(self, key: str) -> Any: + return getattr(self._config, key, None) + + @classmethod + def count_bytes(cls, doc): + """Used for byte-level perplexity metrics in rolling loglikelihood""" + return len(doc.encode("utf-8")) + + @classmethod + def count_words(cls, doc): + """Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!""" + return len(re.split(r"\s+", doc)) + + @utils.positional_deprecated + def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs): + """Returns a fewshot context string that is made up of a prepended description + (if provided), the `num_fewshot` number of examples, and an appended prompt example. + + :param doc: str + The document as returned from training_docs, validation_docs, or test_docs. + :param num_fewshot: int + The number of fewshot examples to provide in the returned context string. + :param rnd: random.Random + The pseudo-random number generator used to randomly sample examples. + WARNING: This is currently a required arg although it's optionalized with a default `None`. + :param description: str + The task's description that will be prepended to the fewshot examples. + :returns: str + The fewshot context. + """ + if rnd is None: + if self.fewshot_rnd is not None: + rnd = self.fewshot_rnd + else: + raise ValueError( + "A `random.Random` generator argument must be provided to `rnd`" + ) + + description = description if description else "" + + if num_fewshot == 0: + labeled_examples = "" + else: + # for sets with no training docs, draw from other set *but ensure no overlap with current doc* + if self.has_training_docs(): + fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd) + else: + if self._fewshot_docs is None: + self._fewshot_docs = list( + self.validation_docs() + if self.has_validation_docs() + else self.test_docs() + ) + + fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] + + labeled_examples = ( + "\n\n".join( + [ + self.doc_to_text(doc) + self.doc_to_target(doc) + for doc in fewshotex + ] + ) + + "\n\n" + ) + + example = self.doc_to_text(doc) + return description + labeled_examples + example + + def apply_filters(self) -> Optional[List[Instance]]: + """Iterates over FilterEnsembles and applies them to instances""" + if hasattr(self, "_filters"): + for f in self._filters: + f.apply(self._instances) + else: + eval_logger.warning("No filter defined, passing through instances") + return self._instances + + def dump_config(self) -> dict: + """Returns the config as a dictionary.""" + # TODO: this should only return the overrides applied to a non-YAML task's configuration. + # (num_fewshot) + return self.config.to_dict() + + def set_config(self, key: str, value: Any, update: bool = False) -> None: + """Set or update the configuration for a given key.""" + if key is None: + raise ValueError("Key must be provided.") + + if update: + current_value = getattr(self._config, key, {}) + if not isinstance(current_value, dict): + raise TypeError( + f"Expected a dict for key '{key}', got {type(current_value).__name__} instead." + ) + current_value.update(value) + else: + setattr(self._config, key, value) + + def override_metric(self, metric_name: str) -> None: + """ + Override the default metrics used for evaluation with custom metrics. + + Parameters: + - metric_name (str): The name of the custom metric to override. Should be registered in api.metrics. + """ + ( + self._metric_fn_list, + self._aggregation_list, + self._metric_fn_kwargs, + self._higher_is_better, + ) = ({}, {}, {}, {}) + self._metric_fn_list[metric_name] = get_metric(metric_name) + self._aggregation_list[metric_name] = get_metric_aggregation(metric_name) + self._higher_is_better[metric_name] = is_higher_better(metric_name) + self._metric_fn_kwargs[metric_name] = {} + if not isinstance(self, ConfigurableTask): + self.process_results = lambda x, y: {metric_name: get_metric(metric_name)} + self.aggregation = lambda: { + metric_name: get_metric_aggregation(metric_name) + } + setattr(self._config, "metric_list", [{"metric": metric_name}]) + setattr(self._config, "process_results", None) + + def set_fewshot_seed(self, seed: Optional[int] = None) -> None: + self.fewshot_rnd = random.Random(seed) + if hasattr(self, "sampler"): + self.sampler.rnd = self.fewshot_rnd + + @property + def eval_docs(self) -> Union[datasets.Dataset, List[dict]]: + if self.has_test_docs(): + return self.test_docs() + elif self.has_validation_docs(): + return self.validation_docs() + else: + raise ValueError( + f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" + ) + + def doc_iterator( + self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1 + ) -> Iterator[Tuple[int, Any]]: + limit = int(limit) if limit else None + doc_iterator = utils.create_iterator( + enumerate(self.eval_docs), + rank=int(rank), + limit=limit, + world_size=int(world_size), + ) + return doc_iterator + + +class ConfigurableTask(Task): + VERSION = "Yaml" + OUTPUT_TYPE = None + CONFIG = None + + def __init__( + self, + data_dir=None, + cache_dir=None, + download_mode=None, + config: Optional[dict] = None, + ) -> None: # TODO no super() call here + # Get pre-configured attributes + self._config = self.CONFIG + + # Use new configurations if there was no preconfiguration + if self.config is None: + self._config = TaskConfig(**config) + # Overwrite configs + else: + if config is not None: + self._config.__dict__.update(config) + + if self.config is None: + raise ValueError( + "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg" + ) + + if isinstance(self.config.metadata, dict): + if "version" in self.config.metadata: + self.VERSION = self.config.metadata["version"] + + if self.config.output_type is not None: + if self.config.output_type not in ALL_OUTPUT_TYPES: + raise ValueError( + f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'" + ) + self.OUTPUT_TYPE = self.config.output_type + + if self.config.doc_to_image is not None: + # mark the task as requiring multimodality. + self.MULTIMODAL = True + + if self.config.doc_to_audio: + # mark the task as requiring multimodality. + self.MULTIMODAL = True + + if self.config.unsafe_code is not False: + self.UNSAFE_CODE = True + + if self.config.dataset_path is not None: + self.DATASET_PATH = self.config.dataset_path + + if self.config.dataset_name is not None: + self.DATASET_NAME = self.config.dataset_name + + self._metric_fn_list = {} + self._metric_fn_kwargs = {} + self._aggregation_list = {} + self._higher_is_better = {} + + if self.config.metric_list is None: + # TODO: handle this in TaskConfig.__post_init__ ? + _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type] + + for metric_name in _metric_list: + self._metric_fn_list[metric_name] = get_metric(metric_name) + self._metric_fn_kwargs[metric_name] = {} + self._aggregation_list[metric_name] = get_metric_aggregation( + metric_name + ) + self._higher_is_better[metric_name] = is_higher_better(metric_name) + else: + for metric_config in self.config.metric_list: + if "metric" not in metric_config: + raise ValueError( + "'metric' key not provided for an entry in 'metric_list', must be specified!" + ) + metric_name = metric_config["metric"] + kwargs = { + key: metric_config[key] + for key in metric_config + if key + not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"] + } + hf_evaluate_metric = ( + "hf_evaluate" in metric_config + and metric_config["hf_evaluate"] is True + ) + + if self.config.process_results is not None: + self._metric_fn_list[metric_name] = None + self._metric_fn_kwargs[metric_name] = {} + elif callable(metric_name): + metric_fn = metric_name.__call__ + metric_name = metric_name.__name__ + self._metric_fn_list[metric_name] = metric_fn + self._metric_fn_kwargs[metric_name] = kwargs + else: + self._metric_fn_list[metric_name] = get_metric( + metric_name, hf_evaluate_metric + ) + self._metric_fn_kwargs[metric_name] = kwargs + + if "aggregation" in metric_config: + agg_name = metric_config["aggregation"] + if isinstance(agg_name, str): + self._aggregation_list[metric_name] = get_aggregation(agg_name) + elif callable(agg_name): # noqa: E721 + self._aggregation_list[metric_name] = metric_config[ + "aggregation" + ] + else: + INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()} + metric_agg = get_metric_aggregation(metric_name) + eval_logger.warning( + f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. " + f"using default " + f"aggregation={INV_AGG_REGISTRY[metric_agg]}" + ) + self._aggregation_list[metric_name] = metric_agg + + if "higher_is_better" in metric_config: + self._higher_is_better[metric_name] = metric_config[ + "higher_is_better" + ] + else: + eval_logger.warning( + f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. " + f"using default " + f"higher_is_better={is_higher_better(metric_name)}" + ) + self._higher_is_better[metric_name] = is_higher_better(metric_name) + + self.download(self.config.dataset_kwargs) + self._training_docs = None + self._fewshot_docs = None + + if self.config.filter_list is not None: + self._filters = [] + for filter_config in self.config.filter_list: + filter_name = filter_config["name"] + filter_functions = filter_config["filter"] + components = [] + for function in filter_functions: + kwargs = { + key: function[key] for key in function if key != "function" + } + components.append([function["function"], kwargs]) + filter_pipeline = build_filter_ensemble(filter_name, components) + self._filters.append(filter_pipeline) + else: + # TODO: handle repeats in a more general way rather than just discarding + eval_logger.debug( + "No custom filters defined. Using default 'take_first' filter for handling repeats." + ) + self._filters = [build_filter_ensemble("none", [["take_first", None]])] + + if self.config.use_prompt is not None: + eval_logger.info(f"loading prompt {self.config.use_prompt}") + self.prompt = get_prompt( + self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME + ) + else: + self.prompt = None + + if self.fewshot_docs() is not None: + self.fewshot_rnd = ( + random.Random() + ) # setting with no seed, to be overridden at a later time + config_sampler: Union[str, Callable] = ( + self.config.fewshot_config.get("sampler", "default") + if self.config.fewshot_config + else "default" + ) + if isinstance(config_sampler, str): + self.sampler = samplers.get_sampler(config_sampler)( + list(self.fewshot_docs()), self, rnd=self.fewshot_rnd + ) + elif callable(config_sampler) and issubclass( + config_sampler, samplers.ContextSampler + ): + self.sampler = config_sampler( + docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd + ) + else: + raise TypeError( + f"fewshot_config.sampler should be a string or callable of ContextSampler type, " + f"not {type(config_sampler)}" + ) + + self.task_docs = self.eval_docs + + # Test One Doc + self.features = list(self.task_docs.features.keys()) + self.multiple_input = 0 + self.multiple_target = 0 + test_doc = self.task_docs[0] + test_text = self.doc_to_text(test_doc) + test_target = self.doc_to_target(test_doc) + + if self.config.doc_to_choice is not None: + test_choice = self.doc_to_choice(test_doc) + if not isinstance(test_choice, list): + eval_logger.error("doc_to_choice must return list") + else: + num_choice = len(test_choice) + + if isinstance(test_text, int): + self.multiple_input = num_choice + else: + test_choice = None + + if isinstance(test_target, list): + self.multiple_target = len(test_target) + else: + if (isinstance(test_target, int)) and (test_choice is not None): + test_target = test_choice[test_target] + else: + test_target = str(test_target) + + if test_choice is not None: + check_choices = test_choice + else: + check_choices = [test_target] + if self.config.doc_to_choice is not None: + for choice in check_choices: + choice_has_whitespace = True if choice[0].isspace() else False + delimiter_has_whitespace = ( + True + if self.config.target_delimiter.rstrip() + != self.config.target_delimiter + else False + ) + + if delimiter_has_whitespace and choice_has_whitespace: + eval_logger.debug( + f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace' + ) + elif (not delimiter_has_whitespace) and (not choice_has_whitespace): + eval_logger.debug( + f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace' + ) + + def download( + self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs + ) -> None: + if isinstance(self.config.custom_dataset, Callable): + eval_logger.warning( + f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager." + + "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme." + ) + self.dataset = self.config.custom_dataset( + **(self.config.metadata or {}), **(self.config.dataset_kwargs or {}) + ) + else: + self.dataset = datasets.load_dataset( + path=self.DATASET_PATH, + name=self.DATASET_NAME, + **dataset_kwargs if dataset_kwargs is not None else {}, + ) + + def has_training_docs(self) -> bool: + if self.config.training_split is not None: + return True + else: + return False + + def has_validation_docs(self) -> bool: + if self.config.validation_split is not None: + return True + else: + return False + + def has_test_docs(self) -> bool: + if self.config.test_split is not None: + return True + else: + return False + + def training_docs(self) -> datasets.Dataset: + if self.has_training_docs(): + if self.config.process_docs is not None: + return self.config.process_docs( + self.dataset[self.config.training_split] + ) + return self.dataset[self.config.training_split] + + def validation_docs(self) -> datasets.Dataset: + if self.has_validation_docs(): + if self.config.process_docs is not None: + return self.config.process_docs( + self.dataset[self.config.validation_split] + ) + return self.dataset[self.config.validation_split] + + def test_docs(self) -> datasets.Dataset: + if self.has_test_docs(): + if self.config.process_docs is not None: + return self.config.process_docs(self.dataset[self.config.test_split]) + return self.dataset[self.config.test_split] + + def fewshot_docs(self): + if self.config.fewshot_split is not None: + if self.config.process_docs is not None: + return self.config.process_docs(self.dataset[self.config.fewshot_split]) + return self.dataset[self.config.fewshot_split] + elif ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("samples", None) is not None + ): + if isinstance(self.config.fewshot_config["samples"], list): + return self.config.fewshot_config["samples"] + elif callable(self.config.fewshot_config["samples"]): + return self.config.fewshot_config["samples"]() + else: + raise Exception( + "`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list." + ) + else: + if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0): + eval_logger.warning( + f"[Task: {self.config.task}] " + "num_fewshot > 0 but fewshot_split is None. " + "using preconfigured rule." + ) + return super().fewshot_docs() + + @staticmethod + def append_target_question( + labeled_examples: List[Dict[str, str]], + question: str, + fewshot_as_multiturn: bool = False, + gen_prefix: Optional[str] = None, + ) -> None: + """Adds a target question to the labeled examples list. + If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry. + Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant. + """ + if not fewshot_as_multiturn: + # if no messages or last message is system, append as new user entry + if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system": + labeled_examples.append({"role": "user", "content": question}) + # if last message is user, append to it to avoid two user messages in a row + else: + labeled_examples[-1]["content"] += question + else: + # if fewshot_as_multiturn is True, append as next user entry (last is always assistant) + labeled_examples.append({"role": "user", "content": question}) + if gen_prefix: + labeled_examples.append({"role": "assistant", "content": gen_prefix}) + + @utils.positional_deprecated + def fewshot_context( + self, + doc: dict, + num_fewshot: int, + system_instruction: Optional[str] = None, + apply_chat_template: bool = False, + fewshot_as_multiturn: bool = False, + chat_template: Optional[Callable] = None, + gen_prefix: Optional[str] = None, + ) -> Union[str, List[str]]: + """Returns a fewshot context string that is made up of a prepended description + (if provided), the `num_fewshot` number of examples, and an appended prompt example. + + :param doc: str + The document as returned from training_docs, validation_docs, or test_docs. + :param num_fewshot: int + The number of fewshot examples to provide in the returned context string. + :param system_instruction: str + System instruction to be applied to the prompt. + :param apply_chat_template: bool + Whether to apply the chat template to the fewshot context. + :param fewshot_as_multiturn: bool + Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param chat_template: + callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string. + :param gen_prefix: + String to append after the <|assistant|> token. + :returns: str + The fewshot context. + """ + if apply_chat_template: + labeled_examples = [] + else: + labeled_examples = "" + + # get task description + if description := self.config.description: + description = utils.apply_template(self.config.description, doc) + + # create system prompt based on the provided system instruction and description + if system_instruction is not None and description: + system_prompt = ( + f"{system_instruction}{self.sampler.fewshot_delimiter}{description}" + ) + elif system_instruction is not None: + system_prompt = system_instruction + elif description: + system_prompt = description + else: + system_prompt = "" + + # add system prompt if specified + if system_prompt: + if apply_chat_template: + labeled_examples.append({"role": "system", "content": system_prompt}) + else: + labeled_examples = system_prompt + # if few-shot - append examples after the system prompt + if num_fewshot > 0: + if apply_chat_template: + labeled_examples.extend( + self.sampler.get_chat_context( + doc, + num_fewshot, + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + ) + else: + labeled_examples += self.sampler.get_context( + doc, num_fewshot, gen_prefix=gen_prefix + ) + + example = self.doc_to_text(doc) + if apply_chat_template: + if self.multiple_input: + # TODO: append prefill? + if not labeled_examples: + return "" + return chat_template(labeled_examples) + if isinstance(example, str): + self.append_target_question( + labeled_examples, + example, + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + # for loglikelihood create a list of questions with appended choices + elif isinstance(example, list): + labeled_examples_list = [] + # copy chat history for each example and append the answer + for ex in example: + chat = deepcopy(labeled_examples) + self.append_target_question( + chat, + ex, + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + # TODO: append prefill? + labeled_examples_list.append( + chat_template( + chat, + add_generation_prompt=False if gen_prefix else True, + ) + ) + return labeled_examples_list + # if example is an integer, append the choice or convert to string + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + self.append_target_question( + labeled_examples, + choices[example], + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + else: + self.append_target_question( + labeled_examples, + str(example), + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + # return lm.apply_chat_template(labeled_examples) + return chat_template( + labeled_examples, + add_generation_prompt=False if gen_prefix else True, + ) + else: + prefix = ( + self.config.target_delimiter + gen_prefix + if gen_prefix is not None + else "" + ) + if self.multiple_input: + return labeled_examples + if isinstance(example, str): + return labeled_examples + example + prefix + elif isinstance(example, list): + return [labeled_examples + ex + prefix for ex in example] + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + return labeled_examples + choices[example] + prefix + else: + return labeled_examples + str(example) + prefix + + def apply_filters(self) -> Optional[List[Instance]]: + """Iterates over FilterEnsembles and applies them to instances""" + if hasattr(self, "_filters"): + for f in self._filters: + f.apply(self._instances) + else: + eval_logger.warning("No filter defined, passing through instances") + return self._instances + + def should_decontaminate(self): + return self.config.should_decontaminate + + def doc_to_decontamination_query(self, doc: dict): + if self.config.should_decontaminate: + if self.config.doc_to_decontamination_query is None: + return self.doc_to_text(doc) + else: + doc_to_decontamination_query = self.config.doc_to_decontamination_query + if doc_to_decontamination_query in self.features: + return doc[doc_to_decontamination_query] + elif callable(doc_to_decontamination_query): + return doc_to_decontamination_query(doc) + else: + return ast.literal_eval( + utils.apply_template( + self.config.doc_to_decontamination_query, doc + ) + ) + + def _process_doc(self, doc: dict) -> dict: + """ + Override this to process (detokenize, strip, replace, etc.) individual + documents. This can be used in a map over documents of a data split. + E.g. `map(self._process_doc, self.dataset["validation"])` + + :return: dict + The processed version of the specified `doc`. + """ + return doc + + def doc_to_text(self, doc, doc_to_text=None): + if self.prompt is not None: + doc_to_text = self.prompt + elif doc_to_text is not None: + doc_to_text = doc_to_text + else: + doc_to_text = self.config.doc_to_text + + if isinstance(doc_to_text, int): + return doc_to_text + elif isinstance(doc_to_text, str): + if doc_to_text in self.features: + # if self.config.doc_to_choice is not None: + # return self.doc_to_choice(doc)[doc[doc_to_text]] + # else: + return doc[doc_to_text] + else: + text_string = utils.apply_template(doc_to_text, doc) + if text_string.isdigit() and self._config.doc_to_choice is not None: + return ast.literal_eval(text_string) + else: + return text_string + elif callable(doc_to_text): + return doc_to_text(doc) + # Used when applying a Promptsource template + elif hasattr(doc_to_text, "apply"): + applied_prompt = doc_to_text.apply(doc) + if len(applied_prompt) == 2: + return applied_prompt[0] + else: + eval_logger.warning("Applied prompt returns empty string") + return self.config.fewshot_delimiter + else: + print(type(doc_to_text)) + raise TypeError + + def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]: + if self.prompt is not None: + doc_to_target = self.prompt + elif doc_to_target is not None: + doc_to_target = doc_to_target + else: + doc_to_target = self.config.doc_to_target + + if isinstance(doc_to_target, int): + return doc_to_target + elif isinstance(doc_to_target, str): + if doc_to_target in self.features: + # if self.config.doc_to_choice is not None: + # return self.doc_to_choice(doc)[doc[doc_to_target]] + # else: + return doc[doc_to_target] + else: + target_string = utils.apply_template(doc_to_target, doc) + if target_string.isdigit() and self._config.doc_to_choice is not None: + return ast.literal_eval(target_string) + elif ( + len(target_string) >= 2 + and (target_string[0] == "[") + and (target_string[-1] == "]") + ): + try: + return ast.literal_eval(target_string) + except (SyntaxError, ValueError): + return target_string + else: + return target_string + elif isinstance(doc_to_target, list): + return doc_to_target + elif callable(doc_to_target): + return doc_to_target(doc) + # Used when applying a Promptsource template + elif hasattr(doc_to_target, "apply"): + applied_prompt = doc_to_target.apply(doc) + if len(applied_prompt) == 2: + return applied_prompt[1] + else: + eval_logger.warning("Applied prompt returns empty string") + return self.config.fewshot_delimiter + else: + raise TypeError + + def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]: + if self.prompt is not None: + doc_to_choice = self.prompt + elif doc_to_choice is not None: + doc_to_choice = doc_to_choice + elif self.config.doc_to_choice is None: + eval_logger.error("doc_to_choice was called but not set in config") + else: + doc_to_choice = self.config.doc_to_choice + + if isinstance(doc_to_choice, str): + if doc_to_choice in self.features: + return doc[doc_to_choice] + else: + return ast.literal_eval(utils.apply_template(doc_to_choice, doc)) + elif isinstance(doc_to_choice, list): + return doc_to_choice + elif isinstance(doc_to_choice, dict): + return list(doc_to_choice.values()) + elif callable(doc_to_choice): + return doc_to_choice(doc) + elif hasattr(doc_to_choice, "get_answer_choices_list"): + return doc_to_choice.get_answer_choices_list(doc) + else: + raise TypeError + + def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]: + if doc_to_image is not None: + doc_to_image = doc_to_image + elif self.config.doc_to_image is not None: + doc_to_image = self.config.doc_to_image + else: + return None + + if isinstance(doc_to_image, list): + image_feature = [ + self.doc_to_image(doc, feature) for feature in doc_to_image + ] + return [feature for feature in image_feature if feature is not None] + elif isinstance(doc_to_image, str): + if doc_to_image in self.features: + return doc[doc_to_image] + else: + return ast.literal_eval(utils.apply_template(doc_to_image, doc)) + elif callable(doc_to_image): + return doc_to_image(doc) + else: + return None + + def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]: + if doc_to_audio is not None: + doc_to_audio = doc_to_audio + elif self.config.doc_to_audio is not None: + doc_to_audio = self.config.doc_to_audio + else: + return None + + if isinstance(doc_to_audio, list): + audio_feature = [ + self.doc_to_audio(doc, feature) for feature in doc_to_audio + ] + return [feature for feature in audio_feature if feature is not None] + elif isinstance(doc_to_audio, str): + if doc_to_audio in self.features: + return doc[doc_to_audio] + else: + return ast.literal_eval(utils.apply_template(doc_to_audio, doc)) + elif callable(doc_to_audio): + return doc_to_audio(doc) + else: + return None + + def doc_to_prefix(self, doc): + if (gen_prefix := self.config.gen_prefix) is not None: + if gen_prefix in self.features: + return doc[gen_prefix] + else: + return utils.apply_template(gen_prefix, doc) + return None + + def construct_requests( + self, doc: dict, ctx: str, **kwargs + ) -> Union[List[Instance], Instance]: + apply_chat_template = kwargs.pop("apply_chat_template", False) + chat_template: Callable | None = kwargs.pop("chat_template", None) + + aux_arguments = None + + if self.OUTPUT_TYPE == "loglikelihood": + arguments = (ctx, self.doc_to_target(doc)) + elif self.OUTPUT_TYPE == "loglikelihood_rolling": + arguments = (self.doc_to_target(doc),) + elif self.OUTPUT_TYPE == "multiple_choice": + choices = self.doc_to_choice(doc) + target_delimiter = self.config.target_delimiter + if apply_chat_template: + target_delimiter = "" + if self.multiple_input: + # If there are multiple inputs, choices are placed in the ctx + # apply chat_template to choices if apply_chat_template + cont = self.doc_to_target(doc) + + arguments = [ + ( + ctx + + ( + chat_template([{"role": "user", "content": choice}]) + if apply_chat_template + else choice + ), + f"{target_delimiter}{cont}", + ) + for choice in choices + ] + else: + # Otherwise they are placed in the continuation + arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] + + # TODO: we should raise a warning telling users this will at most ~2x runtime. + if "acc_mutual_info" in self._metric_fn_list.keys(): + # if we are calculating multiple choice accuracy + # using mutual information instead of raw loglikelihood as metric, need unconditional lls. + + # here mutual info refers to calculating + # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice)) + # in other words normalizing by subtracting the unconditional logprob of each choice. + aux_arguments = [("", f"{choice}") for choice in choices] + + arguments.extend(aux_arguments) + + elif self.OUTPUT_TYPE == "generate_until": + arguments = (ctx, deepcopy(self.config.generation_kwargs)) + + multimodal_arg = {} + if ( + self.config.doc_to_image + ): # TODO: ensure that non-multimodal tasks aren't getting visual args + multimodal_arg = { + **multimodal_arg, + **{"visual": self.doc_to_image(doc)}, + } + + if ( + self.config.doc_to_audio + ): # TODO: ensure that non-multimodal tasks aren't getting audio args + multimodal_arg = { + **multimodal_arg, + **{"audio": self.doc_to_audio(doc)}, + } + + if bool(multimodal_arg): + if isinstance(arguments, list): + arguments = [arg + (multimodal_arg,) for arg in arguments] + else: + arguments = arguments + (multimodal_arg,) + + if self.OUTPUT_TYPE == "multiple_choice": + request_list = [ + Instance( + request_type="loglikelihood", + doc=doc, + arguments=arg, + idx=i, + **kwargs, + ) + for i, arg in enumerate(arguments) + ] + + return request_list + + return Instance( + request_type=self.OUTPUT_TYPE, + doc=doc, + arguments=arguments, + idx=0, + **kwargs, + ) + + def process_results(self, doc, results): + if callable(self.config.process_results): + return self.config.process_results(doc, results) + + result_dict = {} + use_metric = list(self._metric_fn_list.keys()) + if self.OUTPUT_TYPE == "loglikelihood": + results = results[0] + ll, is_greedy = results + return { + **({"perplexity": ll} if "perplexity" in use_metric else {}), + **({"acc": int(is_greedy)} if "acc" in use_metric else {}), + } + elif self.OUTPUT_TYPE == "loglikelihood_rolling": + (loglikelihood,) = results + _words = self.count_words(self.doc_to_target(doc)) + _bytes = self.count_bytes(self.doc_to_target(doc)) + return { + **( + {"word_perplexity": (loglikelihood, _words)} + if "word_perplexity" in use_metric + else {} + ), + **( + {"byte_perplexity": (loglikelihood, _bytes)} + if "byte_perplexity" in use_metric + else {} + ), + **( + {"bits_per_byte": (loglikelihood, _bytes)} + if "bits_per_byte" in use_metric + else {} + ), + } + elif self.OUTPUT_TYPE == "multiple_choice": + lls, is_greedy = zip(*results) + + # retrieve choices in List[str] form, to compute choice lengths, etc. + choices = self.doc_to_choice(doc) + completion_len = np.array([float(len(i)) for i in choices]) + + if ( + 2 * len(choices) == len(lls) + and "acc_mutual_info" in self._metric_fn_list.keys() + ): + # then we are doing mutual info. + # this stores the "dryrun" / unconditional answer loglikelihoods + lls_unconditional = lls[1::2] + if len(lls_unconditional) != len(choices): + raise ValueError + # and this stores our "regular" conditional loglikelihoods + lls = lls[::2] + + pred = np.argmax(lls) + pred_norm = np.argmax(lls / completion_len) + + if self.multiple_input: + gold = self.doc_to_text(doc) + else: + gold = self.doc_to_target(doc) + + gold_index_error = False + if isinstance(gold, list): + gold = [i if i < len(choices) else -100 for i in gold] + if -100 in gold: + gold_index_error = True + else: + if isinstance(gold, int): + gold = gold if gold < len(choices) else -100 + elif isinstance(gold, str): + gold = choices.index(gold) if gold in choices else -100 + + if gold == -100: + gold_index_error = True + + if gold_index_error: + eval_logger.warning( + f"Label index was not in within range of available choices," + f"Sample:\n\n{doc}\n\n" + ) + + if self.multiple_target: + acc = 1.0 if pred in gold else 0.0 + acc_norm = 1.0 if pred_norm in gold else 0.0 + exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold])) + else: + acc = 1.0 if pred == gold else 0.0 + acc_norm = 1.0 if pred_norm == gold else 0.0 + # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly + exact_match = int(is_greedy[gold]) if gold != -100 else 0 + + prob_norm = utils.softmax(lls) + + # TODO use keyword arguments to the metric? + # gold, pred, norm stuff, the original lls, + result_dict = { + **({"acc": acc} if "acc" in use_metric else {}), + **({"f1": (gold, pred)} if "f1" in use_metric else {}), + **({"mcc": (gold, pred)} if "mcc" in use_metric else {}), + **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), + **({"exact_match": exact_match} if "exact_match" in use_metric else {}), + **( + {"brier_score": (gold, prob_norm)} + if "brier_score" in use_metric + else {} + ), + } + + if "acc_mutual_info" in use_metric: + lls_mutual_info = [ + ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional) + ] + acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 + result_dict["acc_mutual_info"] = acc_mutual_info + + elif self.OUTPUT_TYPE == "generate_until": + gold = self.doc_to_target(doc) + result = results[0] + if self.config.doc_to_choice is not None: + # If you set doc_to_choice, + # it assumes that doc_to_target returns a number. + choices = self.doc_to_choice(doc) + gold = choices[gold] + # we expect multiple_targets to be a list. + elif self.multiple_target: + gold = list(gold) + # TODO: handle this better + elif type(gold) is not type(result) and not ( + "bypass" in self._metric_fn_list.keys() or isinstance(result, list) + ): + # cast gold to the same type as result + gold = type(result)(gold) + + for metric in self._metric_fn_list.keys(): + if self.multiple_target: + # in the case where we have multiple targets, + # return true if any are true + # TODO: this may break for multipLe_target, non zero-or-1 metrics + scores = [] + if not isinstance(gold, list): + # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer + # print(gold) + gold = [gold] + if metric == "exact_match": + result = [result for _ in range(len(gold))] + scores = self._metric_fn_list[metric]( + references=gold, + predictions=result, + **self._metric_fn_kwargs[metric], + )[metric] + result_score = 1.0 if scores > 0.0 else 0.0 + else: + for gold_option in gold: + try: + result_score = self._metric_fn_list[metric]( + references=[gold_option], + predictions=[result], + **self._metric_fn_kwargs[metric], + ) + except ( + TypeError + ): # TODO: this is hacky and I don't want to do it + result_score = self._metric_fn_list[metric]( + [gold_option, result] + ) + if isinstance(result_score, dict): + # TODO: this handles the case where HF evaluate returns a dict. + result_score = result_score[metric] + scores.append(result_score) + if any(scores): + result_score = 1.0 + else: + result_score = 0.0 + else: + try: + result_score = self._metric_fn_list[metric]( + references=[gold], + predictions=[result], + **self._metric_fn_kwargs[metric], + ) + except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics + result_score = self._metric_fn_list[metric]([gold, result]) + if isinstance(result_score, dict): + # TODO: this handles the case where HF evaluate returns a dict. + # This allows for multiple metrics to be returned from the same function + for k, v in result_score.items(): + result_dict[k] = v + else: + result_dict[metric] = result_score + else: + raise ValueError( + f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", + "'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'", + ) + + return result_dict + + def aggregation(self) -> dict: + return self._aggregation_list + + def higher_is_better(self) -> dict: + return self._higher_is_better + + def get_config(self, key: str) -> Any: + return getattr(self._config, key, None) + + @property + def task_name(self) -> Any: + return getattr(self.config, "task", None) + + def __repr__(self): + return ( + f"ConfigurableTask(task_name={getattr(self.config, 'task', None)}," + f"output_type={self.OUTPUT_TYPE}," + f"num_fewshot={getattr(self.config, 'num_fewshot', None)}," + f"num_samples={len(self.eval_docs)})" + ) + + +class MultipleChoiceTask(Task): + OUTPUT_TYPE = "loglikelihood" + + def doc_to_target(self, doc: dict) -> str: + return " " + doc["choices"][doc["gold"]] + + def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]: + # TODO: add mutual info here? + return [ + Instance( + request_type="loglikelihood", + doc=doc, + arguments=(ctx, " {}".format(choice)), + idx=i, + **kwargs, + ) + for i, choice in enumerate(doc["choices"]) + ] + + def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict: + results = [ + res[0] for res in results + ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere? + gold = doc["gold"] + + acc = 1.0 if np.argmax(results) == gold else 0.0 + completion_len = np.array([float(len(i)) for i in doc["choices"]]) + acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0 + + return { + "acc": acc, + "acc_norm": acc_norm, + } + + def higher_is_better(self) -> dict: + return { + "acc": True, + "acc_norm": True, + } + + def aggregation(self) -> dict: + return { + "acc": mean, + "acc_norm": mean, + } + + +class PerplexityTask(Task): + OUTPUT_TYPE = "loglikelihood_rolling" + + def has_training_docs(self) -> bool: + return False + + def fewshot_examples(self, k: int, rnd) -> List: + if k != 0: + raise ValueError( + "The number of fewshot examples must be 0 for perplexity tasks." + ) + return [] + + def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]: + if num_fewshot != 0: + raise ValueError( + "The number of fewshot examples must be 0 for perplexity tasks." + ) + + return "" + + def higher_is_better(self) -> dict: + return { + "word_perplexity": False, + "byte_perplexity": False, + "bits_per_byte": False, + } + + def doc_to_decontamination_query(self, doc): + return doc + + def doc_to_text(self, doc) -> str: + return "" + + def doc_to_target(self, doc): + return doc + + def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs): + if bool(ctx): + raise ValueError + + return Instance( + request_type=self.OUTPUT_TYPE, + doc=doc, + arguments=(self.doc_to_target(doc),), + idx=0, + **kwargs, + ) + + def process_results(self, doc: dict, results: Tuple[float]) -> dict: + (loglikelihood,) = results + words = self.count_words(self.doc_to_target(doc)) + bytes_ = self.count_bytes(self.doc_to_target(doc)) + return { + "word_perplexity": (loglikelihood, words), + "byte_perplexity": (loglikelihood, bytes_), + "bits_per_byte": (loglikelihood, bytes_), + } + + def aggregation(self) -> dict: + return { + "word_perplexity": weighted_perplexity, + "byte_perplexity": weighted_perplexity, + "bits_per_byte": bits_per_byte, + } + + @classmethod + def count_bytes(cls, doc) -> int: + return len(doc.encode("utf-8")) + + @classmethod + def count_words(cls, doc) -> int: + """Downstream tasks with custom word boundaries should override this!""" + return len(re.split(r"\s+", doc)) diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/caching/__init__.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/caching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/caching/cache.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/caching/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d293b0ff8b1ebac186f5ac078cdb49227562db --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/caching/cache.py @@ -0,0 +1,59 @@ +import hashlib +import logging +import os + +import dill + + +eval_logger = logging.getLogger(__name__) + + +MODULE_DIR = os.path.dirname(os.path.realpath(__file__)) + +OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH") + + +PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache" + +# This should be sufficient for uniqueness +HASH_INPUT = "EleutherAI-lm-evaluation-harness" + +HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest() + +FILE_SUFFIX = f".{HASH_PREFIX}.pickle" + + +def load_from_cache(file_name: str, cache: bool = False): + if not cache: + return + try: + path = f"{PATH}/{file_name}{FILE_SUFFIX}" + + with open(path, "rb") as file: + cached_task_dict = dill.loads(file.read()) + return cached_task_dict + + except Exception: + eval_logger.debug(f"{file_name} is not cached, generating...") + pass + + +def save_to_cache(file_name, obj): + if not os.path.exists(PATH): + os.mkdir(PATH) + + file_path = f"{PATH}/{file_name}{FILE_SUFFIX}" + + eval_logger.debug(f"Saving {file_path} to cache...") + with open(file_path, "wb") as file: + file.write(dill.dumps(obj)) + + +# NOTE the "key" param is to allow for flexibility +def delete_cache(key: str = ""): + files = os.listdir(PATH) + + for file in files: + if file.startswith(key) and file.endswith(FILE_SUFFIX): + file_path = f"{PATH}/{file}" + os.unlink(file_path) diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/__init__.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/archiver.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/archiver.py new file mode 100644 index 0000000000000000000000000000000000000000..c132232116c2ae5f5ab1dc3a2a0afc0dbd4ef1bd --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/archiver.py @@ -0,0 +1,174 @@ +import datetime +import io +import json +import mmap +import os +from pathlib import Path +from typing import Any + +import jsonlines +import tqdm +import zstandard + + +def json_serial(obj: Any) -> str: + """JSON serializer for objects not serializable by default json code""" + + if isinstance(obj, (datetime.datetime,)): + return obj.isoformat() + raise TypeError("Type %s not serializable" % type(obj)) + + +# Modified version of lm_dataformat Archive for single file. +class Archive: + def __init__(self, file_path: str, compression_level: int = 3) -> None: + self.file_path = file_path + dir_name = os.path.dirname(file_path) + if dir_name: + os.makedirs(dir_name, exist_ok=True) + self.fh = open(self.file_path, "wb") + self.cctx = zstandard.ZstdCompressor(level=compression_level) + self.compressor = self.cctx.stream_writer(self.fh) + + def add_data(self, data, meta=None) -> None: + if meta is None: + meta = {} + self.compressor.write( + json.dumps({"text": data, "meta": meta}, default=json_serial).encode( + "UTF-8" + ) + + b"\n" + ) + + def commit(self) -> None: + self.compressor.flush(zstandard.FLUSH_FRAME) + self.fh.flush() + self.fh.close() + + +# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm. +class Reader: + def __init__(self) -> None: + pass + + def read( + self, + file, + get_meta: bool = False, + autojoin_paragraphs: bool = True, + para_joiner: str = "\n\n", + ): + with open(file, "rb") as fh: + self.fh = fh + cctx = zstandard.ZstdDecompressor() + reader = io.BufferedReader(cctx.stream_reader(fh)) + rdr = jsonlines.Reader(reader) + for ob in rdr: + # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility. + if isinstance(ob, str): + assert not get_meta + yield ob + continue + + text = ob["text"] + + if autojoin_paragraphs and isinstance(text, list): + text = para_joiner.join(text) + + if get_meta: + yield text, (ob["meta"] if "meta" in ob else {}) + else: + yield text + + +class TextArchive: + def __init__(self, file_path, mode: str = "rb+") -> None: + self.file_path = file_path + dir_name = os.path.dirname(file_path) + if dir_name: + os.makedirs(dir_name, exist_ok=True) + + if not os.path.exists(file_path): + Path(file_path).touch() + + self.fh = open(self.file_path, mode) + + def add_data(self, data) -> None: + self.fh.write(data.encode("UTF-8") + b"\n") + + def commit(self) -> None: + self.fh.flush() + self.fh.close() + + +class TextReader: + def __init__(self, file_path) -> None: + self.file_path = file_path + + # Optimized mmap read with infrequent tqdm updates to maintain speed + # Tested up to 250MB/s. + def read_tqdm(self, update_frequency: int = 10000): + current_file_position = 0 + line_counter = 0 + with ( + open(self.file_path, "r", encoding="utf-8") as fh, + tqdm.tqdm( + total=os.path.getsize(self.file_path), + dynamic_ncols=True, + unit="byte", + unit_scale=1, + ) as progress, + ): + with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: + for line in iter(mmap_obj.readline, b""): + line = line.decode("utf-8") + line_counter += 1 + if line_counter == update_frequency: + new_file_pos = mmap_obj.tell() + bytes_read = new_file_pos - current_file_position + current_file_position = new_file_pos + progress.update(bytes_read) + line_counter = 0 + yield line[:-1] + + def read_and_tell(self): + current_file_position = 0 + with open(self.file_path, "r", encoding="utf8") as fh: + with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: + for line in iter(mmap_obj.readline, b""): + line = line.decode("utf-8") + new_file_pos = mmap_obj.tell() + raw_bytes_read = new_file_pos - current_file_position + current_file_position = new_file_pos + yield line[:-1], raw_bytes_read + + def read(self): + with open(self.file_path, "r", encoding="utf8") as fh: + with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: + for line in iter(mmap_obj.readline, b""): + line = line.decode("utf-8") + yield line[:-1] + + def read_slow(self): + with open(self.file_path, "r", encoding="utf8") as fh: + while True: + line = fh.readline() + if line == -1 or line == "": + break + else: + yield line[:-1] + + +# Optimized for speed. Decompresses the archive in shell before +# using the mmap'd TextReader. +class ZStdTextReader: + def __init__(self, file) -> None: + self.file = file + + def read_tqdm(self): + decompressed_file = self.file[:-4] + print("Decompressing file, please wait...") + os.system(f"zstd -d {self.file}") # linux decompress is faster + reader = TextReader(decompressed_file) + yield from reader.read_tqdm() + os.remove(decompressed_file) diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/decontaminate.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/decontaminate.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1250d39bf7cd0272e412452d970ec7c52992c5 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/decontaminate.py @@ -0,0 +1,166 @@ +import collections +import glob +import json +import os +import pickle +import random +import time + +from .archiver import ZStdTextReader +from .janitor import Janitor, word_ngrams + + +# Was used for testing the evaluator decoupled from the full logic below +def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str): + simulated_overlap = 0.1 + contaminated = int(len(docs) * simulated_overlap) + return random.sample(range(len(docs)), contaminated) + + +# Returns a dictionary containing all overlapping documents in each +# task. In the standard use case, an overlap occurs when any of the 13-grams +# found in the task document exist in the training set documents. +# +# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these +# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst" +# files. These should exist in the "ngrams_path" provided to this function. + + +# Algorithm: +# 1. Build lookups for each dataset {ngram: list(document_ids)} +# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]} +# 3. Full scan the 13-grams from the training set against the merged lookup, +# saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)} +# 4. Strip the task_set from the dictionary keys and return +# +# We cache the task+set lookups as well as the overlaps. +def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict: + # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size) + + info_dict_path = os.path.join(ngrams_path, "info.json") + info_dict = json.load(open(info_dict_path, "r", encoding="utf-8")) + ngrams_n_size = info_dict["ngram_size"] + + janitor = Janitor() + + # Build lookup for each dataset first in case we use different task combinations later + print("Building Lookups...") + start = time.perf_counter() + + def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str: + return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps" + + lookups = {} + duplicates = {} # (task_name, task_set): set(doc_ids)} + sets_to_decontaminate = len(docs_by_task_set.keys()) + + for (task_name, task_set), docs in docs_by_task_set.items(): + if not os.path.exists(f"data/{task_name}"): + os.mkdir(f"data/{task_name}") + + # Check if we've decontaminated this combination before + overlaps_dump_path = get_overlaps_dump_path( + task_name, task_set, ngrams_n_size, limit + ) + if os.path.exists(overlaps_dump_path): + duplicates[(task_name, task_set)] = pickle.load( + open(overlaps_dump_path, "rb") + ) + sets_to_decontaminate -= 1 + continue + else: + duplicates[(task_name, task_set)] = set() + + # Build/load the task lookup {ngram: set(documents)}. + task_set_lookup_path = ( + f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup" + ) + if os.path.exists(task_set_lookup_path): + print(f"{task_set_lookup_path} available, loading...") + lookups[(task_name, task_set)] = pickle.load( + open(task_set_lookup_path, "rb") + ) + else: + print(f"{task_set_lookup_path} not available, building...") + lookup = collections.defaultdict(set) + + for doc_id, document in enumerate(docs): + ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size) + for ngram in ngrams: + lookup[ngram].add(doc_id) + + pickle.dump(lookup, open(task_set_lookup_path, "wb")) + lookups[(task_name, task_set)] = lookup + + elapsed = time.perf_counter() - start + print(f"Building lookups took {elapsed:0.5f} seconds.") + + matched_ngrams = [] + + if sets_to_decontaminate > 0: + print("Merging lookups...") + start = time.perf_counter() + merged_lookup = collections.defaultdict(list) + for (task_name, task_set), lookup in lookups.items(): + for ngram, doc_ids in lookup.items(): + merged_lookup[ngram].append((task_name, task_set, doc_ids)) + + elapsed = time.perf_counter() - start + print(f"Merging lookups took {elapsed:0.5f} seconds.") + + print(f"{ngrams_n_size} grams files found in {ngrams_path}:") + files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst")) + print(files) + + for file in files: + start = time.perf_counter() + print(f"Scanning {file}") + reader = ZStdTextReader(file) + total_ngrams = 0 + unique_ngrams = 0 + matching_unique = 0 + non_matching_unique = 0 + + current_ngram = "" + for line in reader.read_tqdm(): # Scan training set ngrams file + total_ngrams += 1 + [ngram, document_id] = line.rsplit(" ", 1) + if ( + ngram != current_ngram + ): # Only need to match the ngram once in training set + unique_ngrams += 1 + current_ngram = ngram + if ngram in merged_lookup: + matched_ngrams.append(ngram) # For logging + matching_unique += 1 + for task_name, task_set, doc_ids in merged_lookup[ngram]: + task_doc_set = duplicates[(task_name, task_set)] + for doc_id in doc_ids: # Record contamination across all relevant task/set combos + task_doc_set.add(doc_id) + del merged_lookup[ngram] # No point matching again + else: + non_matching_unique += 1 + + print(f"Total Ngrams: {total_ngrams}") + print(f"Unique Ngrams: {unique_ngrams}") + print(f"Unique Matching: {matching_unique}") + print(f"Unique Non Matching: {non_matching_unique}") + print("Matched ngrams:") + for ngram in matched_ngrams: + print(ngram) + + elapsed = time.perf_counter() - start + print(f"Read took {elapsed:0.5f} seconds.") + print(f"Speed: {(os.path.getsize(file) / 1000000.0) / elapsed}MB/second") + + print(duplicates) + + # Dump overlaps separately + for (task_name, task_set), doc_ids in duplicates.items(): + overlaps_dump_path = get_overlaps_dump_path( + task_name, task_set, ngrams_n_size, limit + ) + pickle.dump(doc_ids, open(overlaps_dump_path, "wb")) + + # Strip task set and return + return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()} diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/janitor.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/janitor.py new file mode 100644 index 0000000000000000000000000000000000000000..cedf8a5717aa8156674836ba236fdcabf36e0487 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/decontamination/janitor.py @@ -0,0 +1,328 @@ +import pickle +import re +import string +import traceback +from typing import Iterator, List, Sequence, Tuple, TypeVar + + +# This is a cpp module. Compile janitor_util.cpp with: +# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup +try: + import janitor_util + + JANITOR_CPP = True +except Exception: + print("WARNING: C++ module could not be loaded. Janitor running in python mode") + traceback.print_exc() + JANITOR_CPP = False + +T = TypeVar("T") + + +# Implementation from nltk source +# https://www.nltk.org/_modules/nltk/util.html +def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]: + history = [] + while n > 1: + # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator + try: + next_item = next(sequence) + except StopIteration: + # no more data, terminate the generator + return + history.append(next_item) + n -= 1 + for item in sequence: + history.append(item) + yield tuple(history) + del history[0] + + +def word_ngrams(s: str, n: int) -> Iterator[str]: + """Splits a string into ngram words""" + tokens = s.split() # not a generator :( + ngram_seqs = form_ngrams(iter(tokens), n) + return (" ".join(ngram) for ngram in ngram_seqs) + + +# Does character sequences only - combined faster function to play around with later +# def word_ngrams_indices_combined(sequence, n): +# current_word = "" +# history = [] +# gap = False; +# start = 0 +# end = 0 +# for character in sequence: +# if character == " ": +# if not gap: +# gap = True +# history.append(current_word) +# end += len(current_word) - 1 +# current_word = "" +# if len(history) == n: +# yield (tuple(history), start, end) +# del history[0] +# start = end + 1 +# end = start +# else: +# gap = False +# current_word += character + + +# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python +def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]: + """Splits a string on whitespaces and records the indices of each in the original string. + @:return generator((word, (start_idx, end_idx)), ...) + """ + return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s)) + + +def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]: + """Splits a string into pairs of (ngram words, their start/end indices)""" + tokens_with_indices = split_indices(s) + + # Generator of ngrams of (word, idx_pairs) + # ( + # [(word, (start,end)), (word, (start, end))...], + # [(word, (start, end)), ...], + # ... + # ) + ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n) + + # Generator of pairs of word and index ngrams + # ( + # ([word, word, ...], [(start,end), (start,end), ...]), + # ... + # ) + ngram_indices_pairs = ( + zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices + ) + + # Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...) + return ( + (" ".join(ngram_seq), (indices[0][0], indices[-1][1])) + for ngram_seq, indices in ngram_indices_pairs + ) + + +class Janitor: + # FIXME delete_chars: Should anything else go here? Special chars? + def __init__( + self, + ngram_n: int = 13, + window_to_remove: int = 200, + too_dirty_cutoff: int = 10, + minimum_slice_length: int = 200, + delete_chars: str = string.punctuation, + ) -> None: + self.ngram_n = ngram_n + self.window_to_remove = window_to_remove + self.too_dirty_cutoff = too_dirty_cutoff + self.minimum_slice_length = minimum_slice_length + self.delete_chars = delete_chars + + self.dirt_ngrams = set() + + # If in python, we'll translate uppercase to lowercase and delete naughty characters. + # This is fast by python standards + # https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st + self.translation_table = str.maketrans( + string.ascii_lowercase + string.ascii_uppercase, # These characters + string.ascii_lowercase * 2, # Become these characters + self.delete_chars, # These are deleted + ) + + ############## + # I/O for saving contamination ngrams + ############## + + def save_contamination_ngrams(self, filename: str) -> None: + with open(filename, "wb") as fp: + pickle.dump(filename, fp) + + def load_contamination_ngrams(self, filename: str) -> None: + with open(filename, "rb") as fp: + self.dirt_ngrams = pickle.load(fp) + + ############## + # Call these :) + ############## + + def register_contaminant(self, dirt_string: str) -> None: + """Register a string as contamination to be removed, e.g. a test set + This breaks the dirt_string into ngrams to store for future cleaning""" + if JANITOR_CPP: + return self.register_contaminant_cpp(dirt_string) + else: + print("WARNING: Janitor running in python mode") + return self.register_contaminant_python(dirt_string) + + def clean(self, dirty_string: str) -> List[str]: + """Clean a string (e.g. a training set) by removing all ngrams previously + registered as contaminants. Returns a list of clean chunks, or empty if + the string was too dirty""" + if JANITOR_CPP: + return self.clean_cpp(dirty_string) + else: + print("WARNING: Janitor running in python mode") + return self.clean_python(dirty_string) + + def _split_chunks( + self, dirty_string: str, dirty_parts: Sequence[Tuple] + ) -> List[str]: + clean_chunks = [] + splice_idx = 0 + end = -1 + for i, (ngram, start, end) in enumerate(dirty_parts): + if i >= self.too_dirty_cutoff: + return [] + start = max(0, start - self.window_to_remove) + end = min(len(dirty_string), end + self.window_to_remove) + + if start - splice_idx > self.minimum_slice_length: + clean_chunks.append(dirty_string[splice_idx:start]) + splice_idx = end + + if end < len(dirty_string) - self.minimum_slice_length: + clean_chunks.append(dirty_string[end + 1 :]) + + return clean_chunks + + ############## + # Fast C++ + ############## + + def register_contaminant_cpp(self, dirt_string) -> None: + self.dirt_ngrams.update( + janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n) + ) + + def clean_cpp(self, dirty_string: str) -> List[str]: + contamination_indices = janitor_util.clean_ngram_with_indices( + dirty_string, self.delete_chars, self.ngram_n + ) + return self._split_chunks(dirty_string, contamination_indices) + + ############## + # Slow python + ############## + + def normalize_string(self, s: str) -> str: + return s.translate(self.translation_table) + + def register_contaminant_python(self, dirt_string: str) -> None: + self.dirt_ngrams.update( + word_ngrams(self.normalize_string(dirt_string), self.ngram_n) + ) + + def clean_python(self, dirty_string: str) -> List[str]: + contamination_indices = ( + (None, *idx_pair) + for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n) + if self.normalize_string(dirty_ngram) in self.dirt_ngrams + ) + return self._split_chunks(dirty_string, contamination_indices) + + +################################################################## +# Tests +################################################################# + +# def print_cpp(): +# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2 + +# for i in range(1, 10, 2): +# pprint(janitor_util.clean_ngram(source, string.punctuation, i)) +# for ngram, start, end in \ +# janitor_util.clean_ngram_with_indices(source, string.punctuation, i): +# print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n")) + + +# def test_cpp(): +# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2 +# contaminant = "dirty boy. Clean he he" + +# jan_python = Janitor() +# jan_cpp = Janitor() + +# jan_python.register_contaminant_python(contaminant) +# jan_cpp.register_contaminant(contaminant) + +# assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams) + +# assert jan_python.clean_python(source) == jan_cpp.clean(source), \ +# (jan_python.clean_python(source), jan_cpp.clean(source)) + +# print("Passed test, python==cpp") + + +# def benchmark(): +# # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html +# setup = \ +# """ +# with open("data/enwik8", "r") as f: +# data = f.read() +# jan = Janitor(too_dirty_cutoff=1000) +# jan.register_contaminant(''' +# theories is that there is a connection between "geekdom" and autism. +# This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled " +# The [[Geek]] Syndrome", which is a point argued by many in the autism rights +# movement{{ref|Wired}}. This article, many professionals assert, is just one example of +# the media's application of mental disease labels to what is actually variant normal behavior +# &mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual +# interests, even when they seem unusual to others, are not in themselves signs of autism or +# Asperger's syndrome. Others assert that it is actually the medical profession which is applying +# mental disease labels to children who in the past would have simply been accepted as a little +# different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue. +# Due to the recent publicity surrounding autism and autis +# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first, +# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first +# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties +# would last, took a cautious approach, preferring to save the revenue rather than investing it in +# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential +# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his +# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]], +# with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M, +# ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995), +# ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the +# Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the +# [[United Arab Emirates]]. After the Emirates gained independence in 1971, +# ''') +# """ + +# n = 1 +# print(f"Timing {n} run on 100 MB") +# print("Register contaminant") +# # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n)) +# print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n)) + +# print("Clean") +# # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n)) +# print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n)) + + +# def test_janitor_general(): +# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2 +# contaminant = "dirty boy. Clean he he" + +# jan = Janitor(ngram_n=3) +# jan.register_contaminant(contaminant) +# cleaned = " ".join(jan.clean(source)) +# for contam in jan.dirt_ngrams: +# assert contam not in cleaned, contam + +# filename = "data/saved_contam" +# jan.save_contamination_ngrams(filename) + +# jan = Janitor(ngram_n=3) +# jan.load_contamination_ngrams(filename) +# cleaned = " ".join(jan.clean(source)) +# for contam in jan.dirt_ngrams: +# assert contam not in cleaned, contam + + +# if __name__ == "__main__": +# test() +# # print_cpp() +# # test_cpp() +# # benchmark() diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..0d26a64c6f999690e8810419ae930a3cd74a34f5 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator.py @@ -0,0 +1,736 @@ +import itertools +import json +import logging +import random +import time +from collections import defaultdict +from typing import TYPE_CHECKING, List, Optional, Union + +import numpy as np +import torch + +import lm_eval.api.metrics +import lm_eval.api.registry +import lm_eval.api.task +import lm_eval.models +from lm_eval.caching.cache import delete_cache +from lm_eval.evaluator_utils import ( + consolidate_group_results, + consolidate_results, + get_sample_size, + get_subtask_list, + get_task_list, + prepare_print_tasks, + print_writeout, + run_task_tests, +) +from lm_eval.loggers import EvaluationTracker +from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash +from lm_eval.tasks import ( + TaskManager, + get_task_dict, +) +from lm_eval.utils import ( + handle_non_serializable, + hash_string, + positional_deprecated, + setup_logging, + simple_parse_args_string, +) + + +if TYPE_CHECKING: + from lm_eval.api.model import LM + from lm_eval.api.task import Task + +eval_logger = logging.getLogger(__name__) + + +@positional_deprecated +def simple_evaluate( + model, + model_args: Optional[Union[str, dict]] = None, + tasks: Optional[List[Union[str, dict, object]]] = None, + num_fewshot: Optional[int] = None, + batch_size: Optional[Union[int, str]] = None, + max_batch_size: Optional[int] = None, + device: Optional[str] = None, + use_cache: Optional[str] = None, + cache_requests: bool = False, + rewrite_requests_cache: bool = False, + delete_requests_cache: bool = False, + limit: Optional[Union[int, float]] = None, + bootstrap_iters: int = 100000, + check_integrity: bool = False, + write_out: bool = False, + log_samples: bool = True, + evaluation_tracker: Optional[EvaluationTracker] = None, + system_instruction: Optional[str] = None, + apply_chat_template: Union[bool, str] = False, + fewshot_as_multiturn: bool = False, + gen_kwargs: Union[str, dict, None] = None, + task_manager: Optional[TaskManager] = None, + verbosity=None, + predict_only: bool = False, + random_seed: int = 0, + numpy_random_seed: int = 1234, + torch_random_seed: int = 1234, + fewshot_random_seed: int = 1234, + confirm_run_unsafe_code: bool = False, + metadata: Optional[dict] = None, +): + """Instantiate and evaluate a model on a list of tasks. + + :param model: Union[str, LM] + Name of model or LM object, see lm_eval.models.get_model + :param model_args: Optional[str, dict] + String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object. + Ignored if `model` argument is a LM object. + :param tasks: list[Union[str, dict, Task]] + List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. + :param num_fewshot: int + Number of examples in few-shot context + :param batch_size: int or str, optional + Batch size for model + :param max_batch_size: int, optional + Maximal batch size to try with automatic batch size detection + :param device: str, optional + PyTorch device (e.g. "cpu" or "cuda:0") for running models + :param use_cache: str, optional + A path to a sqlite db file for caching model responses. `None` if not caching. + :param cache_requests: bool, optional + Speed up evaluation by caching the building of dataset requests. `None` if not caching. + :param rewrite_requests_cache: bool, optional + Rewrites all the request cache if set to `True`. `None` if not desired. + :param delete_requests_cache: bool, optional + Deletes all the request cache if set to `True`. `None` if not desired. + :param limit: int or float, optional + Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples. + :param bootstrap_iters: + Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed. + :param check_integrity: bool + Whether to run the relevant part of the test suite for the tasks + :param write_out: bool + If True, write out an example document and model input for checking task integrity + :param log_samples: bool + If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis + :param system_instruction: str + System instruction to be applied to the prompt + :param apply_chat_template: Union[bool, str] + Specifies whether to apply a chat template to the prompt. + - If set to True, the default chat template is applied. + - If set to a string, applies the specified chat template by name. + Defaults to False (no chat template applied). + :param fewshot_as_multiturn: bool + Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param gen_kwargs: dict or comma-separated string + Arguments for model generation + Ignored for all tasks with loglikelihood output_type + :param verbosity: str + Verbosity level for logging + :param predict_only: bool + If true only model outputs will be generated and returned. Metrics will not be evaluated + :param random_seed: int + Random seed for python's random module. If set to None, the seed will not be set. + :param numpy_random_seed: int + Random seed for numpy. If set to None, the seed will not be set. + :param torch_random_seed: int + Random seed for torch. If set to None, the seed will not be set. + :param fewshot_random_seed: int + Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None. + :param metadata: dict + Additional metadata to be added to the task manager. Will get passed to the download function of the task. + + return + Dictionary of results + """ + if verbosity is not None: + setup_logging(verbosity=verbosity) + start_date = time.time() + + if isinstance(model_args, str) and ( + "instruct" in model_args and not apply_chat_template + ): + eval_logger.warning( + "Instruct model detected, but chat template not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)." + ) + + if delete_requests_cache: + eval_logger.info("Deleting requests cache...") + delete_cache() + + seed_message = [] + if random_seed is not None: + # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412 + seed_message.append(f"Setting random seed to {random_seed}") + random.seed(random_seed) + + if numpy_random_seed is not None: + seed_message.append(f"Setting numpy seed to {numpy_random_seed}") + np.random.seed(numpy_random_seed) + + if torch_random_seed is not None: + seed_message.append(f"Setting torch manual seed to {torch_random_seed}") + torch.manual_seed(torch_random_seed) + + if fewshot_random_seed is not None: + seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}") + + if seed_message: + eval_logger.info(" | ".join(seed_message)) + + if tasks is None: + tasks = [] + if len(tasks) == 0: + raise ValueError( + "No tasks specified, or no tasks found. Please verify the task names." + ) + + if gen_kwargs is not None: + if isinstance(gen_kwargs, str): + gen_kwargs = simple_parse_args_string(gen_kwargs) + eval_logger.warning( + f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. " + "Ensure 'do_sample=True' for non-greedy decoding!" + ) + if not gen_kwargs: + gen_kwargs = None + + if isinstance(model, str): + if model_args is None: + eval_logger.warning("model_args not specified. Using defaults.") + model_args = "" + + if isinstance(model_args, dict): + eval_logger.info( + f"Initializing {model} model, with arguments: {model_args}" + ) + lm = lm_eval.api.registry.get_model(model).create_from_arg_obj( + model_args, + { + "batch_size": batch_size, + "max_batch_size": max_batch_size, + "device": device, + }, + ) + + else: + eval_logger.info( + f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}" + ) + lm = lm_eval.api.registry.get_model(model).create_from_arg_string( + model_args, + { + "batch_size": batch_size, + "max_batch_size": max_batch_size, + "device": device, + }, + ) + else: + if not isinstance(model, lm_eval.api.model.LM): + raise TypeError( + f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first." + ) + eval_logger.info("Using pre-initialized model") + lm = model + + if use_cache is not None: + eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}") + lm = lm_eval.api.model.CachingLM( + lm, + use_cache + # each rank receives a different cache db. + # necessary to avoid multiple writes to cache at once + + "_rank" + + str(lm.rank) + + ".db", + ) + + if task_manager is None: + metadata = ( + simple_parse_args_string(model_args) + if isinstance(model_args, str) + else model_args + if isinstance(model_args, dict) + else {} + ) | (metadata or {}) + task_manager = TaskManager(metadata=metadata) + + task_dict = get_task_dict( + tasks, + task_manager, + ) + + # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups. + # (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed) + def _adjust_config(task_dict): + adjusted_task_dict = {} + for task_name, task_obj in task_dict.items(): + if isinstance(task_obj, dict): + adjusted_task_dict = { + **adjusted_task_dict, + **{task_name: _adjust_config(task_obj)}, + } + + else: + if task_obj.get_config("output_type") == "generate_until": + if gen_kwargs is not None: + task_obj.set_config( + key="generation_kwargs", value=gen_kwargs, update=True + ) + eval_logger.info( + f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}" + ) + + if predict_only: + eval_logger.info( + f"Processing {task_name} in output-only mode. Metrics will not be calculated!" + ) + # we have to change the class properties post-hoc. This is pretty hacky. + task_obj.override_metric(metric_name="bypass") + + # override tasks' fewshot values to the provided num_fewshot arg value + # except if tasks have it set to 0 manually in their configs--then we should never overwrite that + if num_fewshot is not None: + if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0: + eval_logger.info( + f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored." + ) + else: + eval_logger.warning( + f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" + ) + task_obj.set_config(key="num_fewshot", value=num_fewshot) + else: + # if num_fewshot not provided, and the task does not define a default one, default to 0 + if ( + default_num_fewshot := task_obj.get_config("num_fewshot") + ) is None: + task_obj.set_config(key="num_fewshot", value=0) + # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file) + task_obj.set_fewshot_seed(seed=fewshot_random_seed) + + adjusted_task_dict[task_name] = task_obj + + return adjusted_task_dict + + task_dict = _adjust_config(task_dict) + + if check_integrity: + run_task_tests(task_list=tasks) + + if evaluation_tracker is not None: + evaluation_tracker.general_config_tracker.log_experiment_args( + model_source=model, + model_args=model_args, + system_instruction=system_instruction, + chat_template=lm.chat_template(apply_chat_template) + if apply_chat_template + else None, + fewshot_as_multiturn=fewshot_as_multiturn, + ) + + results = evaluate( + lm=lm, + task_dict=task_dict, + limit=limit, + cache_requests=cache_requests, + rewrite_requests_cache=rewrite_requests_cache, + bootstrap_iters=bootstrap_iters, + write_out=write_out, + log_samples=True if predict_only else log_samples, + system_instruction=system_instruction, + apply_chat_template=apply_chat_template, + fewshot_as_multiturn=fewshot_as_multiturn, + verbosity=verbosity, + confirm_run_unsafe_code=confirm_run_unsafe_code, + ) + if verbosity is not None: + setup_logging(verbosity=verbosity) + + if lm.rank == 0: + if isinstance(model, str): + model_name = model + elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"): + model_name = model.config._name_or_path + else: + model_name = type(model).__name__ + + # add info about the model and few shot config + results["config"] = { + "model": model_name, + "model_args": model_args, + } + # add more detailed model info if available + if isinstance(lm, lm_eval.models.huggingface.HFLM): + results["config"].update(lm.get_model_info()) + # add info about execution + results["config"].update( + { + "batch_size": batch_size, + "batch_sizes": ( + list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else [] + ), + "device": device, + "use_cache": use_cache, + "limit": limit, + "bootstrap_iters": bootstrap_iters, + "gen_kwargs": gen_kwargs, + "random_seed": random_seed, + "numpy_seed": numpy_random_seed, + "torch_seed": torch_random_seed, + "fewshot_seed": fewshot_random_seed, + } + ) + results["git_hash"] = get_git_commit_hash() + results["date"] = start_date + add_env_info(results) # additional environment info to results + add_tokenizer_info(results, lm) # additional info about tokenizer + return results + else: + return None + + +@positional_deprecated +def evaluate( + lm: "LM", + task_dict, + limit: Optional[int] = None, + cache_requests: bool = False, + rewrite_requests_cache: bool = False, + bootstrap_iters: Optional[int] = 100000, + write_out: bool = False, + log_samples: bool = True, + system_instruction: Optional[str] = None, + apply_chat_template: Union[bool, str] = False, + fewshot_as_multiturn: bool = False, + verbosity: str = "INFO", + confirm_run_unsafe_code: bool = False, +): + """Instantiate and evaluate a model on a list of tasks. + + :param lm: obj + Language Model + :param task_dict: dict[str, Task] + Dictionary of tasks. Tasks will be taken to have name type(task).config.task . + :param limit: int, optional + Limit the number of examples per task (only use this for testing) + :param cache_requests: bool, optional + Speed up evaluation by caching the building of dataset requests. + :param rewrite_requests_cache: bool, optional + Rewrites all the request cache if set to `True`. + :param bootstrap_iters: + Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations. + :param write_out: bool + If True, write out an example document and model input for checking task integrity + :param log_samples: bool + If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis + :param system_instruction: str + System instruction to be applied to the prompt + :param apply_chat_template: Union[bool, str] + Specifies whether to apply a chat template to the prompt. + - If set to True, the default chat template is applied. + - If set to a string, applies the specified chat template by name. + Defaults to False (no chat template applied). + :param fewshot_as_multiturn: bool + Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param verbosity: str + Verbosity level for logging + :param confirm_run_unsafe_code: bool + Whether to confirm running tasks marked as unsafe. + :return + Dictionary of results + """ + + if apply_chat_template: + eval_logger.warning( + "Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details." + ) + + # tracks all Instances/requests a model must generate output on. + requests = defaultdict(list) + # stores the amount to pad out reqs per req. type so that + # number of fwd passes per distributed rank is equal + padding_requests = defaultdict(int) + + # get lists of group hierarchy and each type of request + eval_tasks = get_task_list(task_dict) + if not log_samples: + if not all( + "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys() + for task_output in eval_tasks + ): + raise ValueError("log_samples must be True for 'bypass' metric-only tasks") + + # validation checks: + # 1.are we running multimodal task <-> non-multimodal model class, or vice-versa. + # 2.are we running code that is marked as unsafe. + incompatible_tasks = [] + for task_output in eval_tasks: + task: Task = task_output.task + + if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False): + incompatible_tasks.append(task_output.task_name) + elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code: + raise ValueError( + f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task." + ) + if len(incompatible_tasks) > 0: + if not getattr(lm, "MULTIMODAL", False): + raise ValueError( + f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type." + ) + else: + raise ValueError( + f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks." + ) + # end validation check + + # Cache the limit arg. + limit_arg = limit + limits = [] + for task_output in eval_tasks: + task: Task = task_output.task + + limit = get_sample_size(task, limit_arg) + limits.append(limit) + task.build_all_requests( + limit=limit, + rank=lm.rank, + world_size=lm.world_size, + cache_requests=cache_requests, + rewrite_requests_cache=rewrite_requests_cache, + system_instruction=system_instruction, + apply_chat_template=bool(apply_chat_template), + fewshot_as_multiturn=fewshot_as_multiturn, + chat_template=getattr(lm, "apply_chat_template") + if apply_chat_template + else None, + tokenizer_name=getattr(lm, "tokenizer_name", "") + if apply_chat_template + else "", + ) + eval_logger.debug( + f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}" + ) + if write_out: + print_writeout(task) + # aggregate Instances by LM method requested to get output. + for instance in task.instances: + reqtype = instance.request_type + requests[reqtype].append(instance) + + if lm.world_size > 1: + instances_rnk = torch.tensor(len(task._instances), device=lm.device) + gathered_item = ( + lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist() + ) + # "multiple_choice" task types dispatch (several) "loglikelihood" request types + reqtype = ( + "loglikelihood" + if task.OUTPUT_TYPE == "multiple_choice" + else task.OUTPUT_TYPE + ) + # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks) + numpad = max(gathered_item) - gathered_item[lm.rank] + # todo: may not account for padding in cases like SquadV2 which has multiple req types + padding_requests[reqtype] += numpad + + ### Run LM on inputs, get all outputs ### + # execute each type of request + for reqtype, reqs in requests.items(): + eval_logger.info(f"Running {reqtype} requests") + # create `K` copies of each request `req` based off `K = req.repeats` + cloned_reqs = [] + for req in reqs: + cloned_reqs.extend([req] * req.repeats) + + if (lm.world_size > 1) and (padding_requests[reqtype] > 0): + for _ in range(padding_requests[reqtype]): + cloned_reqs.extend([req] * req.repeats) + + # run requests through model + resps = getattr(lm, reqtype)(cloned_reqs) + + # put responses from model into a list of length K for each request. + for x, req in zip(resps, cloned_reqs): + req.resps.append(x) + + if lm.world_size > 1: + lm.accelerator.wait_for_everyone() + + RANK = lm.rank + WORLD_SIZE = lm.world_size + ### Postprocess outputs ### + # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately) + for task_output, limit in zip(eval_tasks, limits): + task = task_output.task + task.apply_filters() + + ### Collect values of metrics on all datapoints ### + # # unpack results and sort back in order and return control to Task + # TODO: make it possible to use a different metric per filter + # Pre-process task.instances to group by doc_id + instances_by_doc_id = defaultdict(list) + for instance in task.instances: + instances_by_doc_id[instance.doc_id].append(instance) + # Sort instances within each group + for instances in instances_by_doc_id.values(): + instances.sort(key=lambda x: x.idx) + # iterate over different filters used + for filter_key in task.instances[0].filtered_resps.keys(): + doc_iterator = task.doc_iterator( + rank=RANK, limit=limit, world_size=WORLD_SIZE + ) + for doc_id, doc in doc_iterator: + requests = instances_by_doc_id[doc_id] + metrics = task.process_results( + doc, [req.filtered_resps[filter_key] for req in requests] + ) + if log_samples: + target = task.doc_to_target(doc) + example = { + "doc_id": doc_id, + "doc": doc, + "target": target, + "arguments": [req.args for req in requests], + "resps": [req.resps for req in requests], + "filtered_resps": [ + req.filtered_resps[filter_key] for req in requests + ], + "filter": filter_key, + "metrics": list(metrics.keys()), + "doc_hash": hash_string( + json.dumps( + requests[0].doc, + indent=2, + default=handle_non_serializable, + ensure_ascii=False, + ) + ), + "prompt_hash": hash_string(requests[0].arguments[0]), + "target_hash": hash_string(str(target)), + } + example.update(metrics) + task_output.logged_samples.append(example) + for metric, value in metrics.items(): + task_output.sample_metrics[(metric, filter_key)].append(value) + + if WORLD_SIZE > 1: + # if multigpu, then gather data across all ranks to rank 0 + # first gather logged samples across all ranks + for task_output in eval_tasks: + if log_samples: + # for task_name, task_samples in list(samples.items()): + full_samples = [None] * WORLD_SIZE if RANK == 0 else None + torch.distributed.gather_object( + obj=task_output.logged_samples, + object_gather_list=full_samples, + dst=0, + ) + + if RANK == 0: + task_output.logged_samples = list( + itertools.chain.from_iterable(full_samples) + ) + + # then collect metrics across all ranks + for metrics in task_output.sample_metrics: + metric_list = [None] * WORLD_SIZE if RANK == 0 else None + torch.distributed.gather_object( + obj=task_output.sample_metrics[metrics], + object_gather_list=metric_list, + dst=0, + ) + if RANK == 0: + task_output.sample_metrics[metrics] = list( + itertools.chain.from_iterable(metric_list) + ) + + if RANK == 0: + ### Aggregate results over all datapoints ### + # aggregate results ; run bootstrap CIs + for task_output in eval_tasks: + task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters) + ( + results, + samples, + configs, + versions, + num_fewshot, + higher_is_better, + ) = consolidate_results(eval_tasks) + + ### Calculate group metrics ### + if bool(results): + results, versions, show_group_table, *_ = consolidate_group_results( + results, versions, task_dict + ) + + results_agg, group_agg = prepare_print_tasks(task_dict, results) + subtask_list = get_subtask_list(task_dict) + + # collect all higher_is_better values for metrics + # in the group's subtasks. + # TODO: clean this up ; unify with the below metric_list loop? + _higher_is_better = {} + for group, task_list in subtask_list.items(): + if ( + len(task_list) != 0 + ): # subtask list will list "task_name": [] for solo tasks + for task in task_list: + for m, h in higher_is_better[task].items(): + if m not in _higher_is_better.keys(): + _higher_is_better[m] = h + + if ( + m in _higher_is_better + and _higher_is_better[m] is not None + and _higher_is_better[m] != h + ): + eval_logger.warning( + f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None." + ) + _higher_is_better[m] = None + higher_is_better[group] = _higher_is_better + + results_dict = { + "results": dict(results_agg.items()), + **( + {"groups": dict(group_agg.items())} + if (bool(group_agg) & show_group_table) + else {} + ), + "group_subtasks": dict(reversed(subtask_list.items())), + "configs": dict(sorted(configs.items())), + "versions": dict(sorted(versions.items())), + "n-shot": dict(sorted(num_fewshot.items())), + "higher_is_better": dict(sorted(higher_is_better.items())), + "n-samples": { + task_output.task_name: { + "original": len(task_output.task.eval_docs), + "effective": min( + limit if limit else len(task_output.task.eval_docs), + len(task_output.task.eval_docs), + ), + } + for task_output, limit in zip(eval_tasks, limits) + }, + } + if log_samples: + results_dict["samples"] = dict(samples) + + return results_dict + + else: + return None + + +def request_caching_arg_to_dict(cache_requests: str) -> dict: + request_caching_args = { + "cache_requests": cache_requests in {"true", "refresh"}, + "rewrite_requests_cache": cache_requests == "refresh", + "delete_requests_cache": cache_requests == "delete", + } + + return request_caching_args diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator_utils.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd87b6c7a3923d7e2cdf71894ca87b10137ac9c --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/evaluator_utils.py @@ -0,0 +1,554 @@ +import collections +import logging +import math +import pathlib +import sys +from typing import List, Optional, Tuple, Union + +from lm_eval.api.group import ConfigurableGroup +from lm_eval.api.metrics import ( + aggregate_subtask_metrics, + mean, + pooled_sample_stderr, + stderr_for_metric, +) +from lm_eval.api.task import Task +from lm_eval.utils import positional_deprecated + + +eval_logger = logging.getLogger(__name__) + + +class TaskOutput: + """ + Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task. + + Attributes: + task (object): The task object. + task_name (str): The name of the task. + task_config (dict): The configuration of the task. + version (str): The version of the task. + group_name (str): The name of the task group. + n_shot (int): The number of shots for the task. + task_alias (str): The alias of the task. + group_alias (str): The alias of the task group. + is_group (bool): Indicates if the task is a group. + logged_samples (list): The list of logged samples. + sample_len (int): The length of the samples. + sample_metrics (defaultdict): The dictionary of samples' metrics. + agg_metrics (defaultdict): The dictionary of aggregate metrics. + + Methods: + from_taskdict(cls, task_name: str, task): + Creates a TaskOutput instance from a task dictionary. + + calculate_aggregate_metric(bootstrap_iters=100000) -> None: + Calculates the aggregate metrics for the task. + """ + + def __init__( + self, + task=None, + task_name=None, + task_config=None, + version=None, + group_name=None, + n_shot=None, + task_alias=None, + group_alias=None, + is_group=None, + ): + self.task = task + self.task_config = task_config + self.task_name = task_name + self.group_name = group_name + self.version = version + self.n_shot = n_shot + self.task_alias = task_alias + self.group_alias = group_alias + self.is_group = is_group + self.logged_samples = [] + self.sample_len = None + self.sample_metrics = collections.defaultdict(list) + self.agg_metrics = collections.defaultdict(list) + + @classmethod + def from_taskdict(cls, task_name: str, task): + if isinstance(task, tuple): + group_name, task = task + else: + group_name = None + if not task: + # these gets filtered out in get_task_list + # once they are added to group hierarchy + is_group = True + return cls( + task=task, task_name=task_name, is_group=is_group, group_name=group_name + ) + version = task.VERSION + task_config = dict(task.dump_config()) + if (n_shot := task_config.get("num_fewshot")) == 0: + n_shot = task_config.get("metadata", {}).get("num_fewshot", 0) + task_alias = task_config.get("alias") + group_alias = task_config.get("group_alias") + return cls( + task=task, + task_name=task_name, + task_config=task_config, + group_name=group_name, + version=version, + n_shot=n_shot, + task_alias=task_alias, + group_alias=group_alias, + ) + + def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None: + for (metric, filter_key), items in self.sample_metrics.items(): + try: + agg_fn = self.task.aggregation()[metric] + except KeyError: + # This is when process results output an arbitrary metric + # TODO: Handle this better and allow other aggregate functions other than mean. + agg_fn = mean + metric_key = f"{metric},{filter_key}" + self.agg_metrics[metric_key] = agg_fn(items) + self.sample_len = len(items) # TODO: same sample size for each metric? + if isinstance(bootstrap_iters, int): + stderr_fn = stderr_for_metric( + metric=agg_fn, + bootstrap_iters=min(bootstrap_iters, 100) + if metric in ["bleu", "chrf", "ter"] + else bootstrap_iters, + ) + self.agg_metrics[f"{metric}_stderr,{filter_key}"] = ( + stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A" + ) + else: + raise ValueError( + f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations." + ) + + def __repr__(self): + return ( + f"TaskOutput(task_name={self.task_name}, " + f"group_name={self.group_name}, " + f"version={self.version}, " + f"n_shot={self.n_shot}, " + f"task_alias={self.task_alias}, " + f"group_alias={self.group_alias})" + ) + + +def get_task_list(task_dict: dict) -> List[TaskOutput]: + outputs = [] + for task_name, task_obj in task_dict.items(): + if isinstance(task_obj, dict): + _outputs = get_task_list(task_obj) + outputs.extend(_outputs) + else: + task_output = TaskOutput.from_taskdict(task_name, task_obj) + outputs.append(task_output) + + return outputs + + +def get_subtask_list(task_dict, task_root=None, depth=0): + subtask_list = {} + for group_obj, task_obj in task_dict.items(): + if isinstance(group_obj, ConfigurableGroup): + # group_name = group_obj.group_name + group_name = group_obj.group_name + else: + group_name = group_obj + if isinstance(task_obj, dict): + _subtask_list = get_subtask_list( + task_obj, task_root=group_name, depth=depth + 1 + ) + if task_root: + subtask_list.setdefault((task_root, depth), []).extend( + [ + _task + for (_task, _depth) in _subtask_list.keys() + if (_depth - 1) == depth + ] + ) + + subtask_list = {**subtask_list, **_subtask_list} + else: + if isinstance(task_obj, ConfigurableGroup): + # group_or_task_name = task_obj.group_name + group_or_task_name = task_obj.group_name + elif isinstance(task_obj, Task): + # group_or_task_name = task_obj.task_name + group_or_task_name = task_obj.task_name + + if task_root is None: + subtask_list.setdefault((group_or_task_name, depth), []) + else: + subtask_list.setdefault((task_root, depth), []).append( + group_or_task_name + ) + + if depth == 0: + _subtask_list = {} + for group_key, task_list in subtask_list.items(): + group_name, depth = group_key + _subtask_list[group_name] = task_list + subtask_list = _subtask_list + + return subtask_list + + +def print_writeout(task) -> None: + for inst in task.instances: + # print the prompt for the first few documents + if inst.doc_id < 1: + eval_logger.info( + f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\ + \n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)" + ) + eval_logger.info(f"Request: {str(inst)}") + + +def get_sample_size(task, limit: Optional[int]) -> Union[int, None]: + if limit is not None: + limit = ( + int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit) + ) + return limit + + +def prepare_print_tasks( + task_dict: dict, + results: dict, + task_depth=0, + group_depth=0, +) -> Tuple[dict, dict]: + """ + @param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its + value is a list of task names. + @param results: Dictionary containing the results of each task. Each key is a + group name and its value is a dictionary of task results. + @param task_depth: The indentation level for printing the task + hierarchy. Default is 0. + @param group_depth: The indentation level for printing the group + hierarchy. Default is 0. + @return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains + aggregated results for each task, and groups_agg contains aggregated results for each group. + + Prepares the task hierarchy and aggregates the results for each task and group recursively for printing. + """ + + def _sort_task_dict(task_dict): + """ + Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name. + Required so that we end up sorting within each sub-header correctly. + """ + + return dict( + sorted( + task_dict.items(), + key=lambda item: item[0].group_name + if isinstance(item[0], ConfigurableGroup) + else item[0], + ) + ) + + task_agg = collections.defaultdict(dict) + group_agg = collections.defaultdict(dict) + task_dict = _sort_task_dict(task_dict) + for task_or_group_name, task_or_group_obj in task_dict.items(): + tab_string = " " * task_depth + "- " if task_depth > 0 else "" + if isinstance(task_or_group_name, ConfigurableGroup): + # string_name = task_or_group_name.group_name + name = task_or_group_name.group_name + from_configurable_group = True + task_or_group_obj = _sort_task_dict(task_or_group_obj) + elif isinstance(task_or_group_name, str): + name = task_or_group_name + if isinstance(task_or_group_obj, Task): + # string_name = task_or_group_obj.task_name + name = task_or_group_obj.task_name + from_configurable_group = False + + task_agg[name] = results[name].copy() + if from_configurable_group: + if task_or_group_name.group_alias is not None: + alias = task_or_group_name.group_alias + else: + alias = task_or_group_name.group + else: + if "alias" in task_agg[name]: + alias = task_agg[name]["alias"] + else: + alias = name + + task_agg[name]["alias"] = tab_string + alias + if "samples" in task_agg[name]: + task_agg[name].pop("samples") + + if from_configurable_group and (" " not in results[name]): + group_tab_string = " " * group_depth + "- " if group_depth > 0 else "" + group_agg[name] = results[name].copy() + group_agg[name]["alias"] = group_tab_string + alias + if "samples" in group_agg[name]: + group_agg[name].pop("samples") + + if isinstance(task_or_group_obj, dict): + task_depth += 1 + group_depth += 1 + _task_agg, _group_agg = prepare_print_tasks( + task_or_group_obj, results, task_depth, group_depth + ) + task_agg = { + **task_agg, + **_task_agg, + } + group_agg = {**group_agg, **_group_agg} + task_depth -= 1 + group_depth -= 1 + return task_agg, group_agg + + +def consolidate_results( + eval_tasks: List[TaskOutput], +) -> Tuple[dict, dict, dict, dict, dict, dict]: + """ + @param eval_tasks: list(TaskOutput). + @return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot. + + Consolidates the results of multiple evaluation tasks into a single structure. + + The method iterates over each evaluation instance and extracts relevant information to create the consolidated + results structure. The consolidated results structure has the following properties: + + - results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains + metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task + aliases specified in the task configuration. + - samples: A defaultdict with task names as keys and lists of log samples as values. + - configs: A defaultdict with task names as keys and task configurations as values. + - versions: A defaultdict with task names as keys and task versions as values. + - num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values. + - higher_is_better: A defaultdict with task names as keys and indicators of whether higher values are better + for each metric as values. + + The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple. + """ + # stores the final result for each task, for each metric/filter pair. + results = collections.defaultdict(dict) + # logs info about each document evaluated. + samples = collections.defaultdict(list) + # store num-fewshot value per task + num_fewshot = collections.defaultdict(int) + # Tracks the YAML configs of all chosen task + configs = collections.defaultdict(dict) + # Tracks each task's version. + versions = collections.defaultdict(dict) + # Track `higher_is_better` for each metric + higher_is_better = collections.defaultdict(dict) + + for task_output in eval_tasks: + if "task_alias" in (task_config := task_output.task_config): + results[task_output.task_name]["alias"] = task_config["task_alias"] + else: + results[task_output.task_name]["alias"] = task_output.task_name + if group_alias := task_output.group_alias: + if group_alias not in results and (group_name := task_output.group_name): + results[group_name]["alias"] = group_alias + num_fewshot[task_output.task_name] = task_output.n_shot + configs[task_output.task_name] = task_output.task_config + versions[task_output.task_name] = task_output.version + samples[task_output.task_name] = task_output.logged_samples + higher_is_better[task_output.task_name] = task_output.task.higher_is_better() + for (metric, filter_key), items in task_output.sample_metrics.items(): + metric_key = f"{metric},{filter_key}" + results[task_output.task_name][metric_key] = task_output.agg_metrics[ + metric_key + ] + results[task_output.task_name]["samples"] = task_output.sample_len + results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = ( + task_output.agg_metrics[f"{metric}_stderr,{filter_key}"] + ) + return results, samples, configs, versions, num_fewshot, higher_is_better + + +def consolidate_group_results( + results, + versions, + task_dict, + task_root=None, + show_group_table=False, + task_aggregation_list=None, +) -> Tuple[dict, dict, bool, Union[None,]]: + """ + (Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info. + + @return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below: + + - results: A defaultdict with task names (and, after this function is called, group names of + groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys. + - versions: A defaultdict with task names (and, after this function is called, group names of + groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None). + - show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table. + - task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric. + + The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple. + In the top-level invocation of this function, task_aggregation_list is ignored. + """ + if task_root is None: + task_root = {} + + if task_aggregation_list is None: + task_aggregation_list = {} + + for group_or_task, group_or_task_info in task_dict.items(): + # Convert to string + if isinstance(group_or_task, ConfigurableGroup): + group_config = group_or_task.config + group_or_task = group_or_task.group_name + else: + group_config = None + + if isinstance(group_or_task_info, Task): + if task_root: + task_aggregation_list.setdefault(task_root, []).append( + group_or_task_info.task_name + ) + else: + ( + results, + versions, + show_group_table, + _task_aggregation_list, + ) = consolidate_group_results( + results, + versions, + group_or_task_info, + group_or_task, + show_group_table, + task_aggregation_list, + ) + if task_root: + task_aggregation_list.setdefault(task_root, []).extend( + task_aggregation_list.get(group_or_task, []) + ) + + if (group_config is None) or ( + group_config["aggregate_metric_list"] is None + ): + results[group_or_task][" "] = " " + continue + + if "aggregate_metric_list" in group_config: + agg_metric_list = group_config["aggregate_metric_list"] + + show_group_table = show_group_table | bool( + group_config["aggregate_metric_list"] + ) + + task_list = _task_aggregation_list[group_or_task] + + metric_list = list( + { + key + for task in task_list + for key in results[task].keys() + if "_stderr" not in key and key not in ["task", "alias", "samples"] + } + ) + for metric in metric_list: + stderr = "_stderr,".join(metric.split(",")) + + # gather metrics, sizes, and stderrs from subtasks + metrics = [ + results[task][metric] + for task in task_list + if metric in results[task] + ] # TODO: copy? + stderrs = [ + results[task][stderr] + for task in task_list + if stderr in results[task] + ] + sizes = [ + results[task]["samples"] + for task in task_list + if metric in results[task] + ] + + for metric_config in agg_metric_list: + for filter_name in metric_config["filter_list"]: + if metric != ",".join([metric_config["metric"], filter_name]): + continue + + # compute group's pooled metric and stderr + if metric_config["aggregation"] == "mean": + aggregate_fn = aggregate_subtask_metrics + elif callable(metric_config["aggregation"]): + aggregate_fn = metric_config["aggregation"] + else: + raise ValueError( + f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'" + ) + + results[group_or_task][metric] = aggregate_fn( + metrics, + sizes, + metric_config["weight_by_size"], + ) + # TODO: calculate groups' metrics using arbitrary agg fns + if "N/A" in stderrs: + results[group_or_task][stderr] = "N/A" + else: + # NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere + results[group_or_task][stderr] = pooled_sample_stderr( + stderrs, sizes + ) + + results[group_or_task]["samples"] = sum(sizes) + group_metadata = group_config.get("metadata", None) + if group_metadata is not None: + versions[group_or_task] = group_metadata.get("version", None) + # print(results) + return results, versions, show_group_table, task_aggregation_list + + +@positional_deprecated +def find_test_root(start_path: pathlib.Path) -> pathlib.Path: + """ + Search upward in the directory tree to a maximum of three layers + to find and return the package root (containing the 'tests' folder) + """ + cur_path = start_path.resolve() + max_layers = 3 + for _ in range(max_layers): + if (cur_path / "tests" / "test_version_stable.py").exists(): + return cur_path + else: + cur_path = cur_path.parent.resolve() + raise FileNotFoundError( + f"Unable to find package root within {max_layers} upwards" + f"of {start_path}" + ) + + +@positional_deprecated +def run_task_tests(task_list: List[str]): + """ + Find the package root and run the tests for the given tasks + """ + import pytest + + package_root = find_test_root(start_path=pathlib.Path(__file__)) + task_string = " or ".join(task_list) + args = [ + f"{package_root}/tests/test_version_stable.py", + f"--rootdir={package_root}", + "-k", + f"{task_string}", + ] + sys.path.append(str(package_root)) + pytest_return_val = pytest.main(args) + if pytest_return_val: + raise ValueError( + f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" + ) diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/__init__.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..be5c9d43624ea901cc578c65689be5bd263209a5 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/__init__.py @@ -0,0 +1,25 @@ +from functools import partial +from typing import List + +from lm_eval.api.filter import FilterEnsemble +from lm_eval.api.registry import get_filter + +from . import custom, extraction, selection, transformation + + +def build_filter_ensemble( + filter_name: str, components: List[List[str]] +) -> FilterEnsemble: + """ + Create a filtering pipeline. + """ + filters = [] + for function, kwargs in components: + if kwargs is None: + kwargs = {} + # create a filter given its name in the registry + f = partial(get_filter(function), **kwargs) + # add the filter as a pipeline step + filters.append(f) + + return FilterEnsemble(name=filter_name, filters=filters) diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/custom.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..ab22c51eda74670aaea6699fc68992994c41932d --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/custom.py @@ -0,0 +1,17 @@ +from lm_eval.api.filter import Filter +from lm_eval.api.registry import register_filter + + +@register_filter("custom") +class CustomFilter(Filter): + """ + Custom filter that applies a custom, user-defined function to the model responses. + """ + + def __init__(self, **kwargs) -> None: + self.filter_fn = kwargs.pop("filter_fn") + + super().__init__(**kwargs) + + def apply(self, resps, docs): + return self.filter_fn(resps, docs) diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/decontamination.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/decontamination.py new file mode 100644 index 0000000000000000000000000000000000000000..4eda4e022445355f191926790b2edf8f0cfa4bbd --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/decontamination.py @@ -0,0 +1,25 @@ +from lm_eval.api.filter import Filter +from lm_eval.api.registry import register_filter + + +@register_filter("decontaminate") +class DecontaminationFilter(Filter): + """ + A filter which evaluates + """ + + name = "track_decontamination" + + def __init__(self, path) -> None: + """ + + TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path"). + should further cache result on a given (task_name, doc_id) + """ + self._decontam_results = None + + def apply(self, resps, docs) -> None: + """ + Return {"no_contamination", "only_contamination"} keys for the 2 different subsets + """ + pass diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/extraction.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..9c8d796b6099d89fb6b6e5b2e17444cfa66f1b06 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/extraction.py @@ -0,0 +1,188 @@ +import re +import sys +import unicodedata + +from lm_eval.api.filter import Filter +from lm_eval.api.registry import register_filter + + +@register_filter("regex") +class RegexFilter(Filter): + """A filter that extracts values from text using regex pattern matching. + + This filter applies a regex pattern to each model response and extracts matched values. + If no match is found, returns a fallback value. Useful for extracting structured data + (like numbers) from unstructured model outputs. + """ + + def __init__( + self, + regex_pattern: str = r"#### (\-?[0-9\.\,]+)", + group_select: int = 0, + fallback: str = "[invalid]", + ) -> None: + """ + pass a string `regex` to run `re.compile(r"regex")` on. + `fallback` defines the output returned if no matches for the regex are located. + """ + self.regex_pattern = regex_pattern + self.regex = re.compile(regex_pattern) + self.group_select = group_select + self.fallback = fallback + + def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + # here, we assume we have a list, in which each element is + # a list of model responses for some particular input/target pair. + # so we process each of these (same input/target response sets) + # independently (and keep them a list.) + def filter_set(inst): + filtered = [] + for resp in inst: + match = self.regex.findall(resp) + if match: + match = match[self.group_select] + if isinstance(match, tuple): + match = [m for m in match if m] + if match: + match = match[0] + else: + match = self.fallback + match = match.strip() + else: + match = self.fallback + filtered.append(match) + return filtered + + filtered_resps = list(map(lambda x: filter_set(x), resps)) + + return filtered_resps + + +@register_filter("remove_whitespace") +class WhitespaceFilter(Filter): + """Filters out leading whitespace from responses.""" + + def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + def filter_set(inst): + filtered_resp = [] + for resp in inst: + resp = resp.lstrip() + filtered_resp.append(resp) + return filtered_resp + + filtered_resps = [filter_set(resp) for resp in resps] + + return filtered_resps + + +@register_filter("multi_choice_regex") +class MultiChoiceRegexFilter(RegexFilter): + """ + A filter used to extract a model's answer on multiple choice questions with + letter answers. assumes each document has a "choices" field + containing the list of answer choices and that the answer label symbols + are of the form (A), (B), (C), ... or A, B, C. + """ + + def __init__( + self, + regex_pattern: str = r"#### (\-?[0-9\.\,]+)", + group_select=0, + fallback: str = "[invalid]", + ignore_case=False, + ignore_punctuation=False, + regexes_to_ignore=None, + ) -> None: + """ + regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure + - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response. + - step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices. + group_select: Selects the (group_select)th match from the findall result. + ignore_case: Ignores the case during step 1 matching + ignore_punctuation: Remove the punctuation during step 1 matching + regexes_to_ignore: Remove these regexes during step 1 matching + """ + super().__init__(regex_pattern, group_select, fallback) + self.ignore_case = ignore_case + self.ignore_punctuation = ignore_punctuation + self.regexes_to_ignore = regexes_to_ignore + + def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + # here, we assume we have a list, in which each element is + # a list of model responses for some particular input/target pair. + # so we process each of these (same input/target response sets) + # independently (and keep them a list.) + + def find_match(regex, resp, convert_dict={}): + match = regex.findall(resp) + if match: + match = match[self.group_select] + if isinstance(match, tuple): + match = [m for m in match if m][0] + match = match.strip() + if match and match in convert_dict: + match = convert_dict[match] + return match + + punct_tbl = dict.fromkeys( + i + for i in range(sys.maxunicode) + if unicodedata.category(chr(i)).startswith("P") + ) + + def filter_ignores(st): + if self.regexes_to_ignore is not None: + for s in self.regexes_to_ignore: + st = re.sub(s, "", st) + + if self.ignore_case: + st = st.lower() + + if self.ignore_punctuation: + # https://stackoverflow.com/a/266162 + st = st.translate(punct_tbl) + return st + + filtered_resps = [] + + for r, doc in zip(resps, docs): + fallback_regexes = [] + choice_to_alpha = {} + next_alpha = "A" + + without_paren_fallback_regexes = [] + without_paren_to_target = {} + + choices = doc["choices"] + for c in choices: + m = filter_ignores(c.strip()) + fallback_regexes.append(f"{re.escape(m)}") + choice_to_alpha[m] = f"({next_alpha})" + + without_paren_fallback_regexes.append(next_alpha) + without_paren_to_target[next_alpha] = f"({next_alpha})" + + next_alpha = chr(ord(next_alpha) + 1) + fallback_regex = re.compile("|".join(fallback_regexes)) + without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) + without_paren_fallback_regex = re.compile( + rf":[\s]*({without_paren_fallback_regex})" + ) + + filtered = [] + for resp in r: + match = find_match(self.regex, resp) + if not match: + match = find_match( + fallback_regex, filter_ignores(resp), choice_to_alpha + ) + if not match: + match = find_match( + without_paren_fallback_regex, resp, without_paren_to_target + ) + if not match: + match = self.fallback + filtered.append(match) + filtered_resps.append(filtered) + + return filtered_resps diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/selection.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/selection.py new file mode 100644 index 0000000000000000000000000000000000000000..8c670ed74d00655441cc45181fba1265f0db5290 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/selection.py @@ -0,0 +1,61 @@ +from collections import Counter + +from lm_eval.api.filter import Filter +from lm_eval.api.registry import register_filter + + +# TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function +# that takes an input and returns a scalar and then should select the max reward, +# or should implement different filters for different ways of handling a reward model's inference. + + +@register_filter("take_first") +class TakeFirstFilter(Filter): + def __init__(self) -> None: + """ + Can define custom behavior here, if an individual instantiation of a Filter class should have state. + """ + + def apply(self, resps, docs): + """ + Assuming each entry of `resps` is a list of model responses, we discard all but the first response. + """ + return map(lambda r: r[0], resps) + + +@register_filter("take_first_k") +class TakeKFilter(Filter): + def __init__(self, **kwargs) -> None: + self.k = kwargs.pop("k") + + super().__init__(**kwargs) + + def apply(self, resps, docs): + # need resp to be subscriptable to check below + resps = list(resps) + # check we have at least k responses per doc, else we can't take the first k + assert len(resps[0]) >= self.k, ( + f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ." + ) + return map(lambda r: r[: self.k], resps) + + +@register_filter("majority_vote") +class MajorityVoteFilter(Filter): + def __init__(self) -> None: + """ + Can define custom behavior here, if an individual instantiation of a Filter class should have state. + """ + + def apply(self, resps, docs): + """ + Each entry of `resps` is a list of model responses. + We select the response that occurs most frequently in each entry of `resps`. + """ + + def select_majority(resp): + counts = Counter(resp) + vote = counts.most_common(1)[0][0] + return vote + + return map(lambda r: [select_majority(r)], resps) diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/transformation.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..1a3592b6dd4811dcef39ff090dfa42e926613b5c --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/filters/transformation.py @@ -0,0 +1,56 @@ +from lm_eval.api.filter import Filter +from lm_eval.api.registry import register_filter + + +@register_filter("lowercase") +class LowercaseFilter(Filter): + def __init__(self) -> None: + pass + + def apply(self, resps, docs): + def filter_set(inst): + return [resp.lower() for resp in inst] + + return [filter_set(resp) for resp in resps] + + +@register_filter("uppercase") +class UppercaseFilter(Filter): + def __init__(self) -> None: + pass + + def apply(self, resps, docs): + def filter_set(inst): + return [resp.upper() for resp in inst] + + return [filter_set(resp) for resp in resps] + + +@register_filter("map") +class MapFilter(Filter): + def __init__(self, mapping_dict: dict = None, default_value=None) -> None: + """ + Initializes the MapFilter with a given mapping dictionary and default value. + + Args: + - mapping_dict (dict): A dictionary containing the key-value mappings. + Default is an empty dictionary. + - default_value (Any): The value to be returned when a key is not found in the mapping_dict. + Default is None. + + Example: + mapper = MapFilter({'A': 1, 'B': 2}, default_value=0) + """ + if mapping_dict is None: + mapping_dict = {} + assert isinstance(mapping_dict, dict), ( + "Provided mapping_dict is not a dictionary" + ) + self.mapping_dict = mapping_dict + self.default_value = default_value + + def apply(self, resps, docs): + def filter_set(inst): + return [self.mapping_dict.get(resp, self.default_value) for resp in inst] + + return [filter_set(resp) for resp in resps] diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/__init__.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e46d196a6bce3df235310b278be7e7e11cb950c3 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/__init__.py @@ -0,0 +1,17 @@ +from . import ( + diffllm, + huggingface, +) + + +# TODO: implement __all__ + + +try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + import huggingface_hub.constants # type: ignore + + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True +except ImportError: + pass diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/diffllm.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/diffllm.py new file mode 100644 index 0000000000000000000000000000000000000000..9fbc269818264e48062761987cc90ff4e383313f --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/diffllm.py @@ -0,0 +1,563 @@ +import logging +import gc +import random +import json +import os +import time +from datetime import timedelta +from typing import List, Optional, Tuple, Type, TypeVar, Union + +import torch +import torch.nn.functional as F +import transformers +from accelerate import ( + Accelerator, + InitProcessGroupKwargs, + find_executable_batch_size, +) +from datasets import Dataset +from packaging import version +from tqdm import tqdm + +from lm_eval import utils +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from lm_eval.api.registry import register_model +from lm_eval.models.utils import Collator, get_dtype + +eval_logger = logging.getLogger(__name__) +T = TypeVar("T", bound="LM") + + +def empty_cache_by_memory(threshold_gb=70): + """ + Empty CUDA cache if allocated memory exceeds threshold + Args: + threshold_gb: Memory threshold in GB + """ + if torch.cuda.is_available(): + # Get current memory allocated + allocated = torch.cuda.memory_allocated() / 1024**3 # Convert to GB + + if allocated > threshold_gb: + # Clear cache + gc.collect() + torch.cuda.empty_cache() + print(f"Cache cleared. Memory freed: {allocated:.2f} GB") + +@register_model("diffllm") +class DiffLLM(LM): + def __init__( + self, + pretrained: Union[str, transformers.PreTrainedModel], + batch_size: Optional[Union[int, str]] = 1, + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = "auto", + max_prompt_len: Optional[int] = 1024, + max_new_tokens: Optional[int] = 128, + nll_type: Optional[str] = "mc", + log_type: Optional[str] = "ftb", + classifier_free_guidance: Optional[float] = 1.0, + pad_to_max_len: Optional[bool] = False, + sampling_eps: Optional[float] = 1e-3, + diffusion_steps: Optional[int] = 32, + trust_remote_code: Optional[bool] = True, + parallelize: Optional[bool] = False, + autogptq: Optional[Union[bool, str]] = False, + **kwargs, + ) -> None: + super().__init__() + + # prepare for parallelism + assert isinstance(device, str) + assert isinstance(pretrained, str) + assert isinstance(batch_size, (int, str)) + + gpus = torch.cuda.device_count() + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + + self.accelerator = accelerator + + if "npu" in accelerator.device.type: + gpus = torch.npu.device_count() + + # using one process with no model parallelism + if not (parallelize or accelerator.num_processes > 1): + # use user-passed device + device_list = set( + ["cuda", "cpu"] + + [f"cuda:{i}" for i in range(gpus)] + + ["mps", "mps:0"] + + [f"npu:{i}" for i in range(gpus)] + ) + if device and device in device_list: + self._device = torch.device(device) + eval_logger.info(f"Using device '{device}'") + if device in ("mps", "mps:0") and version.parse( + torch.__version__ + ) < version.parse("2.1"): + raise RuntimeError( + f"mps requires torch >= 2.1. You have {torch.__version__}" + ) + else: + eval_logger.info("Device not specified") + eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") + self._device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + else: + if device != "cuda": + eval_logger.info( + f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." + ) + self._device = self.accelerator.device + + self.batch_size_per_gpu = batch_size + if isinstance(batch_size, str): + self.batch_size_per_gpu = int(batch_size) + self._create_model_and_tokenizer(pretrained, dtype, trust_remote_code) + + if isinstance(pretrained, str): + if gpus >= 1 or str(self.device) == "mps": + if not (parallelize or autogptq or (hasattr(self, "accelerator") and self.accelerator.num_processes > 1)): + try: + self.model.to(self.device) + except ValueError: + eval_logger.debug( + "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore." + ) + if gpus > 1: + if self.accelerator.num_processes > 1: + self._device = torch.device(f"{accelerator.device}") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self._rank = 0 + self._world_size = 1 + else: + self._rank = 0 + self._world_size = 1 + else: + eval_logger.warning( + "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration" + ) + self._rank = 0 + self._world_size = 1 + + self.max_prompt_len = max_prompt_len + self.max_new_tokens = max_new_tokens + self.diffusion_steps = diffusion_steps + self.temperature = kwargs.get("temperature", 0.7) + self.top_p = kwargs.get("top_p", 0.95) + self.alg = kwargs.get("alg", "entropy") + self.alg_temp = kwargs.get("alg_temp", 0.0) + self.top_k = kwargs.get("top_k", None) + + self.nll_type = nll_type + self.log_type = log_type + self.classifier_free_guidance = classifier_free_guidance + self.pad_to_max_len = pad_to_max_len + self.sampling_eps = sampling_eps + + self.mask_id = 151666 + self.eos_id = 151643 + + raw_use_hts = kwargs.get("use_hts", False) + if isinstance(raw_use_hts, str): + self.use_hts = raw_use_hts.lower() == "true" + else: + self.use_hts = bool(raw_use_hts) + + self.realtime_output = kwargs.get("realtime_output", "eval_results.jsonl") + + if self.use_hts: + from .hts_sampler import HTSSampler + self.hts_sampler = HTSSampler(self.model, self.tokenizer, device=self.device) + eval_logger.info(f"Rank {self.rank}: HTS Sampler initialized for Dream.") + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def _create_model_and_tokenizer(self, pretrained, dtype, trust_remote_code): + self.model = ( + transformers.AutoModel.from_pretrained( + pretrained, + torch_dtype=get_dtype(dtype), + trust_remote_code=trust_remote_code, + ) + .eval() + ).to(self.device) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained, trust_remote_code=trust_remote_code + ) + + def tok_decode(self, tokens, skip_special_tokens=True): + return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def tok_encode(self, text, add_special_tokens=True): + return self.tokenizer( + text, return_tensors="pt", add_special_tokens=add_special_tokens + ).input_ids + + @classmethod + def create_from_arg_string( + cls: Type[T], arg_string: str, additional_config: Optional[dict] = None + ) -> T: + additional_config = {} if additional_config is None else additional_config + args = utils.simple_parse_args_string(arg_string) + args2 = {k: v for k, v in additional_config.items() if v is not None} + return cls(**args, **args2) + + def apply_chat_template( + self, chat_history, add_generation_prompt: bool = True + ) -> str: + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + return chat_templated + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + def _generate_batch(self, prompts: List[str], gen_kwargs: dict = None) -> Tuple[List[str], List[dict]]: + raw_val = gen_kwargs.get("use_hts", self.use_hts) + use_hts_now = str(raw_val).lower() == "true" if not isinstance(raw_val, bool) else raw_val + + all_stats = [] + if not use_hts_now: + prompt_ids = self.tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").input_ids + prompt_ids = prompt_ids[:, -self.max_prompt_len:] + attn_mask = prompt_ids.ne(self.tokenizer.pad_token_id).to(self.device) + prompt_ids = prompt_ids.to(device=self.device) + + generation_ids = self.model.diffusion_generate( + prompt_ids, + attention_mask=attn_mask, + max_new_tokens=self.max_new_tokens, + output_history=False, + return_dict_in_generate=True, + steps=self.diffusion_steps, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + alg=self.alg, + alg_temp=self.alg_temp, + ) + responses = [ + self.tokenizer.decode(g[len(p) :].tolist()).split(self.tokenizer.eos_token)[0] + for p, g in zip(prompt_ids, generation_ids.sequences) + ] + all_stats = [{} for _ in responses] + return responses, all_stats + else: + if not hasattr(self, "hts_sampler"): + from .hts_sampler import HTSSampler + self.hts_sampler = HTSSampler(self.model, self.tokenizer, device=self.device) + + results = [] + for prompt in prompts: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) + + final_codes, stats = self.hts_sampler.generate_hts( + prompt_text=prompt, + input_ids=input_ids, + initial_N=int(gen_kwargs.get("initial_N", 4)), + final_K=int(gen_kwargs.get("final_K", 1)), + hts_survivor_k=int(gen_kwargs.get("hts_survivor_k", 4)), + reward_mode=gen_kwargs.get("reward_mode", "svf"), + task_type=gen_kwargs.get("task_type", "code"), + steps=self.diffusion_steps, + gen_length=self.max_new_tokens, + temperature=float(gen_kwargs.get("temperature", self.temperature)), + top_p=float(gen_kwargs.get("top_p", self.top_p)), + top_k=gen_kwargs.get("top_k", self.top_k), + until=gen_kwargs.get("until", []), + hts_mode=True, + mask_id=self.mask_id, + eos_id=self.eos_id + ) + + results.append(final_codes[0]) + all_stats.append(stats) + return results, all_stats + + def generate_until(self, requests: List[Instance], disable_tqdm: bool = False): + res = [] + + gen_kwargs_first = requests[0].args[1] + actual_output_path = gen_kwargs_first.get("realtime_output", self.realtime_output) + + raw_val = gen_kwargs_first.get("use_hts", self.use_hts) + self.use_hts = str(raw_val).lower() == "true" if not isinstance(raw_val, bool) else raw_val + + rank_tmp_file = actual_output_path.replace(".jsonl", f"_rank{self.rank}.tmp") + + output_dir = os.path.dirname(rank_tmp_file) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running generate_until", + ) + + for batch_idx in range(0, len(requests), self.batch_size_per_gpu): + batch_requests = requests[batch_idx : batch_idx + self.batch_size_per_gpu] + contexts, task_gen_args = zip(*[req.arguments for req in batch_requests]) + + responses, stats_list = self._generate_batch(contexts, gen_kwargs=task_gen_args[0]) + + for i, r in enumerate(responses): + r = r.replace("```python", "").replace("```", "") + + for s in task_gen_args[0].get('until', []): + r = r.split(s)[0] + + target_val = getattr(batch_requests[i], "target", None) + if target_val is None or target_val == "N/A": + target_val = batch_requests[i].doc.get("answer", batch_requests[i].doc.get("solution", "N/A")) + + save_data = { + "doc": batch_requests[i].doc, + "target": target_val, + "prompt": contexts[i], + "response": r, + } + + if self.use_hts: + save_data.update(stats_list[i]) + + with open(rank_tmp_file, "a", encoding="utf-8") as f: + f.write(json.dumps(save_data, ensure_ascii=False) + "\n") + f.flush() + + responses[i] = r + + if self.rank == 0 and batch_idx == 0: + print(f"Sample Response:\n{responses[0]}\n") + + res.extend(responses) + pbar.update(len(batch_requests)) + + pbar.close() + + self.accelerator.wait_for_everyone() + + if self.rank == 0: + eval_logger.info(f"Merging rank files into {actual_output_path}...") + with open(actual_output_path, "w", encoding="utf-8") as final_f: + for r in range(self.world_size): + temp_f = actual_output_path.replace(".jsonl", f"_rank{r}.tmp") + if os.path.exists(temp_f): + with open(temp_f, "r", encoding="utf-8") as tf: + for line in tf: + final_f.write(line) + os.remove(temp_f) + eval_logger.info("Merge completed.") + + return res + + def _forward_process(self, batch): + b, l = batch.shape + u0 = torch.rand(1, device=batch.device, dtype=torch.float32) + indices = torch.arange(b, device=batch.device).float() + t = (u0 + indices / b) % 1 + p_mask = (1 - self.sampling_eps) * t + self.sampling_eps + p_mask = p_mask[:, None].repeat(1, l) + mask_indices = torch.rand((b, l), device=batch.device) < p_mask + mask_indices[:, 0] = False + mask_indices[:, -1] = False + noisy_batch = torch.where(mask_indices, self.mask_id, batch) + return noisy_batch, p_mask + + @torch.no_grad() + def get_logits(self, batch, prompt_index): + if self.classifier_free_guidance > 1.: + assert len(prompt_index) == batch.shape[1] + prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1) + un_batch = batch.clone() + un_batch[prompt_index] = self.mask_id + batch = torch.cat([batch, un_batch]) + + if self.pad_to_max_len: + raise NotImplementedError + else: + input = batch + + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + logits = self.model(input, 'full').logits + + if self.classifier_free_guidance > 1.: + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + self.classifier_free_guidance * (logits - un_logits) + return logits[:, :batch.shape[1]] + + @torch.no_grad() + def _eval_target_nll_mc(self, prefix, target): + if prefix is None: + seq = target[None, :] + else: + seq = torch.concatenate([prefix, target])[None, :] + seq = seq.repeat((self.batch_size, 1)).to(self.device) + + if self.log_type == 'ftb': + prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) + else: + prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix) + + loss_acc = [] + mc_num = self.diffusion_steps + for _ in range(max(mc_num // self.batch_size, 1)): + perturbed_seq = seq.clone() + perturbed_seq_, p_mask = self._forward_process(seq) + if self.log_type == 'ftb': + perturbed_seq[:, -len(target):] = perturbed_seq_[:, -len(target):] + elif self.log_type == 'btf': + perturbed_seq[:, :len(prefix)] = perturbed_seq_[:, :len(prefix)] + elif self.log_type == 'union': + perturbed_seq = perturbed_seq_ + else: + raise NotImplementedError(self.log_type) + + mask_indices = perturbed_seq == self.mask_id + + logits = self.get_logits(perturbed_seq, prompt_index) + + loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices] + loss = loss.sum() / self.batch_size + loss_acc.append(loss.item()) + del logits, loss, perturbed_seq, perturbed_seq_, p_mask, mask_indices + empty_cache_by_memory(threshold_gb=70) + + return sum(loss_acc) / len(loss_acc) + + @torch.no_grad() + def _eval_target_nll_ar(self, prefix, target): + prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2 + assert self.log_type in ['ftb', 'btf'] + assert self.nll_type in ['ar_ftb', 'ar_btf'] + + if self.log_type == 'ftb': + prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) < prefix.shape[1] + else: + prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) >= prefix.shape[1] + + if self.log_type == 'ftb': + perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2 + else: + perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1 + + mask_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool) + if self.nll_type == 'ar_ftb': + mask_index = torch.triu(mask_index) + else: + mask_index = torch.tril(mask_index) + perturbed_[mask_index] = self.mask_id + if self.log_type == 'ftb': + perturbed_seq = torch.cat([prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1) + else: + perturbed_seq = torch.cat([perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1) + + logits_ = [] + num = len(perturbed_seq) // self.batch_size if len(perturbed_seq) % self.batch_size == 0 else len(perturbed_seq) // self.batch_size + 1 + for i in range(num): + end = (i + 1) * self.batch_size if (i + 1) * self.batch_size < len(perturbed_seq) else len(perturbed_seq) + perturbed_seq_ = perturbed_seq[i * self.batch_size: end] + perturbed_seq_ = perturbed_seq_.to(self.device) + if len(perturbed_seq_.shape) == 1: + perturbed_seq_ = perturbed_seq_.unsqueeze(0) + logits = self.get_logits(perturbed_seq_, prompt_index) + logits_.append(logits.cpu()) + logits = torch.cat(logits_, dim=0) + + temp_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool) + if self.nll_type == 'ar_ftb': + temp_index = torch.triu(temp_index, diagonal=1) + else: + temp_index = torch.tril(temp_index, diagonal=-1) + mask_index[temp_index] = False + if self.log_type == 'ftb': + logits_index = torch.cat([torch.zeros((perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool), mask_index], dim=-1) + else: + logits_index = torch.cat([mask_index, torch.zeros((perturbed_.shape[1], target.shape[1]), dtype=torch.bool)], dim=-1) + + if self.log_type == 'ftb': + loss = F.cross_entropy(logits[logits_index], target[0], reduction='sum').cpu().item() + else: + loss = F.cross_entropy(logits[logits_index], prefix[0], reduction='sum').cpu().item() + return loss + + def _encode_pair(self, context, continuation): + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + + whole_enc = self.tokenizer.encode(context + continuation) + [ + self.tokenizer.eos_token_id + ] + context_enc = self.tokenizer.encode(context) + + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + + return context_enc, continuation_enc + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + def _tokenize(e): + prefix, target = self._encode_pair(e["prefix"], e["target"]) + return { + "prefix_text": e["prefix"], + "target_text": e["target"], + "prefix": prefix, + "target": target, + } + + ds = [] + ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests] + ds = Dataset.from_list(ds) + ds = ds.map(_tokenize) + ds = ds.with_format("torch") + + out = [] + with torch.no_grad(): + for elem in tqdm(ds, desc="Computing likelihood..."): + prefix = elem["prefix"] + target = elem["target"] + + if self.nll_type == 'mc': + ll = -self._eval_target_nll_mc(prefix, target) + if self.log_type == 'union': + ll = ll / (len(target) + len(prefix)) + elif self.nll_type == 'ar_ftb' or self.nll_type == 'ar_btf': + ll = -self._eval_target_nll_ar(prefix, target) + else: + raise NotImplementedError(self.nll_type) + + is_target_greedy_dec = False + out.append((ll, 1.0 if is_target_greedy_dec else 0.0)) + return out + + def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + raise NotImplementedError \ No newline at end of file diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/dummy.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..014ad49ee36f756acd0428340f945312d79590e8 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/dummy.py @@ -0,0 +1,41 @@ +import random + +from tqdm import tqdm + +from lm_eval.api.model import LM +from lm_eval.api.registry import register_model + + +@register_model("dummy") +class DummyLM(LM): + def __init__(self) -> None: + super().__init__() + + @classmethod + def create_from_arg_string(cls, arg_string, additional_config=None): + return cls() + + def loglikelihood(self, requests, disable_tqdm: bool = False): + res = [] + + for _ in tqdm(requests, disable=disable_tqdm): + res.append((-random.random(), False)) + + return res + + def generate_until(self, requests, disable_tqdm: bool = False): + res = [] + + for request in tqdm(requests, disable=disable_tqdm): + res.append("lol") + assert request.arguments[0].strip() != "" + + return res + + def loglikelihood_rolling(self, requests, disable_tqdm: bool = False): + res = [] + + for _ in tqdm(requests, disable=disable_tqdm): + res.append(-random.random()) + + return res diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/hts_sampler.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/hts_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..3b10898220d7100bed7f8a3b56531b263c180ab6 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/hts_sampler.py @@ -0,0 +1,257 @@ +import torch +import torch.nn.functional as F +import numpy as np +from .verifier import CodeVerifier +import logging +import re +import math + +logger = logging.getLogger(__name__) + +class HTSSampler: + def __init__(self, model, tokenizer, device="cuda"): + self.model = model + self.tokenizer = tokenizer + self.device = device + self.verifier = CodeVerifier(model, tokenizer, device) + + def _get_num_transfer_tokens(self, block_length, steps): + if steps == 0: return torch.tensor([], dtype=torch.int64) + base = block_length // steps + remainder = block_length % steps + num_transfer_tokens = torch.full((steps,), base, dtype=torch.int64) + num_transfer_tokens[:remainder] += 1 + return num_transfer_tokens + + def _sample_with_temperature(self, logits, temperature, top_k, top_p): + logits = logits.to(torch.float32) + orig_probs = torch.softmax(logits, dim=-1) + x0_p, _ = torch.max(orig_probs, dim=-1) + + if temperature > 0.0: + noise = torch.rand_like(logits, dtype=torch.float32) + gumbel_noise = -torch.log(-torch.log(noise + 1e-10) + 1e-10) + logits = logits / temperature + gumbel_noise + + if top_k is not None and top_k > 0: + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = -float('Inf') + + x0 = torch.argmax(logits, dim=-1) + return x0, x0_p + + def _safe_scalar(self, val): + if isinstance(val, torch.Tensor): + if val.numel() > 1: return val.mean().item() + return val.item() + return float(val) + + def _analyze_structure(self, text, task_type="code"): + score = 0.0 + stripped = text.strip() + if task_type == "code": + if len(stripped) < 5: return -0.1 + keywords = ["return", "print", "yield", "lambda", "class ", "def "] + if any(k in stripped for k in keywords): score += 0.05 + if ":" in stripped: score += 0.02 + if " " in text: score += 0.03 + elif task_type == "math": + if "\\boxed{" in stripped: score += 0.1 + if "The answer is" in stripped: score += 0.05 + return score + + def _chunked_forward(self, x, chunk_size=32, slice_indices=None): + total_batch = x.shape[0] + logits_list = [] + for i in range(0, total_batch, chunk_size): + end_idx = min(i + chunk_size, total_batch) + sub_x = x[i:end_idx] + with torch.no_grad(): + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + outputs = self.model(sub_x, 'full') + sub_logits = outputs.logits + sub_logits = torch.cat([sub_logits[:, :1, :], sub_logits[:, :-1, :]], dim=1) + if slice_indices is not None: + s_start, s_end = slice_indices + sub_logits = sub_logits[:, s_start:s_end, :] + logits_list.append(sub_logits.detach().clone()) + return torch.cat(logits_list, dim=0) + + def _branch_and_resample(self, x, conf_scores, survivor_indices, target_width, mask_id, + prompt_length, resample_window=6, task_type="code"): + num_survivors = len(survivor_indices) + if num_survivors == 0: return x[:target_width].clone(), conf_scores[:target_width].clone() + + base_repeat = target_width // num_survivors + remainder = target_width % num_survivors + new_x_list, new_conf_list = [], [] + + for i in range(num_survivors): + count = base_repeat + (1 if i < remainder else 0) + if count == 0: continue + survivor_x = x[survivor_indices[i]] + survivor_conf = conf_scores[survivor_indices[i]] + + new_x_list.append(survivor_x.unsqueeze(0)) + new_conf_list.append(survivor_conf.unsqueeze(0)) + + if count > 1: + gen_part = survivor_x[prompt_length:] + gen_conf = survivor_conf[prompt_length:] + non_mask_indices = (gen_part != mask_id).nonzero(as_tuple=True)[0] + for _ in range(count - 1): + perturbed_x = survivor_x.clone() + perturbed_conf = survivor_conf.clone() + if len(non_mask_indices) > 0: + pool_size = min(resample_window * 2, len(non_mask_indices)) + current_token_confs = gen_conf[non_mask_indices] + _, candidate_pool = torch.topk(current_token_confs, k=pool_size, largest=False) + + num_to_perturb = min(resample_window, pool_size) + rand_indices = torch.randperm(pool_size, device=self.device)[:num_to_perturb] + selected_sub_indices = candidate_pool[rand_indices] + + target_idx_in_x = prompt_length + non_mask_indices[selected_sub_indices] + perturbed_x[target_idx_in_x] = mask_id + perturbed_conf[target_idx_in_x] = 0.0 + new_x_list.append(perturbed_x.unsqueeze(0)) + new_conf_list.append(perturbed_conf.unsqueeze(0)) + + return torch.cat(new_x_list, dim=0), torch.cat(new_conf_list, dim=0) + + @torch.no_grad() + def generate_hts(self, prompt_text, input_ids, problem_data=None, + initial_N=1, final_K=1, survivor_K=None, + prune_step_pct=0.0, reward_mode="confidence", + temperature=0.7, block_length=32, steps=64, gen_length=1024, + top_p=0.95, top_k=None, minimal_topk=1, threshold=0.9, + eos_id=151643, mask_id=151666, + hts_mode=False, hts_start_pct=0.1, hts_end_pct=0.6, decay_factor=1.2, + hts_survivor_k=4, task_type="code", until=None, pruning_interval=20): + + input_ids = input_ids.to(self.device) + prompt_length = input_ids.shape[1] + total_length = prompt_length + gen_length + + x = torch.full((initial_N, total_length), mask_id, dtype=torch.long, device=self.device) + x[:, :prompt_length] = input_ids.repeat(initial_N, 1) + conf_scores = torch.zeros((initial_N, total_length), dtype=torch.float32, device=self.device) + conf_scores[:, :prompt_length] = 1.0 + + schedule = self._get_num_transfer_tokens(gen_length, steps) + current_bsz = initial_N + schedule_map = {} + ts_start, tr_end = 0, 0 + + if hts_mode: + ts_start, tr_end = int(steps * hts_start_pct), int(steps * hts_end_pct) + else: + final_K_list = [final_K] if not isinstance(final_K, list) else final_K + prune_pct_list = [prune_step_pct] if not isinstance(prune_step_pct, list) else prune_step_pct + for pct, width in zip(prune_pct_list, final_K_list): + if pct > 0: schedule_map[int(steps * pct)] = width + + stats = { + "initial_n": initial_N, + "final_k": final_K if not isinstance(final_K, list) else final_K[-1], + "nfe": 0, + "svf_calls": 0, + "pruning_history": [], + "entropy_history": [], + "final_scores": [] + } + + next_allowed_pruning_step = ts_start + + for step in range(steps): + perform_pruning = False + num_parents_to_select = hts_survivor_k + + if hts_mode and ts_start <= step < tr_end and step >= next_allowed_pruning_step: + target_width = max(stats["final_k"], math.ceil(initial_N * (decay_factor ** -(step - ts_start)))) + if current_bsz > target_width: + perform_pruning = True + elif not hts_mode and step in schedule_map: + target_width = schedule_map[step] + num_parents_to_select = target_width + if current_bsz > target_width: + perform_pruning = True + + if perform_pruning: + stats["svf_calls"] += current_bsz + full_logits = self._chunked_forward(x[:current_bsz, :], slice_indices=(prompt_length, total_length)) + rough_ids = torch.argmax(full_logits, dim=-1) + rough_codes = self.tokenizer.batch_decode(rough_ids, skip_special_tokens=True) + + candidates = [] + for i in range(current_bsz): + s = self._safe_scalar(self.verifier.get_reward(prompt_text, rough_codes[i], mode=reward_mode, current_logits=full_logits[i], task_type=task_type)) + s += self._analyze_structure(rough_codes[i], task_type=task_type) + clean_text = rough_codes[i].strip().replace(" ", "").replace("\n", "") + content_key = hash(clean_text[:150] + clean_text[-150:]) if clean_text else i + candidates.append({'score': s, 'idx': i, 'key': content_key}) + + stats["pruning_history"].append({"step": step, "scores": [c['score'] for c in candidates]}) + candidates.sort(key=lambda c: c['score'], reverse=True) + + selected_indices, seen_keys = [], set() + for cand in candidates: + if len(selected_indices) >= num_parents_to_select: break + if cand['key'] not in seen_keys: + selected_indices.append(cand['idx']); seen_keys.add(cand['key']) + for cand in candidates: + if len(selected_indices) >= num_parents_to_select: break + if cand['idx'] not in selected_indices: selected_indices.append(cand['idx']) + + top_indices = torch.tensor(selected_indices, device=self.device) + x, conf_scores = self._branch_and_resample(x, conf_scores, top_indices, target_width, mask_id, prompt_length, task_type=task_type) + + current_bsz = target_width + next_allowed_pruning_step = step + pruning_interval + + active_mask = (x[:current_bsz, prompt_length:] == mask_id) + if active_mask.sum() == 0: break + + stats["nfe"] += current_bsz + logits = self._chunked_forward(x[:current_bsz, :], slice_indices=(prompt_length, total_length)) + + with torch.no_grad(): + probs = torch.softmax(logits.float(), dim=-1) + token_entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1) + sample_entropy = token_entropy.mean(dim=-1) + stats["entropy_history"].append(sample_entropy.tolist()) + + x0, x0_p = self._sample_with_temperature(logits, temperature, top_k, top_p) + num_transfer = schedule[step].item() + confidence = torch.where(active_mask, x0_p, -torch.inf) + transfer_idx = torch.zeros_like(x0, dtype=torch.bool) + + for b in range(current_bsz): + k = min(num_transfer, active_mask[b].sum().item()) + if k <= 0: continue + high_conf_mask = (confidence[b] > threshold) + if high_conf_mask.sum() >= k: + transfer_idx[b] = high_conf_mask + else: + _, topk_ids = torch.topk(confidence[b], k=k) + transfer_idx[b, topk_ids] = True + + if transfer_idx.any(): + x[:current_bsz, prompt_length:][transfer_idx] = x0[transfer_idx] + conf_scores[:current_bsz, prompt_length:][transfer_idx] = x0_p[transfer_idx] + + final_codes = self.tokenizer.batch_decode(x[:current_bsz, prompt_length:], skip_special_tokens=True) + final_candidates = [] + for i, code in enumerate(final_codes): + txt = code.split(self.tokenizer.eos_token)[0] + if until: + for term in until: + if term in txt: txt = txt.split(term)[0] + s = self._safe_scalar(self.verifier.get_reward(prompt_text, txt, mode=reward_mode, task_type=task_type)) + final_candidates.append({'resp': txt, 'score': s}) + + final_candidates.sort(key=lambda c: c['score'], reverse=True) + stats["final_scores"] = [c['score'] for c in final_candidates] + stats["all_trajectories"] = [{"rank": i+1, "resp": c['resp'], "score": c['score']} for i, c in enumerate(final_candidates)] + + return [c['resp'] for c in final_candidates], stats \ No newline at end of file diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/huggingface.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..f7e098a2e88d1c63276216a42b427ab28d77d835 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/huggingface.py @@ -0,0 +1,1459 @@ +import copy +import logging +import os +from datetime import timedelta +from pathlib import Path +from typing import Dict, List, Literal, Optional, Tuple, Union + +import jinja2 +import torch +import torch.nn.functional as F +import transformers +from accelerate import ( + Accelerator, + InitProcessGroupKwargs, + find_executable_batch_size, +) +from accelerate.utils import get_max_memory +from huggingface_hub import HfApi +from packaging import version +from peft import PeftModel +from peft import __version__ as PEFT_VERSION +from tqdm import tqdm +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, +) + +from lm_eval import utils +from lm_eval.api.instance import Instance +from lm_eval.api.model import TemplateLM +from lm_eval.api.registry import register_model +from lm_eval.models.utils import ( + Collator, + clear_torch_cache, + configure_pad_token, + get_dtype, + handle_stop_sequences, + pad_and_concat, + stop_sequences_criteria, +) + + +eval_logger = logging.getLogger(__name__) + + +@register_model("hf-auto", "hf", "huggingface") +class HFLM(TemplateLM): + """ + An abstracted Huggingface model class. Enables usage with both models of + `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes. + + Supports data-parallel multi-GPU with HF Accelerate. + """ + + AUTO_MODEL_CLASS = None + _DEFAULT_MAX_LENGTH = 2048 + + def __init__( + self, + pretrained: Union[str, transformers.PreTrainedModel], + backend: Literal["default", "causal", "seq2seq"] = "default", + # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq) + revision: Optional[str] = "main", + subfolder: Optional[str] = None, + tokenizer: Optional[ + Union[ + str, + transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast, + ] + ] = None, + truncation: Optional[bool] = False, + logits_cache: bool = True, + max_length: Optional[int] = None, + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = "auto", + batch_size: Optional[Union[int, str]] = 1, + max_batch_size: Optional[int] = 64, + trust_remote_code: Optional[bool] = False, + use_fast_tokenizer: Optional[bool] = True, + add_bos_token: Optional[bool] = False, + prefix_token_id: Optional[int] = None, + # arguments used for splitting a model across GPUs naively. + # only used if `parallelize=True`. + parallelize: Optional[bool] = False, + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = "./offload", + # PEFT, delta weights and quantization options + peft: Optional[str] = None, + delta: Optional[str] = None, + autogptq: Optional[Union[bool, str]] = False, + gptqmodel: Optional[bool] = False, + gguf_file: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__() + # optionally: take in an already-initialized transformers.PreTrainedModel + if not isinstance(pretrained, str): + eval_logger.warning( + "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way." + ) + assert not parallelize, ( + "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" + ) + self._model = pretrained + self._device = self._model.device + self._config = self._model.config + gpus = 0 + + else: + assert isinstance(device, str) + assert isinstance(pretrained, str) + assert isinstance(batch_size, (int, str)) + + gpus = torch.cuda.device_count() + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self.accelerator = accelerator + + if "npu" in accelerator.device.type: + gpus = torch.npu.device_count() + + # using one process with no model parallelism + if not (parallelize or accelerator.num_processes > 1): + # use user-passed device + device_list = set( + ["cuda", "cpu"] + + [f"cuda:{i}" for i in range(gpus)] + + ["mps", "mps:0"] + + [f"npu:{i}" for i in range(gpus)] + ) + if device and device in device_list: + self._device = torch.device(device) + eval_logger.info(f"Using device '{device}'") + if device in ("mps", "mps:0") and version.parse( + torch.__version__ + ) < version.parse("2.1"): + raise RuntimeError( + f"mps requires torch >= 2.1. You have {torch.__version__}" + ) + else: + eval_logger.info("Device not specified") + eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") + self._device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + else: # Parallelism managed by accelerate + if device != "cuda": + eval_logger.info( + f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." + ) + # TODO: include in warning that `load_in_8bit` etc. affect this too + self._device = ( + self.accelerator.device + if hasattr(self, "accelerator") + else torch.device(device) + ) + + revision = str(revision) # cast to string if not already one + # TODO: update this to be less of a hack once subfolder is fixed in HF + revision = revision + ("/" + subfolder if subfolder is not None else "") + + self._get_config( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + ) + + # determine which of 'causal' and 'seq2seq' backends to use for HF models + self._get_backend( + config=self.config, backend=backend, trust_remote_code=trust_remote_code + ) + + # load tokenizer so we know tokenizer vocabulary size before loading model and PEFT + self._create_tokenizer( + pretrained, + tokenizer, + revision=revision, + trust_remote_code=trust_remote_code, + use_fast_tokenizer=use_fast_tokenizer, + gguf_file=gguf_file, + add_bos_token=add_bos_token, + ) + + # if we passed `pretrained` as a string, initialize our model now + if isinstance(pretrained, str): + self._create_model( + pretrained=pretrained, + revision=revision, + dtype=dtype, + trust_remote_code=trust_remote_code, + parallelize=parallelize, + gpus=gpus, + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder, + peft=peft, + delta=delta, + autogptq=autogptq, + gptqmodel=gptqmodel, + gguf_file=gguf_file, + **kwargs, + ) + + # access self._model through self.model property outside this method + if isinstance(self.model, torch.nn.Module): + self.model.eval() + self.model.tie_weights() + + self.truncation = truncation + self.logits_cache = logits_cache + self.vocab_size = self.tokenizer.vocab_size + # select (or create) a pad token to use + self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config) + + self.add_bos_token = add_bos_token + if "gemma" in getattr(self.config, "model_type", ""): + self.add_bos_token = True + eval_logger.info( + f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it." + ) + + self._max_length = max_length + self.pretrained = pretrained + self.delta = delta + self.peft = peft + self.revision = revision + self.batch_schedule = 1 + self.batch_sizes = {} + self.max_batch_size = max_batch_size + + if str(batch_size).startswith("auto"): + batch_size = batch_size.split(":") + self.batch_size_per_gpu = batch_size[0] + self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1 + else: + self.batch_size_per_gpu = int(batch_size) + + if isinstance(pretrained, str): + if gpus >= 1 or str(self.device) == "mps": + # TODO: can remove this whole snippet except in the mps case, perhaps? + if not (parallelize or autogptq or hasattr(self, "accelerator")): + # place model onto device requested manually, + # if not using HF Accelerate or device_map + # or any other option that preloads model onto device + try: + self.model.to(self.device) + except ValueError: + eval_logger.debug( + "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore." + ) + # multigpu data-parallel support when launched with accelerate + if gpus > 1: + if accelerator.num_processes > 1: + if parallelize: + eval_logger.warning( + "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available." + ) + elif gpus > accelerator.num_processes: + eval_logger.warning( + "WARNING: The number of total system GPUs does not match the number of spawned processes. " + "If you would like to use data parallelism, please launch the script " + "with 'accelerate launch *script*'. " + f"Current run will proceed with {accelerator.num_processes} devices." + ) + if self.accelerator.is_local_main_process: + eval_logger.info( + f"Using {gpus} devices with data parallelism" + ) + + self._device = torch.device(f"{accelerator.device}") + self.accelerator = accelerator + + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + # if we aren't launching via accelerate, ditch + self._rank = 0 + self._world_size = 1 + else: + # if a PreTrainedModel was passed into HFLM, we forgo distributed setup. + eval_logger.warning( + "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration" + ) + self._rank = 0 + self._world_size = 1 + + self.custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + eval_logger.info( + f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}" + ) + + def _get_accelerate_args( + self, + parallelize: Optional[bool] = None, + device_map: Optional[str] = "auto", + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[str] = "./offload", + gpus: Optional[int] = None, + ) -> dict: + """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`.""" + num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes + if ( + num_machines == 0 + and hasattr(self, "accelerator") + and self.accelerator is not None + ): + eval_logger.info( + "We are not in a distributed setting for accelerate. Setting model_parallel to False." + ) + parallelize = False + + if parallelize is None: + # If parallelism is unset by the user, we automatically assign model parallelism + # if enough extra GPUs are available + max_memory_all_gpus = get_max_memory() + # We just want gpu, not cpu, max memory + if "cpu" in max_memory_all_gpus: + del max_memory_all_gpus["cpu"] + parallelize = bool(num_local_processes < len(max_memory_all_gpus)) + eval_logger.info( + f"Setting model parallel to {parallelize} since " + f"the number of local processes is {num_local_processes} " + f"and the number of GPUs is {len(max_memory_all_gpus)}" + ) + + args = {} + if parallelize: # Model parallelism will be used + max_memory = {} + if max_memory_per_gpu is not None: # Using the provided memory requirements + max_memory_per_gpu_map = { + device_idx: max_memory_per_gpu for device_idx in range(gpus) + } + else: # Estimating the possible memory requirements + max_memory_all_gpus = get_max_memory() + if "cpu" in max_memory_all_gpus: + del max_memory_all_gpus["cpu"] + if not hasattr(self, "accelerator"): + max_memory_per_gpu_map = { + k: v for k, v in max_memory_all_gpus.items() + } + else: + # use only 1 / num_processes of the GPUs if we are running under accelerate launch + max_memory_per_gpu_map = { + k: v + for k, v in max_memory_all_gpus.items() + if k % num_local_processes + == (self.accelerator.process_index % num_local_processes) + } + args["max_memory"] = max_memory_per_gpu_map + args["device_map"] = "auto" if device_map is None else device_map + eval_logger.info( + f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}" + ) + + if max_cpu_memory is not None: + max_memory["cpu"] = max_cpu_memory + + args["offload_folder"] = offload_folder + elif ( + device_map is None + ): # No model parallelism, we use the default provided device for our model + if hasattr(self, "accelerator"): + device_map = {"": f"{self.accelerator.device}"} + else: + device_map = {"": str(self.device)} + args["max_memory"] = None + args["device_map"] = device_map + eval_logger.info( + f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}" + ) + else: + args["max_memory"] = None + args["device_map"] = None + eval_logger.info("Model parallel was set to False.") + + return args + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + if self.custom_prefix_token_id is not None: + return self.custom_prefix_token_id + if self.tokenizer.bos_token_id is not None: + return self.tokenizer.bos_token_id + return self.tokenizer.eos_token_id + + @property + def max_length(self): + if self._max_length: # if max length manually set, return it + return self._max_length + seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") + for attr in seqlen_config_attrs: + if hasattr(self.model.config, attr): + return getattr(self.model.config, attr) + if hasattr(self.tokenizer, "model_max_length"): + if self.tokenizer.model_max_length == 1000000000000000019884624838656: + return self._DEFAULT_MAX_LENGTH + return self.tokenizer.model_max_length + return self._DEFAULT_MAX_LENGTH + + @property + def max_gen_toks(self) -> int: + return 256 + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + def _get_backend( + self, + config: Union[transformers.PretrainedConfig, transformers.AutoConfig], + backend: Literal["default", "causal", "seq2seq"] = "default", + trust_remote_code: Optional[bool] = False, + ) -> None: + """ + Helper method during initialization. + Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used. + sets `self.AUTO_MODEL_CLASS` appropriately if not already set. + + **If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM, + user must set `self.backend` to be either "causal" or "seq2seq" manually!** + """ + + assert backend in ["default", "causal", "seq2seq"] + + if backend != "default": + # if we've settled on non-default backend, use that manually + if backend == "causal": + self.backend = backend + elif backend == "seq2seq": + self.backend = backend + eval_logger.info( + f"Overrode HF model backend type, and using type '{self.backend}'" + ) + else: + # determine and use the default HF backend for this model, based on its config + metadata. + if ( + getattr(config, "model_type") + in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES + ): + # first check if model type is listed under seq2seq models, since some + # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers. + # these special cases should be treated as seq2seq models. + self.backend = "seq2seq" + eval_logger.debug(f"Using model type '{self.backend}'") + elif ( + getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + ): + self.backend = "causal" + eval_logger.debug(f"Using model type '{self.backend}'") + else: + if not trust_remote_code: + eval_logger.warning( + "HF model type is neither marked as CausalLM or Seq2SeqLM. \ + This is expected if your model requires `trust_remote_code=True` but may be an error otherwise." + "Setting backend to causal" + ) + # if model type is neither in HF transformers causal or seq2seq model registries + # then we default to assuming AutoModelForCausalLM + self.backend = "causal" + eval_logger.info( + f"Model type cannot be determined. Using default model type '{self.backend}'" + ) + + if self.AUTO_MODEL_CLASS is None: + if self.backend == "causal": + self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM + elif self.backend == "seq2seq": + self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM + + def _get_config( + self, + pretrained: str, + revision: str = "main", + trust_remote_code: bool = False, + gguf_file: Optional[str] = None, + ) -> None: + """Return the model config for HuggingFace models""" + self._config = transformers.AutoConfig.from_pretrained( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + ) + + def _create_model( + self, + pretrained: str, + revision: Optional[str] = "main", + dtype: Optional[Union[str, torch.dtype]] = "auto", + trust_remote_code: Optional[bool] = False, + # arguments used for splitting a model across GPUs naively. + # only used if `parallelize=True`. + # (accelerate naive PP (device_map) options) + parallelize: Optional[bool] = False, + gpus: Optional[int] = None, + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[str] = "./offload", + # PEFT, delta weights and quantization options + peft: Optional[str] = None, + delta: Optional[str] = None, + autogptq: Optional[Union[bool, str]] = False, + gptqmodel: Optional[bool] = False, + gguf_file: Optional[str] = None, + **kwargs, + ) -> None: + """ + Initializes an HF or HF-compatible PreTrainedModel from scratch + inside HFLM, using the kwargs passed into self.__init__(). + + Also handles functionality such as AutoGPTQ usage and PEFT wrapping. + + For future similar extensions to AutoGPTQ that are not core to HF's ecosystem, + (such as PyTorch models that are nearly, but not quite, fully mirroring + HF's public interface relied on in this HFLM class) + please consider subclassing HFLM and overriding this and other methods as needed. + """ + + model_kwargs = kwargs if kwargs else {} + + model_kwargs.update( + self._get_accelerate_args( + parallelize=parallelize, + device_map=kwargs.get("device_map", None), + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder, + gpus=gpus, + ) + ) + + if not autogptq and not gptqmodel: + if model_kwargs.get("load_in_4bit", None): + assert transformers.__version__ >= "4.30.0", ( + "load_in_4bit requires transformers >= 4.30.0" + ) + if transformers.__version__ >= "4.30.0": + if model_kwargs.get("load_in_4bit", None): + if model_kwargs.get("bnb_4bit_compute_dtype", None): + model_kwargs["bnb_4bit_compute_dtype"] = get_dtype( + model_kwargs["bnb_4bit_compute_dtype"] + ) + + self._model = self.AUTO_MODEL_CLASS.from_pretrained( + pretrained, + revision=revision, + torch_dtype=get_dtype(dtype), + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + **model_kwargs, + ) + else: + if autogptq and gptqmodel: + raise ValueError( + "Cannot use both 'autogptq' and 'gptqmodel' options at the same time." + ) + + if autogptq: + try: + from auto_gptq import AutoGPTQForCausalLM + except ModuleNotFoundError as exception: + raise type(exception)( + "Tried to load auto_gptq, but auto-gptq is not installed ", + "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]", + ) + + self._model = AutoGPTQForCausalLM.from_quantized( + pretrained, + trust_remote_code=trust_remote_code, + model_basename=None if autogptq is True else Path(autogptq).stem, + use_safetensors=True + if autogptq is True + else autogptq.endswith(".safetensors"), + **model_kwargs, + ) + + if gptqmodel: + try: + from gptqmodel import GPTQModel + except ModuleNotFoundError as exception: + raise type(exception)( + "Tried to load gptqmodel, but gptqmodel is not installed ", + "please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`", + ) + + self._model = GPTQModel.from_quantized( + pretrained, trust_remote_code=trust_remote_code, **model_kwargs + ) + + if peft and delta: + raise ValueError( + "Cannot use both 'peft' and 'delta' options at the same time." + ) + + if peft: + if model_kwargs.get("load_in_4bit", None): + if version.parse(PEFT_VERSION) < version.parse("0.4.0"): + raise AssertionError("load_in_4bit requires peft >= 0.4.0") + if self._model.config.vocab_size != len(self.tokenizer): + # resize model for LoRAs with added tokens + eval_logger.info( + f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..." + ) + self._model.resize_token_embeddings(len(self.tokenizer)) + self._model = PeftModel.from_pretrained( + self._model, peft, revision=revision + ) + elif delta: + if autogptq: + eval_logger.warning( + "Delta weights might trigger unexpected behavior when used with AutoGPTQ." + ) + _model_delta = self.AUTO_MODEL_CLASS.from_pretrained( + delta, + revision=revision, + torch_dtype=get_dtype(dtype), + trust_remote_code=trust_remote_code, + **model_kwargs, + ) + for name, param in self._model.state_dict().items(): + try: + param.data += _model_delta.state_dict()[name] + except KeyError: + raise KeyError(f"Delta model is missing weights for layer: {name}") + except Exception as e: + raise RuntimeError( + f"Failed to add delta weights to layer {name}. Error: {e}" + ) + + del _model_delta + + return None + + def _create_tokenizer( + self, + pretrained: Union[str, transformers.PreTrainedModel], + tokenizer: Optional[ + Union[ + str, + transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast, + ] + ], + revision: Optional[str] = "main", + trust_remote_code: Optional[bool] = False, + use_fast_tokenizer: Optional[bool] = True, + gguf_file: Optional[str] = None, + add_bos_token: Optional[bool] = False, + ) -> None: + """ + Helper method during initialization. + + Create a tokenizer object corresponding to the correct + tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed. + """ + kwargs = { + "revision": revision, + "trust_remote_code": trust_remote_code, + } + + # gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param + if gguf_file is not None: + kwargs["gguf_file"] = gguf_file + else: + kwargs["use_fast"] = use_fast_tokenizer + + if add_bos_token: + kwargs["add_bos_token"] = True + + if tokenizer: + if isinstance(tokenizer, str): + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + tokenizer, **kwargs + ) + else: + assert isinstance( + tokenizer, transformers.PreTrainedTokenizer + ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast) + self.tokenizer = tokenizer + else: + # Get tokenizer based on 'pretrained' + if isinstance(pretrained, str): + model_name = pretrained + else: + # get the HF hub name via accessor on model + model_name = self.model.name_or_path + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name, **kwargs + ) + return None + + def _detect_batch_size(self, requests=None, pos: int = 0): + if requests: + _, context_enc, continuation_enc = requests[pos] + max_length = len( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] + ) + max_context_enc = len(context_enc[-(self.max_length + 1) :]) + max_cont_enc = len(continuation_enc[-(self.max_length + 1) :]) + else: + max_length = self.max_length + max_context_enc = max_length + max_cont_enc = max_length + + # if OOM, then halves batch_size and tries again + @find_executable_batch_size(starting_batch_size=self.max_batch_size) + def forward_batch(batch_size): + if self.backend == "seq2seq": + length = max(max_context_enc, max_cont_enc) + batched_conts = torch.ones( + (batch_size, length), device=self.device + ).long() + test_batch = torch.ones((batch_size, length), device=self.device).long() + call_kwargs = { + "attn_mask": test_batch, + "labels": batched_conts, + } + else: + call_kwargs = {} + test_batch = torch.ones( + (batch_size, max_length), device=self.device + ).long() + for _ in range(5): + out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) # noqa: F841 + + return batch_size + + try: + batch_size = forward_batch() + except RuntimeError as e: + if "No executable batch size found" in str(e): + batch_size = 1 + else: + raise + + if self.world_size > 1: + # if multi-GPU, always take minimum over all selected batch sizes + max_rnk_bs = torch.tensor([batch_size], device=self.device) + gathered = ( + self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist() + ) + batch_size = min(gathered) + clear_torch_cache() + return batch_size + + clear_torch_cache() + return batch_size + + def tok_encode( + self, string: str, left_truncate_len=None, add_special_tokens=None + ) -> List[int]: + """ """ + # default for None - empty dict, use predefined tokenizer param + # used for all models except for CausalLM or predefined value + special_tokens_kwargs = {} + + # by default for CausalLM - false or self.add_bos_token is set + if add_special_tokens is None: + if self.backend == "causal": + special_tokens_kwargs = { + "add_special_tokens": False or self.add_bos_token + } + # otherwise the method explicitly defines the value + else: + special_tokens_kwargs = {"add_special_tokens": add_special_tokens} + + encoding = self.tokenizer.encode(string, **special_tokens_kwargs) + + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + + return encoding + + def tok_batch_encode( + self, + strings: List[str], + padding_side: str = "left", + left_truncate_len: int = None, + truncation: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. + old_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = padding_side + + add_special_tokens = {} + if self.backend == "causal": + add_special_tokens = {"add_special_tokens": False or self.add_bos_token} + + encoding = self.tokenizer( + strings, + truncation=truncation, + padding="longest", + return_tensors="pt", + **add_special_tokens, + ) + if left_truncate_len: + original_lengths = encoding["input_ids"].size(1) + if original_lengths > left_truncate_len: + eval_logger.warn( + f"Left truncation applied. Original sequence length was {original_lengths}, " + f"truncating to last {left_truncate_len} tokens. Some content will be lost.", + ) + encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] + encoding["attention_mask"] = encoding["attention_mask"][ + :, -left_truncate_len: + ] + self.tokenizer.padding_side = old_padding_side + + return encoding["input_ids"], encoding["attention_mask"] + + def tok_decode(self, tokens, skip_special_tokens=True): + return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def _model_call(self, inps, attn_mask=None, labels=None): + """ + :param inps: torch.Tensor + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape + [batch, sequence_ctx]. the size of sequence may vary from call to call + :param attn_mask: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :param labels: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :return + A torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model's decoder + """ + with torch.no_grad(): + if attn_mask is not None or labels is not None: + assert attn_mask is not None and labels is not None + assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM + return self.model( + input_ids=inps, attention_mask=attn_mask, labels=labels + ).logits + else: + assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM + return self.model(inps).logits + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + # temperature = 0.0 if not set + # if do_sample is false and temp==0.0: + # remove temperature, as do_sample=False takes care of this + # and we don't want a warning from HF + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + # build stopping criteria + stopping_criteria = stop_sequences_criteria( + self.tokenizer, stop, context.shape[1], context.shape[0] + ) + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=True, + **generation_kwargs, + ) + + def _select_cont_toks( + self, logits: torch.Tensor, contlen: int = None, inplen: int = None + ) -> torch.Tensor: + if self.backend == "causal": + assert contlen and inplen, ( + "Must pass input len and cont. len to select scored logits for causal LM" + ) + # discard right-padding. + # also discard the input/context tokens. we'll only score continuations. + logits = logits[inplen - contlen : inplen] + elif self.backend == "seq2seq": + assert contlen and not inplen, ( + "Selecting scored logits for Seq2SeqLM requires only cont. len" + ) + # only discard right-padding. + # the logits input to this fn only contain decoder-side tokens. + logits = logits[:contlen] + + return logits + + def loglikelihood_rolling( + self, requests: List[Instance], disable_tqdm: bool = False + ) -> List[float]: + adaptive_batch_size = None + if self.batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + + # First, collect all windows from all requests + all_windows = [] # List of (request_idx, window) tuples + request_window_counts = [] # Track number of windows per request + + for req_idx, (string,) in enumerate( + tqdm( + [req.args for req in requests], + disable=(disable_tqdm or (self.rank != 0)), + ) + ): + rolling_token_windows: List[Tuple[List[int], List[int]]] = list( + map( + utils.make_disjoint_window, + utils.get_rolling_token_windows( + token_list=self.tok_encode(string), + prefix_token=self.prefix_token_id, + max_seq_len=self.max_length, + context_len=1, + ), + ) + ) + + # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case + windows = [(None,) + x for x in rolling_token_windows] + + # Store windows with their request index + all_windows.extend((req_idx, window) for window in windows) + request_window_counts.append(len(windows)) + + # Handle distributed case padding + pad_amnt = 0 + if self.world_size > 1: + mytensor = torch.tensor(len(all_windows), device=self.device) + gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() + pad_amnt = max(gathered) - gathered[self.rank] + if pad_amnt > 0: + all_windows += pad_amnt * [all_windows[0]] + + all_nlls = [] + batch_size = adaptive_batch_size or self.batch_size + for i in range(0, len(all_windows), batch_size): + batch = all_windows[i : i + batch_size] + # Extract just the windows for processing, keeping track of request indices + batch_indices, batch_windows = zip(*batch) + + batch_nlls = self._loglikelihood_tokens( + requests=batch_windows, + disable_tqdm=False, + override_bs=len(batch_windows), + ) + # Store results with their request indices + all_nlls.extend(zip(batch_indices, batch_nlls)) + + # Remove padding if necessary + if (self.world_size > 1) and (pad_amnt > 0): + all_nlls = all_nlls[:-pad_amnt] + + # Reconstruct per-request loglikelihoods + loglikelihoods = [] + current_idx = 0 + for window_count in request_window_counts: + # Get all nlls for this request + request_nlls = all_nlls[current_idx : current_idx + window_count] + # Sum up the nlls for this request (discarding is_greedy) + request_total = sum(nll[0] for _, nll in request_nlls) + loglikelihoods.append(request_total) + current_idx += window_count + + string = requests[len(loglikelihoods) - 1].args[0] + self.cache_hook.add_partial( + "loglikelihood_rolling", (string,), request_total + ) + + return loglikelihoods + + def _batch_scheduler(self, pos, n_reordered_requests): + sched = pos // int(len(n_reordered_requests) / self.batch_schedule) + if sched in self.batch_sizes: + return self.batch_sizes[sched] + if (len(self.batch_sizes) > 1) and ( + self.batch_sizes[sched - 1] == self.max_batch_size + ): + # if previous batch size is already maximal, skip recomputation + self.batch_sizes[sched] = self.max_batch_size + return self.batch_sizes[sched] + print( + f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size" + ) + self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos) + print(f"Determined largest batch size: {self.batch_sizes[sched]}") + return self.batch_sizes[sched] + + def _loglikelihood_tokens( + self, + requests: List[Tuple[Tuple[str, str], List[int], List[int]]], + disable_tqdm: bool = False, + override_bs: int = None, + ) -> List[Tuple[float, bool]]: + # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context + res = [] + + def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + + toks = req[1] + req[2] + return -len(toks), tuple(toks) + + def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): + """Defines the key to group and lookup one-token continuations""" + # Use with group_by="contexts" (optional)" + # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. + # speeds up some multiple-choice tasks proportionally to the number of choices. + # groups requests by context+continuation[:-1] and infer on one request/group. + return req[-2] + req[-1][:-1] + + re_ord = Collator( + requests, + sort_fn=_collate, + group_by="contexts" + if self.backend == "causal" and self.logits_cache + else None, + group_fn=_lookup_one_token_cont, + ) + + # automatic (variable) batch size detection for vectorization + # pull longest context sample from request + n_reordered_requests = len(re_ord) + batch_size = ( + self.batch_size + if self.batch_size != "auto" + else override_bs + if override_bs is not None + else 0 + ) + batch_fn = ( + self._batch_scheduler + if self.batch_size == "auto" + and n_reordered_requests > 0 + and not override_bs + else None + ) + + chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running loglikelihood requests", + ) + for chunk in chunks: + inps = [] + cont_toks_list = [] + inplens = [] + + conts = [] + encoder_attns = [] + + padding_len_inp = None + padding_len_cont = None + # because vectorizing is annoying, we first convert each (context, continuation) pair to padded + # tensors, then we pack them together into a batch, call the model, and then pick it all apart + # again because vectorizing is annoying + + for _, context_enc, continuation_enc in chunk: + # sanity check + assert len(context_enc) > 0 + assert len(continuation_enc) > 0 + assert len(continuation_enc) <= self.max_length + + # how this all works (illustrated on a causal decoder-only setup): + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # model \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + # when too long to fit in context, truncate from the left + if self.backend == "causal": + total_length = len(context_enc) + len(continuation_enc) + if total_length > self.max_length + 1: + eval_logger.warn( + f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) " + f"exceeds model's maximum length ({self.max_length}). " + f"Truncating {total_length - self.max_length + 1} tokens from the left." + ) + inp = torch.tensor( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], + dtype=torch.long, + device=self.device, + ) + (inplen,) = inp.shape + elif self.backend == "seq2seq": + inp = torch.tensor( + (context_enc)[-self.max_length :], + dtype=torch.long, + device=self.device, + ) + (inplen,) = inp.shape + + # build encoder attn masks + encoder_attns.append(torch.ones_like(inp)) + + cont = torch.tensor( + (continuation_enc)[-self.max_length :], + # TODO: left-shift these? + # TODO: our code assumes we never end up truncating conts for either model type + dtype=torch.long, + device=self.device, + ) + (contlen,) = cont.shape + + conts.append(cont) + + padding_len_cont = ( + max(padding_len_cont, contlen) + if padding_len_cont is not None + else contlen + ) + + padding_len_inp = ( + max(padding_len_inp, inplen) + if padding_len_inp is not None + else inplen + ) + + inps.append(inp) # [1, inp_length] + cont_toks_list.append(continuation_enc) + inplens.append(inplen) + + # create encoder attn mask and batched conts, if seq2seq + call_kwargs = {} + if self.backend == "causal": + batched_inps = pad_and_concat( + padding_len_inp, inps, padding_side="right" + ) # [batch, padding_len_inp] + elif self.backend == "seq2seq": + # TODO: left-pad encoder inps and mask? + batched_inps = pad_and_concat( + padding_len_inp, inps + ) # [batch, padding_len_inp] + batched_conts = pad_and_concat( + padding_len_cont, conts + ) # [batch, padding_len_cont] + batched_encoder_mask = pad_and_concat( + padding_len_inp, encoder_attns + ) # [batch, padding_len_inp] + call_kwargs = { + "attn_mask": batched_encoder_mask, + "labels": batched_conts, + } + + multi_logits = F.log_softmax( + self._model_call(batched_inps, **call_kwargs), dim=-1 + ) # [batch, padding_length (inp or cont), vocab] + + for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( + chunk, multi_logits, inplens, cont_toks_list + ): + # Slice to original seq length + contlen = len(cont_toks) + # take only logits in the continuation + # (discard context toks if decoder-only ; discard right-padding) + # also discards + checks for "virtual tokens" in the causal LM's input window + # from prompt/prefix tuning tokens, if applicable + ctx_len = ( + inplen + (logits.shape[0] - padding_len_inp) + if self.backend == "causal" + else None + ) + logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) + logits = logits.unsqueeze(0) # [1, seq, vocab] + + # Check if per-token argmax is exactly equal to continuation + greedy_tokens = logits.argmax(dim=-1) + + # check for one-token continuation cache hits. + # noop in case group_by != "contexts" or no cache hit and returns the + # original args. Otherwise, expands the logits batch dimension and yields each + # batch along with matching continuation tokens and prompt strings. + # logits -> [1, seq, vocab] + for request_str, cont_toks, logits in re_ord.get_cache( + req_str=request_str, + cxt_toks=ctx_tokens, + cont_toks=cont_toks, + logits=logits, + ): + cont_toks = torch.tensor( + cont_toks, dtype=torch.long, device=self.device + ).unsqueeze(0) # [1, seq] + max_equal = (greedy_tokens == cont_toks).all() + + # Obtain log-probs at the corresponding continuation token indices + # last_token_slice = logits[:, -1, :].squeeze(0).tolist() + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( + -1 + ) # [1, seq] + + # Answer: (log prob, is-exact-match) + answer = (float(logits.sum()), bool(max_equal)) + + res.append(answer) + + if request_str is not None: + # special case: loglikelihood_rolling produces a number of loglikelihood requests + # all with cache key None. instead do add_partial on the per-example level + # in the loglikelihood_rolling() function for those. + self.cache_hook.add_partial( + "loglikelihood", request_str, answer + ) + pbar.update(1) + + pbar.close() + + return re_ord.get_original(res) + + def generate_until( + self, requests: List[Instance], disable_tqdm: bool = False + ) -> List[str]: + res = [] + + def _collate(req: Tuple[str, dict]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(req[0]) + return -len(toks), req[0] + + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running generate_until requests", + ) + adaptive_batch_size = None + if self.batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + # for each different set of kwargs, we execute all requests, by batch. + batch_size = ( + self.batch_size + if self.batch_size != "auto" + else adaptive_batch_size + if adaptive_batch_size is not None + else 0 + ) + batch_fn = ( + self._batch_scheduler + if self.batch_size == "auto" and not adaptive_batch_size + else None + ) + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) + re_ords = Collator( + [reg.args for reg in requests], + sort_fn=_collate, + group_by="gen_kwargs", + group_fn=lambda x: x[1], + ) + chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) + eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) + for chunk in chunks: + contexts, all_gen_kwargs = zip(*chunk) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + # unpack our keyword arguments. + if isinstance(gen_kwargs, dict): + kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 + # add EOS token to stop sequences + until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) + else: + raise ValueError( + f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" + ) + if "max_gen_toks" in kwargs.keys(): + max_gen_toks = kwargs.pop("max_gen_toks") + else: + max_gen_toks = self.max_gen_toks + + # set the max length in tokens of inputs ("context_enc") + if self.backend == "causal": + # max len for inputs = max length, minus room to generate the max new tokens + max_ctx_len = self.max_length - max_gen_toks + assert max_ctx_len > 0, ( + f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})." + ) + elif self.backend == "seq2seq": + # max len for inputs = encoder's whole max_length + max_ctx_len = self.max_length + + # encode, pad, and truncate contexts for this batch + context_enc, attn_masks = self.tok_batch_encode( + contexts, + left_truncate_len=max_ctx_len, + truncation=self.truncation, + ) + context_enc = context_enc.to(self.device) + attn_masks = attn_masks.to(self.device) + + if "max_length" not in kwargs: + kwargs["max_length"] = context_enc.shape[1] + max_gen_toks + + # perform batched generation + cont = self._model_generate( + context=context_enc, + attention_mask=attn_masks, + stop=until, + **kwargs, + ) + + cont_toks_list = cont.tolist() + for cont_toks, context in zip(cont_toks_list, contexts): + # discard context + left-padding toks if using causal decoder-only LM + if self.backend == "causal": + cont_toks = cont_toks[context_enc.shape[1] :] + + s = self.tok_decode(cont_toks) + + # use secondary stop seqs to cut off should-have-been-stopped content post-hoc + for term in until: + if len(term) > 0: + # ignore '' separator, + # for seq2seq case where self.tok_decode(self.eot_token_id) = '' + s = s.split(term)[0] + + res.append(s) + + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + + return res + + def apply_chat_template( + self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True + ) -> str: + """ + Method to apply a chat template to a list of chat history between user and model. + """ + try: + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + except jinja2.exceptions.TemplateError: + eval_logger.warning( + "Failed to apply chat template. removing the system role in chat history." + ) + chat_history = [msg for msg in chat_history if msg["role"] != "system"] + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + + return chat_templated + + def get_model_info(self) -> dict: + """ + Method to get Hugging Face model information for experiment reproducibility. + """ + + def get_model_num_params(model) -> int: + if hasattr(model, "num_parameters"): + return model.num_parameters() + if hasattr(model, "parameters"): + return sum(p.numel() for p in model.parameters()) + else: + return -1 + + def get_model_dtype(model) -> str: + if hasattr(model, "dtype"): + return model.dtype + else: + return "" + + def get_model_sha(pretrained: str, revision: str) -> str: + try: + model_info = HfApi().model_info(repo_id=pretrained, revision=revision) + return model_info.sha + except Exception as e: + eval_logger.debug( + f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}" + ) + return "" + + model_info = { + "model_num_parameters": get_model_num_params(self._model), + "model_dtype": get_model_dtype(self._model), + "model_revision": self.revision, + "model_sha": get_model_sha(self.pretrained, self.revision), + } + if self.peft: + model_info["peft_sha"] = get_model_sha(self.peft, self.revision) + if self.delta: + model_info["delta_sha"] = get_model_sha(self.delta, self.revision) + return model_info diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/utils.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2878de6ea500b1e4f6723c9ef63f7d6b90be5711 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/utils.py @@ -0,0 +1,731 @@ +import collections +import fnmatch +import gc +import itertools +import logging +import time +from functools import wraps +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Literal, + Optional, + Tuple, + Type, + Union, +) + +import torch +import transformers + + +eval_logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + from transformers.configuration_utils import PretrainedConfig + + +def chunks(iter, n: int = 0, fn=None): + """ + Divides an iterable into chunks of specified size or based on a given function. + Useful for batching + + Parameters: + - iter: The input iterable to be divided into chunks. + - n: An integer representing the size of each chunk. Default is 0. + - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None. + + Returns: + An iterator that yields chunks of the input iterable. + + Example usage: + ``` + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for chunk in chunks(data, 3): + print(chunk) + ``` + Output: + ``` + [1, 2, 3] + [4, 5, 6] + [7, 8, 9] + [10] + ``` + """ + arr = [] + for i, x in enumerate(iter): + arr.append(x) + if len(arr) == (fn(i, iter) if fn else n): + yield arr + arr = [] + + if arr: + yield arr + + +class MultiChoice: + def __init__(self, choices) -> None: + self.choices = choices + + # Simple wildcard support (linux filename patterns) + def __contains__(self, values) -> bool: + for value in values.split(","): + if len(fnmatch.filter(self.choices, value)) == 0: + eval_logger.info("Available tasks to choose:") + for choice in self.choices: + eval_logger.info(f" - {choice}") + raise ValueError("'{}' is not in task list".format(value)) + return True + + def __iter__(self) -> Iterator: + for choice in self.choices: + yield choice + + +class Grouper: + """ + takes an array `arr` and function `fn` and returns a dictionary + with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all + objects in `arr` satisfying `key == fn(ob)`. + """ + + def __init__(self, arr, fn) -> None: + # self.orig_arr = arr + self.size = len(arr) + arr = list(enumerate(arr)) + + def group_return_dict(arr, fn): + res = collections.defaultdict(list) + + for ob in arr: + res[fn(ob)].append(ob) + return res + + arr = group_return_dict(arr, lambda x: fn(x[1])) + + # self.arr has format Dict[Tuple[int, ]] + self.arr = arr + self._grouped = None + + def get_grouped(self): + # return the contents but not indices for our grouped dict. + if self._grouped: + return self._grouped + grouped = {} + for key in self.arr.keys(): + # drop the index from each element of self.arr + grouped[key] = [y[1] for y in self.arr[key]] + self._grouped = grouped + return grouped + + def get_original(self, grouped_dict): + # take in a grouped dictionary with e.g. results for each key listed + # in the same order as the instances in `self.arr`, and + # return the results in the same (single list) order as `self.orig_arr`. + res = [None] * self.size + cov = [False] * self.size + # orig = [None] * self.size + + assert grouped_dict.keys() == self.arr.keys() + + for key in grouped_dict.keys(): + for (ind, _), v in zip(self.arr[key], grouped_dict[key]): + res[ind] = v + cov[ind] = True + # orig[ind] = _ + + assert all(cov) + # assert orig == self.orig_arr + + return res + + +def pad_and_concat( + max_length: int, + tensors: List[torch.Tensor], + padding_side: Literal["right", "left"] = "right", +): + """ + Method for padding a list of tensors given the maximum tensor + length in the batch. Used for batching inputs and continuations in + seq2seq models. + """ + assert padding_side == "left" or padding_side == "right", ( + f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" + ) + + for i, tensor in enumerate(tensors): + if len(tensor.shape) == 2: + tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size + tensor_len = tensor.shape[0] + if tensor_len < max_length: + if padding_side == "right": + # right-pad + tensors[i] = torch.cat( + [ + tensor, # [seq] + torch.zeros( + max_length - tensor_len, + dtype=torch.long, + device=tensor.device, + ), # [padding_length - seq] + ], + dim=0, + ).unsqueeze(0) + else: + # left-pad + tensors[i] = torch.cat( + [ + torch.zeros( + max_length - tensor_len, + dtype=torch.long, + device=tensor.device, + ), # [padding_length - seq] + tensor, # [seq] + ], + dim=0, + ).unsqueeze(0) + else: + tensors[i] = tensor.unsqueeze(0) + + return torch.cat(tensors, dim=0) + + +def clear_torch_cache() -> None: + gc.collect() + torch.cuda.empty_cache() + + +def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: + """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig""" + if isinstance(dtype, str) and dtype != "auto": + # Convert `str` args torch dtype: `float16` -> `torch.float16` + _torch_dtype = getattr(torch, dtype) + else: + _torch_dtype = dtype + return _torch_dtype + + +class MultiTokenEOSCriteria(transformers.StoppingCriteria): + """Criteria to stop on the specified multi-token sequence.""" + + def __init__( + self, + sequence: str, + tokenizer: transformers.PreTrainedTokenizer, + initial_decoder_input_length: int, + batch_size: int, + ) -> None: + self.initial_decoder_input_length = initial_decoder_input_length + self.done_tracker = [False] * batch_size + self.sequence = sequence + self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) + # print(sequence, self.sequence_ids) + # we look back for 2 more tokens than it takes to encode our stop sequence + # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']` + # and we don't want to mistakenly not stop a generation because our + # (string) stop sequence was output in a different tokenization + + # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model, + # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized + # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described. + self.sequence_id_len = len(self.sequence_ids) + 2 + self.tokenizer = tokenizer + + def __call__(self, input_ids, scores, **kwargs) -> bool: + # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence + lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :] + + lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :] + + lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) + + for i, done in enumerate(self.done_tracker): + if not done: + self.done_tracker[i] = self.sequence in lookback_tokens_batch[i] + return False not in self.done_tracker + + +def stop_sequences_criteria( + tokenizer: transformers.PreTrainedTokenizer, + stop_sequences: List[str], + initial_decoder_input_length: int, + batch_size: int, +) -> transformers.StoppingCriteriaList: + return transformers.StoppingCriteriaList( + [ + *[ + MultiTokenEOSCriteria( + sequence, tokenizer, initial_decoder_input_length, batch_size + ) + for sequence in stop_sequences + ], + ] + ) + + +def undistribute(iterable): + """ + Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute . + + Re-interleaves results that have been split using more_itertools.distribute: + >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 3, 5] + >>> list(group_2) + [2, 4, 6] + >>> undistribute([group_1, group_2]) + [1, 2, 3, 4, 5, 6] + + Handles non-uniform component lengths: + + >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 4, 7], [2, 5], [3, 6]] + >>> undistribute(children) + [1, 2, 3, 4, 5, 6, 7] + + Also handles when some iterables are empty: + + >>> children = distribute(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + >>> undistribute(children) + [1, 2, 3] + + """ + + return [ + x + for x in itertools.chain.from_iterable( + itertools.zip_longest(*[list(x) for x in iterable]) + ) + if x is not None + ] + + +def retry_on_specific_exceptions( + on_exceptions: List[Type[Exception]], + max_retries: Optional[int] = None, + backoff_time: float = 3.0, + backoff_multiplier: float = 1.5, + on_exception_callback: Optional[Callable[[Exception, float], Any]] = None, +): + """Retry on an LLM Provider's rate limit error with exponential backoff + For example, to use for OpenAI, do the following: + ``` + from openai import RateLimitError + + # Recommend specifying max_retries to avoid infinite loops! + @retry_on_specific_exceptions([RateLimitError], max_retries=3) + def completion(...): + # Wrap OpenAI completion function here + ... + ``` + """ + + def decorator(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + sleep_time = backoff_time + attempt = 0 + while max_retries is None or attempt < max_retries: + try: + return func(*args, **kwargs) + except tuple(on_exceptions) as e: + if on_exception_callback is not None: + on_exception_callback(e, sleep_time) + time.sleep(sleep_time) + sleep_time *= backoff_multiplier + attempt += 1 + + return wrapper + + return decorator + + +class Collator: + """ + A class for reordering and batching elements of an array. + + This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data. + + Objects of this class have the group_by attribute which determines the method for grouping + the data while batching it. Three options include "gen_kwargs", "contexts", or None: + If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs + If group_by == "contexts" then requests will be grouped by context + cont[:-1] + If None then requests will just be reordered by length descending. + """ + + def __init__( + self, + arr: List, + sort_fn: Callable = lambda x: x, + group_fn: Callable = lambda x: x[1], + group_by: Union[Literal["gen_kwargs", "contexts"], None] = None, + ) -> None: + self._group_by = group_by + # 0 indices are enumerated indices. Apply functions to original arr. + self._sort_fn = lambda x: sort_fn(x[1]) + self._group_fn = lambda x: group_fn(x[1]) + self._reorder_indices: List = [] + self._size = len(arr) + self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple( + enumerate(arr) + ) # [indices, (arr)] + if self._group_by == "contexts": + self._group_by_context() + elif self._group_by == "gen_kwargs": + self._group_by_index() + + def _group_by_index(self) -> None: + """Group the elements of a list based on their indices.""" + self._arr_with_indices = self.group( + self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs" + ) + + def _group_by_context(self) -> None: + """Group the array with indices by context.""" + self._arr_with_indices = self.group( + self._arr_with_indices, fn=self._group_fn, group_by="contexts" + ) + + def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator: + """ + Generates and yields batches from the reordered array. The method of grouping and batching + depends on the parameter `group_by`. + If `group_by` is set to "gen_kwargs", it will batch the + re-ordered values with same gen_kwargs for each batch. + If `group_by` is "contexts", it caches the requests by context before batching. + If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array + + Parameters: + - n (int): The size of each batch. Defaults to 1. + - batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of + each batch. Optional, defaults to None. + + Returns: + Iterator: An iterator over batches of reordered elements grouped as per the `group_by` + attribute. + + Yields: + List of batched elements according to the `group_by` attribute. + """ + if self._group_by == "gen_kwargs": + for ( + key, + values, + ) in self._arr_with_indices.items(): # type: ignore + values = self._reorder(values) + batch = self.get_chunks(values, n=n, fn=batch_fn) + yield from batch + elif self._group_by == "contexts": + # Get one sample from each key + values = self._reorder( + [value[0] for value in self._arr_with_indices.values()] + ) + batch = self.get_chunks(values, n=n, fn=batch_fn) + yield from batch + else: + values = self._reorder(self._arr_with_indices) # type: ignore + batch = self.get_chunks(values, n=n, fn=batch_fn) + yield from batch + + def get_cache( + self, + req_str: Tuple[str, str] = None, + cxt_toks: List[int] = None, + cont_toks: List[int] = None, + logits: torch.Tensor = None, + ) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]: + """ + Retrieves cached single-token continuations and their associated arguments, updating indices as necessary. + + The behavior of this function varies depending on how the `group_by` attribute is set: + + - When `group_by` is "contexts": + The function identifies single-token continuations by checking for keys that equate to + [context+continuation][-1] and logs the indices for re-ordering. + In this mode, this function can work in two scenarios: + + 1. Cache Hit - Single Match: + If a single matching context-continuation pair is found in the cache, + the function yields the original arguments. + + 2. Cache Hit - Multiple Matches: + If multiple matching context-continuation pairs are found in the cache, + the function expands the logits batch dimension to match the number of cache hits. + It updates the original requests and continuation tokens. + + - When `group_by` is not set to "contexts": + This method yields the original arguments, logits and continuation tokens, + without checking for one-token continuations. + + Parameters: + - req_str (tuple[str, str]): Original strings used for CachingLM. + - cxt_toks (list[int]): Full context tokens used for lookup. + - cont_toks (list[int]): Continuation tokens for which logits were generated. + - logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys. + + Yields: + - Iterator: + - req_str (tuple[str, str]): strings used for CachingLM. + - cont_toks (list[int]) : continuation tokens. + - logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times) + """ + if self._group_by == "contexts": + cache_hit: List[ + Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]] + ] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1])) + if (cache_size := len(cache_hit)) == 1: + self._reorder_indices.extend(x[0] for x in cache_hit) + yield req_str, cont_toks, logits + else: + # If we have matching requests then expand the batch dimension (no-op) and + # yield each along with its corresponding args. + multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size) + indices, req_str, cont_toks = zip( + *[(x[0], x[1][0], x[-1][-1]) for x in cache_hit] + ) + self._reorder_indices.extend(indices) + for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits): + yield c_key, cont_tok, logit + else: + yield req_str, cont_toks, logits + + def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator: + """ + Reorders the elements in the array based on the sorting function. + + Parameters: + - arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered. + + Yields: + Iterator + """ + arr = sorted(arr, key=self._sort_fn) + if not self._group_by == "contexts": + # If grouped by contexts then indices will be set in get_cache() + self._reorder_indices.extend([x[0] for x in arr]) + yield from [x[1] for x in arr] + + def get_original(self, newarr: List) -> List: + """ + Restores the original order of elements from the reordered list. + + Parameters: + - newarr (list): The reordered array. + + Returns: + list: The array with elements restored to their original order. + """ + res = [None] * self._size + cov = [False] * self._size + + for ind, v in zip(self._reorder_indices, newarr): + res[ind] = v + cov[ind] = True + + assert all(cov) + + return res + + def __len__(self): + return self._size + + @staticmethod + def group( + arr: Iterable, + fn: Callable, + group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs", + ) -> dict: + """ + Groups elements of an iterable based on a provided function. + + + The `group_by` parameter determines the method of grouping. + If `group_by` is "contexts", the elements are grouped by [context + cont][:-1]. + If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict. + + Parameters: + - arr (Iterable): The iterable to be grouped. + - fn (Callable): The function to determine the grouping. + - values (bool): If True, returns the values of the group. Defaults to False. + + Returns: + Iterator: An iterable of grouped elements. + """ + res = collections.defaultdict(list) + for ob in arr: + # where ob == [context + cont] + if group_by == "contexts": + res[tuple(fn(ob))].append(ob) + else: + try: + hashable_dict = tuple( + ( + key, + tuple(value) + if isinstance(value, collections.abc.Iterable) + else value, + ) + for key, value in sorted(fn(ob).items()) + ) + res[hashable_dict].append(ob) + except (TypeError, AttributeError): + res[tuple(fn(ob))].append(ob) + return res + + @staticmethod + def get_chunks(_iter, n: int = 0, fn=None): + """ + Divides an iterable into chunks of specified size or based on a given function. + Useful for batching + + Parameters: + - iter: The input iterable to be divided into chunks. + - n: An integer representing the size of each chunk. Default is 0. + - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None. + + Returns: + An iterator that yields chunks of the input iterable. + + Example usage: + ``` + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for chunk in chunks(data, 3): + print(chunk) + ``` + Output: + ``` + [1, 2, 3] + [4, 5, 6] + [7, 8, 9] + [10] + ``` + """ + arr = [] + _iter = tuple(_iter) + for i, x in enumerate(_iter): + arr.append(x) + if len(arr) == (fn(i, _iter) if fn else n): + yield arr + arr = [] + + if arr: + yield arr + + +def configure_pad_token( + tokenizer: "PreTrainedTokenizerBase", + model_config: Optional["PretrainedConfig"] = None, +) -> "PreTrainedTokenizerBase": + """ + This function checks if the (Hugging Face) tokenizer has a padding token and sets it if not present. + Some tokenizers require special handling. + + Args: + tokenizer: The tokenizer for which the padding token is to be handled. + model_config: The configuration of the model. Default is None. + + Returns: + The tokenizer after the padding token has been handled. + + Raises: + AssertionError: If the tokenizer is of type RWKVWorldTokenizer or Rwkv5Tokenizer and the padding token id is not 0. + """ + if tokenizer.pad_token: + pass + elif tokenizer.unk_token: + tokenizer.pad_token_id = tokenizer.unk_token_id + elif tokenizer.eos_token: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + # handle special cases + if model_config and getattr(model_config, "model_type", None) == "qwen": + # Qwen's trust_remote_code tokenizer does not allow for adding special tokens + tokenizer.pad_token = "<|endoftext|>" + elif ( + tokenizer.__class__.__name__ == "RWKVWorldTokenizer" + or tokenizer.__class__.__name__ == "Rwkv5Tokenizer" + ): + # The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0) + # The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer + # --- + # Note that the world tokenizer class name, might change in the future for the final huggingface merge + # https://github.com/huggingface/transformers/pull/26963 + assert tokenizer.pad_token_id == 0 + else: + tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) + + return tokenizer + + +def replace_placeholders( + string: str, default_placeholder: str, image_token: str, max_images: int +): + """ + A utility function used for local multimodal models. It locates all `placeholder` string + occurrences in the given input `string_` and replaces the first `max_count` instances with + `replacement`, and all subsequent occurrences with the empty string. + + This is used to replace placeholder tags by model-specific image tokens like <|image_pad|> + and to allow for only the first `max_count` images to be passed to a model if desired. + + :param string: The original string containing placeholders. + :param default_placeholder: The placeholder text to be replaced. + :param image_token: The token to replace the placeholder with. + :param max_images: The maximum number of replacements to make. + :return: The string with placeholders replaced. + """ + count = 0 + result = [] + + parts = string.split(default_placeholder) + for part in parts[:-1]: # Iterate through all but the last part + result.append(part) + if count < max_images: + result.append(image_token) + count += 1 + elif default_placeholder != image_token: + result.append(default_placeholder) + + # Add the last part of the string + result.append(parts[-1]) + return "".join(result) + + +def flatten_image_list(images: List[List]): + """ + Takes in a list of lists of images, and returns a single list of all images in order. + Used for some multimodal models like Llava-1.5 which expects this flattened-list format for its image processor. + + :param images: A list of lists of PIL images. + :return: a list of PIL images, via concatenating all the sub-lists in order. + """ + return [image for image_list in images for image in image_list] + + +def handle_stop_sequences( + until: Union[str, List[str], None], eos: Optional[str] +) -> List[str]: + """Ensures that the `until` parameter is a list of stop sequences and includes the EOS token.""" + if isinstance(until, str): + until = [until] + elif until is None: + until = [] + elif not isinstance(until, list): + raise ValueError( + f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" + ) + + if eos is not None and eos not in until: + until.append(eos) + return until diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/verifier.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..664cf9159e3233528671db6e60d5f521fae2da4b --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/models/verifier.py @@ -0,0 +1,155 @@ +import torch +import logging +import ast +import re +import numpy as np +import textwrap + +logger = logging.getLogger(__name__) + +class CodeVerifier: + def __init__(self, model, tokenizer, device="cuda"): + self.model = model + self.tokenizer = tokenizer + self.device = device + + self.yes_ids, self.no_ids = [], [] + for t in ["Yes", " Yes", "YES"]: + ids = self.tokenizer.encode(t, add_special_tokens=False) + if len(ids) > 0: self.yes_ids.append(ids[-1]) + for t in ["No", " No", "NO"]: + ids = self.tokenizer.encode(t, add_special_tokens=False) + if len(ids) > 0: self.no_ids.append(ids[-1]) + + self.yes_ids = list(set(self.yes_ids)) + self.no_ids = list(set(self.no_ids)) + + def _extract_python_code(self, text): + text = text.strip() + match = re.search(r"```python\s*(.*?)```", text, re.DOTALL) + if match: return match.group(1) + match_generic = re.search(r"```\s*(.*?)```", text, re.DOTALL) + if match_generic: return match_generic.group(1) + return text + + def check_syntax(self, code_str): + clean_code = self._extract_python_code(code_str) + try: + if len(clean_code.strip()) < 5: return False + ast.parse(clean_code) + return True + except: + return False + + def compute_confidence(self, logits): + if logits is None: return 0.0 + probs = torch.softmax(logits, dim=-1) + max_probs, _ = torch.max(probs, dim=-1) + log_probs = torch.log(max_probs + 1e-10) + return torch.exp(torch.mean(log_probs)).item() + + def svf_score(self, prompt, code_str, task_type="code"): + + max_len = 2000 + if len(code_str) > max_len: + if task_type == "reasoning": + truncated_code = code_str[:500] + "\n...[truncated]...\n" + code_str[-(max_len-500):] + else: + truncated_code = code_str[-max_len:] + else: + truncated_code = code_str + + if task_type == "code": + prompt_template = f""" + You are an expert programming contest judge. Your task is to evaluate a generated solution for a given problem based on correctness, efficiency, and adherence to constraints. + + [Problem Statement] + {prompt} + [/Problem Statement] + + [Proposed Python Solution] + ```python + {truncated_code} + ``` + [/Proposed Python Solution] + + **Analysis Steps:** + 1. Correctness: Does the core algorithm correctly solve the problem? + 2. Efficiency: Is the time complexity acceptable for the given constraints? + 3. Edge Cases & Constraints: Does the code handle all rules and edge cases? + + **Conclusion**: Based on your analysis, is the solution likely to be fully correct? Answer with a single word: Yes or No. + **Answer:** """ + + elif task_type == "math": + prompt_template = f""" + You are an expert mathematician and competition judge. Your task is to evaluate a proposed mathematical solution for a given problem based on its logical rigor and accuracy. + + [Math Problem] + {prompt} + [/Math Problem] + + [Proposed Mathematical Solution] + {truncated_code} + [/Proposed Mathematical Solution] + + **Analysis Steps:** + 1. Reasoning Validity: Are the logical steps and mathematical properties applied correctly? + 2. Calculation Accuracy: Are the intermediate calculations or algebraic manipulations accurate? + 3. Goal Alignment: Does the current reasoning path directly lead toward the final answer required by the problem? + + **Conclusion**: Based on your analysis, is this solution path sound and likely to result in the correct final answer? Answer with a single word: Yes or No. + **Answer:** """ + + elif task_type == "reasoning": + prompt_template = f""" + You are an expert reading comprehension and faithfulness judge. Your task is to evaluate a generated answer based on the provided context and question. + + [Context and Question] + {prompt} + [/Context and Question] + + [Proposed Answer] + {truncated_code} + [/Proposed Answer] + + **Analysis Steps :** + 1. Faithfulness: Is the answer an exact, literal span from the context? + 2. Relevance: Does the answer directly address the specific question asked without hallucinating external information? + 3. Accuracy: Does the provided context strictly support this answer? + + **Conclusion**: Based on your analysis, is the answer fully faithful to the context and correct? Answer with a single word: Yes or No. + **Answer:** """ + + else: + prompt_template = f"Is the following answer correct?\nQuestion: {prompt}\nAnswer: {truncated_code}\nAnswer Yes or No.\nAnswer:" + + verify_text = textwrap.dedent(prompt_template).strip() + input_ids = self.tokenizer(verify_text, return_tensors="pt").input_ids.to(self.device) + + max_pos = getattr(self.model.config, "max_position_embeddings", + getattr(self.model.config, "n_positions", + getattr(self.model.config, "max_sequence_length", 20480))) + + if input_ids.shape[1] > max_pos - 16: + logger.warning("Verifier input is too long, truncating from the left.") + input_ids = input_ids[:, -(max_pos - 16):] + + with torch.no_grad(): + with torch.amp.autocast('cuda', dtype=torch.bfloat16): + outputs = self.model(input_ids, 'full') + logits = outputs.logits[0, -1, :] + + yes_score = max((logits[i].item() for i in self.yes_ids if i < logits.shape[-1]), default=-float('inf')) + no_score = max((logits[i].item() for i in self.no_ids if i < logits.shape[-1]), default=-float('inf')) + + if yes_score == -float('inf') and no_score == -float('inf'): return 0.5 + + probs = torch.softmax(torch.tensor([yes_score, no_score]), dim=0) + return probs[0].item() + + def get_reward(self, prompt, code_str, mode="confidence", problem_data=None, current_logits=None, task_type="code"): + if mode == "svf": + return self.svf_score(prompt, code_str, task_type=task_type) + else: + return self.compute_confidence(current_logits) \ No newline at end of file diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/utils.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bca19866ce7ac86c8c34bfac3f1f07a2c2345565 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/utils.py @@ -0,0 +1,552 @@ +import collections +import fnmatch +import functools +import hashlib +import importlib.util +import inspect +import json +import logging +import os +import re +from dataclasses import asdict, is_dataclass +from itertools import islice +from pathlib import Path +from typing import Any, Callable, Generator, List, Optional, Tuple + +import numpy as np +import yaml +from jinja2 import BaseLoader, Environment, StrictUndefined + + +SPACING = " " * 47 + +HIGHER_IS_BETTER_SYMBOLS = { + True: "↑", + False: "↓", +} + + +def setup_logging(verbosity=logging.INFO): + # Configure the root logger + class CustomFormatter(logging.Formatter): + def format(self, record): + if record.name.startswith("lm_eval."): + record.name = record.name[len("lm_eval.") :] + return super().format(record) + + formatter = CustomFormatter( + "%(asctime)s %(levelname)-8s [%(name)s:%(lineno)d] %(message)s", + datefmt="%Y-%m-%d:%H:%M:%S", + ) + + log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity + + level_map = { + "DEBUG": logging.DEBUG, + "INFO": logging.INFO, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + + log_level = level_map.get(str(log_level).upper(), logging.INFO) + + if not logging.root.handlers: + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + root_logger = logging.getLogger() + root_logger.addHandler(handler) + root_logger.setLevel(log_level) + + if log_level == logging.DEBUG: + third_party_loggers = ["urllib3", "filelock", "fsspec"] + for logger_name in third_party_loggers: + logging.getLogger(logger_name).setLevel(logging.INFO) + else: + logging.getLogger().setLevel(log_level) + + +def hash_string(string: str) -> str: + return hashlib.sha256(string.encode("utf-8")).hexdigest() + + +def escaped_split(text, sep_char, maxsplit=-1): + """Split text into a list on occurrences of the given separation + character `sep_char`. The separation character may be escaped by a + backslash to avoid splitting at that location. + + The separation character must be a string of size 1. + + If `maxsplit` is given, at most `maxsplit` splits are done (thus, + the list will have at most `maxsplit + 1` elements). If `maxsplit` + is not specified or less than 0, then there is no limit on the + number of splits (all possible splits are made). + """ + assert len(sep_char) == 1, ( + "separation string must be a single character for escaped splitting" + ) + + if maxsplit == 0: + return text + maxsplit = max(0, maxsplit) + + return re.split(r"(? dict: + """ + Parses something like + args1=val1,arg2=val2 + Into a dictionary + """ + if args_string is None: + return {} + args_string = args_string.strip() + if not args_string: + return {} + arg_list = [arg for arg in args_string.split(",") if arg] + args_dict = { + kv[0]: handle_arg_string("=".join(kv[1:])) + for kv in [arg.split("=") for arg in arg_list] + } + return args_dict + + +def join_iters(iters): + for iter in iters: + yield from iter + + +def group(arr, fn): + res = collections.defaultdict(list) + + for ob in arr: + res[fn(ob)].append(ob) + + return list(res.values()) + + +# Returns a list containing all values of the source_list that +# match at least one of the patterns +def pattern_match(patterns, source_list): + if isinstance(patterns, str): + patterns = [patterns] + + task_names = set() + for pattern in patterns: + for matching in fnmatch.filter(source_list, pattern): + task_names.add(matching) + return sorted(list(task_names)) + + +def softmax(x) -> np.ndarray: + """Compute softmax values for each sets of scores in x.""" + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum() + + +def general_detokenize(string) -> str: + string = string.replace(" n't", "n't") + string = string.replace(" )", ")") + string = string.replace("( ", "(") + string = string.replace('" ', '"') + string = string.replace(' "', '"') + string = re.sub(r" (['.,])", r"\1", string) + return string + + +def get_file_task_name(filename: str) -> str: + """ + Given the sample results filenames, extracts and returns the task name. + """ + return filename[filename.find("_") + 1 : filename.rfind("_")] + + +def get_file_datetime(filename: str) -> str: + """ + Given the results and sample results filenames, extracts and returns the datetime. + """ + return filename[filename.rfind("_") + 1 :].replace(".jsonl", "") + + +def sanitize_model_name(model_name: str) -> str: + """ + Given the model name, returns a sanitized version of it. + """ + return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name) + + +def sanitize_task_name(task_name: str) -> str: + """ + Given the task name, returns a sanitized version of it. + """ + return re.sub(r"\W", "_", task_name) + + +def get_latest_filename(filenames: List[str]) -> str: + """ + Given a list of filenames, returns the filename with the latest datetime. + """ + return max(filenames, key=lambda f: get_file_datetime(f)) + + +def get_results_filenames(filenames: List[str]) -> List[str]: + """ + Extracts filenames that correspond to aggregated results. + """ + return [f for f in filenames if "/results_" in f and ".json" in f] + + +def get_sample_results_filenames(filenames: List[str]) -> List[str]: + """ + Extracts filenames that correspond to sample results. + """ + return [f for f in filenames if "/samples_" in f and ".json" in f] + + +def get_rolling_token_windows( + token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int +) -> Generator[Tuple[List[int], List[int]], None, None]: + """ + - context_len allows for a rolling window context, allowing each prediction window to potentially + condition on some context + + :param token_list: list + List of tokens to be PREDICTED + :param max_seq_len: int + max_seq_len of model (or max_seq_len we want to use) + :param context_len: int + Amount of desired token context for prediction. Needs to be at least 1. + :param prefix_token: token + Dummy token like so the first token has something to condition on + :return: generator + Generator of tuples + (input_tokens, pred_tokens) + Note: Score only the last len(pred_tokens) logits of the LM + """ + assert 1 <= context_len <= max_seq_len + if not token_list: + return + # +1 offset, going from input->preds + pred_len = max_seq_len - context_len + 1 + predicted = 0 + + # Special handling for first window: predict all tokens + first_seq_len = min(max_seq_len, len(token_list)) + yield [prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len] + predicted += first_seq_len + + while predicted < len(token_list): + window_pred_len = min(len(token_list) - predicted, pred_len) + window_end = predicted + window_pred_len + + yield ( + token_list[window_end - max_seq_len - 1 : window_end - 1], + token_list[window_end - window_pred_len : window_end], + ) + predicted += window_pred_len + + +def make_disjoint_window( + pair: Tuple[List[int], List[int]], +) -> Tuple[List[int], List[int]]: + """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation""" + a, b = pair + return a[: len(a) - (len(b) - 1)], b + + +class EnhancedJSONEncoder(json.JSONEncoder): + """ + Provides a proper json encoding for the loggers and trackers json dumps. + Notably manages the json encoding of dataclasses. + """ + + def default(self, o): + if is_dataclass(o): + return asdict(o) + return super().default(o) + + +class Reorderer: + def __init__(self, arr: List[Any], fn: Callable) -> None: + """Reorder an array according to some function + + Args: + arr (List[Any]): The initial array + fn (Callable[[Any], Any]): A function to determine the priority of elements + """ + self.size = len(arr) + arr = list(enumerate(arr)) + arr = group(arr, lambda x: fn(x[1])) + # arr = [([y[0] for y in x], x[0][1]) for x in arr] + # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this + arr = [([y[0]], x[0][1]) for x in arr for y in x] + arr.sort(key=lambda x: fn(x[1])) + + self.arr = arr + + def get_reordered(self): + """Gets the reordered array + + Returns: + List[Any]: The reordered array + """ + return [x[1] for x in self.arr] + + def get_original(self, newarr): + """Restores the original order of a new array based on the old array's order + + Args: + newarr (List[Any]): The array to be restored + + Returns: + List[Any]: The array restored to the original order + """ + res = [None] * self.size + cov = [False] * self.size + + for (inds, _), v in zip(self.arr, newarr): + for ind in inds: + res[ind] = v + cov[ind] = True + + assert all(cov) + + return res + + +def make_table(result_dict, column: str = "results", sort_results: bool = False): + """Generate table of results.""" + from pytablewriter import LatexTableWriter, MarkdownTableWriter + + if column == "results": + column_name = "Tasks" + elif column == "groups": + column_name = "Groups" + + all_headers = [ + column_name, + "Version", + "Filter", + "n-shot", + "Metric", + "", + "Value", + "", + "Stderr", + ] + + md_writer = MarkdownTableWriter() + latex_writer = LatexTableWriter() + md_writer.headers = all_headers + latex_writer.headers = all_headers + + values = [] + + keys = result_dict[column].keys() + if sort_results: + # sort entries alphabetically by task or group name. + # NOTE: we default here to false, because order matters for multi-level table printing a la mmlu. + # sorting here would mess that up + keys = sorted(keys) + for k in keys: + dic = result_dict[column][k] + version = result_dict["versions"].get(k, " N/A") + n = str(result_dict.get("n-shot", " ").get(k, " ")) + higher_is_better = result_dict.get("higher_is_better", {}).get(k, {}) + + if "alias" in dic: + k = dic.pop("alias") + + metric_items = dic.items() + metric_items = sorted(metric_items) + + for (mf), v in metric_items: + m, _, f = mf.partition(",") + if m.endswith("_stderr"): + continue + + hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "") + + v = "%.4f" % v if isinstance(v, float) else v + + if m + "_stderr" + "," + f in dic: + se = dic[m + "_stderr" + "," + f] + se = " N/A" if se == "N/A" else "%.4f" % se + values.append([k, version, f, n, m, hib, v, "±", se]) + else: + values.append([k, version, f, n, m, hib, v, "", ""]) + k = "" + version = "" + md_writer.value_matrix = values + latex_writer.value_matrix = values + + # todo: make latex table look good + # print(latex_writer.dumps()) + + return md_writer.dumps() + + +def positional_deprecated(fn): + """ + A decorator to nudge users into passing only keyword args (`kwargs`) to the + wrapped function, `fn`. + """ + + @functools.wraps(fn) + def _wrapper(*args, **kwargs): + if len(args) != 1 if inspect.ismethod(fn) else 0: + print( + f"WARNING: using {fn.__name__} with positional arguments is " + "deprecated and will be disallowed in a future version of " + "lm-evaluation-harness!" + ) + return fn(*args, **kwargs) + + return _wrapper + + +def ignore_constructor(loader, node): + return node + + +def import_function(loader: yaml.Loader, node, yaml_path: Path): + function_name = loader.construct_scalar(node) + + *module_name, function_name = function_name.split(".") + if isinstance(module_name, list): + module_name = ".".join(module_name) + module_path = yaml_path.parent / f"{module_name}.py" + + spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix()) + + if spec is None: + raise ImportError(f"Could not import module {module_name} from {module_path}.") + module = importlib.util.module_from_spec(spec) + + if spec.loader is None: + raise ImportError(f"Module loader is None, {module_name} from {module_path}.") + spec.loader.exec_module(module) + + function = getattr(module, function_name) + return function + + +def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"): + if mode == "simple": + constructor_fn = ignore_constructor + elif mode == "full": + if yaml_path is None: + raise ValueError("yaml_path must be provided if mode is 'full'.") + # Attach yaml_path to the import function so that it can be used later + constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path)) + + loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader + # Add the import_function constructor to the YAML loader + yaml.add_constructor("!function", constructor_fn, Loader=loader) + if yaml_config is None: + with open(yaml_path, "rb") as file: + yaml_config = yaml.load(file, Loader=loader) + + if yaml_dir is None: + yaml_dir = os.path.dirname(yaml_path) + + assert yaml_dir is not None + + if "include" in yaml_config: + include_path = yaml_config["include"] + del yaml_config["include"] + + if isinstance(include_path, str): + include_path = [include_path] + + # Load from the last one first + include_path.reverse() + final_yaml_config = {} + for path in include_path: + # Assumes that path is a full path. + # If not found, assume the included yaml + # is in the same dir as the original yaml + if not os.path.isfile(path): + path = os.path.join(yaml_dir, path) + + try: + included_yaml_config = load_yaml_config(yaml_path=path, mode=mode) + final_yaml_config.update(included_yaml_config) + except Exception as ex: + # If failed to load, ignore + raise ex + + final_yaml_config.update(yaml_config) + return final_yaml_config + return yaml_config + + +def regex_replace(string, pattern, repl, count: int = 0): + """Implements the `re.sub` function as a custom Jinja filter.""" + return re.sub(pattern, repl, string, count=count) + + +env = Environment( + loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True +) +env.filters["regex_replace"] = regex_replace + + +def apply_template(template: str, doc: dict) -> str: + rtemplate = env.from_string(template) + return rtemplate.render(**doc) + + +def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None): + """ + Method for creating a (potentially) sliced and limited + iterator from a raw document iterator. Used for splitting data + among ranks in multigpu setting or only pulling a sample of documents + """ + return islice(raw_iterator, rank, limit, world_size) + + +def weighted_f1_score(items): + from sklearn.metrics import f1_score + + unzipped_list = list(zip(*items)) + golds = unzipped_list[0] + preds = unzipped_list[1] + fscore = f1_score(golds, preds, average="weighted") + return fscore diff --git a/Prism/Dream/Dream_Prism/eval_instruct/pyproject.toml b/Prism/Dream/Dream_Prism/eval_instruct/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..888b91d1ea7ee461452ebd44b9ce10103818770c --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/pyproject.toml @@ -0,0 +1,134 @@ +[build-system] +requires = ["setuptools>=40.8.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "lm_eval" +version = "0.4.8" +authors = [ + {name="EleutherAI", email="contact@eleuther.ai"} +] +description = "A framework for evaluating language models" +readme = "README.md" +classifiers = [ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] +requires-python = ">=3.9" +license = { "text" = "MIT" } +dependencies = [ + "accelerate>=0.26.0", + "evaluate", + "datasets>=2.16.0", + "evaluate>=0.4.0", + "jsonlines", + "numexpr", + "peft>=0.2.0", + "pybind11>=2.6.2", + "pytablewriter", + "rouge-score>=0.0.4", + "sacrebleu>=1.5.0", + "scikit-learn>=0.24.1", + "sqlitedict", + "torch>=1.8", + "tqdm-multiprocess", + "transformers>=4.1", + "zstandard", + "dill", + "word2number", + "more_itertools", +] + +[tool.setuptools.packages.find] +include = ["lm_eval*"] + +# required to include yaml files in pip installation +[tool.setuptools.package-data] +lm_eval = ["**/*.yaml", "tasks/**/*"] + +[project.scripts] +lm-eval = "lm_eval.__main__:cli_evaluate" +lm_eval = "lm_eval.__main__:cli_evaluate" + +[project.urls] +Homepage = "https://github.com/EleutherAI/lm-evaluation-harness" +Repository = "https://github.com/EleutherAI/lm-evaluation-harness" + +[project.optional-dependencies] +api = ["requests", "aiohttp", "tenacity", "tqdm", "tiktoken"] +audiolm_qwen = ["librosa", "soundfile"] +deepsparse = ["deepsparse-nightly[llm]>=1.8.0.20240404"] +dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy", "unitxt"] +gptq = ["auto-gptq[triton]>=0.6.0"] +gptqmodel = ["gptqmodel>=1.0.9"] +hf_transfer = ["hf_transfer"] +ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"] +ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"] +ipex = ["optimum"] +japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"] +longbench=["jeiba", "fuzzywuzzy", "rouge"] +mamba = ["mamba_ssm", "causal-conv1d==1.0.2"] +math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"] +multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"] +neuronx = ["optimum[neuronx]"] +optimum = ["optimum[openvino]"] +promptsource = ["promptsource>=0.2.3"] +ruler = ["nltk", "wonderwords", "scipy"] +sae_lens = ["sae_lens"] +sentencepiece = ["sentencepiece>=0.1.98"] +sparseml = ["sparseml-nightly[llm]>=1.8.0.20240404"] +sparsify = ["sparsify"] +testing = ["pytest", "pytest-cov", "pytest-xdist"] +vllm = ["vllm>=0.4.2"] +wandb = ["wandb>=0.16.3", "pandas", "numpy"] +zeno = ["pandas", "zeno-client"] +all = [ + "lm_eval[api]", + "lm_eval[audiolm_qwen]", + "lm_eval[deepsparse]", + "lm_eval[dev]", + "lm_eval[gptq]", + "lm_eval[gptqmodel]", + "lm_eval[hf_transfer]", + "lm_eval[ibm_watsonx_ai]", + "lm_eval[ifeval]", + "lm_eval[ipex]", + "lm_eval[japanese_leaderboard]", + "lm_eval[longbench]", + "lm_eval[mamba]", + "lm_eval[math]", + "lm_eval[multilingual]", + "lm_eval[neuronx]", + "lm_eval[optimum]", + "lm_eval[promptsource]", + "lm_eval[ruler]", + "lm_eval[sae_lens]", + "lm_eval[sentencepiece]", + "lm_eval[sparseml]", + "lm_eval[sparsify]", + "lm_eval[testing]", + "lm_eval[vllm]", + "lm_eval[wandb]", + "lm_eval[zeno]", +] + +[tool.pymarkdown] +plugins.md013.enabled = false # line-length +plugins.md024.allow_different_nesting = true # no-duplicate-headers +plugins.md025.enabled = false # single-header +plugins.md028.enabled = false # no-blanks-blockquote +plugins.md029.allow_extended_start_values = true # ol-prefix +plugins.md034.enabled = false # no-bare-urls + +[tool.ruff.lint] +extend-select = ["I"] + +[tool.ruff.lint.isort] +lines-after-imports = 2 +known-first-party = ["lm_eval"] + +[tool.ruff.lint.extend-per-file-ignores] +"__init__.py" = ["F401","F402","F403"] +"utils.py" = ["F401"] diff --git a/Prism/Dream/Dream_Prism/eval_instruct/requirements.txt b/Prism/Dream/Dream_Prism/eval_instruct/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d6e1198b1ab1f5a7f19c6f1fc2ba7338438cf718 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/requirements.txt @@ -0,0 +1 @@ +-e . diff --git a/Prism/Dream/Dream_Prism/eval_instruct/setup.py b/Prism/Dream/Dream_Prism/eval_instruct/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..b5d8fabb8630f2b4c1e00465e4d22d60e8aa00c3 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/setup.py @@ -0,0 +1,5 @@ +import setuptools + + +# This is to make sure that the package supports editable installs +setuptools.setup() diff --git a/Prism/Dream/Dream_Prism/metrics/gsmk8_eval.py b/Prism/Dream/Dream_Prism/metrics/gsmk8_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..28ed927a5d38cb2f9a6e6c1501fb1e79821da959 --- /dev/null +++ b/Prism/Dream/Dream_Prism/metrics/gsmk8_eval.py @@ -0,0 +1,188 @@ +import json +import re +import os +import glob +import math +import argparse +from collections import Counter + + +RES_PATH = "" + + +def last_boxed_only_string(string): + if not string: return None + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: return None + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + return string[idx : right_brace_idx + 1] if right_brace_idx else None + +def remove_boxed(s): + if not s: return None + if "\\boxed " in s: return s[len("\\boxed ") :] + if "\\boxed{" in s and s.endswith("}"): return s[len("\\boxed{") : -1] + return s + +def strip_string(string): + if string is None: return "" + string = str(string).strip() + while re.search(r"(\d),(\d{3})", string): + string = re.sub(r"(\d),(\d{3})", r"\1\2", string) + string = string.replace("\n", "").replace("\\!", "") + string = string.replace("tfrac", "frac").replace("dfrac", "frac") + string = string.replace("\\left", "").replace("\\right", "") + string = string.replace("^{\\circ}", "").replace("^\\circ", "") + string = string.replace("\\$", "").replace("\\%", "").replace("\%", "") + if "=" in string and len(string.split("=")[0]) <= 3: + string = string.split("=")[1].strip() + string = string.replace(" ", "") + return string + +def extract_answer_gsm8k(text): + if not text: return "" + boxed = last_boxed_only_string(text) + if boxed: + ans = remove_boxed(boxed) + if ans: return strip_string(ans) + + tag_match = re.search(r"(.*?)", text, re.DOTALL) + if tag_match: + return strip_string(tag_match.group(1)) + + nums = re.findall(r"-?\d+\.?\d*", text[-50:]) + if nums: + return strip_string(nums[-1]) + + return "" + +def extract_gold_gsm8k(target_str): + if "####" in target_str: + return strip_string(target_str.split("####")[-1]) + return strip_string(target_str) + +def is_equiv(pred, gold): + p = strip_string(pred) + g = strip_string(gold) + try: + return math.isclose(float(p), float(g), rel_tol=1e-4) + except: + return p == g + +def run_evaluation(target_path): + if os.path.isdir(target_path): + jsonl_files = glob.glob(os.path.join(target_path, "*.jsonl")) + else: + jsonl_files = [target_path] + + for file_path in jsonl_files: + print(f">>> 正在评测: {file_path}") + detailed_results = [] + correct_count = 0 + total_count = 0 + nfe_list = [] + svf_list = [] + + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + if not line.strip(): continue + item = json.loads(line) + doc = item.get("doc", {}) + + ground_truth = extract_gold_gsm8k(str(item.get("target", ""))) + nfe_list.append(item.get("nfe", 0)) + svf_list.append(item.get("svf_calls", 0)) + + ans_stats = {} + + trajectories = item.get("all_trajectories", []) + if not trajectories: + resps = item.get("resps", []) + for r in resps: + text = r[0] if isinstance(r, list) else r + trajectories.append({"resp": text, "score": 0.0}) + + for traj in trajectories: + raw_text = traj.get("resp", "") + score = traj.get("score", -float('inf')) + extracted = extract_answer_gsm8k(raw_text) + + if not extracted: continue + + norm = strip_string(extracted) + if norm not in ans_stats: + ans_stats[norm] = {"count": 0, "max_score": -float('inf'), "original": extracted} + + ans_stats[norm]["count"] += 1 + if score > ans_stats[norm]["max_score"]: + ans_stats[norm]["max_score"] = score + ans_stats[norm]["original"] = extracted + + if not ans_stats: + best_pred = "" + else: + sorted_norms = sorted( + ans_stats.keys(), + key=lambda x: (ans_stats[x]["count"], ans_stats[x]["max_score"]), + reverse=True + ) + best_norm = sorted_norms[0] + best_pred = ans_stats[best_norm]["original"] + + ans_correct = is_equiv(best_pred, ground_truth) + if ans_correct: + correct_count += 1 + total_count += 1 + + detailed_results.append({ + "question": doc.get("question", "N/A"), + "final_voted_answer": best_pred, + "ground_truth": ground_truth, + "is_correct": ans_correct, + "nfe": item.get("nfe", 0), + "svf_calls": item.get("svf_calls", 0) + }) + + accuracy = (correct_count / total_count * 100) if total_count > 0 else 0 + + avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0 + avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0 + + print(f"Accuracy: {accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---") + + output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}" + output_path = os.path.join(os.path.dirname(file_path), output_name) + + final_report = { + "summary": { + "accuracy": f"{accuracy:.2f}%", + "correct": correct_count, + "total": total_count, + "nfe": avg_nfe, + "svf_calls": avg_svf + }, + "details": detailed_results + } + + with open(output_path, 'w', encoding='utf-8') as out_f: + json.dump(final_report, out_f, ensure_ascii=False, indent=4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--res_path", type=str, default=RES_PATH) + args = parser.parse_args() + run_evaluation(args.res_path) \ No newline at end of file diff --git a/Prism/Dream/Dream_Prism/metrics/humaneval_eval.py b/Prism/Dream/Dream_Prism/metrics/humaneval_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..c68643e27789e2c99183c915c0bdf05daf41e1bf --- /dev/null +++ b/Prism/Dream/Dream_Prism/metrics/humaneval_eval.py @@ -0,0 +1,234 @@ +import os +import sys +import json +import ast +import re +import glob +import argparse +import textwrap +import evaluate as hf_evaluate +from collections import Counter + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" + +RES_PATH = "" + +def strict_dedent(text: str) -> str: + lines = text.split('\n') + while lines and not lines[0].strip(): lines.pop(0) + while lines and not lines[-1].strip(): lines.pop() + + if not lines: + return "" + + min_indent = None + for line in lines: + if line.strip(): + indent = len(line) - len(line.lstrip()) + if min_indent is None or indent < min_indent: + min_indent = indent + + if min_indent is None: + min_indent = 0 + + dedented_lines = [] + for line in lines: + if line.strip(): + if len(line) >= min_indent: + dedented_lines.append(line[min_indent:]) + else: + dedented_lines.append(line.lstrip()) + else: + dedented_lines.append("") + + return "\n".join(dedented_lines) + +def extract_python_code(text: str) -> str: + if not text: + return "" + + text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "") + + tag_match = re.search(r"(.*?)", text, re.DOTALL) + if tag_match: + text = tag_match.group(1) + + code_block_pattern = re.compile(r"```(?:python)?\n?(.*?)```", re.DOTALL) + match = code_block_pattern.search(text) + + if match: + content = match.group(1) + else: + if "```" in text: + content = text.split("```")[0] + else: + lines = text.split('\n') + cleaned_lines = [] + stop_words = ["Explanation:", "Example:", "Test Case:", "Output:", "Here are the tests:"] + for line in lines: + if any(sw in line for sw in stop_words): + break + cleaned_lines.append(line) + content = "\n".join(cleaned_lines) + + return strict_dedent(content) + +def normalize_code(code: str) -> str: + try: + tree = ast.parse(code) + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)): + if (node.body and isinstance(node.body[0], ast.Expr) and + isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)): + node.body.pop(0) + return ast.unparse(tree).strip() + except: + return re.sub(r"\s+", "", code) + +def sanitize(prompt: str, completion: str, entry_point: str) -> str: + if f"def {entry_point}" in completion: + imports = [line for line in prompt.split("\n") if line.startswith("import ") or line.startswith("from ")] + return "\n".join(imports) + "\n" + completion + + clean_body = strict_dedent(completion) + if not clean_body: + return prompt + + indented_body = "\n".join([" " + line if line.strip() else "" for line in clean_body.split('\n')]) + return prompt.strip() + "\n" + indented_body + +def perform_majority_voting(trajectories, prompt, entry_point): + candidate_stats = {} + + for item in trajectories: + if isinstance(item, dict): + raw_text = item.get("resp", "") + score = item.get("score", 0.0) + else: + raw_text = str(item[0] if isinstance(item, list) else item) + score = 0.0 + + extracted_code = extract_python_code(raw_text) + full_code = sanitize(prompt, extracted_code, entry_point) + + is_valid = False + try: + ast.parse(full_code) + is_valid = True + except: + is_valid = False + + norm_key = normalize_code(full_code) + if not norm_key: continue + + if norm_key not in candidate_stats: + candidate_stats[norm_key] = { + "count": 0, + "max_score": -float("inf"), + "code": full_code, + "is_valid": is_valid + } + + candidate_stats[norm_key]["count"] += 1 + candidate_stats[norm_key]["max_score"] = max(candidate_stats[norm_key]["max_score"], score) + + if not candidate_stats: + return prompt + + sorted_candidates = sorted( + candidate_stats.values(), + key=lambda x: (x["is_valid"], x["count"], x["max_score"]), + reverse=True + ) + + return sorted_candidates[0]["code"] + +def run_evaluation(target_path): + if os.path.isdir(target_path): + jsonl_files = glob.glob(os.path.join(target_path, "*.jsonl")) + else: + jsonl_files = [target_path] + + try: + code_eval = hf_evaluate.load("code_eval") + except Exception as e: + print(f"Error loading code_eval: {e}") + return + + for file_path in jsonl_files: + print(f">>> 正在评测文件: {file_path}") + + all_voted_predictions = [] + all_references = [] + detailed_logs = [] + + nfe_sum = 0 + svf_sum = 0 + valid_samples = 0 + + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + if not line.strip(): continue + try: + data = json.loads(line) + except: + continue + + doc = data.get("doc", {}) + task_id = doc.get("task_id", f"Task_{valid_samples}") + prompt = doc.get("prompt", "") + entry_point = doc.get("entry_point", "solution") + test_code = doc.get("test", "") + f"\ncheck({entry_point})" + + nfe_sum += data.get("nfe", 0) + svf_sum += data.get("svf_calls", 0) + valid_samples += 1 + + trajectories = data.get("all_trajectories", data.get("resps", [])) + voted_code = perform_majority_voting(trajectories, prompt, entry_point) + + all_voted_predictions.append([voted_code]) + all_references.append(test_code) + + detailed_logs.append({ + "task_id": task_id, + "entry_point": entry_point, + "final_code": voted_code, + "nfe": data.get("nfe", 0), + "svf": data.get("svf_calls", 0), + }) + + if not all_voted_predictions: continue + + print(f"执行测试中...") + pass_at_k, exec_results = code_eval.compute( + references=all_references, + predictions=all_voted_predictions, + k=[1] + ) + + accuracy = pass_at_k.get("pass@1", 0.0) * 100 + avg_nfe = nfe_sum / valid_samples if valid_samples > 0 else 0 + avg_svf = svf_sum / valid_samples if valid_samples > 0 else 0 + + for i, log in enumerate(detailed_logs): + res = exec_results.get(i, []) + log["passed"] = res[0][1].get("passed", False) if res else False + log["exec_msg"] = res[0][1].get("result", "failed") if res else "failed" + + output_path = file_path.replace(".jsonl", "_voted_result.json") + final_report = { + "meta": {"file": file_path, "total_samples": valid_samples}, + "metrics": {"accuracy": f"{accuracy:.2f}%", "avg_nfe": avg_nfe, "avg_svf": avg_svf}, + "details": detailed_logs + } + + with open(output_path, 'w', encoding='utf-8') as out_f: + json.dump(final_report, out_f, ensure_ascii=False, indent=4) + print(f"Accuracy: {accuracy:.2f}% | SVF: {avg_svf:.1f}\n") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--res_path", type=str, default=RES_PATH) + args = parser.parse_args() + run_evaluation(args.res_path) \ No newline at end of file diff --git a/Prism/Dream/Dream_Prism/metrics/math500_eval.py b/Prism/Dream/Dream_Prism/metrics/math500_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..9a091896593d42a14ffc5e05d1726168dca9b14b --- /dev/null +++ b/Prism/Dream/Dream_Prism/metrics/math500_eval.py @@ -0,0 +1,205 @@ +import json +import re +import os +import math +import argparse +from collections import Counter + +RES_PATH = "" + +def extract_answer(text): + if not text: + return "", False + text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").strip() + + boxed_pattern = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" + all_boxes = re.findall(boxed_pattern, text) + if all_boxes: + return all_boxes[-1], True + + tag_match = re.search(r"(.*?)", text, re.DOTALL) + if tag_match: + return tag_match.group(1).strip(), True + + marker = "the answer is" + if marker in text.lower(): + pos = text.lower().rfind(marker) + after_text = text[pos + len(marker):].strip() + after_text = re.sub(r"^[:\s]+", "", after_text) + return after_text.split('\n')[0].split('$')[0].strip(), True + + tail = text[-50:].strip() + nums = re.findall(r"(-?\d+[\./\d]*|\\sqrt\{\d+\}|\(-?\d+.*?\))", tail) + if nums: + return nums[-1], False + return "", False + +def normalize_math(string): + if not string: return "" + string = str(string).lower().strip() + + string = string.replace("", "").replace("", "").replace("", "") + string = string.replace("...", "").replace("cannot be determined", "") + + string = re.sub(r"([a-z]+|\\theta|\\alpha|\\pi)\s*=\s*", "", string) + string = re.sub(r"\\text\{([^}]*)\}", r"\1", string) + string = re.sub(r"\\(mathbf|mathrm|bold|unit|mbox|operatorname|mathrm)\{([^}]*)\}", r"\2", string) + string = re.sub(r"\\(d|t)?frac\{([^{}]*)\}\{([^{}]*)\}", r"\2/\3", string) + string = string.replace("\\!", "").replace("\\ ", "").replace("{", "").replace("}", "") + string = string.replace("\\left", "").replace("\\right", "") + string = string.replace("\\$", "").replace("$", "").replace("\\%", "").replace("%", "") + + units_pattern = r"(units?|cm\^2|cm|inches|inch|square|degrees?|radians?|miles?|per|hour|cents?)" + string = re.sub(units_pattern, "", string) + string = string.replace("^{\\circ}", "").replace("^\\circ", "").replace("°", "").replace("\\degree", "") + string = string.replace("\\pi", "pi") + string = re.sub(r"(\d),(\d{3})", r"\1\2", string) + string = string.rstrip(".:,; ").replace(" ", "") + + if "=" in string: + string = string.split("=")[-1] + + return string + +def is_equiv(pred, gold): + if not pred: return False + p, g = normalize_math(pred), normalize_math(gold) + if p == g: return True + + if "=" in pred: + if normalize_math(pred.split("=")[-1]) == g: + return True + + try: + def to_float(s): + if '/' in s and s.count('/') == 1: + parts = s.split('/') + return float(parts[0]) / float(parts[1]) + if '_' in s: s = s.split('_')[0] + return float(s) + return math.isclose(to_float(p), to_float(g), rel_tol=1e-4) + except: + p_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", p) + g_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", g) + return p_fuzzy == g_fuzzy if p_fuzzy else False + +def run_evaluation(target_path): + jsonl_files = [] + if os.path.isdir(target_path): + for root, dirs, files in os.walk(target_path): + for file in files: + if file.endswith(".jsonl") and not file.startswith("eval_voted_"): + jsonl_files.append(os.path.join(root, file)) + else: + jsonl_files = [target_path] + + for file_path in jsonl_files: + print(f">>> 正在评测: {file_path}") + detailed_results = [] + + voted_correct_count = 0 + total_count = 0 + + nfe_list = [] + svf_list = [] + + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + if not line.strip(): continue + try: + item = json.loads(line) + except: + continue + + doc = item.get("doc", {}) + ground_truth = str(item.get("target", doc.get("answer", ""))) + + current_nfe = item.get("nfe", 0) + nfe_list.append(current_nfe) + current_svf = item.get("svf_calls", 0) + svf_list.append(current_svf) + + ans_stats = {} + trajectories = item.get("all_trajectories", []) + + for traj in trajectories: + raw_text = traj.get("resp", "") + score = traj.get("score", 0) + + extracted, _ = extract_answer(raw_text) + if not extracted: continue + + norm = normalize_math(extracted) + if norm not in ans_stats: + ans_stats[norm] = { + "count": 0, + "max_score": -float('inf'), + "total_weight": 0.0, + "original": extracted + } + + ans_stats[norm]["count"] += 1 + if score > ans_stats[norm]["max_score"]: + ans_stats[norm]["max_score"] = score + + try: + weight = math.exp(score) + except OverflowError: + weight = float('inf') + ans_stats[norm]["total_weight"] += weight + + if not ans_stats: + best_pred = "" + else: + sorted_norms = sorted( + ans_stats.keys(), + key=lambda x: (ans_stats[x]["total_weight"], ans_stats[x]["max_score"], ans_stats[x]["count"]), + reverse=True + ) + best_norm = sorted_norms[0] + best_pred = ans_stats[best_norm]["original"] + + is_voted_correct = False + if best_pred and is_equiv(best_pred, ground_truth): + voted_correct_count += 1 + is_voted_correct = True + + total_count += 1 + + detailed_results.append({ + "question": doc.get("problem", "N/A"), + "final_voted_answer": best_pred, + "ground_truth": ground_truth, + "is_voted_correct": is_voted_correct, + "nfe": current_nfe, + "svf_calls": current_svf + }) + + accuracy = (voted_correct_count / total_count * 100) if total_count > 0 else 0 + + avg_nfe = sum(nfe_list) / len(nfe_list) if nfe_list else 0 + avg_svf = sum(svf_list) / len(svf_list) if svf_list else 0 + + print(f"--- Accuracy : {accuracy:.2f}% | NFE: {avg_nfe:.1f} | SVF: {avg_svf:.1f} ---") + + output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}" + output_path = os.path.join(os.path.dirname(file_path), output_name) + + final_report = { + "summary": { + "Accuracy": f"{accuracy:.2f}%", + "correct_voted_count": voted_correct_count, + "total": total_count, + "avg_nfe": avg_nfe, + "avg_svf": avg_svf + }, + "details": detailed_results + } + with open(output_path, 'w', encoding='utf-8') as out_f: + json.dump(final_report, out_f, ensure_ascii=False, indent=4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--res_path", type=str, default=RES_PATH) + args = parser.parse_args() + run_evaluation(args.res_path) \ No newline at end of file diff --git a/Prism/Dream/Dream_Prism/metrics/mbpp_eval.py b/Prism/Dream/Dream_Prism/metrics/mbpp_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..02122b8f5c3a40a1a46675fa8328e60b0b11a8f5 --- /dev/null +++ b/Prism/Dream/Dream_Prism/metrics/mbpp_eval.py @@ -0,0 +1,281 @@ +import os +import sys +import json +import ast +import re +import glob +import argparse +import textwrap +import evaluate as hf_evaluate +from collections import Counter + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +RES_PATH = "" + +def strict_dedent(text: str) -> str: + lines = text.split('\n') + while lines and not lines[0].strip(): lines.pop(0) + while lines and not lines[-1].strip(): lines.pop() + + if not lines: + return "" + + min_indent = None + for line in lines: + if line.strip(): + indent = len(line) - len(line.lstrip()) + if min_indent is None or indent < min_indent: + min_indent = indent + + if min_indent is None: + min_indent = 0 + + dedented_lines = [] + for line in lines: + if line.strip(): + if len(line) >= min_indent: + dedented_lines.append(line[min_indent:]) + else: + dedented_lines.append(line.lstrip()) + else: + dedented_lines.append("") + + return "\n".join(dedented_lines) + +def extract_python_code(text: str) -> str: + if not text: + return "" + + text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "") + text = text.replace("[DONE]", "") + + tag_match = re.search(r"(.*?)", text, re.DOTALL) + if tag_match: + text = tag_match.group(1) + + code_block_pattern = re.compile(r"```(?:python)?\n?(.*?)```", re.DOTALL) + match = code_block_pattern.search(text) + + if match: + content = match.group(1) + else: + if "```" in text: + content = text.split("```")[0] + else: + lines = text.split('\n') + start_idx = 0 + stop_words = ["Here is", "Explanation", "Example", "Note", "python", "The code"] + + for i, line in enumerate(lines): + stripped = line.strip() + if stripped.startswith(("def ", "import ", "from ", "class ")): + start_idx = i + break + if any(sw in line for sw in stop_words) and not stripped.endswith(":"): + continue + + content = "\n".join(lines[start_idx:]) + + return strict_dedent(content) + +def normalize_code(code: str) -> str: + try: + tree = ast.parse(code) + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)): + if (node.body and isinstance(node.body[0], ast.Expr) and + isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)): + node.body.pop(0) + return ast.unparse(tree).strip() + except: + return re.sub(r"\s+", "", code) + +def perform_majority_voting(trajectories): + candidate_stats = {} + + for item in trajectories: + if isinstance(item, dict): + raw_text = item.get("resp", "") + score = item.get("score", 0.0) + elif isinstance(item, (list, tuple)): + raw_text = item[0] + score = 0.0 + else: + raw_text = str(item) + score = 0.0 + + extracted_code = extract_python_code(raw_text) + + if not extracted_code.strip(): + continue + + is_valid = False + try: + ast.parse(extracted_code) + is_valid = True + except: + is_valid = False + + norm_key = normalize_code(extracted_code) + if not norm_key: continue + + if norm_key not in candidate_stats: + candidate_stats[norm_key] = { + "count": 0, + "max_score": -float("inf"), + "code": extracted_code, + "is_valid": is_valid + } + + candidate_stats[norm_key]["count"] += 1 + candidate_stats[norm_key]["max_score"] = max(candidate_stats[norm_key]["max_score"], score) + + if not candidate_stats: + return "" + + sorted_candidates = sorted( + candidate_stats.values(), + key=lambda x: (x["is_valid"], x["count"], x["max_score"]), + reverse=True + ) + + return sorted_candidates[0]["code"] + +def run_evaluation(target_path): + if os.path.isdir(target_path): + jsonl_files = glob.glob(os.path.join(target_path, "*.jsonl")) + else: + jsonl_files = [target_path] + + try: + code_eval = hf_evaluate.load("code_eval") + except Exception as e: + print(f"Error loading code_eval: {e}") + return + + for file_path in jsonl_files: + print(f"\n>>> 正在评测 MBPP 文件: {file_path}") + + all_voted_predictions = [] + all_references = [] + detailed_logs = [] + + nfe_total = 0 + svf_total = 0 + count_valid_samples = 0 + + with open(file_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + + for idx, line in enumerate(lines): + if not line.strip(): continue + try: + data = json.loads(line) + except json.JSONDecodeError: + continue + + doc = data.get("doc", {}) + task_id = doc.get("task_id", f"MBPP_{idx}") + + test_list = doc.get("test_list", []) + test_setup = doc.get("test_setup_code", "") + challenge_tests = doc.get("challenge_test_list", []) + + full_test_code = "" + if test_setup: + full_test_code += test_setup + "\n" + if test_list: + full_test_code += "\n".join(test_list) + "\n" + if challenge_tests: + full_test_code += "\n".join(challenge_tests) + + current_nfe = data.get("nfe", 0) + current_svf = data.get("svf_calls", 0) + + nfe_total += current_nfe + svf_total += current_svf + count_valid_samples += 1 + + trajectories = data.get("all_trajectories", []) + if not trajectories: + resps = data.get("resps", []) + trajectories = [{"resp": r} for r in resps] + + voted_code = perform_majority_voting(trajectories) + + if not voted_code: + voted_code = "def placeholder(): pass" + + all_voted_predictions.append([voted_code]) + all_references.append(full_test_code) + + detailed_logs.append({ + "task_id": task_id, + "final_code": voted_code, + "reference": full_test_code, + "nfe": current_nfe, + "svf": current_svf, + "traj_count": len(trajectories) + }) + + if not all_voted_predictions: + print("未找到有效数据。") + continue + + print(f"正在执行代码测试 (共 {len(all_voted_predictions)} 题)...") + + pass_at_k, exec_results = code_eval.compute( + references=all_references, + predictions=all_voted_predictions, + k=[1], + num_workers=4 + ) + + accuracy = pass_at_k.get("pass@1", 0.0) * 100 + avg_nfe = nfe_total / count_valid_samples if count_valid_samples > 0 else 0 + avg_svf = svf_total / count_valid_samples if count_valid_samples > 0 else 0 + print(f"Accuracy: {accuracy:.2f}% | NFE: {avg_nfe:.1f} | SVF: {avg_svf:.1f}") + + for i, log in enumerate(detailed_logs): + res = exec_results.get(i, []) + if res and len(res) > 0: + is_passed = res[0][1].get("passed", False) + eval_result_str = res[0][1].get("result", "passed") if not is_passed else "passed" + else: + is_passed = False + eval_result_str = "Execution Failed" + + log["passed"] = is_passed + log["exec_msg"] = eval_result_str + + output_name = f"eval_mbpp_{os.path.basename(file_path).replace('.jsonl', '.json')}" + output_path = os.path.join(os.path.dirname(file_path), output_name) + + final_report = { + "meta": { + "file": file_path, + "total_samples": count_valid_samples + }, + "metrics": { + "accuracy": f"{accuracy:.2f}%", + "avg_nfe": avg_nfe, + "avg_svf": avg_svf + }, + "details": detailed_logs + } + + with open(output_path, 'w', encoding='utf-8') as out_f: + json.dump(final_report, out_f, ensure_ascii=False, indent=4) + print(f"结果已保存至: {output_path}\n") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="MBPP Metrics Evaluation Script") + parser.add_argument("-r", "--res_path", type=str, default=RES_PATH, help="Path to jsonl result file or directory") + args = parser.parse_args() + + if os.path.exists(args.res_path): + run_evaluation(args.res_path) + else: + print(f"Path not found: {args.res_path}") \ No newline at end of file diff --git a/Prism/Dream/Dream_Prism/scripts/run_gsm8k.sh b/Prism/Dream/Dream_Prism/scripts/run_gsm8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..cd4bcdd1b0c04ec8280650e5f8e224fc1d700e73 --- /dev/null +++ b/Prism/Dream/Dream_Prism/scripts/run_gsm8k.sh @@ -0,0 +1,31 @@ +#!/bin/bash +set -e +set -x + +PROJECT_ROOT="" +MODEL_PATH="" +BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/dream_gsm8k" + +cd ${PROJECT_ROOT} +export CUDA_VISIBLE_DEVICES=0 +export HF_ENDPOINT=https://hf-mirror.com +export HF_ALLOW_CODE_EVAL=1 +export PYTHONPATH=. + +TASK="gsm8k" +LENGTH=256 +STEPS=256 +PORT=12334 +NAME="win_0.1-0.6_s2_k4" + +mkdir -p "${BASE_OUTPUT_PATH}/${NAME}" + +accelerate launch --main_process_port ${PORT} -m lm_eval\ + --model diffllm \ + --tasks ${TASK} \ + --batch_size 1 \ + --model_args "pretrained=${MODEL_PATH},trust_remote_code=True,dtype=bfloat16,max_new_tokens=${LENGTH},diffusion_steps=${STEPS}" \ + --gen_kwargs "use_hts=True,initial_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=3,decay_factor=1.8,reward_mode=svf,task_type=math,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \ + --num_fewshot 0 \ + --confirm_run_unsafe_code \ + --output_path "${BASE_OUTPUT_PATH}/${NAME}" \ No newline at end of file diff --git a/Prism/Dream/Dream_Prism/scripts/run_humaneval.sh b/Prism/Dream/Dream_Prism/scripts/run_humaneval.sh new file mode 100644 index 0000000000000000000000000000000000000000..bd15d5c03526d97ef327b70f2d5625479a663d10 --- /dev/null +++ b/Prism/Dream/Dream_Prism/scripts/run_humaneval.sh @@ -0,0 +1,31 @@ +#!/bin/bash +set -e +set -x + +PROJECT_ROOT="" +MODEL_PATH="" +BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/dream_humaneval" + +cd ${PROJECT_ROOT} +export CUDA_VISIBLE_DEVICES=0 +export HF_ENDPOINT=https://hf-mirror.com +export HF_ALLOW_CODE_EVAL=1 +export PYTHONPATH=. + +TASK="humaneval_instruct" +LENGTH=512 +STEPS=512 +PORT=12334 +NAME="win_0.1-0.6_s2_k4" + +mkdir -p "${BASE_OUTPUT_PATH}/${NAME}" + +accelerate launch --main_process_port ${PORT} -m lm_eval\ + --model diffllm \ + --tasks ${TASK} \ + --batch_size 1 \ + --model_args "pretrained=${MODEL_PATH},trust_remote_code=True,dtype=bfloat16,max_new_tokens=${LENGTH},diffusion_steps=${STEPS}" \ + --gen_kwargs "use_hts=True,initial_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=20,decay_factor=1.8,reward_mode=svf,task_type=code,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \ + --num_fewshot 0 \ + --confirm_run_unsafe_code \ + --output_path "${BASE_OUTPUT_PATH}/${NAME}" \ No newline at end of file diff --git a/Prism/Dream/Dream_Prism/scripts/run_math500.sh b/Prism/Dream/Dream_Prism/scripts/run_math500.sh new file mode 100644 index 0000000000000000000000000000000000000000..b7ac4334acbef3a84e4d9d7464d677a0da0fc82c --- /dev/null +++ b/Prism/Dream/Dream_Prism/scripts/run_math500.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -e +set -x + +PROJECT_ROOT="" +MODEL_PATH="" +BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/dream_math500" + +cd ${PROJECT_ROOT} +export CUDA_VISIBLE_DEVICES=0 +export HF_ENDPOINT=https://hf-mirror.com +export HF_ALLOW_CODE_EVAL=1 +export PYTHONPATH=. + +TASK="math500" +LENGTH=256 +STEPS=256 +PORT=12334 +NAME="win_0.1-0.6_s2_k4" + +mkdir -p "${BASE_OUTPUT_PATH}/${NAME}" + +accelerate launch --main_process_port ${PORT} -m lm_eval\ + --model diffllm \ + --tasks ${TASK} \ + --batch_size 1 \ + --model_args "pretrained=${MODEL_PATH},trust_remote_code=True,dtype=bfloat16,max_new_tokens=${LENGTH},diffusion_steps=${STEPS}" \ + --gen_kwargs "use_hts=True,initial_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=10,decay_factor=1.8,reward_mode=svf,task_type=math,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \ + --num_fewshot 0 \ + --output_path "${BASE_OUTPUT_PATH}/${NAME}" \ No newline at end of file diff --git a/Prism/Dream/Dream_Prism/scripts/run_mbpp.sh b/Prism/Dream/Dream_Prism/scripts/run_mbpp.sh new file mode 100644 index 0000000000000000000000000000000000000000..3f83a4303005d305ed9dd538e02c0a6ced2e966f --- /dev/null +++ b/Prism/Dream/Dream_Prism/scripts/run_mbpp.sh @@ -0,0 +1,31 @@ +#!/bin/bash +set -e +set -x + +PROJECT_ROOT="" +MODEL_PATH="" +BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/dream_mbpp" + +cd ${PROJECT_ROOT} +export CUDA_VISIBLE_DEVICES=0 +export HF_ENDPOINT=https://hf-mirror.com +export HF_ALLOW_CODE_EVAL=1 +export PYTHONPATH=. + +TASK="mbpp" +LENGTH=512 +STEPS=512 +PORT=12334 +NAME="win_0.1-0.6_s2_k4" + +mkdir -p "${BASE_OUTPUT_PATH}/${NAME}" + +accelerate launch --main_process_port ${PORT} -m lm_eval\ + --model diffllm \ + --tasks ${TASK} \ + --batch_size 1 \ + --model_args "pretrained=${MODEL_PATH},trust_remote_code=True,dtype=bfloat16,max_new_tokens=${LENGTH},diffusion_steps=${STEPS}" \ + --gen_kwargs "use_hts=True,initial_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=3,decay_factor=1.8,reward_mode=svf,task_type=code,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \ + --num_fewshot 0 \ + --confirm_run_unsafe_code \ + --output_path "${BASE_OUTPUT_PATH}/${NAME}" \ No newline at end of file diff --git a/Prism/Dream/Dream_Prism/src/__init__.py b/Prism/Dream/Dream_Prism/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohappyeyeballs/__pycache__/types.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohappyeyeballs/__pycache__/types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb71e96cdaf62b79997bf7db062fdd398785e92f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohappyeyeballs/__pycache__/types.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/components/semiconnected.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/components/semiconnected.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca5d762ca882524d1406f9295fa3a238fedb724 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/components/semiconnected.py @@ -0,0 +1,71 @@ +"""Semiconnectedness.""" + +import networkx as nx +from networkx.utils import not_implemented_for, pairwise + +__all__ = ["is_semiconnected"] + + +@not_implemented_for("undirected") +@nx._dispatchable +def is_semiconnected(G): + r"""Returns True if the graph is semiconnected, False otherwise. + + A graph is semiconnected if and only if for any pair of nodes, either one + is reachable from the other, or they are mutually reachable. + + This function uses a theorem that states that a DAG is semiconnected + if for any topological sort, for node $v_n$ in that sort, there is an + edge $(v_i, v_{i+1})$. That allows us to check if a non-DAG `G` is + semiconnected by condensing the graph: i.e. constructing a new graph `H` + with nodes being the strongly connected components of `G`, and edges + (scc_1, scc_2) if there is a edge $(v_1, v_2)$ in `G` for some + $v_1 \in scc_1$ and $v_2 \in scc_2$. That results in a DAG, so we compute + the topological sort of `H` and check if for every $n$ there is an edge + $(scc_n, scc_{n+1})$. + + Parameters + ---------- + G : NetworkX graph + A directed graph. + + Returns + ------- + semiconnected : bool + True if the graph is semiconnected, False otherwise. + + Raises + ------ + NetworkXNotImplemented + If the input graph is undirected. + + NetworkXPointlessConcept + If the graph is empty. + + Examples + -------- + >>> G = nx.path_graph(4, create_using=nx.DiGraph()) + >>> print(nx.is_semiconnected(G)) + True + >>> G = nx.DiGraph([(1, 2), (3, 2)]) + >>> print(nx.is_semiconnected(G)) + False + + See Also + -------- + is_strongly_connected + is_weakly_connected + is_connected + is_biconnected + """ + if len(G) == 0: + raise nx.NetworkXPointlessConcept( + "Connectivity is undefined for the null graph." + ) + + if not nx.is_weakly_connected(G): + return False + + H = nx.condensation(G) + + return all(H.has_edge(u, v) for u, v in pairwise(nx.topological_sort(H))) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebc6ab9998db144234c2601c24861b2c48fa339 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/__init__.py @@ -0,0 +1,4 @@ +from networkx.algorithms.operators.all import * +from networkx.algorithms.operators.binary import * +from networkx.algorithms.operators.product import * +from networkx.algorithms.operators.unary import * diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/all.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/all.py new file mode 100644 index 0000000000000000000000000000000000000000..322a15ace64c99f0175eae4647ac39410d6c1b26 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/all.py @@ -0,0 +1,324 @@ +"""Operations on many graphs.""" + +from itertools import chain, repeat + +import networkx as nx + +__all__ = ["union_all", "compose_all", "disjoint_union_all", "intersection_all"] + + +@nx._dispatchable(graphs="[graphs]", preserve_all_attrs=True, returns_graph=True) +def union_all(graphs, rename=()): + """Returns the union of all graphs. + + The graphs must be disjoint, otherwise an exception is raised. + + Parameters + ---------- + graphs : iterable + Iterable of NetworkX graphs + + rename : iterable , optional + Node names of graphs can be changed by specifying the tuple + rename=('G-','H-') (for example). Node "u" in G is then renamed + "G-u" and "v" in H is renamed "H-v". Infinite generators (like itertools.count) + are also supported. + + Returns + ------- + U : a graph with the same type as the first graph in list + + Raises + ------ + ValueError + If `graphs` is an empty list. + + NetworkXError + In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs. + + Notes + ----- + For operating on mixed type graphs, they should be converted to the same type. + >>> G = nx.Graph() + >>> H = nx.DiGraph() + >>> GH = union_all([nx.DiGraph(G), H]) + + To force a disjoint union with node relabeling, use + disjoint_union_all(G,H) or convert_node_labels_to integers(). + + Graph, edge, and node attributes are propagated to the union graph. + If a graph attribute is present in multiple graphs, then the value + from the last graph in the list with that attribute is used. + + Examples + -------- + >>> G1 = nx.Graph([(1, 2), (2, 3)]) + >>> G2 = nx.Graph([(4, 5), (5, 6)]) + >>> result_graph = nx.union_all([G1, G2]) + >>> result_graph.nodes() + NodeView((1, 2, 3, 4, 5, 6)) + >>> result_graph.edges() + EdgeView([(1, 2), (2, 3), (4, 5), (5, 6)]) + + See Also + -------- + union + disjoint_union_all + """ + R = None + seen_nodes = set() + + # rename graph to obtain disjoint node labels + def add_prefix(graph, prefix): + if prefix is None: + return graph + + def label(x): + return f"{prefix}{x}" + + return nx.relabel_nodes(graph, label) + + rename = chain(rename, repeat(None)) + graphs = (add_prefix(G, name) for G, name in zip(graphs, rename)) + + for i, G in enumerate(graphs): + G_nodes_set = set(G.nodes) + if i == 0: + # Union is the same type as first graph + R = G.__class__() + elif G.is_directed() != R.is_directed(): + raise nx.NetworkXError("All graphs must be directed or undirected.") + elif G.is_multigraph() != R.is_multigraph(): + raise nx.NetworkXError("All graphs must be graphs or multigraphs.") + elif not seen_nodes.isdisjoint(G_nodes_set): + raise nx.NetworkXError( + "The node sets of the graphs are not disjoint.\n" + "Use `rename` to specify prefixes for the graphs or use\n" + "disjoint_union(G1, G2, ..., GN)." + ) + + seen_nodes |= G_nodes_set + R.graph.update(G.graph) + R.add_nodes_from(G.nodes(data=True)) + R.add_edges_from( + G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True) + ) + + if R is None: + raise ValueError("cannot apply union_all to an empty list") + + return R + + +@nx._dispatchable(graphs="[graphs]", preserve_all_attrs=True, returns_graph=True) +def disjoint_union_all(graphs): + """Returns the disjoint union of all graphs. + + This operation forces distinct integer node labels starting with 0 + for the first graph in the list and numbering consecutively. + + Parameters + ---------- + graphs : iterable + Iterable of NetworkX graphs + + Returns + ------- + U : A graph with the same type as the first graph in list + + Raises + ------ + ValueError + If `graphs` is an empty list. + + NetworkXError + In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs. + + Examples + -------- + >>> G1 = nx.Graph([(1, 2), (2, 3)]) + >>> G2 = nx.Graph([(4, 5), (5, 6)]) + >>> U = nx.disjoint_union_all([G1, G2]) + >>> list(U.nodes()) + [0, 1, 2, 3, 4, 5] + >>> list(U.edges()) + [(0, 1), (1, 2), (3, 4), (4, 5)] + + Notes + ----- + For operating on mixed type graphs, they should be converted to the same type. + + Graph, edge, and node attributes are propagated to the union graph. + If a graph attribute is present in multiple graphs, then the value + from the last graph in the list with that attribute is used. + """ + + def yield_relabeled(graphs): + first_label = 0 + for G in graphs: + yield nx.convert_node_labels_to_integers(G, first_label=first_label) + first_label += len(G) + + R = union_all(yield_relabeled(graphs)) + + return R + + +@nx._dispatchable(graphs="[graphs]", preserve_all_attrs=True, returns_graph=True) +def compose_all(graphs): + """Returns the composition of all graphs. + + Composition is the simple union of the node sets and edge sets. + The node sets of the supplied graphs need not be disjoint. + + Parameters + ---------- + graphs : iterable + Iterable of NetworkX graphs + + Returns + ------- + C : A graph with the same type as the first graph in list + + Raises + ------ + ValueError + If `graphs` is an empty list. + + NetworkXError + In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs. + + Examples + -------- + >>> G1 = nx.Graph([(1, 2), (2, 3)]) + >>> G2 = nx.Graph([(3, 4), (5, 6)]) + >>> C = nx.compose_all([G1, G2]) + >>> list(C.nodes()) + [1, 2, 3, 4, 5, 6] + >>> list(C.edges()) + [(1, 2), (2, 3), (3, 4), (5, 6)] + + Notes + ----- + For operating on mixed type graphs, they should be converted to the same type. + + Graph, edge, and node attributes are propagated to the union graph. + If a graph attribute is present in multiple graphs, then the value + from the last graph in the list with that attribute is used. + """ + R = None + + # add graph attributes, H attributes take precedent over G attributes + for i, G in enumerate(graphs): + if i == 0: + # create new graph + R = G.__class__() + elif G.is_directed() != R.is_directed(): + raise nx.NetworkXError("All graphs must be directed or undirected.") + elif G.is_multigraph() != R.is_multigraph(): + raise nx.NetworkXError("All graphs must be graphs or multigraphs.") + + R.graph.update(G.graph) + R.add_nodes_from(G.nodes(data=True)) + R.add_edges_from( + G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True) + ) + + if R is None: + raise ValueError("cannot apply compose_all to an empty list") + + return R + + +@nx._dispatchable(graphs="[graphs]", returns_graph=True) +def intersection_all(graphs): + """Returns a new graph that contains only the nodes and the edges that exist in + all graphs. + + Parameters + ---------- + graphs : iterable + Iterable of NetworkX graphs + + Returns + ------- + R : A new graph with the same type as the first graph in list + + Raises + ------ + ValueError + If `graphs` is an empty list. + + NetworkXError + In case of mixed type graphs, like MultiGraph and Graph, or directed and undirected graphs. + + Notes + ----- + For operating on mixed type graphs, they should be converted to the same type. + + Attributes from the graph, nodes, and edges are not copied to the new + graph. + + The resulting graph can be updated with attributes if desired. + For example, code which adds the minimum attribute for each node across all + graphs could work:: + + >>> g = nx.Graph() + >>> g.add_node(0, capacity=4) + >>> g.add_node(1, capacity=3) + >>> g.add_edge(0, 1) + + >>> h = g.copy() + >>> h.nodes[0]["capacity"] = 2 + + >>> gh = nx.intersection_all([g, h]) + + >>> new_node_attr = { + ... n: min(*(anyG.nodes[n].get("capacity", float("inf")) for anyG in [g, h])) + ... for n in gh + ... } + >>> nx.set_node_attributes(gh, new_node_attr, "new_capacity") + >>> gh.nodes(data=True) + NodeDataView({0: {'new_capacity': 2}, 1: {'new_capacity': 3}}) + + Examples + -------- + >>> G1 = nx.Graph([(1, 2), (2, 3)]) + >>> G2 = nx.Graph([(2, 3), (3, 4)]) + >>> R = nx.intersection_all([G1, G2]) + >>> list(R.nodes()) + [2, 3] + >>> list(R.edges()) + [(2, 3)] + + """ + R = None + + for i, G in enumerate(graphs): + G_nodes_set = set(G.nodes) + G_edges_set = set(G.edges) + if not G.is_directed(): + if G.is_multigraph(): + G_edges_set.update((v, u, k) for u, v, k in list(G_edges_set)) + else: + G_edges_set.update((v, u) for u, v in list(G_edges_set)) + if i == 0: + # create new graph + R = G.__class__() + node_intersection = G_nodes_set + edge_intersection = G_edges_set + elif G.is_directed() != R.is_directed(): + raise nx.NetworkXError("All graphs must be directed or undirected.") + elif G.is_multigraph() != R.is_multigraph(): + raise nx.NetworkXError("All graphs must be graphs or multigraphs.") + else: + node_intersection &= G_nodes_set + edge_intersection &= G_edges_set + + if R is None: + raise ValueError("cannot apply intersection_all to an empty list") + + R.add_nodes_from(node_intersection) + R.add_edges_from(edge_intersection) + + return R diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/binary.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/binary.py new file mode 100644 index 0000000000000000000000000000000000000000..c1212927a6b855f85d5464bb349b0c20ed79beb0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/binary.py @@ -0,0 +1,468 @@ +""" +Operations on graphs including union, intersection, difference. +""" + +import networkx as nx + +__all__ = [ + "union", + "compose", + "disjoint_union", + "intersection", + "difference", + "symmetric_difference", + "full_join", +] +_G_H = {"G": 0, "H": 1} + + +@nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True) +def union(G, H, rename=()): + """Combine graphs G and H. The names of nodes must be unique. + + A name collision between the graphs will raise an exception. + + A renaming facility is provided to avoid name collisions. + + + Parameters + ---------- + G, H : graph + A NetworkX graph + + rename : iterable , optional + Node names of G and H can be changed by specifying the tuple + rename=('G-','H-') (for example). Node "u" in G is then renamed + "G-u" and "v" in H is renamed "H-v". + + Returns + ------- + U : A union graph with the same type as G. + + See Also + -------- + compose + :func:`~networkx.Graph.update` + disjoint_union + + Notes + ----- + To combine graphs that have common nodes, consider compose(G, H) + or the method, Graph.update(). + + disjoint_union() is similar to union() except that it avoids name clashes + by relabeling the nodes with sequential integers. + + Edge and node attributes are propagated from G and H to the union graph. + Graph attributes are also propagated, but if they are present in both G and H, + then the value from H is used. + + Examples + -------- + >>> from pprint import pprint + >>> G = nx.Graph([(0, 1), (0, 2), (1, 2)]) + >>> H = nx.Graph([(0, 1), (0, 3), (1, 3), (1, 2)]) + >>> U = nx.union(G, H, rename=("G", "H")) + >>> U.nodes + NodeView(('G0', 'G1', 'G2', 'H0', 'H1', 'H3', 'H2')) + >>> edgelist = list(U.edges) + >>> pprint(edgelist) + [('G0', 'G1'), + ('G0', 'G2'), + ('G1', 'G2'), + ('H0', 'H1'), + ('H0', 'H3'), + ('H1', 'H3'), + ('H1', 'H2')] + + + """ + return nx.union_all([G, H], rename) + + +@nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True) +def disjoint_union(G, H): + """Combine graphs G and H. The nodes are assumed to be unique (disjoint). + + This algorithm automatically relabels nodes to avoid name collisions. + + Parameters + ---------- + G,H : graph + A NetworkX graph + + Returns + ------- + U : A union graph with the same type as G. + + See Also + -------- + union + compose + :func:`~networkx.Graph.update` + + Notes + ----- + A new graph is created, of the same class as G. It is recommended + that G and H be either both directed or both undirected. + + The nodes of G are relabeled 0 to len(G)-1, and the nodes of H are + relabeled len(G) to len(G)+len(H)-1. + + Renumbering forces G and H to be disjoint, so no exception is ever raised for a name collision. + To preserve the check for common nodes, use union(). + + Edge and node attributes are propagated from G and H to the union graph. + Graph attributes are also propagated, but if they are present in both G and H, + then the value from H is used. + + To combine graphs that have common nodes, consider compose(G, H) + or the method, Graph.update(). + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2), (1, 2)]) + >>> H = nx.Graph([(0, 3), (1, 2), (2, 3)]) + >>> G.nodes[0]["key1"] = 5 + >>> H.nodes[0]["key2"] = 10 + >>> U = nx.disjoint_union(G, H) + >>> U.nodes(data=True) + NodeDataView({0: {'key1': 5}, 1: {}, 2: {}, 3: {'key2': 10}, 4: {}, 5: {}, 6: {}}) + >>> U.edges + EdgeView([(0, 1), (0, 2), (1, 2), (3, 4), (4, 6), (5, 6)]) + """ + return nx.disjoint_union_all([G, H]) + + +@nx._dispatchable(graphs=_G_H, returns_graph=True) +def intersection(G, H): + """Returns a new graph that contains only the nodes and the edges that exist in + both G and H. + + Parameters + ---------- + G,H : graph + A NetworkX graph. G and H can have different node sets but must be both graphs or both multigraphs. + + Raises + ------ + NetworkXError + If one is a MultiGraph and the other one is a graph. + + Returns + ------- + GH : A new graph with the same type as G. + + Notes + ----- + Attributes from the graph, nodes, and edges are not copied to the new + graph. If you want a new graph of the intersection of G and H + with the attributes (including edge data) from G use remove_nodes_from() + as follows + + >>> G = nx.path_graph(3) + >>> H = nx.path_graph(5) + >>> R = G.copy() + >>> R.remove_nodes_from(n for n in G if n not in H) + >>> R.remove_edges_from(e for e in G.edges if e not in H.edges) + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2), (1, 2)]) + >>> H = nx.Graph([(0, 3), (1, 2), (2, 3)]) + >>> R = nx.intersection(G, H) + >>> R.nodes + NodeView((0, 1, 2)) + >>> R.edges + EdgeView([(1, 2)]) + """ + return nx.intersection_all([G, H]) + + +@nx._dispatchable(graphs=_G_H, returns_graph=True) +def difference(G, H): + """Returns a new graph that contains the edges that exist in G but not in H. + + The node sets of H and G must be the same. + + Parameters + ---------- + G,H : graph + A NetworkX graph. G and H must have the same node sets. + + Returns + ------- + D : A new graph with the same type as G. + + Notes + ----- + Attributes from the graph, nodes, and edges are not copied to the new + graph. If you want a new graph of the difference of G and H with + the attributes (including edge data) from G use remove_nodes_from() + as follows: + + >>> G = nx.path_graph(3) + >>> H = nx.path_graph(5) + >>> R = G.copy() + >>> R.remove_nodes_from(n for n in G if n in H) + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2), (1, 2), (1, 3)]) + >>> H = nx.Graph([(0, 1), (1, 2), (0, 3)]) + >>> R = nx.difference(G, H) + >>> R.nodes + NodeView((0, 1, 2, 3)) + >>> R.edges + EdgeView([(0, 2), (1, 3)]) + """ + # create new graph + if not G.is_multigraph() == H.is_multigraph(): + raise nx.NetworkXError("G and H must both be graphs or multigraphs.") + R = nx.create_empty_copy(G, with_data=False) + + if set(G) != set(H): + raise nx.NetworkXError("Node sets of graphs not equal") + + if G.is_multigraph(): + edges = G.edges(keys=True) + else: + edges = G.edges() + for e in edges: + if not H.has_edge(*e): + R.add_edge(*e) + return R + + +@nx._dispatchable(graphs=_G_H, returns_graph=True) +def symmetric_difference(G, H): + """Returns new graph with edges that exist in either G or H but not both. + + The node sets of H and G must be the same. + + Parameters + ---------- + G,H : graph + A NetworkX graph. G and H must have the same node sets. + + Returns + ------- + D : A new graph with the same type as G. + + Notes + ----- + Attributes from the graph, nodes, and edges are not copied to the new + graph. + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2), (1, 2), (1, 3)]) + >>> H = nx.Graph([(0, 1), (1, 2), (0, 3)]) + >>> R = nx.symmetric_difference(G, H) + >>> R.nodes + NodeView((0, 1, 2, 3)) + >>> R.edges + EdgeView([(0, 2), (0, 3), (1, 3)]) + """ + # create new graph + if not G.is_multigraph() == H.is_multigraph(): + raise nx.NetworkXError("G and H must both be graphs or multigraphs.") + R = nx.create_empty_copy(G, with_data=False) + + if set(G) != set(H): + raise nx.NetworkXError("Node sets of graphs not equal") + + gnodes = set(G) # set of nodes in G + hnodes = set(H) # set of nodes in H + nodes = gnodes.symmetric_difference(hnodes) + R.add_nodes_from(nodes) + + if G.is_multigraph(): + edges = G.edges(keys=True) + else: + edges = G.edges() + # we could copy the data here but then this function doesn't + # match intersection and difference + for e in edges: + if not H.has_edge(*e): + R.add_edge(*e) + + if H.is_multigraph(): + edges = H.edges(keys=True) + else: + edges = H.edges() + for e in edges: + if not G.has_edge(*e): + R.add_edge(*e) + return R + + +@nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True) +def compose(G, H): + """Compose graph G with H by combining nodes and edges into a single graph. + + The node sets and edges sets do not need to be disjoint. + + Composing preserves the attributes of nodes and edges. + Attribute values from H take precedent over attribute values from G. + + Parameters + ---------- + G, H : graph + A NetworkX graph + + Returns + ------- + C: A new graph with the same type as G + + See Also + -------- + :func:`~networkx.Graph.update` + union + disjoint_union + + Notes + ----- + It is recommended that G and H be either both directed or both undirected. + + For MultiGraphs, the edges are identified by incident nodes AND edge-key. + This can cause surprises (i.e., edge `(1, 2)` may or may not be the same + in two graphs) if you use MultiGraph without keeping track of edge keys. + + If combining the attributes of common nodes is not desired, consider union(), + which raises an exception for name collisions. + + Examples + -------- + >>> G = nx.Graph([(0, 1), (0, 2)]) + >>> H = nx.Graph([(0, 1), (1, 2)]) + >>> R = nx.compose(G, H) + >>> R.nodes + NodeView((0, 1, 2)) + >>> R.edges + EdgeView([(0, 1), (0, 2), (1, 2)]) + + By default, the attributes from `H` take precedent over attributes from `G`. + If you prefer another way of combining attributes, you can update them after the compose operation: + + >>> G = nx.Graph([(0, 1, {"weight": 2.0}), (3, 0, {"weight": 100.0})]) + >>> H = nx.Graph([(0, 1, {"weight": 10.0}), (1, 2, {"weight": -1.0})]) + >>> nx.set_node_attributes(G, {0: "dark", 1: "light", 3: "black"}, name="color") + >>> nx.set_node_attributes(H, {0: "green", 1: "orange", 2: "yellow"}, name="color") + >>> GcomposeH = nx.compose(G, H) + + Normally, color attribute values of nodes of GcomposeH come from H. We can workaround this as follows: + + >>> node_data = { + ... n: G.nodes[n]["color"] + " " + H.nodes[n]["color"] + ... for n in G.nodes & H.nodes + ... } + >>> nx.set_node_attributes(GcomposeH, node_data, "color") + >>> print(GcomposeH.nodes[0]["color"]) + dark green + + >>> print(GcomposeH.nodes[3]["color"]) + black + + Similarly, we can update edge attributes after the compose operation in a way we prefer: + + >>> edge_data = { + ... e: G.edges[e]["weight"] * H.edges[e]["weight"] for e in G.edges & H.edges + ... } + >>> nx.set_edge_attributes(GcomposeH, edge_data, "weight") + >>> print(GcomposeH.edges[(0, 1)]["weight"]) + 20.0 + + >>> print(GcomposeH.edges[(3, 0)]["weight"]) + 100.0 + """ + return nx.compose_all([G, H]) + + +@nx._dispatchable(graphs=_G_H, preserve_all_attrs=True, returns_graph=True) +def full_join(G, H, rename=(None, None)): + """Returns the full join of graphs G and H. + + Full join is the union of G and H in which all edges between + G and H are added. + The node sets of G and H must be disjoint, + otherwise an exception is raised. + + Parameters + ---------- + G, H : graph + A NetworkX graph + + rename : tuple , default=(None, None) + Node names of G and H can be changed by specifying the tuple + rename=('G-','H-') (for example). Node "u" in G is then renamed + "G-u" and "v" in H is renamed "H-v". + + Returns + ------- + U : The full join graph with the same type as G. + + Notes + ----- + It is recommended that G and H be either both directed or both undirected. + + If G is directed, then edges from G to H are added as well as from H to G. + + Note that full_join() does not produce parallel edges for MultiGraphs. + + The full join operation of graphs G and H is the same as getting + their complement, performing a disjoint union, and finally getting + the complement of the resulting graph. + + Graph, edge, and node attributes are propagated from G and H + to the union graph. If a graph attribute is present in both + G and H the value from H is used. + + Examples + -------- + >>> from pprint import pprint + >>> G = nx.Graph([(0, 1), (0, 2)]) + >>> H = nx.Graph([(3, 4)]) + >>> R = nx.full_join(G, H, rename=("G", "H")) + >>> R.nodes + NodeView(('G0', 'G1', 'G2', 'H3', 'H4')) + >>> edgelist = list(R.edges) + >>> pprint(edgelist) + [('G0', 'G1'), + ('G0', 'G2'), + ('G0', 'H3'), + ('G0', 'H4'), + ('G1', 'H3'), + ('G1', 'H4'), + ('G2', 'H3'), + ('G2', 'H4'), + ('H3', 'H4')] + + See Also + -------- + union + disjoint_union + """ + R = union(G, H, rename) + + def add_prefix(graph, prefix): + if prefix is None: + return graph + + def label(x): + return f"{prefix}{x}" + + return nx.relabel_nodes(graph, label) + + G = add_prefix(G, rename[0]) + H = add_prefix(H, rename[1]) + + for i in G: + for j in H: + R.add_edge(i, j) + if R.is_directed(): + for i in H: + for j in G: + R.add_edge(i, j) + + return R diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/product.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/product.py new file mode 100644 index 0000000000000000000000000000000000000000..28ca78bf4deb45ffa422d2792b966adfa112692f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/product.py @@ -0,0 +1,633 @@ +""" +Graph products. +""" + +from itertools import product + +import networkx as nx +from networkx.utils import not_implemented_for + +__all__ = [ + "tensor_product", + "cartesian_product", + "lexicographic_product", + "strong_product", + "power", + "rooted_product", + "corona_product", + "modular_product", +] +_G_H = {"G": 0, "H": 1} + + +def _dict_product(d1, d2): + return {k: (d1.get(k), d2.get(k)) for k in set(d1) | set(d2)} + + +# Generators for producing graph products +def _node_product(G, H): + for u, v in product(G, H): + yield ((u, v), _dict_product(G.nodes[u], H.nodes[v])) + + +def _directed_edges_cross_edges(G, H): + if not G.is_multigraph() and not H.is_multigraph(): + for u, v, c in G.edges(data=True): + for x, y, d in H.edges(data=True): + yield (u, x), (v, y), _dict_product(c, d) + if not G.is_multigraph() and H.is_multigraph(): + for u, v, c in G.edges(data=True): + for x, y, k, d in H.edges(data=True, keys=True): + yield (u, x), (v, y), k, _dict_product(c, d) + if G.is_multigraph() and not H.is_multigraph(): + for u, v, k, c in G.edges(data=True, keys=True): + for x, y, d in H.edges(data=True): + yield (u, x), (v, y), k, _dict_product(c, d) + if G.is_multigraph() and H.is_multigraph(): + for u, v, j, c in G.edges(data=True, keys=True): + for x, y, k, d in H.edges(data=True, keys=True): + yield (u, x), (v, y), (j, k), _dict_product(c, d) + + +def _undirected_edges_cross_edges(G, H): + if not G.is_multigraph() and not H.is_multigraph(): + for u, v, c in G.edges(data=True): + for x, y, d in H.edges(data=True): + yield (v, x), (u, y), _dict_product(c, d) + if not G.is_multigraph() and H.is_multigraph(): + for u, v, c in G.edges(data=True): + for x, y, k, d in H.edges(data=True, keys=True): + yield (v, x), (u, y), k, _dict_product(c, d) + if G.is_multigraph() and not H.is_multigraph(): + for u, v, k, c in G.edges(data=True, keys=True): + for x, y, d in H.edges(data=True): + yield (v, x), (u, y), k, _dict_product(c, d) + if G.is_multigraph() and H.is_multigraph(): + for u, v, j, c in G.edges(data=True, keys=True): + for x, y, k, d in H.edges(data=True, keys=True): + yield (v, x), (u, y), (j, k), _dict_product(c, d) + + +def _edges_cross_nodes(G, H): + if G.is_multigraph(): + for u, v, k, d in G.edges(data=True, keys=True): + for x in H: + yield (u, x), (v, x), k, d + else: + for u, v, d in G.edges(data=True): + for x in H: + if H.is_multigraph(): + yield (u, x), (v, x), None, d + else: + yield (u, x), (v, x), d + + +def _nodes_cross_edges(G, H): + if H.is_multigraph(): + for x in G: + for u, v, k, d in H.edges(data=True, keys=True): + yield (x, u), (x, v), k, d + else: + for x in G: + for u, v, d in H.edges(data=True): + if G.is_multigraph(): + yield (x, u), (x, v), None, d + else: + yield (x, u), (x, v), d + + +def _edges_cross_nodes_and_nodes(G, H): + if G.is_multigraph(): + for u, v, k, d in G.edges(data=True, keys=True): + for x in H: + for y in H: + yield (u, x), (v, y), k, d + else: + for u, v, d in G.edges(data=True): + for x in H: + for y in H: + if H.is_multigraph(): + yield (u, x), (v, y), None, d + else: + yield (u, x), (v, y), d + + +def _init_product_graph(G, H): + if G.is_directed() != H.is_directed(): + msg = "G and H must be both directed or both undirected" + raise nx.NetworkXError(msg) + if G.is_multigraph() or H.is_multigraph(): + GH = nx.MultiGraph() + else: + GH = nx.Graph() + if G.is_directed(): + GH = GH.to_directed() + return GH + + +@nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True) +def tensor_product(G, H): + r"""Returns the tensor product of G and H. + + The tensor product $P$ of the graphs $G$ and $H$ has a node set that + is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$. + $P$ has an edge $((u,v), (x,y))$ if and only if $(u,x)$ is an edge in $G$ + and $(v,y)$ is an edge in $H$. + + Tensor product is sometimes also referred to as the categorical product, + direct product, cardinal product or conjunction. + + + Parameters + ---------- + G, H: graphs + Networkx graphs. + + Returns + ------- + P: NetworkX graph + The tensor product of G and H. P will be a multi-graph if either G + or H is a multi-graph, will be a directed if G and H are directed, + and undirected if G and H are undirected. + + Raises + ------ + NetworkXError + If G and H are not both directed or both undirected. + + Notes + ----- + Node attributes in P are two-tuple of the G and H node attributes. + Missing attributes are assigned None. + + Examples + -------- + >>> G = nx.Graph() + >>> H = nx.Graph() + >>> G.add_node(0, a1=True) + >>> H.add_node("a", a2="Spam") + >>> P = nx.tensor_product(G, H) + >>> list(P) + [(0, 'a')] + + Edge attributes and edge keys (for multigraphs) are also copied to the + new product graph + """ + GH = _init_product_graph(G, H) + GH.add_nodes_from(_node_product(G, H)) + GH.add_edges_from(_directed_edges_cross_edges(G, H)) + if not GH.is_directed(): + GH.add_edges_from(_undirected_edges_cross_edges(G, H)) + return GH + + +@nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True) +def cartesian_product(G, H): + r"""Returns the Cartesian product of G and H. + + The Cartesian product $P$ of the graphs $G$ and $H$ has a node set that + is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$. + $P$ has an edge $((u,v),(x,y))$ if and only if either $u$ is equal to $x$ + and both $v$ and $y$ are adjacent in $H$ or if $v$ is equal to $y$ and + both $u$ and $x$ are adjacent in $G$. + + Parameters + ---------- + G, H: graphs + Networkx graphs. + + Returns + ------- + P: NetworkX graph + The Cartesian product of G and H. P will be a multi-graph if either G + or H is a multi-graph. Will be a directed if G and H are directed, + and undirected if G and H are undirected. + + Raises + ------ + NetworkXError + If G and H are not both directed or both undirected. + + Notes + ----- + Node attributes in P are two-tuple of the G and H node attributes. + Missing attributes are assigned None. + + Examples + -------- + >>> G = nx.Graph() + >>> H = nx.Graph() + >>> G.add_node(0, a1=True) + >>> H.add_node("a", a2="Spam") + >>> P = nx.cartesian_product(G, H) + >>> list(P) + [(0, 'a')] + + Edge attributes and edge keys (for multigraphs) are also copied to the + new product graph + """ + GH = _init_product_graph(G, H) + GH.add_nodes_from(_node_product(G, H)) + GH.add_edges_from(_edges_cross_nodes(G, H)) + GH.add_edges_from(_nodes_cross_edges(G, H)) + return GH + + +@nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True) +def lexicographic_product(G, H): + r"""Returns the lexicographic product of G and H. + + The lexicographical product $P$ of the graphs $G$ and $H$ has a node set + that is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$. + $P$ has an edge $((u,v), (x,y))$ if and only if $(u,v)$ is an edge in $G$ + or $u==v$ and $(x,y)$ is an edge in $H$. + + Parameters + ---------- + G, H: graphs + Networkx graphs. + + Returns + ------- + P: NetworkX graph + The Cartesian product of G and H. P will be a multi-graph if either G + or H is a multi-graph. Will be a directed if G and H are directed, + and undirected if G and H are undirected. + + Raises + ------ + NetworkXError + If G and H are not both directed or both undirected. + + Notes + ----- + Node attributes in P are two-tuple of the G and H node attributes. + Missing attributes are assigned None. + + Examples + -------- + >>> G = nx.Graph() + >>> H = nx.Graph() + >>> G.add_node(0, a1=True) + >>> H.add_node("a", a2="Spam") + >>> P = nx.lexicographic_product(G, H) + >>> list(P) + [(0, 'a')] + + Edge attributes and edge keys (for multigraphs) are also copied to the + new product graph + """ + GH = _init_product_graph(G, H) + GH.add_nodes_from(_node_product(G, H)) + # Edges in G regardless of H designation + GH.add_edges_from(_edges_cross_nodes_and_nodes(G, H)) + # For each x in G, only if there is an edge in H + GH.add_edges_from(_nodes_cross_edges(G, H)) + return GH + + +@nx._dispatchable(graphs=_G_H, preserve_node_attrs=True, returns_graph=True) +def strong_product(G, H): + r"""Returns the strong product of G and H. + + The strong product $P$ of the graphs $G$ and $H$ has a node set that + is the Cartesian product of the node sets, $V(P)=V(G) \times V(H)$. + $P$ has an edge $((u,x), (v,y))$ if any of the following conditions + are met: + + - $u=v$ and $(x,y)$ is an edge in $H$ + - $x=y$ and $(u,v)$ is an edge in $G$ + - $(u,v)$ is an edge in $G$ and $(x,y)$ is an edge in $H$ + + Parameters + ---------- + G, H: graphs + Networkx graphs. + + Returns + ------- + P: NetworkX graph + The Cartesian product of G and H. P will be a multi-graph if either G + or H is a multi-graph. Will be a directed if G and H are directed, + and undirected if G and H are undirected. + + Raises + ------ + NetworkXError + If G and H are not both directed or both undirected. + + Notes + ----- + Node attributes in P are two-tuple of the G and H node attributes. + Missing attributes are assigned None. + + Examples + -------- + >>> G = nx.Graph() + >>> H = nx.Graph() + >>> G.add_node(0, a1=True) + >>> H.add_node("a", a2="Spam") + >>> P = nx.strong_product(G, H) + >>> list(P) + [(0, 'a')] + + Edge attributes and edge keys (for multigraphs) are also copied to the + new product graph + """ + GH = _init_product_graph(G, H) + GH.add_nodes_from(_node_product(G, H)) + GH.add_edges_from(_nodes_cross_edges(G, H)) + GH.add_edges_from(_edges_cross_nodes(G, H)) + GH.add_edges_from(_directed_edges_cross_edges(G, H)) + if not GH.is_directed(): + GH.add_edges_from(_undirected_edges_cross_edges(G, H)) + return GH + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable(returns_graph=True) +def power(G, k): + """Returns the specified power of a graph. + + The $k$th power of a simple graph $G$, denoted $G^k$, is a + graph on the same set of nodes in which two distinct nodes $u$ and + $v$ are adjacent in $G^k$ if and only if the shortest path + distance between $u$ and $v$ in $G$ is at most $k$. + + Parameters + ---------- + G : graph + A NetworkX simple graph object. + + k : positive integer + The power to which to raise the graph `G`. + + Returns + ------- + NetworkX simple graph + `G` to the power `k`. + + Raises + ------ + ValueError + If the exponent `k` is not positive. + + NetworkXNotImplemented + If `G` is not a simple graph. + + Examples + -------- + The number of edges will never decrease when taking successive + powers: + + >>> G = nx.path_graph(4) + >>> list(nx.power(G, 2).edges) + [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)] + >>> list(nx.power(G, 3).edges) + [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)] + + The `k` th power of a cycle graph on *n* nodes is the complete graph + on *n* nodes, if `k` is at least ``n // 2``: + + >>> G = nx.cycle_graph(5) + >>> H = nx.complete_graph(5) + >>> nx.is_isomorphic(nx.power(G, 2), H) + True + >>> G = nx.cycle_graph(8) + >>> H = nx.complete_graph(8) + >>> nx.is_isomorphic(nx.power(G, 4), H) + True + + References + ---------- + .. [1] J. A. Bondy, U. S. R. Murty, *Graph Theory*. Springer, 2008. + + Notes + ----- + This definition of "power graph" comes from Exercise 3.1.6 of + *Graph Theory* by Bondy and Murty [1]_. + + """ + if k <= 0: + raise ValueError("k must be a positive integer") + H = nx.Graph() + H.add_nodes_from(G) + # update BFS code to ignore self loops. + for n in G: + seen = {} # level (number of hops) when seen in BFS + level = 1 # the current level + nextlevel = G[n] + while nextlevel: + thislevel = nextlevel # advance to next level + nextlevel = {} # and start a new list (fringe) + for v in thislevel: + if v == n: # avoid self loop + continue + if v not in seen: + seen[v] = level # set the level of vertex v + nextlevel.update(G[v]) # add neighbors of v + if k <= level: + break + level += 1 + H.add_edges_from((n, nbr) for nbr in seen) + return H + + +@not_implemented_for("multigraph") +@nx._dispatchable(graphs=_G_H, returns_graph=True) +def rooted_product(G, H, root): + """Return the rooted product of graphs G and H rooted at root in H. + + A new graph is constructed representing the rooted product of + the inputted graphs, G and H, with a root in H. + A rooted product duplicates H for each nodes in G with the root + of H corresponding to the node in G. Nodes are renamed as the direct + product of G and H. The result is a subgraph of the cartesian product. + + Parameters + ---------- + G,H : graph + A NetworkX graph + root : node + A node in H + + Returns + ------- + R : The rooted product of G and H with a specified root in H + + Notes + ----- + The nodes of R are the Cartesian Product of the nodes of G and H. + The nodes of G and H are not relabeled. + """ + if root not in H: + raise nx.NodeNotFound("root must be a vertex in H") + + R = nx.Graph() + R.add_nodes_from(product(G, H)) + + R.add_edges_from(((e[0], root), (e[1], root)) for e in G.edges()) + R.add_edges_from(((g, e[0]), (g, e[1])) for g in G for e in H.edges()) + + return R + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@nx._dispatchable(graphs=_G_H, returns_graph=True) +def corona_product(G, H): + r"""Returns the Corona product of G and H. + + The corona product of $G$ and $H$ is the graph $C = G \circ H$ obtained by + taking one copy of $G$, called the center graph, $|V(G)|$ copies of $H$, + called the outer graph, and making the $i$-th vertex of $G$ adjacent to + every vertex of the $i$-th copy of $H$, where $1 ≤ i ≤ |V(G)|$. + + Parameters + ---------- + G, H: NetworkX graphs + The graphs to take the carona product of. + `G` is the center graph and `H` is the outer graph + + Returns + ------- + C: NetworkX graph + The Corona product of G and H. + + Raises + ------ + NetworkXError + If G and H are not both directed or both undirected. + + Examples + -------- + >>> G = nx.cycle_graph(4) + >>> H = nx.path_graph(2) + >>> C = nx.corona_product(G, H) + >>> list(C) + [0, 1, 2, 3, (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (3, 0), (3, 1)] + >>> print(C) + Graph with 12 nodes and 16 edges + + References + ---------- + [1] M. Tavakoli, F. Rahbarnia, and A. R. Ashrafi, + "Studying the corona product of graphs under some graph invariants," + Transactions on Combinatorics, vol. 3, no. 3, pp. 43–49, Sep. 2014, + doi: 10.22108/toc.2014.5542. + [2] A. Faraji, "Corona Product in Graph Theory," Ali Faraji, May 11, 2021. + https://blog.alifaraji.ir/math/graph-theory/corona-product.html (accessed Dec. 07, 2021). + """ + GH = _init_product_graph(G, H) + GH.add_nodes_from(G) + GH.add_edges_from(G.edges) + + for G_node in G: + # copy nodes of H in GH, call it H_i + GH.add_nodes_from((G_node, v) for v in H) + + # copy edges of H_i based on H + GH.add_edges_from( + ((G_node, e0), (G_node, e1), d) for e0, e1, d in H.edges.data() + ) + + # creating new edges between H_i and a G's node + GH.add_edges_from((G_node, (G_node, H_node)) for H_node in H) + + return GH + + +@nx._dispatchable( + graphs=_G_H, preserve_edge_attrs=True, preserve_node_attrs=True, returns_graph=True +) +def modular_product(G, H): + r"""Returns the Modular product of G and H. + + The modular product of `G` and `H` is the graph $M = G \nabla H$, + consisting of the node set $V(M) = V(G) \times V(H)$ that is the Cartesian + product of the node sets of `G` and `H`. Further, M contains an edge ((u, v), (x, y)): + + - if u is adjacent to x in `G` and v is adjacent to y in `H`, or + - if u is not adjacent to x in `G` and v is not adjacent to y in `H`. + + More formally:: + + E(M) = {((u, v), (x, y)) | ((u, x) in E(G) and (v, y) in E(H)) or + ((u, x) not in E(G) and (v, y) not in E(H))} + + Parameters + ---------- + G, H: NetworkX graphs + The graphs to take the modular product of. + + Returns + ------- + M: NetworkX graph + The Modular product of `G` and `H`. + + Raises + ------ + NetworkXNotImplemented + If `G` is not a simple graph. + + Examples + -------- + >>> G = nx.cycle_graph(4) + >>> H = nx.path_graph(2) + >>> M = nx.modular_product(G, H) + >>> list(M) + [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (3, 0), (3, 1)] + >>> print(M) + Graph with 8 nodes and 8 edges + + Notes + ----- + The *modular product* is defined in [1]_ and was first + introduced as the *weak modular product*. + + The modular product reduces the problem of counting isomorphic subgraphs + in `G` and `H` to the problem of counting cliques in M. The subgraphs of + `G` and `H` that are induced by the nodes of a clique in M are + isomorphic [2]_ [3]_. + + References + ---------- + .. [1] R. Hammack, W. Imrich, and S. Klavžar, + "Handbook of Product Graphs", CRC Press, 2011. + + .. [2] H. G. Barrow and R. M. Burstall, + "Subgraph isomorphism, matching relational structures and maximal + cliques", Information Processing Letters, vol. 4, issue 4, pp. 83-84, + 1976, https://doi.org/10.1016/0020-0190(76)90049-1. + + .. [3] V. G. Vizing, "Reduction of the problem of isomorphism and isomorphic + entrance to the task of finding the nondensity of a graph." Proc. Third + All-Union Conference on Problems of Theoretical Cybernetics. 1974. + """ + if G.is_directed() or H.is_directed(): + raise nx.NetworkXNotImplemented( + "Modular product not implemented for directed graphs" + ) + if G.is_multigraph() or H.is_multigraph(): + raise nx.NetworkXNotImplemented( + "Modular product not implemented for multigraphs" + ) + + GH = _init_product_graph(G, H) + GH.add_nodes_from(_node_product(G, H)) + + for u, v, c in G.edges(data=True): + for x, y, d in H.edges(data=True): + GH.add_edge((u, x), (v, y), **_dict_product(c, d)) + GH.add_edge((v, x), (u, y), **_dict_product(c, d)) + + G = nx.complement(G) + H = nx.complement(H) + + for u, v, c in G.edges(data=True): + for x, y, d in H.edges(data=True): + GH.add_edge((u, x), (v, y), **_dict_product(c, d)) + GH.add_edge((v, x), (u, y), **_dict_product(c, d)) + + return GH diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/unary.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/unary.py new file mode 100644 index 0000000000000000000000000000000000000000..79e44d1cc04cff72c5c87d1852544514a6f53246 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/operators/unary.py @@ -0,0 +1,77 @@ +"""Unary operations on graphs""" + +import networkx as nx + +__all__ = ["complement", "reverse"] + + +@nx._dispatchable(returns_graph=True) +def complement(G): + """Returns the graph complement of G. + + Parameters + ---------- + G : graph + A NetworkX graph + + Returns + ------- + GC : A new graph. + + Notes + ----- + Note that `complement` does not create self-loops and also + does not produce parallel edges for MultiGraphs. + + Graph, node, and edge data are not propagated to the new graph. + + Examples + -------- + >>> G = nx.Graph([(1, 2), (1, 3), (2, 3), (3, 4), (3, 5)]) + >>> G_complement = nx.complement(G) + >>> G_complement.edges() # This shows the edges of the complemented graph + EdgeView([(1, 4), (1, 5), (2, 4), (2, 5), (4, 5)]) + + """ + R = G.__class__() + R.add_nodes_from(G) + R.add_edges_from( + ((n, n2) for n, nbrs in G.adjacency() for n2 in G if n2 not in nbrs if n != n2) + ) + return R + + +@nx._dispatchable(returns_graph=True) +def reverse(G, copy=True): + """Returns the reverse directed graph of G. + + Parameters + ---------- + G : directed graph + A NetworkX directed graph + copy : bool + If True, then a new graph is returned. If False, then the graph is + reversed in place. + + Returns + ------- + H : directed graph + The reversed G. + + Raises + ------ + NetworkXError + If graph is undirected. + + Examples + -------- + >>> G = nx.DiGraph([(1, 2), (1, 3), (2, 3), (3, 4), (3, 5)]) + >>> G_reversed = nx.reverse(G) + >>> G_reversed.edges() + OutEdgeView([(2, 1), (3, 1), (3, 2), (4, 3), (5, 3)]) + + """ + if not G.is_directed(): + raise nx.NetworkXError("Cannot reverse an undirected graph.") + else: + return G.reverse(copy=copy) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93e6cdd08ca959835cc5c5f9a3e6dc353f4b217d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/__init__.py @@ -0,0 +1,5 @@ +from .beamsearch import * +from .breadth_first_search import * +from .depth_first_search import * +from .edgedfs import * +from .edgebfs import * diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/beamsearch.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/beamsearch.py new file mode 100644 index 0000000000000000000000000000000000000000..23fbe7bbb3f037eca4f7ede25b7edf6305356587 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/beamsearch.py @@ -0,0 +1,90 @@ +"""Basic algorithms for breadth-first searching the nodes of a graph.""" + +import networkx as nx + +__all__ = ["bfs_beam_edges"] + + +@nx._dispatchable +def bfs_beam_edges(G, source, value, width=None): + """Iterates over edges in a beam search. + + The beam search is a generalized breadth-first search in which only + the "best" *w* neighbors of the current node are enqueued, where *w* + is the beam width and "best" is an application-specific + heuristic. In general, a beam search with a small beam width might + not visit each node in the graph. + + .. note:: + + With the default value of ``width=None`` or `width` greater than the + maximum degree of the graph, this function equates to a slower + version of `~networkx.algorithms.traversal.breadth_first_search.bfs_edges`. + All nodes will be visited, though the order of the reported edges may + vary. In such cases, `value` has no effect - consider using `bfs_edges` + directly instead. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node for the breadth-first search; this function + iterates over only those edges in the component reachable from + this node. + + value : function + A function that takes a node of the graph as input and returns a + real number indicating how "good" it is. A higher value means it + is more likely to be visited sooner during the search. When + visiting a new node, only the `width` neighbors with the highest + `value` are enqueued (in decreasing order of `value`). + + width : int (default = None) + The beam width for the search. This is the number of neighbors + (ordered by `value`) to enqueue when visiting each new node. + + Yields + ------ + edge + Edges in the beam search starting from `source`, given as a pair + of nodes. + + Examples + -------- + To give nodes with, for example, a higher centrality precedence + during the search, set the `value` function to return the centrality + value of the node: + + >>> G = nx.karate_club_graph() + >>> centrality = nx.eigenvector_centrality(G) + >>> list(nx.bfs_beam_edges(G, source=0, value=centrality.get, width=3)) + [(0, 2), (0, 1), (0, 8), (2, 32), (1, 13), (8, 33)] + """ + + if width is None: + width = len(G) + + def successors(v): + """Returns a list of the best neighbors of a node. + + `v` is a node in the graph `G`. + + The "best" neighbors are chosen according to the `value` + function (higher is better). Only the `width` best neighbors of + `v` are returned. + """ + # TODO The Python documentation states that for small values, it + # is better to use `heapq.nlargest`. We should determine the + # threshold at which its better to use `heapq.nlargest()` + # instead of `sorted()[:]` and apply that optimization here. + # + # If `width` is greater than the number of neighbors of `v`, all + # neighbors are returned by the semantics of slicing in + # Python. This occurs in the special case that the user did not + # specify a `width`: in this case all neighbors are always + # returned, so this is just a (slower) implementation of + # `bfs_edges(G, source)` but with a sorted enqueue step. + return iter(sorted(G.neighbors(v), key=value, reverse=True)[:width]) + + yield from nx.generic_bfs_edges(G, source, successors) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/breadth_first_search.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/breadth_first_search.py new file mode 100644 index 0000000000000000000000000000000000000000..71076ebb03c445ca842b040ff86e3a2d4e719af1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/breadth_first_search.py @@ -0,0 +1,576 @@ +"""Basic algorithms for breadth-first searching the nodes of a graph.""" + +from collections import deque + +import networkx as nx + +__all__ = [ + "bfs_edges", + "bfs_tree", + "bfs_predecessors", + "bfs_successors", + "descendants_at_distance", + "bfs_layers", + "bfs_labeled_edges", + "generic_bfs_edges", +] + + +@nx._dispatchable +def generic_bfs_edges(G, source, neighbors=None, depth_limit=None): + """Iterate over edges in a breadth-first search. + + The breadth-first search begins at `source` and enqueues the + neighbors of newly visited nodes specified by the `neighbors` + function. + + Parameters + ---------- + G : NetworkX graph + + source : node + Starting node for the breadth-first search; this function + iterates over only those edges in the component reachable from + this node. + + neighbors : function + A function that takes a newly visited node of the graph as input + and returns an *iterator* (not just a list) of nodes that are + neighbors of that node with custom ordering. If not specified, this is + just the ``G.neighbors`` method, but in general it can be any function + that returns an iterator over some or all of the neighbors of a + given node, in any order. + + depth_limit : int, optional(default=len(G)) + Specify the maximum search depth. + + Yields + ------ + edge + Edges in the breadth-first search starting from `source`. + + Examples + -------- + >>> G = nx.path_graph(7) + >>> list(nx.generic_bfs_edges(G, source=0)) + [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)] + >>> list(nx.generic_bfs_edges(G, source=2)) + [(2, 1), (2, 3), (1, 0), (3, 4), (4, 5), (5, 6)] + >>> list(nx.generic_bfs_edges(G, source=2, depth_limit=2)) + [(2, 1), (2, 3), (1, 0), (3, 4)] + + The `neighbors` param can be used to specify the visitation order of each + node's neighbors generically. In the following example, we modify the default + neighbor to return *odd* nodes first: + + >>> def odd_first(n): + ... return sorted(G.neighbors(n), key=lambda x: x % 2, reverse=True) + + >>> G = nx.star_graph(5) + >>> list(nx.generic_bfs_edges(G, source=0)) # Default neighbor ordering + [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5)] + >>> list(nx.generic_bfs_edges(G, source=0, neighbors=odd_first)) + [(0, 1), (0, 3), (0, 5), (0, 2), (0, 4)] + + Notes + ----- + This implementation is from `PADS`_, which was in the public domain + when it was first accessed in July, 2004. The modifications + to allow depth limits are based on the Wikipedia article + "`Depth-limited-search`_". + + .. _PADS: http://www.ics.uci.edu/~eppstein/PADS/BFS.py + .. _Depth-limited-search: https://en.wikipedia.org/wiki/Depth-limited_search + """ + if neighbors is None: + neighbors = G.neighbors + if depth_limit is None: + depth_limit = len(G) + + seen = {source} + n = len(G) + depth = 0 + next_parents_children = [(source, neighbors(source))] + while next_parents_children and depth < depth_limit: + this_parents_children = next_parents_children + next_parents_children = [] + for parent, children in this_parents_children: + for child in children: + if child not in seen: + seen.add(child) + next_parents_children.append((child, neighbors(child))) + yield parent, child + if len(seen) == n: + return + depth += 1 + + +@nx._dispatchable +def bfs_edges(G, source, reverse=False, depth_limit=None, sort_neighbors=None): + """Iterate over edges in a breadth-first-search starting at source. + + Parameters + ---------- + G : NetworkX graph + + source : node + Specify starting node for breadth-first search; this function + iterates over only those edges in the component reachable from + this node. + + reverse : bool, optional + If True traverse a directed graph in the reverse direction + + depth_limit : int, optional(default=len(G)) + Specify the maximum search depth + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Yields + ------ + edge: 2-tuple of nodes + Yields edges resulting from the breadth-first search. + + Examples + -------- + To get the edges in a breadth-first search: + + >>> G = nx.path_graph(3) + >>> list(nx.bfs_edges(G, 0)) + [(0, 1), (1, 2)] + >>> list(nx.bfs_edges(G, source=0, depth_limit=1)) + [(0, 1)] + + To get the nodes in a breadth-first search order: + + >>> G = nx.path_graph(3) + >>> root = 2 + >>> edges = nx.bfs_edges(G, root) + >>> nodes = [root] + [v for u, v in edges] + >>> nodes + [2, 1, 0] + + Notes + ----- + The naming of this function is very similar to + :func:`~networkx.algorithms.traversal.edgebfs.edge_bfs`. The difference + is that ``edge_bfs`` yields edges even if they extend back to an already + explored node while this generator yields the edges of the tree that results + from a breadth-first-search (BFS) so no edges are reported if they extend + to already explored nodes. That means ``edge_bfs`` reports all edges while + ``bfs_edges`` only reports those traversed by a node-based BFS. Yet another + description is that ``bfs_edges`` reports the edges traversed during BFS + while ``edge_bfs`` reports all edges in the order they are explored. + + Based on the breadth-first search implementation in PADS [1]_ + by D. Eppstein, July 2004; with modifications to allow depth limits + as described in [2]_. + + References + ---------- + .. [1] http://www.ics.uci.edu/~eppstein/PADS/BFS.py. + .. [2] https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + bfs_tree + :func:`~networkx.algorithms.traversal.depth_first_search.dfs_edges` + :func:`~networkx.algorithms.traversal.edgebfs.edge_bfs` + + """ + if reverse and G.is_directed(): + successors = G.predecessors + else: + successors = G.neighbors + + if sort_neighbors is not None: + yield from generic_bfs_edges( + G, source, lambda node: iter(sort_neighbors(successors(node))), depth_limit + ) + else: + yield from generic_bfs_edges(G, source, successors, depth_limit) + + +@nx._dispatchable(returns_graph=True) +def bfs_tree(G, source, reverse=False, depth_limit=None, sort_neighbors=None): + """Returns an oriented tree constructed from of a breadth-first-search + starting at source. + + Parameters + ---------- + G : NetworkX graph + + source : node + Specify starting node for breadth-first search + + reverse : bool, optional + If True traverse a directed graph in the reverse direction + + depth_limit : int, optional(default=len(G)) + Specify the maximum search depth + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + T: NetworkX DiGraph + An oriented tree + + Examples + -------- + >>> G = nx.path_graph(3) + >>> list(nx.bfs_tree(G, 1).edges()) + [(1, 0), (1, 2)] + >>> H = nx.Graph() + >>> nx.add_path(H, [0, 1, 2, 3, 4, 5, 6]) + >>> nx.add_path(H, [2, 7, 8, 9, 10]) + >>> sorted(list(nx.bfs_tree(H, source=3, depth_limit=3).edges())) + [(1, 0), (2, 1), (2, 7), (3, 2), (3, 4), (4, 5), (5, 6), (7, 8)] + + + Notes + ----- + Based on http://www.ics.uci.edu/~eppstein/PADS/BFS.py + by D. Eppstein, July 2004. The modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited-search`_". + + .. _Depth-limited-search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + dfs_tree + bfs_edges + edge_bfs + """ + T = nx.DiGraph() + T.add_node(source) + edges_gen = bfs_edges( + G, + source, + reverse=reverse, + depth_limit=depth_limit, + sort_neighbors=sort_neighbors, + ) + T.add_edges_from(edges_gen) + return T + + +@nx._dispatchable +def bfs_predecessors(G, source, depth_limit=None, sort_neighbors=None): + """Returns an iterator of predecessors in breadth-first-search from source. + + Parameters + ---------- + G : NetworkX graph + + source : node + Specify starting node for breadth-first search + + depth_limit : int, optional(default=len(G)) + Specify the maximum search depth + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + pred: iterator + (node, predecessor) iterator where `predecessor` is the predecessor of + `node` in a breadth first search starting from `source`. + + Examples + -------- + >>> G = nx.path_graph(3) + >>> dict(nx.bfs_predecessors(G, 0)) + {1: 0, 2: 1} + >>> H = nx.Graph() + >>> H.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + >>> dict(nx.bfs_predecessors(H, 0)) + {1: 0, 2: 0, 3: 1, 4: 1, 5: 2, 6: 2} + >>> M = nx.Graph() + >>> nx.add_path(M, [0, 1, 2, 3, 4, 5, 6]) + >>> nx.add_path(M, [2, 7, 8, 9, 10]) + >>> sorted(nx.bfs_predecessors(M, source=1, depth_limit=3)) + [(0, 1), (2, 1), (3, 2), (4, 3), (7, 2), (8, 7)] + >>> N = nx.DiGraph() + >>> nx.add_path(N, [0, 1, 2, 3, 4, 7]) + >>> nx.add_path(N, [3, 5, 6, 7]) + >>> sorted(nx.bfs_predecessors(N, source=2)) + [(3, 2), (4, 3), (5, 3), (6, 5), (7, 4)] + + Notes + ----- + Based on http://www.ics.uci.edu/~eppstein/PADS/BFS.py + by D. Eppstein, July 2004. The modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited-search`_". + + .. _Depth-limited-search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + bfs_tree + bfs_edges + edge_bfs + """ + for s, t in bfs_edges( + G, source, depth_limit=depth_limit, sort_neighbors=sort_neighbors + ): + yield (t, s) + + +@nx._dispatchable +def bfs_successors(G, source, depth_limit=None, sort_neighbors=None): + """Returns an iterator of successors in breadth-first-search from source. + + Parameters + ---------- + G : NetworkX graph + + source : node + Specify starting node for breadth-first search + + depth_limit : int, optional(default=len(G)) + Specify the maximum search depth + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + succ: iterator + (node, successors) iterator where `successors` is the non-empty list of + successors of `node` in a breadth first search from `source`. + To appear in the iterator, `node` must have successors. + + Examples + -------- + >>> G = nx.path_graph(3) + >>> dict(nx.bfs_successors(G, 0)) + {0: [1], 1: [2]} + >>> H = nx.Graph() + >>> H.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + >>> dict(nx.bfs_successors(H, 0)) + {0: [1, 2], 1: [3, 4], 2: [5, 6]} + >>> G = nx.Graph() + >>> nx.add_path(G, [0, 1, 2, 3, 4, 5, 6]) + >>> nx.add_path(G, [2, 7, 8, 9, 10]) + >>> dict(nx.bfs_successors(G, source=1, depth_limit=3)) + {1: [0, 2], 2: [3, 7], 3: [4], 7: [8]} + >>> G = nx.DiGraph() + >>> nx.add_path(G, [0, 1, 2, 3, 4, 5]) + >>> dict(nx.bfs_successors(G, source=3)) + {3: [4], 4: [5]} + + Notes + ----- + Based on http://www.ics.uci.edu/~eppstein/PADS/BFS.py + by D. Eppstein, July 2004.The modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited-search`_". + + .. _Depth-limited-search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + bfs_tree + bfs_edges + edge_bfs + """ + parent = source + children = [] + for p, c in bfs_edges( + G, source, depth_limit=depth_limit, sort_neighbors=sort_neighbors + ): + if p == parent: + children.append(c) + continue + yield (parent, children) + children = [c] + parent = p + yield (parent, children) + + +@nx._dispatchable +def bfs_layers(G, sources): + """Returns an iterator of all the layers in breadth-first search traversal. + + Parameters + ---------- + G : NetworkX graph + A graph over which to find the layers using breadth-first search. + + sources : node in `G` or iterable of nodes in `G` + Specify starting nodes for single source or multiple sources + breadth-first search. + + Yields + ------ + layer : list of nodes + Yields list of nodes at the same distance from `sources`. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> dict(enumerate(nx.bfs_layers(G, [0, 4]))) + {0: [0, 4], 1: [1, 3], 2: [2]} + >>> H = nx.Graph() + >>> H.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + >>> dict(enumerate(nx.bfs_layers(H, [1]))) + {0: [1], 1: [0, 3, 4], 2: [2], 3: [5, 6]} + >>> dict(enumerate(nx.bfs_layers(H, [1, 6]))) + {0: [1, 6], 1: [0, 3, 4, 2], 2: [5]} + """ + if sources in G: + sources = [sources] + + visited = set(sources) + current_layer = list(visited) + + for source in current_layer: + if source not in G: + raise nx.NetworkXError(f"The node {source} is not in the graph.") + + # this is basically BFS, except that the current layer only stores the nodes at + # same distance from sources at each iteration + while current_layer: + yield current_layer + next_layer = [] + for node in current_layer: + for child in G[node]: + if child not in visited: + visited.add(child) + next_layer.append(child) + current_layer = next_layer + + +REVERSE_EDGE = "reverse" +TREE_EDGE = "tree" +FORWARD_EDGE = "forward" +LEVEL_EDGE = "level" + + +@nx._dispatchable +def bfs_labeled_edges(G, sources): + """Iterate over edges in a breadth-first search (BFS) labeled by type. + + We generate triple of the form (*u*, *v*, *d*), where (*u*, *v*) is the + edge being explored in the breadth-first search and *d* is one of the + strings 'tree', 'forward', 'level', or 'reverse'. A 'tree' edge is one in + which *v* is first discovered and placed into the layer below *u*. A + 'forward' edge is one in which *u* is on the layer above *v* and *v* has + already been discovered. A 'level' edge is one in which both *u* and *v* + occur on the same layer. A 'reverse' edge is one in which *u* is on a layer + below *v*. + + We emit each edge exactly once. In an undirected graph, 'reverse' edges do + not occur, because each is discovered either as a 'tree' or 'forward' edge. + + Parameters + ---------- + G : NetworkX graph + A graph over which to find the layers using breadth-first search. + + sources : node in `G` or list of nodes in `G` + Starting nodes for single source or multiple sources breadth-first search + + Yields + ------ + edges: generator + A generator of triples (*u*, *v*, *d*) where (*u*, *v*) is the edge being + explored and *d* is described above. + + Examples + -------- + >>> G = nx.cycle_graph(4, create_using=nx.DiGraph) + >>> list(nx.bfs_labeled_edges(G, 0)) + [(0, 1, 'tree'), (1, 2, 'tree'), (2, 3, 'tree'), (3, 0, 'reverse')] + >>> G = nx.complete_graph(3) + >>> list(nx.bfs_labeled_edges(G, 0)) + [(0, 1, 'tree'), (0, 2, 'tree'), (1, 2, 'level')] + >>> list(nx.bfs_labeled_edges(G, [0, 1])) + [(0, 1, 'level'), (0, 2, 'tree'), (1, 2, 'forward')] + """ + if sources in G: + sources = [sources] + + neighbors = G._adj + directed = G.is_directed() + visited = set() + visit = visited.discard if directed else visited.add + # We use visited in a negative sense, so the visited set stays empty for the + # directed case and level edges are reported on their first occurrence in + # the undirected case. Note our use of visited.discard -- this is built-in + # thus somewhat faster than a python-defined def nop(x): pass + depth = {s: 0 for s in sources} + queue = deque(depth.items()) + push = queue.append + pop = queue.popleft + while queue: + u, du = pop() + for v in neighbors[u]: + if v not in depth: + depth[v] = dv = du + 1 + push((v, dv)) + yield u, v, TREE_EDGE + else: + dv = depth[v] + if du == dv: + if v not in visited: + yield u, v, LEVEL_EDGE + elif du < dv: + yield u, v, FORWARD_EDGE + elif directed: + yield u, v, REVERSE_EDGE + visit(u) + + +@nx._dispatchable +def descendants_at_distance(G, source, distance): + """Returns all nodes at a fixed `distance` from `source` in `G`. + + Parameters + ---------- + G : NetworkX graph + A graph + source : node in `G` + distance : the distance of the wanted nodes from `source` + + Returns + ------- + set() + The descendants of `source` in `G` at the given `distance` from `source` + + Examples + -------- + >>> G = nx.path_graph(5) + >>> nx.descendants_at_distance(G, 2, 2) + {0, 4} + >>> H = nx.DiGraph() + >>> H.add_edges_from([(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]) + >>> nx.descendants_at_distance(H, 0, 2) + {3, 4, 5, 6} + >>> nx.descendants_at_distance(H, 5, 0) + {5} + >>> nx.descendants_at_distance(H, 5, 1) + set() + """ + if source not in G: + raise nx.NetworkXError(f"The node {source} is not in the graph.") + + bfs_generator = nx.bfs_layers(G, source) + for i, layer in enumerate(bfs_generator): + if i == distance: + return set(layer) + return set() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/depth_first_search.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/depth_first_search.py new file mode 100644 index 0000000000000000000000000000000000000000..5bac5ecfd1cbefcba5707cac2885ef32987ee98b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/depth_first_search.py @@ -0,0 +1,529 @@ +"""Basic algorithms for depth-first searching the nodes of a graph.""" + +from collections import defaultdict + +import networkx as nx + +__all__ = [ + "dfs_edges", + "dfs_tree", + "dfs_predecessors", + "dfs_successors", + "dfs_preorder_nodes", + "dfs_postorder_nodes", + "dfs_labeled_edges", +] + + +@nx._dispatchable +def dfs_edges(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Iterate over edges in a depth-first-search (DFS). + + Perform a depth-first-search over the nodes of `G` and yield + the edges in order. This may not generate all edges in `G` + (see `~networkx.algorithms.traversal.edgedfs.edge_dfs`). + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search and yield edges in + the component reachable from source. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Yields + ------ + edge: 2-tuple of nodes + Yields edges resulting from the depth-first-search. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> list(nx.dfs_edges(G, source=0)) + [(0, 1), (1, 2), (2, 3), (3, 4)] + >>> list(nx.dfs_edges(G, source=0, depth_limit=2)) + [(0, 1), (1, 2)] + + Notes + ----- + If a source is not specified then a source is chosen arbitrarily and + repeatedly until all components in the graph are searched. + + The implementation of this function is adapted from David Eppstein's + depth-first search function in PADS [1]_, with modifications + to allow depth limits based on the Wikipedia article + "Depth-limited search" [2]_. + + See Also + -------- + dfs_preorder_nodes + dfs_postorder_nodes + dfs_labeled_edges + :func:`~networkx.algorithms.traversal.edgedfs.edge_dfs` + :func:`~networkx.algorithms.traversal.breadth_first_search.bfs_edges` + + References + ---------- + .. [1] http://www.ics.uci.edu/~eppstein/PADS + .. [2] https://en.wikipedia.org/wiki/Depth-limited_search + """ + if source is None: + # edges for all components + nodes = G + else: + # edges for components with source + nodes = [source] + if depth_limit is None: + depth_limit = len(G) + + get_children = ( + G.neighbors + if sort_neighbors is None + else lambda n: iter(sort_neighbors(G.neighbors(n))) + ) + + visited = set() + for start in nodes: + if start in visited: + continue + visited.add(start) + stack = [(start, get_children(start))] + depth_now = 1 + while stack: + parent, children = stack[-1] + for child in children: + if child not in visited: + yield parent, child + visited.add(child) + if depth_now < depth_limit: + stack.append((child, get_children(child))) + depth_now += 1 + break + else: + stack.pop() + depth_now -= 1 + + +@nx._dispatchable(returns_graph=True) +def dfs_tree(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Returns oriented tree constructed from a depth-first-search from source. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + T : NetworkX DiGraph + An oriented tree + + Examples + -------- + >>> G = nx.path_graph(5) + >>> T = nx.dfs_tree(G, source=0, depth_limit=2) + >>> list(T.edges()) + [(0, 1), (1, 2)] + >>> T = nx.dfs_tree(G, source=0) + >>> list(T.edges()) + [(0, 1), (1, 2), (2, 3), (3, 4)] + + See Also + -------- + dfs_preorder_nodes + dfs_postorder_nodes + dfs_labeled_edges + :func:`~networkx.algorithms.traversal.edgedfs.edge_dfs` + :func:`~networkx.algorithms.traversal.breadth_first_search.bfs_tree` + """ + T = nx.DiGraph() + if source is None: + T.add_nodes_from(G) + else: + T.add_node(source) + T.add_edges_from(dfs_edges(G, source, depth_limit, sort_neighbors=sort_neighbors)) + return T + + +@nx._dispatchable +def dfs_predecessors(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Returns dictionary of predecessors in depth-first-search from source. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search. + Note that you will get predecessors for all nodes in the + component containing `source`. This input only specifies + where the DFS starts. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + pred: dict + A dictionary with nodes as keys and predecessor nodes as values. + + Examples + -------- + >>> G = nx.path_graph(4) + >>> nx.dfs_predecessors(G, source=0) + {1: 0, 2: 1, 3: 2} + >>> nx.dfs_predecessors(G, source=0, depth_limit=2) + {1: 0, 2: 1} + + Notes + ----- + If a source is not specified then a source is chosen arbitrarily and + repeatedly until all components in the graph are searched. + + The implementation of this function is adapted from David Eppstein's + depth-first search function in `PADS`_, with modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited search`_". + + .. _PADS: http://www.ics.uci.edu/~eppstein/PADS + .. _Depth-limited search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + dfs_preorder_nodes + dfs_postorder_nodes + dfs_labeled_edges + :func:`~networkx.algorithms.traversal.edgedfs.edge_dfs` + :func:`~networkx.algorithms.traversal.breadth_first_search.bfs_tree` + """ + return { + t: s + for s, t in dfs_edges(G, source, depth_limit, sort_neighbors=sort_neighbors) + } + + +@nx._dispatchable +def dfs_successors(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Returns dictionary of successors in depth-first-search from source. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search. + Note that you will get successors for all nodes in the + component containing `source`. This input only specifies + where the DFS starts. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + succ: dict + A dictionary with nodes as keys and list of successor nodes as values. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> nx.dfs_successors(G, source=0) + {0: [1], 1: [2], 2: [3], 3: [4]} + >>> nx.dfs_successors(G, source=0, depth_limit=2) + {0: [1], 1: [2]} + + Notes + ----- + If a source is not specified then a source is chosen arbitrarily and + repeatedly until all components in the graph are searched. + + The implementation of this function is adapted from David Eppstein's + depth-first search function in `PADS`_, with modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited search`_". + + .. _PADS: http://www.ics.uci.edu/~eppstein/PADS + .. _Depth-limited search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + dfs_preorder_nodes + dfs_postorder_nodes + dfs_labeled_edges + :func:`~networkx.algorithms.traversal.edgedfs.edge_dfs` + :func:`~networkx.algorithms.traversal.breadth_first_search.bfs_tree` + """ + d = defaultdict(list) + for s, t in dfs_edges( + G, + source=source, + depth_limit=depth_limit, + sort_neighbors=sort_neighbors, + ): + d[s].append(t) + return dict(d) + + +@nx._dispatchable +def dfs_postorder_nodes(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Generate nodes in a depth-first-search post-ordering starting at source. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + nodes: generator + A generator of nodes in a depth-first-search post-ordering. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> list(nx.dfs_postorder_nodes(G, source=0)) + [4, 3, 2, 1, 0] + >>> list(nx.dfs_postorder_nodes(G, source=0, depth_limit=2)) + [1, 0] + + Notes + ----- + If a source is not specified then a source is chosen arbitrarily and + repeatedly until all components in the graph are searched. + + The implementation of this function is adapted from David Eppstein's + depth-first search function in `PADS`_, with modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited search`_". + + .. _PADS: http://www.ics.uci.edu/~eppstein/PADS + .. _Depth-limited search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + dfs_edges + dfs_preorder_nodes + dfs_labeled_edges + :func:`~networkx.algorithms.traversal.edgedfs.edge_dfs` + :func:`~networkx.algorithms.traversal.breadth_first_search.bfs_tree` + """ + edges = nx.dfs_labeled_edges( + G, source=source, depth_limit=depth_limit, sort_neighbors=sort_neighbors + ) + return (v for u, v, d in edges if d == "reverse") + + +@nx._dispatchable +def dfs_preorder_nodes(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Generate nodes in a depth-first-search pre-ordering starting at source. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search and return nodes in + the component reachable from source. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + nodes: generator + A generator of nodes in a depth-first-search pre-ordering. + + Examples + -------- + >>> G = nx.path_graph(5) + >>> list(nx.dfs_preorder_nodes(G, source=0)) + [0, 1, 2, 3, 4] + >>> list(nx.dfs_preorder_nodes(G, source=0, depth_limit=2)) + [0, 1, 2] + + Notes + ----- + If a source is not specified then a source is chosen arbitrarily and + repeatedly until all components in the graph are searched. + + The implementation of this function is adapted from David Eppstein's + depth-first search function in `PADS`_, with modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited search`_". + + .. _PADS: http://www.ics.uci.edu/~eppstein/PADS + .. _Depth-limited search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + dfs_edges + dfs_postorder_nodes + dfs_labeled_edges + :func:`~networkx.algorithms.traversal.breadth_first_search.bfs_edges` + """ + edges = nx.dfs_labeled_edges( + G, source=source, depth_limit=depth_limit, sort_neighbors=sort_neighbors + ) + return (v for u, v, d in edges if d == "forward") + + +@nx._dispatchable +def dfs_labeled_edges(G, source=None, depth_limit=None, *, sort_neighbors=None): + """Iterate over edges in a depth-first-search (DFS) labeled by type. + + Parameters + ---------- + G : NetworkX graph + + source : node, optional + Specify starting node for depth-first search and return edges in + the component reachable from source. + + depth_limit : int, optional (default=len(G)) + Specify the maximum search depth. + + sort_neighbors : function (default=None) + A function that takes an iterator over nodes as the input, and + returns an iterable of the same nodes with a custom ordering. + For example, `sorted` will sort the nodes in increasing order. + + Returns + ------- + edges: generator + A generator of triples of the form (*u*, *v*, *d*), where (*u*, + *v*) is the edge being explored in the depth-first search and *d* + is one of the strings 'forward', 'nontree', 'reverse', or 'reverse-depth_limit'. + A 'forward' edge is one in which *u* has been visited but *v* has + not. A 'nontree' edge is one in which both *u* and *v* have been + visited but the edge is not in the DFS tree. A 'reverse' edge is + one in which both *u* and *v* have been visited and the edge is in + the DFS tree. When the `depth_limit` is reached via a 'forward' edge, + a 'reverse' edge is immediately generated rather than the subtree + being explored. To indicate this flavor of 'reverse' edge, the string + yielded is 'reverse-depth_limit'. + + Examples + -------- + + The labels reveal the complete transcript of the depth-first search + algorithm in more detail than, for example, :func:`dfs_edges`:: + + >>> from pprint import pprint + >>> + >>> G = nx.DiGraph([(0, 1), (1, 2), (2, 1)]) + >>> pprint(list(nx.dfs_labeled_edges(G, source=0))) + [(0, 0, 'forward'), + (0, 1, 'forward'), + (1, 2, 'forward'), + (2, 1, 'nontree'), + (1, 2, 'reverse'), + (0, 1, 'reverse'), + (0, 0, 'reverse')] + + Notes + ----- + If a source is not specified then a source is chosen arbitrarily and + repeatedly until all components in the graph are searched. + + The implementation of this function is adapted from David Eppstein's + depth-first search function in `PADS`_, with modifications + to allow depth limits based on the Wikipedia article + "`Depth-limited search`_". + + .. _PADS: http://www.ics.uci.edu/~eppstein/PADS + .. _Depth-limited search: https://en.wikipedia.org/wiki/Depth-limited_search + + See Also + -------- + dfs_edges + dfs_preorder_nodes + dfs_postorder_nodes + """ + # Based on http://www.ics.uci.edu/~eppstein/PADS/DFS.py + # by D. Eppstein, July 2004. + if source is None: + # edges for all components + nodes = G + else: + # edges for components with source + nodes = [source] + if depth_limit is None: + depth_limit = len(G) + + get_children = ( + G.neighbors + if sort_neighbors is None + else lambda n: iter(sort_neighbors(G.neighbors(n))) + ) + + visited = set() + for start in nodes: + if start in visited: + continue + yield start, start, "forward" + visited.add(start) + stack = [(start, get_children(start))] + depth_now = 1 + while stack: + parent, children = stack[-1] + for child in children: + if child in visited: + yield parent, child, "nontree" + else: + yield parent, child, "forward" + visited.add(child) + if depth_now < depth_limit: + stack.append((child, iter(get_children(child)))) + depth_now += 1 + break + else: + yield parent, child, "reverse-depth_limit" + else: + stack.pop() + depth_now -= 1 + if stack: + yield stack[-1][0], parent, "reverse" + yield start, start, "reverse" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/edgebfs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/edgebfs.py new file mode 100644 index 0000000000000000000000000000000000000000..2f468e10b3ebbf43888ac6761a473bbc2e1b106b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/edgebfs.py @@ -0,0 +1,185 @@ +""" +============================= +Breadth First Search on Edges +============================= + +Algorithms for a breadth-first traversal of edges in a graph. + +""" + +from collections import deque + +import networkx as nx + +FORWARD = "forward" +REVERSE = "reverse" + +__all__ = ["edge_bfs"] + + +@nx._dispatchable +def edge_bfs(G, source=None, orientation=None): + """A directed, breadth-first-search of edges in `G`, beginning at `source`. + + Yield the edges of G in a breadth-first-search order continuing until + all edges are generated. + + Parameters + ---------- + G : graph + A directed/undirected graph/multigraph. + + source : node, list of nodes + The node from which the traversal begins. If None, then a source + is chosen arbitrarily and repeatedly until all edges from each node in + the graph are searched. + + orientation : None | 'original' | 'reverse' | 'ignore' (default: None) + For directed graphs and directed multigraphs, edge traversals need not + respect the original orientation of the edges. + When set to 'reverse' every edge is traversed in the reverse direction. + When set to 'ignore', every edge is treated as undirected. + When set to 'original', every edge is treated as directed. + In all three cases, the yielded edge tuples add a last entry to + indicate the direction in which that edge was traversed. + If orientation is None, the yielded edge has no direction indicated. + The direction is respected, but not reported. + + Yields + ------ + edge : directed edge + A directed edge indicating the path taken by the breadth-first-search. + For graphs, `edge` is of the form `(u, v)` where `u` and `v` + are the tail and head of the edge as determined by the traversal. + For multigraphs, `edge` is of the form `(u, v, key)`, where `key` is + the key of the edge. When the graph is directed, then `u` and `v` + are always in the order of the actual directed edge. + If orientation is not None then the edge tuple is extended to include + the direction of traversal ('forward' or 'reverse') on that edge. + + Examples + -------- + >>> from pprint import pprint + >>> nodes = [0, 1, 2, 3] + >>> edges = [(0, 1), (1, 0), (1, 0), (2, 0), (2, 1), (3, 1)] + + >>> list(nx.edge_bfs(nx.Graph(edges), nodes)) + [(0, 1), (0, 2), (1, 2), (1, 3)] + + >>> list(nx.edge_bfs(nx.DiGraph(edges), nodes)) + [(0, 1), (1, 0), (2, 0), (2, 1), (3, 1)] + + >>> list(nx.edge_bfs(nx.MultiGraph(edges), nodes)) + [(0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 2, 0), (1, 2, 0), (1, 3, 0)] + + >>> list(nx.edge_bfs(nx.MultiDiGraph(edges), nodes)) + [(0, 1, 0), (1, 0, 0), (1, 0, 1), (2, 0, 0), (2, 1, 0), (3, 1, 0)] + + >>> list(nx.edge_bfs(nx.DiGraph(edges), nodes, orientation="ignore")) + [(0, 1, 'forward'), (1, 0, 'reverse'), (2, 0, 'reverse'), (2, 1, 'reverse'), (3, 1, 'reverse')] + + >>> elist = list(nx.edge_bfs(nx.MultiDiGraph(edges), nodes, orientation="ignore")) + >>> pprint(elist) + [(0, 1, 0, 'forward'), + (1, 0, 0, 'reverse'), + (1, 0, 1, 'reverse'), + (2, 0, 0, 'reverse'), + (2, 1, 0, 'reverse'), + (3, 1, 0, 'reverse')] + + Notes + ----- + The goal of this function is to visit edges. It differs from the more + familiar breadth-first-search of nodes, as provided by + :func:`networkx.algorithms.traversal.breadth_first_search.bfs_edges`, in + that it does not stop once every node has been visited. In a directed graph + with edges [(0, 1), (1, 2), (2, 1)], the edge (2, 1) would not be visited + if not for the functionality provided by this function. + + The naming of this function is very similar to bfs_edges. The difference + is that 'edge_bfs' yields edges even if they extend back to an already + explored node while 'bfs_edges' yields the edges of the tree that results + from a breadth-first-search (BFS) so no edges are reported if they extend + to already explored nodes. That means 'edge_bfs' reports all edges while + 'bfs_edges' only report those traversed by a node-based BFS. Yet another + description is that 'bfs_edges' reports the edges traversed during BFS + while 'edge_bfs' reports all edges in the order they are explored. + + See Also + -------- + bfs_edges + bfs_tree + edge_dfs + + """ + nodes = list(G.nbunch_iter(source)) + if not nodes: + return + + directed = G.is_directed() + kwds = {"data": False} + if G.is_multigraph() is True: + kwds["keys"] = True + + # set up edge lookup + if orientation is None: + + def edges_from(node): + return iter(G.edges(node, **kwds)) + + elif not directed or orientation == "original": + + def edges_from(node): + for e in G.edges(node, **kwds): + yield e + (FORWARD,) + + elif orientation == "reverse": + + def edges_from(node): + for e in G.in_edges(node, **kwds): + yield e + (REVERSE,) + + elif orientation == "ignore": + + def edges_from(node): + for e in G.edges(node, **kwds): + yield e + (FORWARD,) + for e in G.in_edges(node, **kwds): + yield e + (REVERSE,) + + else: + raise nx.NetworkXError("invalid orientation argument.") + + if directed: + neighbors = G.successors + + def edge_id(edge): + # remove direction indicator + return edge[:-1] if orientation is not None else edge + + else: + neighbors = G.neighbors + + def edge_id(edge): + return (frozenset(edge[:2]),) + edge[2:] + + check_reverse = directed and orientation in ("reverse", "ignore") + + # start BFS + visited_nodes = set(nodes) + visited_edges = set() + queue = deque([(n, edges_from(n)) for n in nodes]) + while queue: + parent, children_edges = queue.popleft() + for edge in children_edges: + if check_reverse and edge[-1] == REVERSE: + child = edge[0] + else: + child = edge[1] + if child not in visited_nodes: + visited_nodes.add(child) + queue.append((child, edges_from(child))) + edgeid = edge_id(edge) + if edgeid not in visited_edges: + visited_edges.add(edgeid) + yield edge diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/edgedfs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/edgedfs.py new file mode 100644 index 0000000000000000000000000000000000000000..f8173105ddb380f42accc31af87e4190b6af6ce1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/traversal/edgedfs.py @@ -0,0 +1,182 @@ +""" +=========================== +Depth First Search on Edges +=========================== + +Algorithms for a depth-first traversal of edges in a graph. + +""" + +import networkx as nx + +FORWARD = "forward" +REVERSE = "reverse" + +__all__ = ["edge_dfs"] + + +@nx._dispatchable +def edge_dfs(G, source=None, orientation=None): + """A directed, depth-first-search of edges in `G`, beginning at `source`. + + Yield the edges of G in a depth-first-search order continuing until + all edges are generated. + + Parameters + ---------- + G : graph + A directed/undirected graph/multigraph. + + source : node, list of nodes + The node from which the traversal begins. If None, then a source + is chosen arbitrarily and repeatedly until all edges from each node in + the graph are searched. + + orientation : None | 'original' | 'reverse' | 'ignore' (default: None) + For directed graphs and directed multigraphs, edge traversals need not + respect the original orientation of the edges. + When set to 'reverse' every edge is traversed in the reverse direction. + When set to 'ignore', every edge is treated as undirected. + When set to 'original', every edge is treated as directed. + In all three cases, the yielded edge tuples add a last entry to + indicate the direction in which that edge was traversed. + If orientation is None, the yielded edge has no direction indicated. + The direction is respected, but not reported. + + Yields + ------ + edge : directed edge + A directed edge indicating the path taken by the depth-first traversal. + For graphs, `edge` is of the form `(u, v)` where `u` and `v` + are the tail and head of the edge as determined by the traversal. + For multigraphs, `edge` is of the form `(u, v, key)`, where `key` is + the key of the edge. When the graph is directed, then `u` and `v` + are always in the order of the actual directed edge. + If orientation is not None then the edge tuple is extended to include + the direction of traversal ('forward' or 'reverse') on that edge. + + Examples + -------- + >>> from pprint import pprint + >>> nodes = [0, 1, 2, 3] + >>> edges = [(0, 1), (1, 0), (1, 0), (2, 1), (3, 1)] + + >>> list(nx.edge_dfs(nx.Graph(edges), nodes)) + [(0, 1), (1, 2), (1, 3)] + + >>> list(nx.edge_dfs(nx.DiGraph(edges), nodes)) + [(0, 1), (1, 0), (2, 1), (3, 1)] + + >>> list(nx.edge_dfs(nx.MultiGraph(edges), nodes)) + [(0, 1, 0), (1, 0, 1), (0, 1, 2), (1, 2, 0), (1, 3, 0)] + + >>> list(nx.edge_dfs(nx.MultiDiGraph(edges), nodes)) + [(0, 1, 0), (1, 0, 0), (1, 0, 1), (2, 1, 0), (3, 1, 0)] + + >>> list(nx.edge_dfs(nx.DiGraph(edges), nodes, orientation="ignore")) + [(0, 1, 'forward'), (1, 0, 'forward'), (2, 1, 'reverse'), (3, 1, 'reverse')] + + >>> elist = list(nx.edge_dfs(nx.MultiDiGraph(edges), nodes, orientation="ignore")) + >>> pprint(elist) + [(0, 1, 0, 'forward'), + (1, 0, 0, 'forward'), + (1, 0, 1, 'reverse'), + (2, 1, 0, 'reverse'), + (3, 1, 0, 'reverse')] + + Notes + ----- + The goal of this function is to visit edges. It differs from the more + familiar depth-first traversal of nodes, as provided by + :func:`~networkx.algorithms.traversal.depth_first_search.dfs_edges`, in + that it does not stop once every node has been visited. In a directed graph + with edges [(0, 1), (1, 2), (2, 1)], the edge (2, 1) would not be visited + if not for the functionality provided by this function. + + See Also + -------- + :func:`~networkx.algorithms.traversal.depth_first_search.dfs_edges` + + """ + nodes = list(G.nbunch_iter(source)) + if not nodes: + return + + directed = G.is_directed() + kwds = {"data": False} + if G.is_multigraph() is True: + kwds["keys"] = True + + # set up edge lookup + if orientation is None: + + def edges_from(node): + return iter(G.edges(node, **kwds)) + + elif not directed or orientation == "original": + + def edges_from(node): + for e in G.edges(node, **kwds): + yield e + (FORWARD,) + + elif orientation == "reverse": + + def edges_from(node): + for e in G.in_edges(node, **kwds): + yield e + (REVERSE,) + + elif orientation == "ignore": + + def edges_from(node): + for e in G.edges(node, **kwds): + yield e + (FORWARD,) + for e in G.in_edges(node, **kwds): + yield e + (REVERSE,) + + else: + raise nx.NetworkXError("invalid orientation argument.") + + # set up formation of edge_id to easily look up if edge already returned + if directed: + + def edge_id(edge): + # remove direction indicator + return edge[:-1] if orientation is not None else edge + + else: + + def edge_id(edge): + # single id for undirected requires frozenset on nodes + return (frozenset(edge[:2]),) + edge[2:] + + # Basic setup + check_reverse = directed and orientation in ("reverse", "ignore") + + visited_edges = set() + visited_nodes = set() + edges = {} + + # start DFS + for start_node in nodes: + stack = [start_node] + while stack: + current_node = stack[-1] + if current_node not in visited_nodes: + edges[current_node] = edges_from(current_node) + visited_nodes.add(current_node) + + try: + edge = next(edges[current_node]) + except StopIteration: + # No more edges from the current node. + stack.pop() + else: + edgeid = edge_id(edge) + if edgeid not in visited_edges: + visited_edges.add(edgeid) + # Mark the traversed "to" node as to-be-explored. + if check_reverse and edge[-1] == REVERSE: + stack.append(edge[0]) + else: + stack.append(edge[1]) + yield edge diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6b6da30e2da0924e96987eb75de268f1b0278f0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/__init__.py @@ -0,0 +1,7 @@ +from .branchings import * +from .coding import * +from .distance_measures import * +from .mst import * +from .recognition import * +from .operations import * +from .decomposition import * diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/branchings.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/branchings.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9c7cf1189d341577a81501e5ca6760ed73a58c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/branchings.py @@ -0,0 +1,1042 @@ +""" +Algorithms for finding optimum branchings and spanning arborescences. + +This implementation is based on: + + J. Edmonds, Optimum branchings, J. Res. Natl. Bur. Standards 71B (1967), + 233–240. URL: http://archive.org/details/jresv71Bn4p233 + +""" + +# TODO: Implement method from Gabow, Galil, Spence and Tarjan: +# +# @article{ +# year={1986}, +# issn={0209-9683}, +# journal={Combinatorica}, +# volume={6}, +# number={2}, +# doi={10.1007/BF02579168}, +# title={Efficient algorithms for finding minimum spanning trees in +# undirected and directed graphs}, +# url={https://doi.org/10.1007/BF02579168}, +# publisher={Springer-Verlag}, +# keywords={68 B 15; 68 C 05}, +# author={Gabow, Harold N. and Galil, Zvi and Spencer, Thomas and Tarjan, +# Robert E.}, +# pages={109-122}, +# language={English} +# } +import string +from dataclasses import dataclass, field +from operator import itemgetter +from queue import PriorityQueue + +import networkx as nx +from networkx.utils import py_random_state + +from .recognition import is_arborescence, is_branching + +__all__ = [ + "branching_weight", + "greedy_branching", + "maximum_branching", + "minimum_branching", + "minimal_branching", + "maximum_spanning_arborescence", + "minimum_spanning_arborescence", + "ArborescenceIterator", +] + +KINDS = {"max", "min"} + +STYLES = { + "branching": "branching", + "arborescence": "arborescence", + "spanning arborescence": "arborescence", +} + +INF = float("inf") + + +@py_random_state(1) +def random_string(L=15, seed=None): + return "".join([seed.choice(string.ascii_letters) for n in range(L)]) + + +def _min_weight(weight): + return -weight + + +def _max_weight(weight): + return weight + + +@nx._dispatchable(edge_attrs={"attr": "default"}) +def branching_weight(G, attr="weight", default=1): + """ + Returns the total weight of a branching. + + You must access this function through the networkx.algorithms.tree module. + + Parameters + ---------- + G : DiGraph + The directed graph. + attr : str + The attribute to use as weights. If None, then each edge will be + treated equally with a weight of 1. + default : float + When `attr` is not None, then if an edge does not have that attribute, + `default` specifies what value it should take. + + Returns + ------- + weight: int or float + The total weight of the branching. + + Examples + -------- + >>> G = nx.DiGraph() + >>> G.add_weighted_edges_from([(0, 1, 2), (1, 2, 4), (2, 3, 3), (3, 4, 2)]) + >>> nx.tree.branching_weight(G) + 11 + + """ + return sum(edge[2].get(attr, default) for edge in G.edges(data=True)) + + +@py_random_state(4) +@nx._dispatchable(edge_attrs={"attr": "default"}, returns_graph=True) +def greedy_branching(G, attr="weight", default=1, kind="max", seed=None): + """ + Returns a branching obtained through a greedy algorithm. + + This algorithm is wrong, and cannot give a proper optimal branching. + However, we include it for pedagogical reasons, as it can be helpful to + see what its outputs are. + + The output is a branching, and possibly, a spanning arborescence. However, + it is not guaranteed to be optimal in either case. + + Parameters + ---------- + G : DiGraph + The directed graph to scan. + attr : str + The attribute to use as weights. If None, then each edge will be + treated equally with a weight of 1. + default : float + When `attr` is not None, then if an edge does not have that attribute, + `default` specifies what value it should take. + kind : str + The type of optimum to search for: 'min' or 'max' greedy branching. + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + B : directed graph + The greedily obtained branching. + + """ + if kind not in KINDS: + raise nx.NetworkXException("Unknown value for `kind`.") + + if kind == "min": + reverse = False + else: + reverse = True + + if attr is None: + # Generate a random string the graph probably won't have. + attr = random_string(seed=seed) + + edges = [(u, v, data.get(attr, default)) for (u, v, data) in G.edges(data=True)] + + # We sort by weight, but also by nodes to normalize behavior across runs. + try: + edges.sort(key=itemgetter(2, 0, 1), reverse=reverse) + except TypeError: + # This will fail in Python 3.x if the nodes are of varying types. + # In that case, we use the arbitrary order. + edges.sort(key=itemgetter(2), reverse=reverse) + + # The branching begins with a forest of no edges. + B = nx.DiGraph() + B.add_nodes_from(G) + + # Now we add edges greedily so long we maintain the branching. + uf = nx.utils.UnionFind() + for i, (u, v, w) in enumerate(edges): + if uf[u] == uf[v]: + # Adding this edge would form a directed cycle. + continue + elif B.in_degree(v) == 1: + # The edge would increase the degree to be greater than one. + continue + else: + # If attr was None, then don't insert weights... + data = {} + if attr is not None: + data[attr] = w + B.add_edge(u, v, **data) + uf.union(u, v) + + return B + + +@nx._dispatchable(preserve_edge_attrs=True, returns_graph=True) +def maximum_branching( + G, + attr="weight", + default=1, + preserve_attrs=False, + partition=None, +): + ####################################### + ### Data Structure Helper Functions ### + ####################################### + + def edmonds_add_edge(G, edge_index, u, v, key, **d): + """ + Adds an edge to `G` while also updating the edge index. + + This algorithm requires the use of an external dictionary to track + the edge keys since it is possible that the source or destination + node of an edge will be changed and the default key-handling + capabilities of the MultiDiGraph class do not account for this. + + Parameters + ---------- + G : MultiDiGraph + The graph to insert an edge into. + edge_index : dict + A mapping from integers to the edges of the graph. + u : node + The source node of the new edge. + v : node + The destination node of the new edge. + key : int + The key to use from `edge_index`. + d : keyword arguments, optional + Other attributes to store on the new edge. + """ + + if key in edge_index: + uu, vv, _ = edge_index[key] + if (u != uu) or (v != vv): + raise Exception(f"Key {key!r} is already in use.") + + G.add_edge(u, v, key, **d) + edge_index[key] = (u, v, G.succ[u][v][key]) + + def edmonds_remove_node(G, edge_index, n): + """ + Remove a node from the graph, updating the edge index to match. + + Parameters + ---------- + G : MultiDiGraph + The graph to remove an edge from. + edge_index : dict + A mapping from integers to the edges of the graph. + n : node + The node to remove from `G`. + """ + keys = set() + for keydict in G.pred[n].values(): + keys.update(keydict) + for keydict in G.succ[n].values(): + keys.update(keydict) + + for key in keys: + del edge_index[key] + + G.remove_node(n) + + ####################### + ### Algorithm Setup ### + ####################### + + # Pick an attribute name that the original graph is unlikly to have + candidate_attr = "edmonds' secret candidate attribute" + new_node_base_name = "edmonds new node base name " + + G_original = G + G = nx.MultiDiGraph() + G.__networkx_cache__ = None # Disable caching + + # A dict to reliably track mutations to the edges using the key of the edge. + G_edge_index = {} + # Each edge is given an arbitrary numerical key + for key, (u, v, data) in enumerate(G_original.edges(data=True)): + d = {attr: data.get(attr, default)} + + if data.get(partition) is not None: + d[partition] = data.get(partition) + + if preserve_attrs: + for d_k, d_v in data.items(): + if d_k != attr: + d[d_k] = d_v + + edmonds_add_edge(G, G_edge_index, u, v, key, **d) + + level = 0 # Stores the number of contracted nodes + + # These are the buckets from the paper. + # + # In the paper, G^i are modified versions of the original graph. + # D^i and E^i are the nodes and edges of the maximal edges that are + # consistent with G^i. In this implementation, D^i and E^i are stored + # together as the graph B^i. We will have strictly more B^i then the + # paper will have. + # + # Note that the data in graphs and branchings are tuples with the graph as + # the first element and the edge index as the second. + B = nx.MultiDiGraph() + B_edge_index = {} + graphs = [] # G^i list + branchings = [] # B^i list + selected_nodes = set() # D^i bucket + uf = nx.utils.UnionFind() + + # A list of lists of edge indices. Each list is a circuit for graph G^i. + # Note the edge list is not required to be a circuit in G^0. + circuits = [] + + # Stores the index of the minimum edge in the circuit found in G^i and B^i. + # The ordering of the edges seems to preserver the weight ordering from + # G^0. So even if the circuit does not form a circuit in G^0, it is still + # true that the minimum edges in circuit G^0 (despite their weights being + # different) + minedge_circuit = [] + + ########################### + ### Algorithm Structure ### + ########################### + + # Each step listed in the algorithm is an inner function. Thus, the overall + # loop structure is: + # + # while True: + # step_I1() + # if cycle detected: + # step_I2() + # elif every node of G is in D and E is a branching: + # break + + ################################## + ### Algorithm Helper Functions ### + ################################## + + def edmonds_find_desired_edge(v): + """ + Find the edge directed towards v with maximal weight. + + If an edge partition exists in this graph, return the included + edge if it exists and never return any excluded edge. + + Note: There can only be one included edge for each vertex otherwise + the edge partition is empty. + + Parameters + ---------- + v : node + The node to search for the maximal weight incoming edge. + """ + edge = None + max_weight = -INF + for u, _, key, data in G.in_edges(v, data=True, keys=True): + # Skip excluded edges + if data.get(partition) == nx.EdgePartition.EXCLUDED: + continue + + new_weight = data[attr] + + # Return the included edge + if data.get(partition) == nx.EdgePartition.INCLUDED: + max_weight = new_weight + edge = (u, v, key, new_weight, data) + break + + # Find the best open edge + if new_weight > max_weight: + max_weight = new_weight + edge = (u, v, key, new_weight, data) + + return edge, max_weight + + def edmonds_step_I2(v, desired_edge, level): + """ + Perform step I2 from Edmonds' paper + + First, check if the last step I1 created a cycle. If it did not, do nothing. + If it did, store the cycle for later reference and contract it. + + Parameters + ---------- + v : node + The current node to consider + desired_edge : edge + The minimum desired edge to remove from the cycle. + level : int + The current level, i.e. the number of cycles that have already been removed. + """ + u = desired_edge[0] + + Q_nodes = nx.shortest_path(B, v, u) + Q_edges = [ + list(B[Q_nodes[i]][vv].keys())[0] for i, vv in enumerate(Q_nodes[1:]) + ] + Q_edges.append(desired_edge[2]) # Add the new edge key to complete the circuit + + # Get the edge in the circuit with the minimum weight. + # Also, save the incoming weights for each node. + minweight = INF + minedge = None + Q_incoming_weight = {} + for edge_key in Q_edges: + u, v, data = B_edge_index[edge_key] + w = data[attr] + # We cannot remove an included edge, even if it is the + # minimum edge in the circuit + Q_incoming_weight[v] = w + if data.get(partition) == nx.EdgePartition.INCLUDED: + continue + if w < minweight: + minweight = w + minedge = edge_key + + circuits.append(Q_edges) + minedge_circuit.append(minedge) + graphs.append((G.copy(), G_edge_index.copy())) + branchings.append((B.copy(), B_edge_index.copy())) + + # Mutate the graph to contract the circuit + new_node = new_node_base_name + str(level) + G.add_node(new_node) + new_edges = [] + for u, v, key, data in G.edges(data=True, keys=True): + if u in Q_incoming_weight: + if v in Q_incoming_weight: + # Circuit edge. For the moment do nothing, + # eventually it will be removed. + continue + else: + # Outgoing edge from a node in the circuit. + # Make it come from the new node instead + dd = data.copy() + new_edges.append((new_node, v, key, dd)) + else: + if v in Q_incoming_weight: + # Incoming edge to the circuit. + # Update it's weight + w = data[attr] + w += minweight - Q_incoming_weight[v] + dd = data.copy() + dd[attr] = w + new_edges.append((u, new_node, key, dd)) + else: + # Outside edge. No modification needed + continue + + for node in Q_nodes: + edmonds_remove_node(G, G_edge_index, node) + edmonds_remove_node(B, B_edge_index, node) + + selected_nodes.difference_update(set(Q_nodes)) + + for u, v, key, data in new_edges: + edmonds_add_edge(G, G_edge_index, u, v, key, **data) + if candidate_attr in data: + del data[candidate_attr] + edmonds_add_edge(B, B_edge_index, u, v, key, **data) + uf.union(u, v) + + def is_root(G, u, edgekeys): + """ + Returns True if `u` is a root node in G. + + Node `u` is a root node if its in-degree over the specified edges is zero. + + Parameters + ---------- + G : Graph + The current graph. + u : node + The node in `G` to check if it is a root. + edgekeys : iterable of edges + The edges for which to check if `u` is a root of. + """ + if u not in G: + raise Exception(f"{u!r} not in G") + + for v in G.pred[u]: + for edgekey in G.pred[u][v]: + if edgekey in edgekeys: + return False, edgekey + else: + return True, None + + nodes = iter(list(G.nodes)) + while True: + try: + v = next(nodes) + except StopIteration: + # If there are no more new nodes to consider, then we should + # meet stopping condition (b) from the paper: + # (b) every node of G^i is in D^i and E^i is a branching + assert len(G) == len(B) + if len(B): + assert is_branching(B) + + graphs.append((G.copy(), G_edge_index.copy())) + branchings.append((B.copy(), B_edge_index.copy())) + circuits.append([]) + minedge_circuit.append(None) + + break + else: + ##################### + ### BEGIN STEP I1 ### + ##################### + + # This is a very simple step, so I don't think it needs a method of it's own + if v in selected_nodes: + continue + + selected_nodes.add(v) + B.add_node(v) + desired_edge, desired_edge_weight = edmonds_find_desired_edge(v) + + # There might be no desired edge if all edges are excluded or + # v is the last node to be added to B, the ultimate root of the branching + if desired_edge is not None and desired_edge_weight > 0: + u = desired_edge[0] + # Flag adding the edge will create a circuit before merging the two + # connected components of u and v in B + circuit = uf[u] == uf[v] + dd = {attr: desired_edge_weight} + if desired_edge[4].get(partition) is not None: + dd[partition] = desired_edge[4].get(partition) + + edmonds_add_edge(B, B_edge_index, u, v, desired_edge[2], **dd) + G[u][v][desired_edge[2]][candidate_attr] = True + uf.union(u, v) + + ################### + ### END STEP I1 ### + ################### + + ##################### + ### BEGIN STEP I2 ### + ##################### + + if circuit: + edmonds_step_I2(v, desired_edge, level) + nodes = iter(list(G.nodes())) + level += 1 + + ################### + ### END STEP I2 ### + ################### + + ##################### + ### BEGIN STEP I3 ### + ##################### + + # Create a new graph of the same class as the input graph + H = G_original.__class__() + + # Start with the branching edges in the last level. + edges = set(branchings[level][1]) + while level > 0: + level -= 1 + + # The current level is i, and we start counting from 0. + # + # We need the node at level i+1 that results from merging a circuit + # at level i. basename_0 is the first merged node and this happens + # at level 1. That is basename_0 is a node at level 1 that results + # from merging a circuit at level 0. + + merged_node = new_node_base_name + str(level) + circuit = circuits[level] + isroot, edgekey = is_root(graphs[level + 1][0], merged_node, edges) + edges.update(circuit) + + if isroot: + minedge = minedge_circuit[level] + if minedge is None: + raise Exception + + # Remove the edge in the cycle with minimum weight + edges.remove(minedge) + else: + # We have identified an edge at the next higher level that + # transitions into the merged node at this level. That edge + # transitions to some corresponding node at the current level. + # + # We want to remove an edge from the cycle that transitions + # into the corresponding node, otherwise the result would not + # be a branching. + + G, G_edge_index = graphs[level] + target = G_edge_index[edgekey][1] + for edgekey in circuit: + u, v, data = G_edge_index[edgekey] + if v == target: + break + else: + raise Exception("Couldn't find edge incoming to merged node.") + + edges.remove(edgekey) + + H.add_nodes_from(G_original) + for edgekey in edges: + u, v, d = graphs[0][1][edgekey] + dd = {attr: d[attr]} + + if preserve_attrs: + for key, value in d.items(): + if key not in [attr, candidate_attr]: + dd[key] = value + + H.add_edge(u, v, **dd) + + ################### + ### END STEP I3 ### + ################### + + return H + + +@nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True) +def minimum_branching( + G, attr="weight", default=1, preserve_attrs=False, partition=None +): + for _, _, d in G.edges(data=True): + d[attr] = -d.get(attr, default) + nx._clear_cache(G) + + B = maximum_branching(G, attr, default, preserve_attrs, partition) + + for _, _, d in G.edges(data=True): + d[attr] = -d.get(attr, default) + nx._clear_cache(G) + + for _, _, d in B.edges(data=True): + d[attr] = -d.get(attr, default) + nx._clear_cache(B) + + return B + + +@nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True) +def minimal_branching( + G, /, *, attr="weight", default=1, preserve_attrs=False, partition=None +): + """ + Returns a minimal branching from `G`. + + A minimal branching is a branching similar to a minimal arborescence but + without the requirement that the result is actually a spanning arborescence. + This allows minimal branchinges to be computed over graphs which may not + have arborescence (such as multiple components). + + Parameters + ---------- + G : (multi)digraph-like + The graph to be searched. + attr : str + The edge attribute used in determining optimality. + default : float + The value of the edge attribute used if an edge does not have + the attribute `attr`. + preserve_attrs : bool + If True, preserve the other attributes of the original graph (that are not + passed to `attr`) + partition : str + The key for the edge attribute containing the partition + data on the graph. Edges can be included, excluded or open using the + `EdgePartition` enum. + + Returns + ------- + B : (multi)digraph-like + A minimal branching. + """ + max_weight = -INF + min_weight = INF + for _, _, w in G.edges(data=attr, default=default): + if w > max_weight: + max_weight = w + if w < min_weight: + min_weight = w + + for _, _, d in G.edges(data=True): + # Transform the weights so that the minimum weight is larger than + # the difference between the max and min weights. This is important + # in order to prevent the edge weights from becoming negative during + # computation + d[attr] = max_weight + 1 + (max_weight - min_weight) - d.get(attr, default) + nx._clear_cache(G) + + B = maximum_branching(G, attr, default, preserve_attrs, partition) + + # Reverse the weight transformations + for _, _, d in G.edges(data=True): + d[attr] = max_weight + 1 + (max_weight - min_weight) - d.get(attr, default) + nx._clear_cache(G) + + for _, _, d in B.edges(data=True): + d[attr] = max_weight + 1 + (max_weight - min_weight) - d.get(attr, default) + nx._clear_cache(B) + + return B + + +@nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True) +def maximum_spanning_arborescence( + G, attr="weight", default=1, preserve_attrs=False, partition=None +): + # In order to use the same algorithm is the maximum branching, we need to adjust + # the weights of the graph. The branching algorithm can choose to not include an + # edge if it doesn't help find a branching, mainly triggered by edges with negative + # weights. + # + # To prevent this from happening while trying to find a spanning arborescence, we + # just have to tweak the edge weights so that they are all positive and cannot + # become negative during the branching algorithm, find the maximum branching and + # then return them to their original values. + + min_weight = INF + max_weight = -INF + for _, _, w in G.edges(data=attr, default=default): + if w < min_weight: + min_weight = w + if w > max_weight: + max_weight = w + + for _, _, d in G.edges(data=True): + d[attr] = d.get(attr, default) - min_weight + 1 - (min_weight - max_weight) + nx._clear_cache(G) + + B = maximum_branching(G, attr, default, preserve_attrs, partition) + + for _, _, d in G.edges(data=True): + d[attr] = d.get(attr, default) + min_weight - 1 + (min_weight - max_weight) + nx._clear_cache(G) + + for _, _, d in B.edges(data=True): + d[attr] = d.get(attr, default) + min_weight - 1 + (min_weight - max_weight) + nx._clear_cache(B) + + if not is_arborescence(B): + raise nx.exception.NetworkXException("No maximum spanning arborescence in G.") + + return B + + +@nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True) +def minimum_spanning_arborescence( + G, attr="weight", default=1, preserve_attrs=False, partition=None +): + B = minimal_branching( + G, + attr=attr, + default=default, + preserve_attrs=preserve_attrs, + partition=partition, + ) + + if not is_arborescence(B): + raise nx.exception.NetworkXException("No minimum spanning arborescence in G.") + + return B + + +docstring_branching = """ +Returns a {kind} {style} from G. + +Parameters +---------- +G : (multi)digraph-like + The graph to be searched. +attr : str + The edge attribute used to in determining optimality. +default : float + The value of the edge attribute used if an edge does not have + the attribute `attr`. +preserve_attrs : bool + If True, preserve the other attributes of the original graph (that are not + passed to `attr`) +partition : str + The key for the edge attribute containing the partition + data on the graph. Edges can be included, excluded or open using the + `EdgePartition` enum. + +Returns +------- +B : (multi)digraph-like + A {kind} {style}. +""" + +docstring_arborescence = ( + docstring_branching + + """ +Raises +------ +NetworkXException + If the graph does not contain a {kind} {style}. + +""" +) + +maximum_branching.__doc__ = docstring_branching.format( + kind="maximum", style="branching" +) + +minimum_branching.__doc__ = ( + docstring_branching.format(kind="minimum", style="branching") + + """ +See Also +-------- + minimal_branching +""" +) + +maximum_spanning_arborescence.__doc__ = docstring_arborescence.format( + kind="maximum", style="spanning arborescence" +) + +minimum_spanning_arborescence.__doc__ = docstring_arborescence.format( + kind="minimum", style="spanning arborescence" +) + + +class ArborescenceIterator: + """ + Iterate over all spanning arborescences of a graph in either increasing or + decreasing cost. + + Notes + ----- + This iterator uses the partition scheme from [1]_ (included edges, + excluded edges and open edges). It generates minimum spanning + arborescences using a modified Edmonds' Algorithm which respects the + partition of edges. For arborescences with the same weight, ties are + broken arbitrarily. + + References + ---------- + .. [1] G.K. Janssens, K. Sörensen, An algorithm to generate all spanning + trees in order of increasing cost, Pesquisa Operacional, 2005-08, + Vol. 25 (2), p. 219-229, + https://www.scielo.br/j/pope/a/XHswBwRwJyrfL88dmMwYNWp/?lang=en + """ + + @dataclass(order=True) + class Partition: + """ + This dataclass represents a partition and stores a dict with the edge + data and the weight of the minimum spanning arborescence of the + partition dict. + """ + + mst_weight: float + partition_dict: dict = field(compare=False) + + def __copy__(self): + return ArborescenceIterator.Partition( + self.mst_weight, self.partition_dict.copy() + ) + + def __init__(self, G, weight="weight", minimum=True, init_partition=None): + """ + Initialize the iterator + + Parameters + ---------- + G : nx.DiGraph + The directed graph which we need to iterate trees over + + weight : String, default = "weight" + The edge attribute used to store the weight of the edge + + minimum : bool, default = True + Return the trees in increasing order while true and decreasing order + while false. + + init_partition : tuple, default = None + In the case that certain edges have to be included or excluded from + the arborescences, `init_partition` should be in the form + `(included_edges, excluded_edges)` where each edges is a + `(u, v)`-tuple inside an iterable such as a list or set. + + """ + self.G = G.copy() + self.weight = weight + self.minimum = minimum + self.method = ( + minimum_spanning_arborescence if minimum else maximum_spanning_arborescence + ) + # Randomly create a key for an edge attribute to hold the partition data + self.partition_key = ( + "ArborescenceIterators super secret partition attribute name" + ) + if init_partition is not None: + partition_dict = {} + for e in init_partition[0]: + partition_dict[e] = nx.EdgePartition.INCLUDED + for e in init_partition[1]: + partition_dict[e] = nx.EdgePartition.EXCLUDED + self.init_partition = ArborescenceIterator.Partition(0, partition_dict) + else: + self.init_partition = None + + def __iter__(self): + """ + Returns + ------- + ArborescenceIterator + The iterator object for this graph + """ + self.partition_queue = PriorityQueue() + self._clear_partition(self.G) + + # Write the initial partition if it exists. + if self.init_partition is not None: + self._write_partition(self.init_partition) + + mst_weight = self.method( + self.G, + self.weight, + partition=self.partition_key, + preserve_attrs=True, + ).size(weight=self.weight) + + self.partition_queue.put( + self.Partition( + mst_weight if self.minimum else -mst_weight, + ( + {} + if self.init_partition is None + else self.init_partition.partition_dict + ), + ) + ) + + return self + + def __next__(self): + """ + Returns + ------- + (multi)Graph + The spanning tree of next greatest weight, which ties broken + arbitrarily. + """ + if self.partition_queue.empty(): + del self.G, self.partition_queue + raise StopIteration + + partition = self.partition_queue.get() + self._write_partition(partition) + next_arborescence = self.method( + self.G, + self.weight, + partition=self.partition_key, + preserve_attrs=True, + ) + self._partition(partition, next_arborescence) + + self._clear_partition(next_arborescence) + return next_arborescence + + def _partition(self, partition, partition_arborescence): + """ + Create new partitions based of the minimum spanning tree of the + current minimum partition. + + Parameters + ---------- + partition : Partition + The Partition instance used to generate the current minimum spanning + tree. + partition_arborescence : nx.Graph + The minimum spanning arborescence of the input partition. + """ + # create two new partitions with the data from the input partition dict + p1 = self.Partition(0, partition.partition_dict.copy()) + p2 = self.Partition(0, partition.partition_dict.copy()) + for e in partition_arborescence.edges: + # determine if the edge was open or included + if e not in partition.partition_dict: + # This is an open edge + p1.partition_dict[e] = nx.EdgePartition.EXCLUDED + p2.partition_dict[e] = nx.EdgePartition.INCLUDED + + self._write_partition(p1) + try: + p1_mst = self.method( + self.G, + self.weight, + partition=self.partition_key, + preserve_attrs=True, + ) + + p1_mst_weight = p1_mst.size(weight=self.weight) + p1.mst_weight = p1_mst_weight if self.minimum else -p1_mst_weight + self.partition_queue.put(p1.__copy__()) + except nx.NetworkXException: + pass + + p1.partition_dict = p2.partition_dict.copy() + + def _write_partition(self, partition): + """ + Writes the desired partition into the graph to calculate the minimum + spanning tree. Also, if one incoming edge is included, mark all others + as excluded so that if that vertex is merged during Edmonds' algorithm + we cannot still pick another of that vertex's included edges. + + Parameters + ---------- + partition : Partition + A Partition dataclass describing a partition on the edges of the + graph. + """ + for u, v, d in self.G.edges(data=True): + if (u, v) in partition.partition_dict: + d[self.partition_key] = partition.partition_dict[(u, v)] + else: + d[self.partition_key] = nx.EdgePartition.OPEN + nx._clear_cache(self.G) + + for n in self.G: + included_count = 0 + excluded_count = 0 + for u, v, d in self.G.in_edges(nbunch=n, data=True): + if d.get(self.partition_key) == nx.EdgePartition.INCLUDED: + included_count += 1 + elif d.get(self.partition_key) == nx.EdgePartition.EXCLUDED: + excluded_count += 1 + # Check that if there is an included edges, all other incoming ones + # are excluded. If not fix it! + if included_count == 1 and excluded_count != self.G.in_degree(n) - 1: + for u, v, d in self.G.in_edges(nbunch=n, data=True): + if d.get(self.partition_key) != nx.EdgePartition.INCLUDED: + d[self.partition_key] = nx.EdgePartition.EXCLUDED + + def _clear_partition(self, G): + """ + Removes partition data from the graph + """ + for u, v, d in G.edges(data=True): + if self.partition_key in d: + del d[self.partition_key] + nx._clear_cache(self.G) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/coding.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/coding.py new file mode 100644 index 0000000000000000000000000000000000000000..276c87d0745dade0f11bb9374c3479e1f9f8bcfe --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/coding.py @@ -0,0 +1,413 @@ +"""Functions for encoding and decoding trees. + +Since a tree is a highly restricted form of graph, it can be represented +concisely in several ways. This module includes functions for encoding +and decoding trees in the form of nested tuples and Prüfer +sequences. The former requires a rooted tree, whereas the latter can be +applied to unrooted trees. Furthermore, there is a bijection from Prüfer +sequences to labeled trees. + +""" + +from collections import Counter +from itertools import chain + +import networkx as nx +from networkx.utils import not_implemented_for + +__all__ = [ + "from_nested_tuple", + "from_prufer_sequence", + "NotATree", + "to_nested_tuple", + "to_prufer_sequence", +] + + +class NotATree(nx.NetworkXException): + """Raised when a function expects a tree (that is, a connected + undirected graph with no cycles) but gets a non-tree graph as input + instead. + + """ + + +@not_implemented_for("directed") +@nx._dispatchable(graphs="T") +def to_nested_tuple(T, root, canonical_form=False): + """Returns a nested tuple representation of the given tree. + + The nested tuple representation of a tree is defined + recursively. The tree with one node and no edges is represented by + the empty tuple, ``()``. A tree with ``k`` subtrees is represented + by a tuple of length ``k`` in which each element is the nested tuple + representation of a subtree. + + Parameters + ---------- + T : NetworkX graph + An undirected graph object representing a tree. + + root : node + The node in ``T`` to interpret as the root of the tree. + + canonical_form : bool + If ``True``, each tuple is sorted so that the function returns + a canonical form for rooted trees. This means "lighter" subtrees + will appear as nested tuples before "heavier" subtrees. In this + way, each isomorphic rooted tree has the same nested tuple + representation. + + Returns + ------- + tuple + A nested tuple representation of the tree. + + Notes + ----- + This function is *not* the inverse of :func:`from_nested_tuple`; the + only guarantee is that the rooted trees are isomorphic. + + See also + -------- + from_nested_tuple + to_prufer_sequence + + Examples + -------- + The tree need not be a balanced binary tree:: + + >>> T = nx.Graph() + >>> T.add_edges_from([(0, 1), (0, 2), (0, 3)]) + >>> T.add_edges_from([(1, 4), (1, 5)]) + >>> T.add_edges_from([(3, 6), (3, 7)]) + >>> root = 0 + >>> nx.to_nested_tuple(T, root) + (((), ()), (), ((), ())) + + Continuing the above example, if ``canonical_form`` is ``True``, the + nested tuples will be sorted:: + + >>> nx.to_nested_tuple(T, root, canonical_form=True) + ((), ((), ()), ((), ())) + + Even the path graph can be interpreted as a tree:: + + >>> T = nx.path_graph(4) + >>> root = 0 + >>> nx.to_nested_tuple(T, root) + ((((),),),) + + """ + + def _make_tuple(T, root, _parent): + """Recursively compute the nested tuple representation of the + given rooted tree. + + ``_parent`` is the parent node of ``root`` in the supertree in + which ``T`` is a subtree, or ``None`` if ``root`` is the root of + the supertree. This argument is used to determine which + neighbors of ``root`` are children and which is the parent. + + """ + # Get the neighbors of `root` that are not the parent node. We + # are guaranteed that `root` is always in `T` by construction. + children = set(T[root]) - {_parent} + if len(children) == 0: + return () + nested = (_make_tuple(T, v, root) for v in children) + if canonical_form: + nested = sorted(nested) + return tuple(nested) + + # Do some sanity checks on the input. + if not nx.is_tree(T): + raise nx.NotATree("provided graph is not a tree") + if root not in T: + raise nx.NodeNotFound(f"Graph {T} contains no node {root}") + + return _make_tuple(T, root, None) + + +@nx._dispatchable(graphs=None, returns_graph=True) +def from_nested_tuple(sequence, sensible_relabeling=False): + """Returns the rooted tree corresponding to the given nested tuple. + + The nested tuple representation of a tree is defined + recursively. The tree with one node and no edges is represented by + the empty tuple, ``()``. A tree with ``k`` subtrees is represented + by a tuple of length ``k`` in which each element is the nested tuple + representation of a subtree. + + Parameters + ---------- + sequence : tuple + A nested tuple representing a rooted tree. + + sensible_relabeling : bool + Whether to relabel the nodes of the tree so that nodes are + labeled in increasing order according to their breadth-first + search order from the root node. + + Returns + ------- + NetworkX graph + The tree corresponding to the given nested tuple, whose root + node is node 0. If ``sensible_labeling`` is ``True``, nodes will + be labeled in breadth-first search order starting from the root + node. + + Notes + ----- + This function is *not* the inverse of :func:`to_nested_tuple`; the + only guarantee is that the rooted trees are isomorphic. + + See also + -------- + to_nested_tuple + from_prufer_sequence + + Examples + -------- + Sensible relabeling ensures that the nodes are labeled from the root + starting at 0:: + + >>> balanced = (((), ()), ((), ())) + >>> T = nx.from_nested_tuple(balanced, sensible_relabeling=True) + >>> edges = [(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)] + >>> all((u, v) in T.edges() or (v, u) in T.edges() for (u, v) in edges) + True + + """ + + def _make_tree(sequence): + """Recursively creates a tree from the given sequence of nested + tuples. + + This function employs the :func:`~networkx.tree.join` function + to recursively join subtrees into a larger tree. + + """ + # The empty sequence represents the empty tree, which is the + # (unique) graph with a single node. We mark the single node + # with an attribute that indicates that it is the root of the + # graph. + if len(sequence) == 0: + return nx.empty_graph(1) + # For a nonempty sequence, get the subtrees for each child + # sequence and join all the subtrees at their roots. After + # joining the subtrees, the root is node 0. + return nx.tree.join_trees([(_make_tree(child), 0) for child in sequence]) + + # Make the tree and remove the `is_root` node attribute added by the + # helper function. + T = _make_tree(sequence) + if sensible_relabeling: + # Relabel the nodes according to their breadth-first search + # order, starting from the root node (that is, the node 0). + bfs_nodes = chain([0], (v for u, v in nx.bfs_edges(T, 0))) + labels = {v: i for i, v in enumerate(bfs_nodes)} + # We would like to use `copy=False`, but `relabel_nodes` doesn't + # allow a relabel mapping that can't be topologically sorted. + T = nx.relabel_nodes(T, labels) + return T + + +@not_implemented_for("directed") +@nx._dispatchable(graphs="T") +def to_prufer_sequence(T): + r"""Returns the Prüfer sequence of the given tree. + + A *Prüfer sequence* is a list of *n* - 2 numbers between 0 and + *n* - 1, inclusive. The tree corresponding to a given Prüfer + sequence can be recovered by repeatedly joining a node in the + sequence with a node with the smallest potential degree according to + the sequence. + + Parameters + ---------- + T : NetworkX graph + An undirected graph object representing a tree. + + Returns + ------- + list + The Prüfer sequence of the given tree. + + Raises + ------ + NetworkXPointlessConcept + If the number of nodes in `T` is less than two. + + NotATree + If `T` is not a tree. + + KeyError + If the set of nodes in `T` is not {0, …, *n* - 1}. + + Notes + ----- + There is a bijection from labeled trees to Prüfer sequences. This + function is the inverse of the :func:`from_prufer_sequence` + function. + + Sometimes Prüfer sequences use nodes labeled from 1 to *n* instead + of from 0 to *n* - 1. This function requires nodes to be labeled in + the latter form. You can use :func:`~networkx.relabel_nodes` to + relabel the nodes of your tree to the appropriate format. + + This implementation is from [1]_ and has a running time of + $O(n)$. + + See also + -------- + to_nested_tuple + from_prufer_sequence + + References + ---------- + .. [1] Wang, Xiaodong, Lei Wang, and Yingjie Wu. + "An optimal algorithm for Prufer codes." + *Journal of Software Engineering and Applications* 2.02 (2009): 111. + + + Examples + -------- + There is a bijection between Prüfer sequences and labeled trees, so + this function is the inverse of the :func:`from_prufer_sequence` + function: + + >>> edges = [(0, 3), (1, 3), (2, 3), (3, 4), (4, 5)] + >>> tree = nx.Graph(edges) + >>> sequence = nx.to_prufer_sequence(tree) + >>> sequence + [3, 3, 3, 4] + >>> tree2 = nx.from_prufer_sequence(sequence) + >>> list(tree2.edges()) == edges + True + + """ + # Perform some sanity checks on the input. + n = len(T) + if n < 2: + msg = "Prüfer sequence undefined for trees with fewer than two nodes" + raise nx.NetworkXPointlessConcept(msg) + if not nx.is_tree(T): + raise nx.NotATree("provided graph is not a tree") + if set(T) != set(range(n)): + raise KeyError("tree must have node labels {0, ..., n - 1}") + + degree = dict(T.degree()) + + def parents(u): + return next(v for v in T[u] if degree[v] > 1) + + index = u = next(k for k in range(n) if degree[k] == 1) + result = [] + for i in range(n - 2): + v = parents(u) + result.append(v) + degree[v] -= 1 + if v < index and degree[v] == 1: + u = v + else: + index = u = next(k for k in range(index + 1, n) if degree[k] == 1) + return result + + +@nx._dispatchable(graphs=None, returns_graph=True) +def from_prufer_sequence(sequence): + r"""Returns the tree corresponding to the given Prüfer sequence. + + A *Prüfer sequence* is a list of *n* - 2 numbers between 0 and + *n* - 1, inclusive. The tree corresponding to a given Prüfer + sequence can be recovered by repeatedly joining a node in the + sequence with a node with the smallest potential degree according to + the sequence. + + Parameters + ---------- + sequence : list + A Prüfer sequence, which is a list of *n* - 2 integers between + zero and *n* - 1, inclusive. + + Returns + ------- + NetworkX graph + The tree corresponding to the given Prüfer sequence. + + Raises + ------ + NetworkXError + If the Prüfer sequence is not valid. + + Notes + ----- + There is a bijection from labeled trees to Prüfer sequences. This + function is the inverse of the :func:`from_prufer_sequence` function. + + Sometimes Prüfer sequences use nodes labeled from 1 to *n* instead + of from 0 to *n* - 1. This function requires nodes to be labeled in + the latter form. You can use :func:`networkx.relabel_nodes` to + relabel the nodes of your tree to the appropriate format. + + This implementation is from [1]_ and has a running time of + $O(n)$. + + References + ---------- + .. [1] Wang, Xiaodong, Lei Wang, and Yingjie Wu. + "An optimal algorithm for Prufer codes." + *Journal of Software Engineering and Applications* 2.02 (2009): 111. + + + See also + -------- + from_nested_tuple + to_prufer_sequence + + Examples + -------- + There is a bijection between Prüfer sequences and labeled trees, so + this function is the inverse of the :func:`to_prufer_sequence` + function: + + >>> edges = [(0, 3), (1, 3), (2, 3), (3, 4), (4, 5)] + >>> tree = nx.Graph(edges) + >>> sequence = nx.to_prufer_sequence(tree) + >>> sequence + [3, 3, 3, 4] + >>> tree2 = nx.from_prufer_sequence(sequence) + >>> list(tree2.edges()) == edges + True + + """ + n = len(sequence) + 2 + # `degree` stores the remaining degree (plus one) for each node. The + # degree of a node in the decoded tree is one more than the number + # of times it appears in the code. + degree = Counter(chain(sequence, range(n))) + T = nx.empty_graph(n) + # `not_orphaned` is the set of nodes that have a parent in the + # tree. After the loop, there should be exactly two nodes that are + # not in this set. + not_orphaned = set() + index = u = next(k for k in range(n) if degree[k] == 1) + for v in sequence: + # check the validity of the prufer sequence + if v < 0 or v > n - 1: + raise nx.NetworkXError( + f"Invalid Prufer sequence: Values must be between 0 and {n - 1}, got {v}" + ) + T.add_edge(u, v) + not_orphaned.add(u) + degree[v] -= 1 + if v < index and degree[v] == 1: + u = v + else: + index = u = next(k for k in range(index + 1, n) if degree[k] == 1) + # At this point, there must be exactly two orphaned nodes; join them. + orphans = set(T) - not_orphaned + u, v = orphans + T.add_edge(u, v) + return T diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/distance_measures.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/distance_measures.py new file mode 100644 index 0000000000000000000000000000000000000000..6d0b273737686d561f88504b937f1267e0fec106 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/distance_measures.py @@ -0,0 +1,219 @@ +""" +Algorithms for computing distance measures on trees. +""" + +import networkx as nx + +__all__ = [ + "center", + "centroid", +] + + +@nx.utils.not_implemented_for("directed") +def center(G): + """Returns the center of an undirected tree graph. + + The center of a tree consists of nodes that minimize the maximum eccentricity. + That is, these nodes minimize the maximum distance to all other nodes. + This implementation currently only works for unweighted edges. + + If the input graph is not a tree, results are not guaranteed to be correct and while + some non-trees will raise an ``nx.NotATree`` exception, not all non-trees will be discovered. + Thus, this function should not be used if caller is unsure whether the input graph + is a tree. Use ``nx.is_tree(G)`` to check. + + Parameters + ---------- + G : NetworkX graph + A tree graph (undirected, acyclic graph). + + Returns + ------- + center : list + A list of nodes forming the center of the tree. This can be one or two nodes. + + Raises + ------ + NetworkXNotImplemented + If the input graph is directed. + + NotATree + If the algorithm detects the input graph is not a tree. There is no guarantee + this error will always raise if a non-tree is passed. + + Notes + ----- + This algorithm iteratively removes leaves (nodes with degree 1) from the tree until + there are only 1 or 2 nodes left. The remaining nodes form the center of the tree. + + This algorithm's time complexity is ``O(N)`` where ``N`` is the number of nodes in the tree. + + Examples + -------- + >>> G = nx.Graph([(1, 2), (1, 3), (2, 4), (2, 5)]) + >>> nx.tree.center(G) + [1, 2] + + >>> G = nx.path_graph(5) + >>> nx.tree.center(G) + [2] + """ + center_candidates_degree = dict(G.degree) + leaves = {node for node, degree in center_candidates_degree.items() if degree == 1} + + # It's better to fail than an infinite loop, so check leaves to ensure progress. + while len(center_candidates_degree) > 2 and leaves: + new_leaves = set() + for leaf in leaves: + del center_candidates_degree[leaf] + for neighbor in G.neighbors(leaf): + if neighbor not in center_candidates_degree: + continue + center_candidates_degree[neighbor] -= 1 + if (cddn := center_candidates_degree[neighbor]) == 1: + new_leaves.add(neighbor) + elif cddn == 0 and len(center_candidates_degree) != 1: + raise nx.NotATree("input graph is not a tree") + leaves = new_leaves + + n = len(center_candidates_degree) + # check disconnected or cyclic + if (n == 2 and not leaves) or n not in {1, 2}: + # We have detected graph is not a tree. This check does not cover all cases. + # For example, `nx.Graph([(0, 0)])` will not raise an error. + raise nx.NotATree("input graph is not a tree") + + return list(center_candidates_degree) + + +def _subtree_sizes(G, root): + """Return a `dict` of the size of each subtree, for every subtree + of a tree rooted at a given node. + + For every node in the given tree, consider the new tree that would + be created by detaching it from its parent node (if any). The + number of nodes in the resulting tree rooted at that node is then + assigned as the value for that node in the return dictionary. + + Parameters + ---------- + G : NetworkX graph + A tree. + + root : node + A node in `G`. + + Returns + ------- + s : dict + Dictionary of number of nodes in every subtree of this tree, + keyed on the root node for each subtree. + + Examples + -------- + >>> _subtree_sizes(nx.path_graph(4), 0) + {0: 4, 1: 3, 2: 2, 3: 1} + + >>> _subtree_sizes(nx.path_graph(4), 2) + {2: 4, 1: 2, 0: 1, 3: 1} + + """ + sizes = {root: 1} + stack = [root] + for parent, child in nx.dfs_edges(G, root): + while stack[-1] != parent: + descendant = stack.pop() + sizes[stack[-1]] += sizes[descendant] + stack.append(child) + sizes[child] = 1 + for child, parent in nx.utils.pairwise(reversed(stack)): + sizes[parent] += sizes[child] + return sizes + + +@nx.utils.not_implemented_for("directed") +@nx._dispatchable +def centroid(G): + """Return the centroid of an unweighted tree. + + The centroid of a tree is the set of nodes such that removing any + one of them would split the tree into a forest of subtrees, each + with at most ``N / 2`` nodes, where ``N`` is the number of nodes + in the original tree. This set may contain two nodes if removing + an edge between them results in two trees of size exactly ``N / + 2``. + + Parameters + ---------- + G : NetworkX graph + A tree. + + Returns + ------- + c : list + List of nodes in centroid of the tree. This could be one or two nodes. + + Raises + ------ + NotATree + If the input graph is not a tree. + NotImplementedException + If the input graph is directed. + NetworkXPointlessConcept + If `G` has no nodes or edges. + + Notes + ----- + This algorithm's time complexity is ``O(N)`` where ``N`` is the + number of nodes in the tree. + + In unweighted trees the centroid coincides with the barycenter, + the node or nodes that minimize the sum of distances to all other + nodes. However, this concept is different from that of the graph + center, which is the set of nodes minimizing the maximum distance + to all other nodes. + + Examples + -------- + >>> G = nx.path_graph(4) + >>> nx.tree.centroid(G) + [1, 2] + + A star-shaped tree with one long branch illustrates the difference + between the centroid and the center. The center lies near the + middle of the long branch, minimizing maximum distance. The + centroid, however, limits the size of any resulting subtree to at + most half the total nodes, forcing it to remain near the hub when + enough short branches are present. + + >>> G = nx.star_graph(6) + >>> nx.add_path(G, [6, 7, 8, 9, 10]) + >>> nx.tree.centroid(G), nx.tree.center(G) + ([0], [7]) + + See Also + -------- + :func:`~networkx.algorithms.distance_measures.barycenter` + :func:`~networkx.algorithms.distance_measures.center` + center : tree center + """ + if not nx.is_tree(G): + raise nx.NotATree("provided graph is not a tree") + prev, root = None, nx.utils.arbitrary_element(G) + sizes = _subtree_sizes(G, root) + total_size = G.number_of_nodes() + + def _heaviest_child(prev, root): + return max( + (x for x in G.neighbors(root) if x != prev), key=sizes.get, default=None + ) + + hc = _heaviest_child(prev, root) + while max(total_size - sizes[root], sizes.get(hc, 0)) > total_size / 2: + prev, root = root, hc + hc = _heaviest_child(prev, root) + + return [root] + [ + x for x in G.neighbors(root) if x != prev and sizes[x] == total_size / 2 + ] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/mst.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/mst.py new file mode 100644 index 0000000000000000000000000000000000000000..5886c6ef8a5e951550187917534e445c5a943cf5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/mst.py @@ -0,0 +1,1281 @@ +""" +Algorithms for calculating min/max spanning trees/forests. + +""" + +from dataclasses import dataclass, field +from enum import Enum +from heapq import heappop, heappush +from itertools import count +from math import isnan +from operator import itemgetter +from queue import PriorityQueue + +import networkx as nx +from networkx.utils import UnionFind, not_implemented_for, py_random_state + +__all__ = [ + "minimum_spanning_edges", + "maximum_spanning_edges", + "minimum_spanning_tree", + "maximum_spanning_tree", + "number_of_spanning_trees", + "random_spanning_tree", + "partition_spanning_tree", + "EdgePartition", + "SpanningTreeIterator", +] + + +class EdgePartition(Enum): + """ + An enum to store the state of an edge partition. The enum is written to the + edges of a graph before being pasted to `kruskal_mst_edges`. Options are: + + - EdgePartition.OPEN + - EdgePartition.INCLUDED + - EdgePartition.EXCLUDED + """ + + OPEN = 0 + INCLUDED = 1 + EXCLUDED = 2 + + +@not_implemented_for("multigraph") +@nx._dispatchable(edge_attrs="weight", preserve_edge_attrs="data") +def boruvka_mst_edges( + G, minimum=True, weight="weight", keys=False, data=True, ignore_nan=False +): + """Iterate over edges of a Borůvka's algorithm min/max spanning tree. + + Parameters + ---------- + G : NetworkX Graph + The edges of `G` must have distinct weights, + otherwise the edges may not form a tree. + + minimum : bool (default: True) + Find the minimum (True) or maximum (False) spanning tree. + + weight : string (default: 'weight') + The name of the edge attribute holding the edge weights. + + keys : bool (default: True) + This argument is ignored since this function is not + implemented for multigraphs; it exists only for consistency + with the other minimum spanning tree functions. + + data : bool (default: True) + Flag for whether to yield edge attribute dicts. + If True, yield edges `(u, v, d)`, where `d` is the attribute dict. + If False, yield edges `(u, v)`. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + """ + # Initialize a forest, assuming initially that it is the discrete + # partition of the nodes of the graph. + forest = UnionFind(G) + + def best_edge(component): + """Returns the optimum (minimum or maximum) edge on the edge + boundary of the given set of nodes. + + A return value of ``None`` indicates an empty boundary. + + """ + sign = 1 if minimum else -1 + minwt = float("inf") + boundary = None + for e in nx.edge_boundary(G, component, data=True): + wt = e[-1].get(weight, 1) * sign + if isnan(wt): + if ignore_nan: + continue + msg = f"NaN found as an edge weight. Edge {e}" + raise ValueError(msg) + if wt < minwt: + minwt = wt + boundary = e + return boundary + + # Determine the optimum edge in the edge boundary of each component + # in the forest. + best_edges = (best_edge(component) for component in forest.to_sets()) + best_edges = [edge for edge in best_edges if edge is not None] + # If each entry was ``None``, that means the graph was disconnected, + # so we are done generating the forest. + while best_edges: + # Determine the optimum edge in the edge boundary of each + # component in the forest. + # + # This must be a sequence, not an iterator. In this list, the + # same edge may appear twice, in different orientations (but + # that's okay, since a union operation will be called on the + # endpoints the first time it is seen, but not the second time). + # + # Any ``None`` indicates that the edge boundary for that + # component was empty, so that part of the forest has been + # completed. + # + # TODO This can be parallelized, both in the outer loop over + # each component in the forest and in the computation of the + # minimum. (Same goes for the identical lines outside the loop.) + best_edges = (best_edge(component) for component in forest.to_sets()) + best_edges = [edge for edge in best_edges if edge is not None] + # Join trees in the forest using the best edges, and yield that + # edge, since it is part of the spanning tree. + # + # TODO This loop can be parallelized, to an extent (the union + # operation must be atomic). + for u, v, d in best_edges: + if forest[u] != forest[v]: + if data: + yield u, v, d + else: + yield u, v + forest.union(u, v) + + +@nx._dispatchable( + edge_attrs={"weight": None, "partition": None}, preserve_edge_attrs="data" +) +def kruskal_mst_edges( + G, minimum, weight="weight", keys=True, data=True, ignore_nan=False, partition=None +): + """ + Iterate over edge of a Kruskal's algorithm min/max spanning tree. + + Parameters + ---------- + G : NetworkX Graph + The graph holding the tree of interest. + + minimum : bool (default: True) + Find the minimum (True) or maximum (False) spanning tree. + + weight : string (default: 'weight') + The name of the edge attribute holding the edge weights. + + keys : bool (default: True) + If `G` is a multigraph, `keys` controls whether edge keys ar yielded. + Otherwise `keys` is ignored. + + data : bool (default: True) + Flag for whether to yield edge attribute dicts. + If True, yield edges `(u, v, d)`, where `d` is the attribute dict. + If False, yield edges `(u, v)`. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + partition : string (default: None) + The name of the edge attribute holding the partition data, if it exists. + Partition data is written to the edges using the `EdgePartition` enum. + If a partition exists, all included edges and none of the excluded edges + will appear in the final tree. Open edges may or may not be used. + + Yields + ------ + edge tuple + The edges as discovered by Kruskal's method. Each edge can + take the following forms: `(u, v)`, `(u, v, d)` or `(u, v, k, d)` + depending on the `key` and `data` parameters + """ + subtrees = UnionFind() + if G.is_multigraph(): + edges = G.edges(keys=True, data=True) + else: + edges = G.edges(data=True) + + # Sort the edges of the graph with respect to the partition data. + # Edges are returned in the following order: + + # * Included edges + # * Open edges from smallest to largest weight + # * Excluded edges + included_edges = [] + open_edges = [] + for e in edges: + d = e[-1] + wt = d.get(weight, 1) + if isnan(wt): + if ignore_nan: + continue + raise ValueError(f"NaN found as an edge weight. Edge {e}") + + edge = (wt,) + e + if d.get(partition) == EdgePartition.INCLUDED: + included_edges.append(edge) + elif d.get(partition) == EdgePartition.EXCLUDED: + continue + else: + open_edges.append(edge) + + if minimum: + sorted_open_edges = sorted(open_edges, key=itemgetter(0)) + else: + sorted_open_edges = sorted(open_edges, key=itemgetter(0), reverse=True) + + # Condense the lists into one + included_edges.extend(sorted_open_edges) + sorted_edges = included_edges + del open_edges, sorted_open_edges, included_edges + + # Multigraphs need to handle edge keys in addition to edge data. + if G.is_multigraph(): + for wt, u, v, k, d in sorted_edges: + if subtrees[u] != subtrees[v]: + if keys: + if data: + yield u, v, k, d + else: + yield u, v, k + else: + if data: + yield u, v, d + else: + yield u, v + subtrees.union(u, v) + else: + for wt, u, v, d in sorted_edges: + if subtrees[u] != subtrees[v]: + if data: + yield u, v, d + else: + yield u, v + subtrees.union(u, v) + + +@nx._dispatchable(edge_attrs="weight", preserve_edge_attrs="data") +def prim_mst_edges(G, minimum, weight="weight", keys=True, data=True, ignore_nan=False): + """Iterate over edges of Prim's algorithm min/max spanning tree. + + Parameters + ---------- + G : NetworkX Graph + The graph holding the tree of interest. + + minimum : bool (default: True) + Find the minimum (True) or maximum (False) spanning tree. + + weight : string (default: 'weight') + The name of the edge attribute holding the edge weights. + + keys : bool (default: True) + If `G` is a multigraph, `keys` controls whether edge keys ar yielded. + Otherwise `keys` is ignored. + + data : bool (default: True) + Flag for whether to yield edge attribute dicts. + If True, yield edges `(u, v, d)`, where `d` is the attribute dict. + If False, yield edges `(u, v)`. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + """ + is_multigraph = G.is_multigraph() + + nodes = set(G) + c = count() + + sign = 1 if minimum else -1 + + while nodes: + u = nodes.pop() + frontier = [] + visited = {u} + if is_multigraph: + for v, keydict in G.adj[u].items(): + for k, d in keydict.items(): + wt = d.get(weight, 1) * sign + if isnan(wt): + if ignore_nan: + continue + msg = f"NaN found as an edge weight. Edge {(u, v, k, d)}" + raise ValueError(msg) + heappush(frontier, (wt, next(c), u, v, k, d)) + else: + for v, d in G.adj[u].items(): + wt = d.get(weight, 1) * sign + if isnan(wt): + if ignore_nan: + continue + msg = f"NaN found as an edge weight. Edge {(u, v, d)}" + raise ValueError(msg) + heappush(frontier, (wt, next(c), u, v, d)) + while nodes and frontier: + if is_multigraph: + W, _, u, v, k, d = heappop(frontier) + else: + W, _, u, v, d = heappop(frontier) + if v in visited or v not in nodes: + continue + # Multigraphs need to handle edge keys in addition to edge data. + if is_multigraph and keys: + if data: + yield u, v, k, d + else: + yield u, v, k + else: + if data: + yield u, v, d + else: + yield u, v + # update frontier + visited.add(v) + nodes.discard(v) + if is_multigraph: + for w, keydict in G.adj[v].items(): + if w in visited: + continue + for k2, d2 in keydict.items(): + new_weight = d2.get(weight, 1) * sign + if isnan(new_weight): + if ignore_nan: + continue + msg = f"NaN found as an edge weight. Edge {(v, w, k2, d2)}" + raise ValueError(msg) + heappush(frontier, (new_weight, next(c), v, w, k2, d2)) + else: + for w, d2 in G.adj[v].items(): + if w in visited: + continue + new_weight = d2.get(weight, 1) * sign + if isnan(new_weight): + if ignore_nan: + continue + msg = f"NaN found as an edge weight. Edge {(v, w, d2)}" + raise ValueError(msg) + heappush(frontier, (new_weight, next(c), v, w, d2)) + + +ALGORITHMS = { + "boruvka": boruvka_mst_edges, + "borůvka": boruvka_mst_edges, + "kruskal": kruskal_mst_edges, + "prim": prim_mst_edges, +} + + +@not_implemented_for("directed") +@nx._dispatchable(edge_attrs="weight", preserve_edge_attrs="data") +def minimum_spanning_edges( + G, algorithm="kruskal", weight="weight", keys=True, data=True, ignore_nan=False +): + """Generate edges in a minimum spanning forest of an undirected + weighted graph. + + A minimum spanning tree is a subgraph of the graph (a tree) + with the minimum sum of edge weights. A spanning forest is a + union of the spanning trees for each connected component of the graph. + + Parameters + ---------- + G : undirected Graph + An undirected graph. If `G` is connected, then the algorithm finds a + spanning tree. Otherwise, a spanning forest is found. + + algorithm : string + The algorithm to use when finding a minimum spanning tree. Valid + choices are 'kruskal', 'prim', or 'boruvka'. The default is 'kruskal'. + + weight : string + Edge data key to use for weight (default 'weight'). + + keys : bool + Whether to yield edge key in multigraphs in addition to the edge. + If `G` is not a multigraph, this is ignored. + + data : bool, optional + If True yield the edge data along with the edge. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + Returns + ------- + edges : iterator + An iterator over edges in a maximum spanning tree of `G`. + Edges connecting nodes `u` and `v` are represented as tuples: + `(u, v, k, d)` or `(u, v, k)` or `(u, v, d)` or `(u, v)` + + If `G` is a multigraph, `keys` indicates whether the edge key `k` will + be reported in the third position in the edge tuple. `data` indicates + whether the edge datadict `d` will appear at the end of the edge tuple. + + If `G` is not a multigraph, the tuples are `(u, v, d)` if `data` is True + or `(u, v)` if `data` is False. + + Examples + -------- + >>> from networkx.algorithms import tree + + Find minimum spanning edges by Kruskal's algorithm + + >>> G = nx.cycle_graph(4) + >>> G.add_edge(0, 3, weight=2) + >>> mst = tree.minimum_spanning_edges(G, algorithm="kruskal", data=False) + >>> edgelist = list(mst) + >>> sorted(sorted(e) for e in edgelist) + [[0, 1], [1, 2], [2, 3]] + + Find minimum spanning edges by Prim's algorithm + + >>> G = nx.cycle_graph(4) + >>> G.add_edge(0, 3, weight=2) + >>> mst = tree.minimum_spanning_edges(G, algorithm="prim", data=False) + >>> edgelist = list(mst) + >>> sorted(sorted(e) for e in edgelist) + [[0, 1], [1, 2], [2, 3]] + + Notes + ----- + For Borůvka's algorithm, each edge must have a weight attribute, and + each edge weight must be distinct. + + For the other algorithms, if the graph edges do not have a weight + attribute a default weight of 1 will be used. + + Modified code from David Eppstein, April 2006 + http://www.ics.uci.edu/~eppstein/PADS/ + + """ + try: + algo = ALGORITHMS[algorithm] + except KeyError as err: + msg = f"{algorithm} is not a valid choice for an algorithm." + raise ValueError(msg) from err + + return algo( + G, minimum=True, weight=weight, keys=keys, data=data, ignore_nan=ignore_nan + ) + + +@not_implemented_for("directed") +@nx._dispatchable(edge_attrs="weight", preserve_edge_attrs="data") +def maximum_spanning_edges( + G, algorithm="kruskal", weight="weight", keys=True, data=True, ignore_nan=False +): + """Generate edges in a maximum spanning forest of an undirected + weighted graph. + + A maximum spanning tree is a subgraph of the graph (a tree) + with the maximum possible sum of edge weights. A spanning forest is a + union of the spanning trees for each connected component of the graph. + + Parameters + ---------- + G : undirected Graph + An undirected graph. If `G` is connected, then the algorithm finds a + spanning tree. Otherwise, a spanning forest is found. + + algorithm : string + The algorithm to use when finding a maximum spanning tree. Valid + choices are 'kruskal', 'prim', or 'boruvka'. The default is 'kruskal'. + + weight : string + Edge data key to use for weight (default 'weight'). + + keys : bool + Whether to yield edge key in multigraphs in addition to the edge. + If `G` is not a multigraph, this is ignored. + + data : bool, optional + If True yield the edge data along with the edge. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + Returns + ------- + edges : iterator + An iterator over edges in a maximum spanning tree of `G`. + Edges connecting nodes `u` and `v` are represented as tuples: + `(u, v, k, d)` or `(u, v, k)` or `(u, v, d)` or `(u, v)` + + If `G` is a multigraph, `keys` indicates whether the edge key `k` will + be reported in the third position in the edge tuple. `data` indicates + whether the edge datadict `d` will appear at the end of the edge tuple. + + If `G` is not a multigraph, the tuples are `(u, v, d)` if `data` is True + or `(u, v)` if `data` is False. + + Examples + -------- + >>> from networkx.algorithms import tree + + Find maximum spanning edges by Kruskal's algorithm + + >>> G = nx.cycle_graph(4) + >>> G.add_edge(0, 3, weight=2) + >>> mst = tree.maximum_spanning_edges(G, algorithm="kruskal", data=False) + >>> edgelist = list(mst) + >>> sorted(sorted(e) for e in edgelist) + [[0, 1], [0, 3], [1, 2]] + + Find maximum spanning edges by Prim's algorithm + + >>> G = nx.cycle_graph(4) + >>> G.add_edge(0, 3, weight=2) # assign weight 2 to edge 0-3 + >>> mst = tree.maximum_spanning_edges(G, algorithm="prim", data=False) + >>> edgelist = list(mst) + >>> sorted(sorted(e) for e in edgelist) + [[0, 1], [0, 3], [2, 3]] + + Notes + ----- + For Borůvka's algorithm, each edge must have a weight attribute, and + each edge weight must be distinct. + + For the other algorithms, if the graph edges do not have a weight + attribute a default weight of 1 will be used. + + Modified code from David Eppstein, April 2006 + http://www.ics.uci.edu/~eppstein/PADS/ + """ + try: + algo = ALGORITHMS[algorithm] + except KeyError as err: + msg = f"{algorithm} is not a valid choice for an algorithm." + raise ValueError(msg) from err + + return algo( + G, minimum=False, weight=weight, keys=keys, data=data, ignore_nan=ignore_nan + ) + + +@nx._dispatchable(preserve_all_attrs=True, returns_graph=True) +def minimum_spanning_tree(G, weight="weight", algorithm="kruskal", ignore_nan=False): + """Returns a minimum spanning tree or forest on an undirected graph `G`. + + Parameters + ---------- + G : undirected graph + An undirected graph. If `G` is connected, then the algorithm finds a + spanning tree. Otherwise, a spanning forest is found. + + weight : str + Data key to use for edge weights. + + algorithm : string + The algorithm to use when finding a minimum spanning tree. Valid + choices are 'kruskal', 'prim', or 'boruvka'. The default is + 'kruskal'. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + Returns + ------- + G : NetworkX Graph + A minimum spanning tree or forest. + + Examples + -------- + >>> G = nx.cycle_graph(4) + >>> G.add_edge(0, 3, weight=2) + >>> T = nx.minimum_spanning_tree(G) + >>> sorted(T.edges(data=True)) + [(0, 1, {}), (1, 2, {}), (2, 3, {})] + + + Notes + ----- + For Borůvka's algorithm, each edge must have a weight attribute, and + each edge weight must be distinct. + + For the other algorithms, if the graph edges do not have a weight + attribute a default weight of 1 will be used. + + There may be more than one tree with the same minimum or maximum weight. + See :mod:`networkx.tree.recognition` for more detailed definitions. + + Isolated nodes with self-loops are in the tree as edgeless isolated nodes. + + """ + edges = minimum_spanning_edges( + G, algorithm, weight, keys=True, data=True, ignore_nan=ignore_nan + ) + T = G.__class__() # Same graph class as G + T.graph.update(G.graph) + T.add_nodes_from(G.nodes.items()) + T.add_edges_from(edges) + return T + + +@nx._dispatchable(preserve_all_attrs=True, returns_graph=True) +def partition_spanning_tree( + G, minimum=True, weight="weight", partition="partition", ignore_nan=False +): + """ + Find a spanning tree while respecting a partition of edges. + + Edges can be flagged as either `INCLUDED` which are required to be in the + returned tree, `EXCLUDED`, which cannot be in the returned tree and `OPEN`. + + This is used in the SpanningTreeIterator to create new partitions following + the algorithm of Sörensen and Janssens [1]_. + + Parameters + ---------- + G : undirected graph + An undirected graph. + + minimum : bool (default: True) + Determines whether the returned tree is the minimum spanning tree of + the partition of the maximum one. + + weight : str + Data key to use for edge weights. + + partition : str + The key for the edge attribute containing the partition + data on the graph. Edges can be included, excluded or open using the + `EdgePartition` enum. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + + Returns + ------- + G : NetworkX Graph + A minimum spanning tree using all of the included edges in the graph and + none of the excluded edges. + + References + ---------- + .. [1] G.K. Janssens, K. Sörensen, An algorithm to generate all spanning + trees in order of increasing cost, Pesquisa Operacional, 2005-08, + Vol. 25 (2), p. 219-229, + https://www.scielo.br/j/pope/a/XHswBwRwJyrfL88dmMwYNWp/?lang=en + """ + edges = kruskal_mst_edges( + G, + minimum, + weight, + keys=True, + data=True, + ignore_nan=ignore_nan, + partition=partition, + ) + T = G.__class__() # Same graph class as G + T.graph.update(G.graph) + T.add_nodes_from(G.nodes.items()) + T.add_edges_from(edges) + return T + + +@nx._dispatchable(preserve_all_attrs=True, returns_graph=True) +def maximum_spanning_tree(G, weight="weight", algorithm="kruskal", ignore_nan=False): + """Returns a maximum spanning tree or forest on an undirected graph `G`. + + Parameters + ---------- + G : undirected graph + An undirected graph. If `G` is connected, then the algorithm finds a + spanning tree. Otherwise, a spanning forest is found. + + weight : str + Data key to use for edge weights. + + algorithm : string + The algorithm to use when finding a maximum spanning tree. Valid + choices are 'kruskal', 'prim', or 'boruvka'. The default is + 'kruskal'. + + ignore_nan : bool (default: False) + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + + + Returns + ------- + G : NetworkX Graph + A maximum spanning tree or forest. + + + Examples + -------- + >>> G = nx.cycle_graph(4) + >>> G.add_edge(0, 3, weight=2) + >>> T = nx.maximum_spanning_tree(G) + >>> sorted(T.edges(data=True)) + [(0, 1, {}), (0, 3, {'weight': 2}), (1, 2, {})] + + + Notes + ----- + For Borůvka's algorithm, each edge must have a weight attribute, and + each edge weight must be distinct. + + For the other algorithms, if the graph edges do not have a weight + attribute a default weight of 1 will be used. + + There may be more than one tree with the same minimum or maximum weight. + See :mod:`networkx.tree.recognition` for more detailed definitions. + + Isolated nodes with self-loops are in the tree as edgeless isolated nodes. + + """ + edges = maximum_spanning_edges( + G, algorithm, weight, keys=True, data=True, ignore_nan=ignore_nan + ) + edges = list(edges) + T = G.__class__() # Same graph class as G + T.graph.update(G.graph) + T.add_nodes_from(G.nodes.items()) + T.add_edges_from(edges) + return T + + +@py_random_state(3) +@nx._dispatchable(preserve_edge_attrs=True, returns_graph=True) +def random_spanning_tree(G, weight=None, *, multiplicative=True, seed=None): + """ + Sample a random spanning tree using the edges weights of `G`. + + This function supports two different methods for determining the + probability of the graph. If ``multiplicative=True``, the probability + is based on the product of edge weights, and if ``multiplicative=False`` + it is based on the sum of the edge weight. However, since it is + easier to determine the total weight of all spanning trees for the + multiplicative version, that is significantly faster and should be used if + possible. Additionally, setting `weight` to `None` will cause a spanning tree + to be selected with uniform probability. + + The function uses algorithm A8 in [1]_ . + + Parameters + ---------- + G : nx.Graph + An undirected version of the original graph. + + weight : string + The edge key for the edge attribute holding edge weight. + + multiplicative : bool, default=True + If `True`, the probability of each tree is the product of its edge weight + over the sum of the product of all the spanning trees in the graph. If + `False`, the probability is the sum of its edge weight over the sum of + the sum of weights for all spanning trees in the graph. + + seed : integer, random_state, or None (default) + Indicator of random number generation state. + See :ref:`Randomness`. + + Returns + ------- + nx.Graph + A spanning tree using the distribution defined by the weight of the tree. + + References + ---------- + .. [1] V. Kulkarni, Generating random combinatorial objects, Journal of + Algorithms, 11 (1990), pp. 185–207 + """ + + def find_node(merged_nodes, node): + """ + We can think of clusters of contracted nodes as having one + representative in the graph. Each node which is not in merged_nodes + is still its own representative. Since a representative can be later + contracted, we need to recursively search though the dict to find + the final representative, but once we know it we can use path + compression to speed up the access of the representative for next time. + + This cannot be replaced by the standard NetworkX union_find since that + data structure will merge nodes with less representing nodes into the + one with more representing nodes but this function requires we merge + them using the order that contract_edges contracts using. + + Parameters + ---------- + merged_nodes : dict + The dict storing the mapping from node to representative + node + The node whose representative we seek + + Returns + ------- + The representative of the `node` + """ + if node not in merged_nodes: + return node + else: + rep = find_node(merged_nodes, merged_nodes[node]) + merged_nodes[node] = rep + return rep + + def prepare_graph(): + """ + For the graph `G`, remove all edges not in the set `V` and then + contract all edges in the set `U`. + + Returns + ------- + A copy of `G` which has had all edges not in `V` removed and all edges + in `U` contracted. + """ + + # The result is a MultiGraph version of G so that parallel edges are + # allowed during edge contraction + result = nx.MultiGraph(incoming_graph_data=G) + + # Remove all edges not in V + edges_to_remove = set(result.edges()).difference(V) + result.remove_edges_from(edges_to_remove) + + # Contract all edges in U + # + # Imagine that you have two edges to contract and they share an + # endpoint like this: + # [0] ----- [1] ----- [2] + # If we contract (0, 1) first, the contraction function will always + # delete the second node it is passed so the resulting graph would be + # [0] ----- [2] + # and edge (1, 2) no longer exists but (0, 2) would need to be contracted + # in its place now. That is why I use the below dict as a merge-find + # data structure with path compression to track how the nodes are merged. + merged_nodes = {} + + for u, v in U: + u_rep = find_node(merged_nodes, u) + v_rep = find_node(merged_nodes, v) + # We cannot contract a node with itself + if u_rep == v_rep: + continue + nx.contracted_nodes(result, u_rep, v_rep, self_loops=False, copy=False) + merged_nodes[v_rep] = u_rep + + return merged_nodes, result + + def spanning_tree_total_weight(G, weight): + """ + Find the sum of weights of the spanning trees of `G` using the + appropriate `method`. + + This is easy if the chosen method is 'multiplicative', since we can + use Kirchhoff's Tree Matrix Theorem directly. However, with the + 'additive' method, this process is slightly more complex and less + computationally efficient as we have to find the number of spanning + trees which contain each possible edge in the graph. + + Parameters + ---------- + G : NetworkX Graph + The graph to find the total weight of all spanning trees on. + + weight : string + The key for the weight edge attribute of the graph. + + Returns + ------- + float + The sum of either the multiplicative or additive weight for all + spanning trees in the graph. + """ + if multiplicative: + return number_of_spanning_trees(G, weight=weight) + else: + # There are two cases for the total spanning tree additive weight. + # 1. There is one edge in the graph. Then the only spanning tree is + # that edge itself, which will have a total weight of that edge + # itself. + if G.number_of_edges() == 1: + return G.edges(data=weight).__iter__().__next__()[2] + # 2. There are no edges or two or more edges in the graph. Then, we find the + # total weight of the spanning trees using the formula in the + # reference paper: take the weight of each edge and multiply it by + # the number of spanning trees which include that edge. This + # can be accomplished by contracting the edge and finding the + # multiplicative total spanning tree weight if the weight of each edge + # is assumed to be 1, which is conveniently built into networkx already, + # by calling number_of_spanning_trees with weight=None. + # Note that with no edges the returned value is just zero. + else: + total = 0 + for u, v, w in G.edges(data=weight): + total += w * nx.number_of_spanning_trees( + nx.contracted_edge(G, edge=(u, v), self_loops=False), + weight=None, + ) + return total + + if G.number_of_nodes() < 2: + # no edges in the spanning tree + return nx.empty_graph(G.nodes) + + U = set() + st_cached_value = 0 + V = set(G.edges()) + shuffled_edges = list(G.edges()) + seed.shuffle(shuffled_edges) + + for u, v in shuffled_edges: + e_weight = G[u][v][weight] if weight is not None else 1 + node_map, prepared_G = prepare_graph() + G_total_tree_weight = spanning_tree_total_weight(prepared_G, weight) + # Add the edge to U so that we can compute the total tree weight + # assuming we include that edge + # Now, if (u, v) cannot exist in G because it is fully contracted out + # of existence, then it by definition cannot influence G_e's Kirchhoff + # value. But, we also cannot pick it. + rep_edge = (find_node(node_map, u), find_node(node_map, v)) + # Check to see if the 'representative edge' for the current edge is + # in prepared_G. If so, then we can pick it. + if rep_edge in prepared_G.edges: + prepared_G_e = nx.contracted_edge( + prepared_G, edge=rep_edge, self_loops=False + ) + G_e_total_tree_weight = spanning_tree_total_weight(prepared_G_e, weight) + if multiplicative: + threshold = e_weight * G_e_total_tree_weight / G_total_tree_weight + else: + numerator = (st_cached_value + e_weight) * nx.number_of_spanning_trees( + prepared_G_e + ) + G_e_total_tree_weight + denominator = ( + st_cached_value * nx.number_of_spanning_trees(prepared_G) + + G_total_tree_weight + ) + threshold = numerator / denominator + else: + threshold = 0.0 + z = seed.uniform(0.0, 1.0) + if z > threshold: + # Remove the edge from V since we did not pick it. + V.remove((u, v)) + else: + # Add the edge to U since we picked it. + st_cached_value += e_weight + U.add((u, v)) + # If we decide to keep an edge, it may complete the spanning tree. + if len(U) == G.number_of_nodes() - 1: + spanning_tree = nx.Graph() + spanning_tree.add_edges_from(U) + return spanning_tree + raise Exception(f"Something went wrong! Only {len(U)} edges in the spanning tree!") + + +class SpanningTreeIterator: + """ + Iterate over all spanning trees of a graph in either increasing or + decreasing cost. + + Notes + ----- + This iterator uses the partition scheme from [1]_ (included edges, + excluded edges and open edges) as well as a modified Kruskal's Algorithm + to generate minimum spanning trees which respect the partition of edges. + For spanning trees with the same weight, ties are broken arbitrarily. + + References + ---------- + .. [1] G.K. Janssens, K. Sörensen, An algorithm to generate all spanning + trees in order of increasing cost, Pesquisa Operacional, 2005-08, + Vol. 25 (2), p. 219-229, + https://www.scielo.br/j/pope/a/XHswBwRwJyrfL88dmMwYNWp/?lang=en + """ + + @dataclass(order=True) + class Partition: + """ + This dataclass represents a partition and stores a dict with the edge + data and the weight of the minimum spanning tree of the partition dict. + """ + + mst_weight: float + partition_dict: dict = field(compare=False) + + def __copy__(self): + return SpanningTreeIterator.Partition( + self.mst_weight, self.partition_dict.copy() + ) + + def __init__(self, G, weight="weight", minimum=True, ignore_nan=False): + """ + Initialize the iterator + + Parameters + ---------- + G : nx.Graph + The directed graph which we need to iterate trees over + + weight : String, default = "weight" + The edge attribute used to store the weight of the edge + + minimum : bool, default = True + Return the trees in increasing order while true and decreasing order + while false. + + ignore_nan : bool, default = False + If a NaN is found as an edge weight normally an exception is raised. + If `ignore_nan is True` then that edge is ignored instead. + """ + self.G = G.copy() + self.G.__networkx_cache__ = None # Disable caching + self.weight = weight + self.minimum = minimum + self.ignore_nan = ignore_nan + # Randomly create a key for an edge attribute to hold the partition data + self.partition_key = ( + "SpanningTreeIterators super secret partition attribute name" + ) + + def __iter__(self): + """ + Returns + ------- + SpanningTreeIterator + The iterator object for this graph + """ + self.partition_queue = PriorityQueue() + self._clear_partition(self.G) + mst_weight = partition_spanning_tree( + self.G, self.minimum, self.weight, self.partition_key, self.ignore_nan + ).size(weight=self.weight) + + self.partition_queue.put( + self.Partition(mst_weight if self.minimum else -mst_weight, {}) + ) + + return self + + def __next__(self): + """ + Returns + ------- + (multi)Graph + The spanning tree of next greatest weight, which ties broken + arbitrarily. + """ + if self.partition_queue.empty(): + del self.G, self.partition_queue + raise StopIteration + + partition = self.partition_queue.get() + self._write_partition(partition) + next_tree = partition_spanning_tree( + self.G, self.minimum, self.weight, self.partition_key, self.ignore_nan + ) + self._partition(partition, next_tree) + + self._clear_partition(next_tree) + return next_tree + + def _partition(self, partition, partition_tree): + """ + Create new partitions based of the minimum spanning tree of the + current minimum partition. + + Parameters + ---------- + partition : Partition + The Partition instance used to generate the current minimum spanning + tree. + partition_tree : nx.Graph + The minimum spanning tree of the input partition. + """ + # create two new partitions with the data from the input partition dict + p1 = self.Partition(0, partition.partition_dict.copy()) + p2 = self.Partition(0, partition.partition_dict.copy()) + for e in partition_tree.edges: + # determine if the edge was open or included + if e not in partition.partition_dict: + # This is an open edge + p1.partition_dict[e] = EdgePartition.EXCLUDED + p2.partition_dict[e] = EdgePartition.INCLUDED + + self._write_partition(p1) + p1_mst = partition_spanning_tree( + self.G, + self.minimum, + self.weight, + self.partition_key, + self.ignore_nan, + ) + p1_mst_weight = p1_mst.size(weight=self.weight) + if nx.is_connected(p1_mst): + p1.mst_weight = p1_mst_weight if self.minimum else -p1_mst_weight + self.partition_queue.put(p1.__copy__()) + p1.partition_dict = p2.partition_dict.copy() + + def _write_partition(self, partition): + """ + Writes the desired partition into the graph to calculate the minimum + spanning tree. + + Parameters + ---------- + partition : Partition + A Partition dataclass describing a partition on the edges of the + graph. + """ + + partition_dict = partition.partition_dict + partition_key = self.partition_key + G = self.G + + edges = ( + G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True) + ) + for *e, d in edges: + d[partition_key] = partition_dict.get(tuple(e), EdgePartition.OPEN) + + def _clear_partition(self, G): + """ + Removes partition data from the graph + """ + partition_key = self.partition_key + edges = ( + G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True) + ) + for *e, d in edges: + if partition_key in d: + del d[partition_key] + + +@nx._dispatchable(edge_attrs="weight") +def number_of_spanning_trees(G, *, root=None, weight=None): + """Returns the number of spanning trees in `G`. + + A spanning tree for an undirected graph is a tree that connects + all nodes in the graph. For a directed graph, the analog of a + spanning tree is called a (spanning) arborescence. The arborescence + includes a unique directed path from the `root` node to each other node. + The graph must be weakly connected, and the root must be a node + that includes all nodes as successors [3]_. Note that to avoid + discussing sink-roots and reverse-arborescences, we have reversed + the edge orientation from [3]_ and use the in-degree laplacian. + + This function (when `weight` is `None`) returns the number of + spanning trees for an undirected graph and the number of + arborescences from a single root node for a directed graph. + When `weight` is the name of an edge attribute which holds the + weight value of each edge, the function returns the sum over + all trees of the multiplicative weight of each tree. That is, + the weight of the tree is the product of its edge weights. + + Kirchoff's Tree Matrix Theorem states that any cofactor of the + Laplacian matrix of a graph is the number of spanning trees in the + graph. (Here we use cofactors for a diagonal entry so that the + cofactor becomes the determinant of the matrix with one row + and its matching column removed.) For a weighted Laplacian matrix, + the cofactor is the sum across all spanning trees of the + multiplicative weight of each tree. That is, the weight of each + tree is the product of its edge weights. The theorem is also + known as Kirchhoff's theorem [1]_ and the Matrix-Tree theorem [2]_. + + For directed graphs, a similar theorem (Tutte's Theorem) holds with + the cofactor chosen to be the one with row and column removed that + correspond to the root. The cofactor is the number of arborescences + with the specified node as root. And the weighted version gives the + sum of the arborescence weights with root `root`. The arborescence + weight is the product of its edge weights. + + Parameters + ---------- + G : NetworkX graph + + root : node + A node in the directed graph `G` that has all nodes as descendants. + (This is ignored for undirected graphs.) + + weight : string or None, optional (default=None) + The name of the edge attribute holding the edge weight. + If `None`, then each edge is assumed to have a weight of 1. + + Returns + ------- + Number + Undirected graphs: + The number of spanning trees of the graph `G`. + Or the sum of all spanning tree weights of the graph `G` + where the weight of a tree is the product of its edge weights. + Directed graphs: + The number of arborescences of `G` rooted at node `root`. + Or the sum of all arborescence weights of the graph `G` with + specified root where the weight of an arborescence is the product + of its edge weights. + + Raises + ------ + NetworkXPointlessConcept + If `G` does not contain any nodes. + + NetworkXError + If the graph `G` is directed and the root node + is not specified or is not in G. + + Examples + -------- + >>> G = nx.complete_graph(5) + >>> round(nx.number_of_spanning_trees(G)) + 125 + + >>> G = nx.Graph() + >>> G.add_edge(1, 2, weight=2) + >>> G.add_edge(1, 3, weight=1) + >>> G.add_edge(2, 3, weight=1) + >>> round(nx.number_of_spanning_trees(G, weight="weight")) + 5 + + Notes + ----- + Self-loops are excluded. Multi-edges are contracted in one edge + equal to the sum of the weights. + + References + ---------- + .. [1] Wikipedia + "Kirchhoff's theorem." + https://en.wikipedia.org/wiki/Kirchhoff%27s_theorem + .. [2] Kirchhoff, G. R. + Über die Auflösung der Gleichungen, auf welche man + bei der Untersuchung der linearen Vertheilung + Galvanischer Ströme geführt wird + Annalen der Physik und Chemie, vol. 72, pp. 497-508, 1847. + .. [3] Margoliash, J. + "Matrix-Tree Theorem for Directed Graphs" + https://www.math.uchicago.edu/~may/VIGRE/VIGRE2010/REUPapers/Margoliash.pdf + """ + import numpy as np + + if len(G) == 0: + raise nx.NetworkXPointlessConcept("Graph G must contain at least one node.") + + # undirected G + if not nx.is_directed(G): + if not nx.is_connected(G): + return 0 + G_laplacian = nx.laplacian_matrix(G, weight=weight).toarray() + return float(np.linalg.det(G_laplacian[1:, 1:])) + + # directed G + if root is None: + raise nx.NetworkXError("Input `root` must be provided when G is directed") + if root not in G: + raise nx.NetworkXError("The node root is not in the graph G.") + if not nx.is_weakly_connected(G): + return 0 + + # Compute directed Laplacian matrix + nodelist = [root] + [n for n in G if n != root] + A = nx.adjacency_matrix(G, nodelist=nodelist, weight=weight) + D = np.diag(A.sum(axis=0)) + G_laplacian = D - A + + # Compute number of spanning trees + return float(np.linalg.det(G_laplacian[1:, 1:])) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/operations.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/operations.py new file mode 100644 index 0000000000000000000000000000000000000000..a75770a5968eaff5f5975d23a6af4557b52e6e7e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/operations.py @@ -0,0 +1,106 @@ +"""Operations on trees.""" + +from functools import partial +from itertools import accumulate, chain + +import networkx as nx + +__all__ = ["join_trees"] + + +# Argument types don't match dispatching, but allow manual selection of backend +@nx._dispatchable(graphs=None, returns_graph=True) +def join_trees(rooted_trees, *, label_attribute=None, first_label=0): + """Returns a new rooted tree made by joining `rooted_trees` + + Constructs a new tree by joining each tree in `rooted_trees`. + A new root node is added and connected to each of the roots + of the input trees. While copying the nodes from the trees, + relabeling to integers occurs. If the `label_attribute` is provided, + the old node labels will be stored in the new tree under this attribute. + + Parameters + ---------- + rooted_trees : list + A list of pairs in which each left element is a NetworkX graph + object representing a tree and each right element is the root + node of that tree. The nodes of these trees will be relabeled to + integers. + + label_attribute : str + If provided, the old node labels will be stored in the new tree + under this node attribute. If not provided, the original labels + of the nodes in the input trees are not stored. + + first_label : int, optional (default=0) + Specifies the label for the new root node. If provided, the root node of the joined tree + will have this label. If not provided, the root node will default to a label of 0. + + Returns + ------- + NetworkX graph + The rooted tree resulting from joining the provided `rooted_trees`. The new tree has a root node + labeled as specified by `first_label` (defaulting to 0 if not provided). Subtrees from the input + `rooted_trees` are attached to this new root node. Each non-root node, if the `label_attribute` + is provided, has an attribute that indicates the original label of the node in the input tree. + + Notes + ----- + Trees are stored in NetworkX as NetworkX Graphs. There is no specific + enforcement of the fact that these are trees. Testing for each tree + can be done using :func:`networkx.is_tree`. + + Graph, edge, and node attributes are propagated from the given + rooted trees to the created tree. If there are any overlapping graph + attributes, those from later trees will overwrite those from earlier + trees in the tuple of positional arguments. + + Examples + -------- + Join two full balanced binary trees of height *h* to get a full + balanced binary tree of depth *h* + 1:: + + >>> h = 4 + >>> left = nx.balanced_tree(2, h) + >>> right = nx.balanced_tree(2, h) + >>> joined_tree = nx.join_trees([(left, 0), (right, 0)]) + >>> nx.is_isomorphic(joined_tree, nx.balanced_tree(2, h + 1)) + True + + """ + if not rooted_trees: + return nx.empty_graph(1) + + # Unzip the zipped list of (tree, root) pairs. + trees, roots = zip(*rooted_trees) + + # The join of the trees has the same type as the type of the first tree. + R = type(trees[0])() + + lengths = (len(tree) for tree in trees[:-1]) + first_labels = list(accumulate(lengths, initial=first_label + 1)) + + new_roots = [] + for tree, root, first_node in zip(trees, roots, first_labels): + new_root = first_node + list(tree.nodes()).index(root) + new_roots.append(new_root) + + # Relabel the nodes so that their union is the integers starting at first_label. + relabel = partial( + nx.convert_node_labels_to_integers, label_attribute=label_attribute + ) + new_trees = [ + relabel(tree, first_label=first_label) + for tree, first_label in zip(trees, first_labels) + ] + + # Add all sets of nodes and edges, attributes + for tree in new_trees: + R.update(tree) + + # Finally, join the subtrees at the root. We know first_label is unused by + # the way we relabeled the subtrees. + R.add_node(first_label) + R.add_edges_from((first_label, root) for root in new_roots) + + return R diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/recognition.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/recognition.py new file mode 100644 index 0000000000000000000000000000000000000000..a9eae98707a6889213ff8b93887c481ba59215a0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/networkx/algorithms/tree/recognition.py @@ -0,0 +1,273 @@ +""" +Recognition Tests +================= + +A *forest* is an acyclic, undirected graph, and a *tree* is a connected forest. +Depending on the subfield, there are various conventions for generalizing these +definitions to directed graphs. + +In one convention, directed variants of forest and tree are defined in an +identical manner, except that the direction of the edges is ignored. In effect, +each directed edge is treated as a single undirected edge. Then, additional +restrictions are imposed to define *branchings* and *arborescences*. + +In another convention, directed variants of forest and tree correspond to +the previous convention's branchings and arborescences, respectively. Then two +new terms, *polyforest* and *polytree*, are defined to correspond to the other +convention's forest and tree. + +Summarizing:: + + +-----------------------------+ + | Convention A | Convention B | + +=============================+ + | forest | polyforest | + | tree | polytree | + | branching | forest | + | arborescence | tree | + +-----------------------------+ + +Each convention has its reasons. The first convention emphasizes definitional +similarity in that directed forests and trees are only concerned with +acyclicity and do not have an in-degree constraint, just as their undirected +counterparts do not. The second convention emphasizes functional similarity +in the sense that the directed analog of a spanning tree is a spanning +arborescence. That is, take any spanning tree and choose one node as the root. +Then every edge is assigned a direction such there is a directed path from the +root to every other node. The result is a spanning arborescence. + +NetworkX follows convention "A". Explicitly, these are: + +undirected forest + An undirected graph with no undirected cycles. + +undirected tree + A connected, undirected forest. + +directed forest + A directed graph with no undirected cycles. Equivalently, the underlying + graph structure (which ignores edge orientations) is an undirected forest. + In convention B, this is known as a polyforest. + +directed tree + A weakly connected, directed forest. Equivalently, the underlying graph + structure (which ignores edge orientations) is an undirected tree. In + convention B, this is known as a polytree. + +branching + A directed forest with each node having, at most, one parent. So the maximum + in-degree is equal to 1. In convention B, this is known as a forest. + +arborescence + A directed tree with each node having, at most, one parent. So the maximum + in-degree is equal to 1. In convention B, this is known as a tree. + +For trees and arborescences, the adjective "spanning" may be added to designate +that the graph, when considered as a forest/branching, consists of a single +tree/arborescence that includes all nodes in the graph. It is true, by +definition, that every tree/arborescence is spanning with respect to the nodes +that define the tree/arborescence and so, it might seem redundant to introduce +the notion of "spanning". However, the nodes may represent a subset of +nodes from a larger graph, and it is in this context that the term "spanning" +becomes a useful notion. + +""" + +import networkx as nx + +__all__ = ["is_arborescence", "is_branching", "is_forest", "is_tree"] + + +@nx.utils.not_implemented_for("undirected") +@nx._dispatchable +def is_arborescence(G): + """ + Returns True if `G` is an arborescence. + + An arborescence is a directed tree with maximum in-degree equal to 1. + + Parameters + ---------- + G : graph + The graph to test. + + Returns + ------- + b : bool + A boolean that is True if `G` is an arborescence. + + Examples + -------- + >>> G = nx.DiGraph([(0, 1), (0, 2), (2, 3), (3, 4)]) + >>> nx.is_arborescence(G) + True + >>> G.remove_edge(0, 1) + >>> G.add_edge(1, 2) # maximum in-degree is 2 + >>> nx.is_arborescence(G) + False + + Notes + ----- + In another convention, an arborescence is known as a *tree*. + + See Also + -------- + is_tree + + """ + return is_tree(G) and max(d for n, d in G.in_degree()) <= 1 + + +@nx.utils.not_implemented_for("undirected") +@nx._dispatchable +def is_branching(G): + """ + Returns True if `G` is a branching. + + A branching is a directed forest with maximum in-degree equal to 1. + + Parameters + ---------- + G : directed graph + The directed graph to test. + + Returns + ------- + b : bool + A boolean that is True if `G` is a branching. + + Examples + -------- + >>> G = nx.DiGraph([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> nx.is_branching(G) + True + >>> G.remove_edge(2, 3) + >>> G.add_edge(3, 1) # maximum in-degree is 2 + >>> nx.is_branching(G) + False + + Notes + ----- + In another convention, a branching is also known as a *forest*. + + See Also + -------- + is_forest + + """ + return is_forest(G) and max(d for n, d in G.in_degree()) <= 1 + + +@nx._dispatchable +def is_forest(G): + """ + Returns True if `G` is a forest. + + A forest is a graph with no undirected cycles. + + For directed graphs, `G` is a forest if the underlying graph is a forest. + The underlying graph is obtained by treating each directed edge as a single + undirected edge in a multigraph. + + Parameters + ---------- + G : graph + The graph to test. + + Returns + ------- + b : bool + A boolean that is True if `G` is a forest. + + Raises + ------ + NetworkXPointlessConcept + If `G` is empty. + + Examples + -------- + >>> G = nx.Graph() + >>> G.add_edges_from([(1, 2), (1, 3), (2, 4), (2, 5)]) + >>> nx.is_forest(G) + True + >>> G.add_edge(4, 1) + >>> nx.is_forest(G) + False + + Notes + ----- + In another convention, a directed forest is known as a *polyforest* and + then *forest* corresponds to a *branching*. + + See Also + -------- + is_branching + + """ + if len(G) == 0: + raise nx.exception.NetworkXPointlessConcept("G has no nodes.") + + if G.is_directed(): + components = (G.subgraph(c) for c in nx.weakly_connected_components(G)) + else: + components = (G.subgraph(c) for c in nx.connected_components(G)) + + return all(len(c) - 1 == c.number_of_edges() for c in components) + + +@nx._dispatchable +def is_tree(G): + """ + Returns True if `G` is a tree. + + A tree is a connected graph with no undirected cycles. + + For directed graphs, `G` is a tree if the underlying graph is a tree. The + underlying graph is obtained by treating each directed edge as a single + undirected edge in a multigraph. + + Parameters + ---------- + G : graph + The graph to test. + + Returns + ------- + b : bool + A boolean that is True if `G` is a tree. + + Raises + ------ + NetworkXPointlessConcept + If `G` is empty. + + Examples + -------- + >>> G = nx.Graph() + >>> G.add_edges_from([(1, 2), (1, 3), (2, 4), (2, 5)]) + >>> nx.is_tree(G) # n-1 edges + True + >>> G.add_edge(3, 4) + >>> nx.is_tree(G) # n edges + False + + Notes + ----- + In another convention, a directed tree is known as a *polytree* and then + *tree* corresponds to an *arborescence*. + + See Also + -------- + is_arborescence + + """ + if len(G) == 0: + raise nx.exception.NetworkXPointlessConcept("G has no nodes.") + + if G.is_directed(): + is_connected = nx.is_weakly_connected + else: + is_connected = nx.is_connected + + # A connected graph with no cycles has n-1 edges. + return len(G) - 1 == G.number_of_edges() and is_connected(G) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/altclip/__pycache__/configuration_altclip.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/altclip/__pycache__/configuration_altclip.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f28fd37a2cbc9c244afd7536eed1d6756a1e20fe Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/altclip/__pycache__/configuration_altclip.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/altclip/__pycache__/modeling_altclip.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/altclip/__pycache__/modeling_altclip.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6eb59c3d6fa063f54a9de1273eaa23abde6622c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/altclip/__pycache__/modeling_altclip.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/audio_spectrogram_transformer/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/audio_spectrogram_transformer/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53e55d635cf067dad92b1e907da5affc09fd2c49 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/audio_spectrogram_transformer/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/audio_spectrogram_transformer/__pycache__/configuration_audio_spectrogram_transformer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/audio_spectrogram_transformer/__pycache__/configuration_audio_spectrogram_transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..228d3137a734c5a3887c211d5b50e69a72723d38 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/audio_spectrogram_transformer/__pycache__/configuration_audio_spectrogram_transformer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/audio_spectrogram_transformer/__pycache__/feature_extraction_audio_spectrogram_transformer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/audio_spectrogram_transformer/__pycache__/feature_extraction_audio_spectrogram_transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9993e3e3f0b28d5c84563020feca4d5e16257a9a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/audio_spectrogram_transformer/__pycache__/feature_extraction_audio_spectrogram_transformer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/audio_spectrogram_transformer/__pycache__/modeling_audio_spectrogram_transformer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/audio_spectrogram_transformer/__pycache__/modeling_audio_spectrogram_transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2b235d95a2d9d0c20c85e77e73e29c356cd17d9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/audio_spectrogram_transformer/__pycache__/modeling_audio_spectrogram_transformer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..edfeb4dbe75bb53c011719d6c550b245ac814b28 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_bit import * + from .image_processing_bit import * + from .image_processing_bit_fast import * + from .modeling_bit import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/configuration_bit.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/configuration_bit.py new file mode 100644 index 0000000000000000000000000000000000000000..2b1f24fa0688fe8a62748e8fac58cf62a2b176d0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/configuration_bit.py @@ -0,0 +1,136 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BiT model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class BitConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BitModel`]. It is used to instantiate an BiT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the BiT + [google/bit-50](https://huggingface.co/google/bit-50) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embedding_size (`int`, *optional*, defaults to 64): + Dimensionality (hidden size) for the embedding layer. + hidden_sizes (`list[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`): + Dimensionality (hidden size) at each stage. + depths (`list[int]`, *optional*, defaults to `[3, 4, 6, 3]`): + Depth (number of layers) for each stage. + layer_type (`str`, *optional*, defaults to `"preactivation"`): + The layer to use, it can be either `"preactivation"` or `"bottleneck"`. + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` + are supported. + global_padding (`str`, *optional*): + Padding strategy to use for the convolutional layers. Can be either `"valid"`, `"same"`, or `None`. + num_groups (`int`, *optional*, defaults to 32): + Number of groups used for the `BitGroupNormActivation` layers. + drop_path_rate (`float`, *optional*, defaults to 0.0): + The drop path rate for the stochastic depth. + embedding_dynamic_padding (`bool`, *optional*, defaults to `False`): + Whether or not to make use of dynamic padding for the embedding layer. + output_stride (`int`, *optional*, defaults to 32): + The output stride of the model. + width_factor (`int`, *optional*, defaults to 1): + The width factor for the model. + out_features (`list[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`list[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + ```python + >>> from transformers import BitConfig, BitModel + + >>> # Initializing a BiT bit-50 style configuration + >>> configuration = BitConfig() + + >>> # Initializing a model (with random weights) from the bit-50 style configuration + >>> model = BitModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "bit" + layer_types = ["preactivation", "bottleneck"] + supported_padding = ["SAME", "VALID"] + + def __init__( + self, + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="preactivation", + hidden_act="relu", + global_padding=None, + num_groups=32, + drop_path_rate=0.0, + embedding_dynamic_padding=False, + output_stride=32, + width_factor=1, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + if layer_type not in self.layer_types: + raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}") + if global_padding is not None: + if global_padding.upper() in self.supported_padding: + global_padding = global_padding.upper() + else: + raise ValueError(f"Padding strategy {global_padding} not supported") + self.num_channels = num_channels + self.embedding_size = embedding_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.layer_type = layer_type + self.hidden_act = hidden_act + self.global_padding = global_padding + self.num_groups = num_groups + self.drop_path_rate = drop_path_rate + self.embedding_dynamic_padding = embedding_dynamic_padding + self.output_stride = output_stride + self.width_factor = width_factor + + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +__all__ = ["BitConfig"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/image_processing_bit.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/image_processing_bit.py new file mode 100644 index 0000000000000000000000000000000000000000..3d32752edca86aa5105edbbc7dc566d71f79ab8f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/image_processing_bit.py @@ -0,0 +1,324 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for BiT.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +class BitImageProcessor(BaseImageProcessor): + r""" + Constructs a BiT image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize: + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Optional[dict[str, int]] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + + # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + all_images = [] + for image in images: + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + all_images.append(image) + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in all_images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["BitImageProcessor"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/image_processing_bit_fast.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/image_processing_bit_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..46c0e94f3d190b0b5157aa3e2a3ba7c38f3f41cc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/image_processing_bit_fast.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for BiT.""" + +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling +from ...utils import auto_docstring + + +@auto_docstring +class BitImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"shortest_edge": 224} + default_to_square = False + crop_size = {"height": 224, "width": 224} + rescale_factor = 1 / 255 + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + + +__all__ = ["BitImageProcessorFast"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/modeling_bit.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/modeling_bit.py new file mode 100644 index 0000000000000000000000000000000000000000..1e491f06eae6c8bb77de28c372aeebfc7795966f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bit/modeling_bit.py @@ -0,0 +1,821 @@ +# coding=utf-8 +# Copyright 2022 Google AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BiT model. Also supports backbone for ViT hybrid.""" + +import collections +import math +from typing import Optional + +import numpy as np +import torch +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BackboneOutput, + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging +from ...utils.backbone_utils import BackboneMixin +from .configuration_bit import BitConfig + + +logger = logging.get_logger(__name__) + + +def get_padding_value(padding=None, kernel_size=7, stride=1, dilation=1) -> tuple[tuple, bool]: + r""" + Utility function to get the tuple padding value given the kernel_size and padding. + + Args: + padding (Union[`str`, `int`], *optional*): + Padding value, can be either `"same"`, `"valid"`. If a different value is provided the default padding from + PyTorch is used. + kernel_size (`int`, *optional*, defaults to 7): + Kernel size of the convolution layers. + stride (`int`, *optional*, defaults to 1): + Stride value of the convolution layers. + dilation (`int`, *optional*, defaults to 1): + Dilation value of the convolution layers. + """ + dynamic = False + if padding is None: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding, dynamic + + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == "same": + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0: + # static case, no extra overhead + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == "valid": + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding, dynamic + + +class WeightStandardizedConv2d(nn.Conv2d): + """Conv2d with Weight Standardization. Includes TensorFlow compatible SAME padding. Used for ViT Hybrid model. + + Paper: [Micro-Batch Training with Batch-Channel Normalization and Weight + Standardization](https://huggingface.co/papers/1903.10520v2) + """ + + def __init__( + self, + in_channel, + out_channels, + kernel_size, + stride=1, + padding="SAME", + dilation=1, + groups=1, + bias=False, + eps=1e-6, + ): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) + super().__init__( + in_channel, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + if is_dynamic: + self.pad = DynamicPad2d(kernel_size, stride, dilation) + else: + self.pad = None + self.eps = eps + + def forward(self, hidden_state): + if self.pad is not None: + hidden_state = self.pad(hidden_state) + weight = nn.functional.batch_norm( + self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0.0, eps=self.eps + ).reshape_as(self.weight) + hidden_state = nn.functional.conv2d( + hidden_state, weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + return hidden_state + + +class BitGroupNormActivation(nn.GroupNorm): + r""" + A module that combines group normalization with an activation function. + """ + + def __init__(self, config, num_channels, eps=1e-5, affine=True, apply_activation=True): + super().__init__(config.num_groups, num_channels, eps=eps, affine=affine) + if apply_activation: + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = nn.Identity() + + def forward(self, hidden_state): + hidden_state = nn.functional.group_norm(hidden_state, self.num_groups, self.weight, self.bias, self.eps) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class DynamicPad2d(nn.Module): + r""" + A module that wraps dynamic padding of any input, given the parameters of the convolutional layer and the input + hidden states. + """ + + def __init__(self, kernel_size, stride, dilation, value=0): + super().__init__() + # Safety checkers + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + + if isinstance(stride, int): + stride = (stride, stride) + + if isinstance(dilation, int): + dilation = (dilation, dilation) + + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.value = value + + def compute_padding(x, kernel_size, stride, dilation): + return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0) + + self.compute_padding = compute_padding + + def forward(self, input): + # Get width and height + input_height, input_width = input.size()[-2:] + + # Compute the padding values + padding_height = self.compute_padding(input_height, self.kernel_size[0], self.stride[0], self.dilation[0]) + padding_width = self.compute_padding(input_width, self.kernel_size[1], self.stride[1], self.dilation[1]) + + # apply pad + if padding_height > 0 or padding_width > 0: + input = nn.functional.pad( + input, + [ + padding_width // 2, + padding_width - padding_width // 2, + padding_height // 2, + padding_height - padding_height // 2, + ], + value=self.value, + ) + return input + + +class BitMaxPool2d(nn.MaxPool2d): + """Tensorflow like 'SAME' wrapper for 2D max pooling""" + + def __init__( + self, + kernel_size: int, + stride=None, + dilation=1, + ceil_mode=False, + padding=(0, 0), + padding_value=0, + use_dynamic_padding=True, + ): + kernel_size = kernel_size if isinstance(kernel_size, collections.abc.Iterable) else (kernel_size, kernel_size) + stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride) + dilation = dilation if isinstance(dilation, collections.abc.Iterable) else (dilation, dilation) + super().__init__(kernel_size, stride, padding, dilation, ceil_mode) + if use_dynamic_padding: + self.pad = DynamicPad2d(kernel_size, stride, dilation, padding_value) + else: + self.pad = nn.Identity() + + def forward(self, hidden_states): + hidden_states = self.pad(hidden_states) + return nn.functional.max_pool2d( + hidden_states, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode + ) + + +class BitEmbeddings(nn.Module): + """ + BiT Embeddings (stem) composed of a single aggressive convolution. + """ + + def __init__(self, config: BitConfig): + super().__init__() + + self.convolution = WeightStandardizedConv2d( + config.num_channels, + config.embedding_size, + kernel_size=7, + stride=2, + eps=1e-8, + padding=config.global_padding, + ) + + self.pooler = BitMaxPool2d(kernel_size=3, stride=2, use_dynamic_padding=config.embedding_dynamic_padding) + + # Use the same padding strategy as convolutional layers + if config.global_padding is not None and config.global_padding.upper() == "SAME": + self.pad = nn.Identity() + else: + self.pad = nn.ConstantPad2d(padding=(1, 1, 1, 1), value=0.0) + + if config.layer_type != "preactivation": + self.norm = BitGroupNormActivation(config, num_channels=config.embedding_size) + else: + self.norm = nn.Identity() + + self.num_channels = config.num_channels + + def forward(self, pixel_values: Tensor) -> Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + + embedding = self.convolution(pixel_values) + + embedding = self.pad(embedding) + + embedding = self.norm(embedding) + + embedding = self.pooler(embedding) + + return embedding + + +# Copied from transformers.models.convnext.modeling_convnext.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Bit +class BitDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +def make_div(value, divisor=8): + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + if new_value < 0.9 * value: + new_value += divisor + return new_value + + +class BitPreActivationBottleneckLayer(nn.Module): + """Pre-activation (v2) bottleneck block. + Follows the implementation of "Identity Mappings in Deep Residual Networks": + https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua + + Except it puts the stride on 3x3 conv when available. + """ + + def __init__( + self, + config, + in_channels, + out_channels=None, + bottle_ratio=0.25, + stride=1, + dilation=1, + first_dilation=None, + groups=1, + drop_path_rate=0.0, + is_first_layer=False, + ): + super().__init__() + + first_dilation = first_dilation or dilation + + out_channels = out_channels or in_channels + mid_channels = make_div(out_channels * bottle_ratio) + + if is_first_layer: + self.downsample = BitDownsampleConv( + config, + in_channels, + out_channels, + stride=stride, + preact=True, + ) + else: + self.downsample = None + + self.norm1 = BitGroupNormActivation(config, in_channels) + self.conv1 = WeightStandardizedConv2d(in_channels, mid_channels, 1, eps=1e-8, padding=config.global_padding) + + self.norm2 = BitGroupNormActivation(config, num_channels=mid_channels) + self.conv2 = WeightStandardizedConv2d( + mid_channels, mid_channels, 3, stride=stride, groups=groups, eps=1e-8, padding=config.global_padding + ) + + self.norm3 = BitGroupNormActivation(config, mid_channels) + self.conv3 = WeightStandardizedConv2d(mid_channels, out_channels, 1, eps=1e-8, padding=config.global_padding) + + self.drop_path = BitDropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + + def forward(self, hidden_states): + hidden_states_preact = self.norm1(hidden_states) + + # shortcut branch + shortcut = hidden_states + if self.downsample is not None: + shortcut = self.downsample(hidden_states_preact) + + # residual branch + hidden_states = self.conv1(hidden_states_preact) + hidden_states = self.conv2(self.norm2(hidden_states)) + hidden_states = self.conv3(self.norm3(hidden_states)) + hidden_states = self.drop_path(hidden_states) + return hidden_states + shortcut + + +class BitBottleneckLayer(nn.Module): + """Non Pre-activation bottleneck block, equivalent to V1.5/V1b bottleneck. Used for ViT Hybrid.""" + + def __init__( + self, + config, + in_channels, + out_channels=None, + bottle_ratio=0.25, + stride=1, + dilation=1, + first_dilation=None, + groups=1, + drop_path_rate=0.0, + is_first_layer=False, + ): + super().__init__() + first_dilation = first_dilation or dilation + + out_channels = out_channels or in_channels + mid_chs = make_div(out_channels * bottle_ratio) + + if is_first_layer: + self.downsample = BitDownsampleConv( + config, + in_channels, + out_channels, + stride=stride, + preact=False, + ) + else: + self.downsample = None + + self.conv1 = WeightStandardizedConv2d(in_channels, mid_chs, 1, eps=1e-8, padding=config.global_padding) + self.norm1 = BitGroupNormActivation(config, num_channels=mid_chs) + self.conv2 = WeightStandardizedConv2d( + mid_chs, + mid_chs, + 3, + stride=stride, + dilation=first_dilation, + groups=groups, + eps=1e-8, + padding=config.global_padding, + ) + self.norm2 = BitGroupNormActivation(config, num_channels=mid_chs) + self.conv3 = WeightStandardizedConv2d(mid_chs, out_channels, 1, eps=1e-8, padding=config.global_padding) + self.norm3 = BitGroupNormActivation(config, num_channels=out_channels, apply_activation=False) + self.drop_path = BitDropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + + self.activation = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + # shortcut branch + shortcut = hidden_states + if self.downsample is not None: + shortcut = self.downsample(hidden_states) + + # residual + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.norm2(hidden_states) + + hidden_states = self.conv3(hidden_states) + hidden_states = self.norm3(hidden_states) + + hidden_states = self.drop_path(hidden_states) + hidden_states = self.activation(hidden_states + shortcut) + return hidden_states + + +class BitDownsampleConv(nn.Module): + def __init__( + self, + config, + in_channels, + out_channels, + stride=1, + preact=True, + ): + super().__init__() + self.conv = WeightStandardizedConv2d( + in_channels, out_channels, 1, stride=stride, eps=1e-8, padding=config.global_padding + ) + self.norm = ( + nn.Identity() + if preact + else BitGroupNormActivation(config, num_channels=out_channels, apply_activation=False) + ) + + def forward(self, x): + return self.norm(self.conv(x)) + + +class BitStage(nn.Module): + """ + A ResNet v2 stage composed by stacked layers. + """ + + def __init__( + self, + config, + in_channels, + out_channels, + stride, + dilation, + depth, + bottle_ratio=0.25, + layer_dropout=None, + ): + super().__init__() + + first_dilation = 1 if dilation in (1, 2) else 2 + + # Get the layer type + if config.layer_type == "bottleneck": + layer_cls = BitBottleneckLayer + else: + layer_cls = BitPreActivationBottleneckLayer + + prev_chs = in_channels + self.layers = nn.Sequential() + for layer_idx in range(depth): + # Get the current hyper-parameters + stride, drop_path_rate, is_first_layer = self._get_updated_hyperparameters( + layer_idx, stride, layer_dropout + ) + + self.layers.add_module( + str(layer_idx), + layer_cls( + config, + prev_chs, + out_channels, + stride=stride, + dilation=dilation, + bottle_ratio=bottle_ratio, + first_dilation=first_dilation, + drop_path_rate=drop_path_rate, + is_first_layer=is_first_layer, + ), + ) + prev_chs = out_channels + first_dilation = dilation + + def _get_updated_hyperparameters(self, layer_idx, stride, layer_dropout): + r""" + Get the new hyper-parameters with respect to the previous ones and the index of the current layer. + """ + if layer_dropout: + drop_path_rate = layer_dropout[layer_idx] + else: + drop_path_rate = 0.0 + + if layer_idx != 0: + stride = 1 + + is_first_layer = layer_idx == 0 + + return stride, drop_path_rate, is_first_layer + + def forward(self, input: Tensor) -> Tensor: + hidden_state = input + for _, layer in enumerate(self.layers): + hidden_state = layer(hidden_state) + return hidden_state + + +class BitEncoder(nn.Module): + def __init__(self, config: BitConfig): + super().__init__() + self.stages = nn.ModuleList([]) + + prev_chs = config.embedding_size + + # These needs to stay hardcoded + current_stride = 4 + dilation = 1 + + layer_dropouts = [ + x.tolist() + for x in torch.Tensor(np.linspace(0, config.drop_path_rate, sum(config.depths))).split(config.depths) + ] + + for stage_idx, (current_depth, current_hidden_size, layer_dropout) in enumerate( + zip(config.depths, config.hidden_sizes, layer_dropouts) + ): + # Get the updated hyper params + out_channels, stride, dilation = self._get_updated_hyperparameters( + stage_idx, current_stride, current_hidden_size, dilation, config + ) + + stage = BitStage( + config, + prev_chs, + out_channels, + stride=stride, + dilation=dilation, + depth=current_depth, + layer_dropout=layer_dropout, + ) + + prev_chs = out_channels + current_stride *= stride + + self.stages.add_module(str(stage_idx), stage) + + def _get_updated_hyperparameters(self, stage_idx, current_stride, current_hidden_size, dilation, config): + out_channels = make_div(current_hidden_size * config.width_factor) + stride = 1 if stage_idx == 0 else 2 + if current_stride >= config.output_stride: + dilation *= stride + stride = 1 + return out_channels, stride, dilation + + def forward( + self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True + ) -> BaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage_module(hidden_state) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +@auto_docstring +class BitPreTrainedModel(PreTrainedModel): + config: BitConfig + base_model_prefix = "bit" + main_input_name = "pixel_values" + _no_split_modules = ["BitEmbeddings"] + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`. + elif isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + +@auto_docstring +class BitModel(BitPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embedder = BitEmbeddings(config) + + self.encoder = BitEncoder(config) + self.norm = ( + BitGroupNormActivation(config, num_channels=config.hidden_sizes[-1]) + if config.layer_type == "preactivation" + else nn.Identity() + ) + + self.pooler = nn.AdaptiveAvgPool2d((1, 1)) + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> BaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict + ) + + last_hidden_state = encoder_outputs[0] + + last_hidden_state = self.norm(last_hidden_state) + + pooled_output = self.pooler(last_hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + BiT Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """ +) +class BitForImageClassification(BitPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.bit = BitModel(config) + # classification head + self.classifier = nn.Sequential( + nn.Flatten(), + nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(), + ) + # initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> ImageClassifierOutputWithNoAttention: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return (loss,) + output if loss is not None else output + + return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + +@auto_docstring( + custom_intro=""" + BiT backbone, to be used with frameworks like DETR and MaskFormer. + """ +) +class BitBackbone(BitPreTrainedModel, BackboneMixin): + has_attentions = False + + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.bit = BitModel(config) + self.num_features = [config.embedding_size] + config.hidden_sizes + + # initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> BackboneOutput: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("google/bit-50") + >>> model = AutoBackbone.from_pretrained("google/bit-50") + + >>> inputs = processor(image, return_tensors="pt") + >>> outputs = model(**inputs) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.bit(pixel_values, output_hidden_states=True, return_dict=True) + + hidden_states = outputs.hidden_states + + feature_maps = () + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps += (hidden_states[idx],) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) + + +__all__ = ["BitForImageClassification", "BitModel", "BitPreTrainedModel", "BitBackbone"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e082d38ba7339e66ba6a75d8db2d9c8c12e5304 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/configuration_bloom.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/configuration_bloom.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..272af6388c804cccdea5ce0232b0ed157ab96f5e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/configuration_bloom.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/modeling_bloom.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/modeling_bloom.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3366a0b90acd6255197cbbe607a35cca3f3ffcd7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/modeling_bloom.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/modeling_flax_bloom.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/modeling_flax_bloom.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad34cef014ecc43e5e8d1a60f306f18633a65e63 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/modeling_flax_bloom.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/tokenization_bloom_fast.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/tokenization_bloom_fast.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..901c3a355888e2ca167444fbd343136089d65c51 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bloom/__pycache__/tokenization_bloom_fast.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bridgetower/modeling_bridgetower.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bridgetower/modeling_bridgetower.py new file mode 100644 index 0000000000000000000000000000000000000000..59c5be00c3169996c5f3490491dc262022817dc3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bridgetower/modeling_bridgetower.py @@ -0,0 +1,1875 @@ +# coding=utf-8 +# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BridgeTower Model""" + +import math +from collections import OrderedDict +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN, QuickGELUActivation +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + MaskedLMOutput, + ModelOutput, + SequenceClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import auto_docstring, logging, torch_int +from ...utils.deprecation import deprecate_kwarg +from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig + + +logger = logging.get_logger(__name__) + +_TOKENIZER_FOR_DOC = "RobertaTokenizer" + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`BridgeTowerModel`]. + """ +) +class BridgeTowerModelOutput(ModelOutput): + r""" + text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`): + Sequence of hidden-states at the text output of the last layer of the model. + image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`): + Sequence of hidden-states at the image output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`): + Concatenation of last layer hidden-state of the first token of the text and image sequence (classification + token), respectively, after further processing through layers used for auxiliary pretraining tasks. + """ + + text_features: Optional[torch.FloatTensor] = None + image_features: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of ['BridgeTowerForContrastiveLearning'] + """ +) +class BridgeTowerContrastiveOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Image-text contrastive loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + image_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): + The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + text_embeds: Optional[tuple[torch.FloatTensor]] = None + image_embeds: Optional[tuple[torch.FloatTensor]] = None + cross_embeds: Optional[tuple[torch.FloatTensor]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +class BridgeTowerResidualAttention(nn.Module): + def __init__(self, config): + super().__init__() + + self.attn = nn.MultiheadAttention(config.hidden_size, config.hidden_size // 64) + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = nn.ModuleDict( + OrderedDict( + [ + ("c_fc", nn.Linear(config.hidden_size, config.hidden_size * 4)), + ("gelu", QuickGELUActivation()), + ("c_proj", nn.Linear(config.hidden_size * 4, config.hidden_size)), + ] + ) + ) + self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn_mask = None + + def attention(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor): + if attention_mask is not None: + attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_state.device) + self.attn_mask = ( + self.attn_mask.to(dtype=hidden_state.dtype, device=hidden_state.device) + if self.attn_mask is not None + else None + ) + return self.attn( + hidden_state, + hidden_state, + hidden_state, + need_weights=False, + attn_mask=self.attn_mask, + key_padding_mask=attention_mask, + )[0] + + def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): + residual_state = hidden_state + self.attention(self.ln_1(hidden_state), attention_mask) + hidden_state = self.ln_2(residual_state) + for layer in self.mlp.values(): + hidden_state = layer(hidden_state) + hidden_state = residual_state + hidden_state + return hidden_state + + +class BridgeTowerTransformer(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + if config.remove_last_layer: + self.resblocks = nn.ModuleList( + [BridgeTowerResidualAttention(config) for _ in range(self.num_hidden_layers - 1)] + ) + else: + self.resblocks = nn.ModuleList( + [BridgeTowerResidualAttention(config) for _ in range(self.num_hidden_layers)] + ) + self.stop_gradient = config.stop_gradient + + def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): + hidden_states = [] + for block in self.resblocks: + hidden_state = block(hidden_state, attention_mask) + if self.stop_gradient: + hidden_states.append(hidden_state.detach()) + else: + hidden_states.append(hidden_state) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->BridgeTower +class BridgeTowerVisionEmbeddings(nn.Module): + def __init__(self, config: BridgeTowerVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})." + ) + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class BridgeTowerVisionTransformer(nn.Module): + def __init__(self, config): + super().__init__() + + self.embeddings = BridgeTowerVisionEmbeddings(config) + self.ln_pre = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.transformer = BridgeTowerTransformer(config) + self.ln_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.share_layernorm = config.share_layernorm + if not config.share_layernorm: + self.ln_separate = nn.ModuleList( + [nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)] + ) + + def forward( + self, + pixel_values: torch.Tensor, + attention_mask, + interpolate_pos_encoding: bool = False, + ): + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding) + hidden_states = self.ln_pre(hidden_states) + # NLD -> LND + hidden_states = hidden_states.permute(1, 0, 2) + + hidden_states = self.transformer(hidden_states, attention_mask) + # shape = [num_hidden_layers, hidden_size, *, grid ** 2] + hidden_states = torch.stack(hidden_states, dim=0) + # shape = [num_hidden_layers, *, hidden_size, grid ** 2] + hidden_states = hidden_states.permute(0, 2, 1, 3) + if self.share_layernorm: + hidden_states = self.ln_post(hidden_states) + else: + hidden_states_stack = [] + for hidden_states, ln in zip(hidden_states, self.ln_separate): + hidden_states = ln(hidden_states) + hidden_states_stack.append(hidden_states) + # shape = [num_hidden_layers, *, hidden_size, grid ** 2] + hidden_states = torch.stack(hidden_states_stack, dim=0) + return hidden_states + + def forward_pre( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False, + ): + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.ln_pre(hidden_states) + # NLD -> LND + hidden_states = hidden_states.permute(1, 0, 2) + return hidden_states + + def forward_post(self, hidden_state: torch.Tensor): + visual_output_post = hidden_state.permute(1, 0, 2) + visual_output_post = self.ln_post(visual_output_post) + return visual_output_post + + +class BridgeTowerLinkTower(nn.Module): + def __init__(self, config): + super().__init__() + self.link_tower_type = config.link_tower_type + self.hidden_size = config.hidden_size + if config.link_tower_type in ["add", "scaled_add", "interpolate"]: + if config.link_tower_type == "scaled_add": + self.scaled_factor = nn.Parameter(torch.tensor(1.0)) + elif config.link_tower_type == "interpolate": + self.beta = nn.Parameter(torch.tensor(0.5)) + self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) + else: + raise NotImplementedError(f"link_tower_type {config.link_tower_type} is not implemented") + + def forward(self, hidden_states, cross_modal_hidden_states, attention_mask): + if self.link_tower_type == "add": + return self.LayerNorm(hidden_states + cross_modal_hidden_states) + elif self.link_tower_type == "scaled_add": + return self.LayerNorm(hidden_states * self.scaled_factor + cross_modal_hidden_states) + elif self.link_tower_type == "interpolate": + return self.LayerNorm(hidden_states * (1 - self.beta) + cross_modal_hidden_states * self.beta) + else: + raise NotImplementedError(f"link_tower_type {self.link_tower_type} is not implemented") + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->BridgeTower +class BridgeTowerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->BridgeTower +class BridgeTowerIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->BridgeTower +class BridgeTowerOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->BridgeTower +class BridgeTowerPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->BridgeTower +class BridgeTowerSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + + is_updated = False + is_cross_attention = encoder_hidden_states is not None + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values + else: + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if past_key_values is not None: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BridgeTowerModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, attention_probs + + +BRIDGE_TOWER_SELF_ATTENTION_CLASSES = { + "eager": BridgeTowerSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower,BERT->BRIDGE_TOWER +class BridgeTowerAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + self.self = BRIDGE_TOWER_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, + ) + self.output = BridgeTowerSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BridgeTowerBertCrossLayer(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BridgeTowerAttention(config, layer_idx=layer_idx) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + self.crossattention = BridgeTowerAttention(config, layer_idx=layer_idx) + self.intermediate = BridgeTowerIntermediate(config) + self.output = BridgeTowerOutput(config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states, + encoder_hidden_states, + attention_mask=None, + head_mask=None, + encoder_attention_mask=None, + past_key_values=None, + output_attentions=False, + cache_position=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=None, + output_attentions=output_attentions, + past_key_values=None, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = cross_attention_outputs[0] + # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BridgeTowerTextLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx=None): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BridgeTowerAttention(config, layer_idx=layer_idx) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = BridgeTowerAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) + self.intermediate = BridgeTowerIntermediate(config) + self.output = BridgeTowerOutput(config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + past_key_values=past_key_values, + cache_position=cache_position, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + return (layer_output,) + outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->BridgeTowerText +class BridgeTowerTextEncoder(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BridgeTowerTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->BridgeTowerText +class BridgeTowerTextEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +@auto_docstring +class BridgeTowerPreTrainedModel(PreTrainedModel): + config: BridgeTowerConfig + base_model_prefix = "bridgetower" + supports_gradient_checkpointing = False + _no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module: nn.Module): + std = self.config.initializer_factor + if isinstance(module, BridgeTowerVisionTransformer): + proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5) + attn_std = self.config.hidden_size**-0.5 + fc_std = (2 * self.config.hidden_size) ** -0.5 + for block in module.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std * std) + block.attn.in_proj_bias.data.zero_() + nn.init.normal_(block.attn.out_proj.weight, std=proj_std * std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std * std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std * std) + + nn.init.normal_(module.embeddings.class_embedding, std=attn_std * std) + nn.init.normal_(module.embeddings.position_embedding.weight, std=attn_std * std) + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.05 * std) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, BridgeTowerForContrastiveLearning): + module.logit_scale.data.fill_(self.config.logit_scale_init_value) + + if isinstance(module, (nn.Linear, BridgeTowerMLMHead)) and module.bias is not None: + module.bias.data.zero_() + + +class BridgeTowerVisionModel(BridgeTowerPreTrainedModel): + config: BridgeTowerVisionConfig + + def __init__(self, config): + super().__init__(config) + self.visual = BridgeTowerVisionTransformer(config) + + @property + def dtype(self): + return self.visual.embeddings.patch_embedding.weight.dtype + + def forward(self, image, image_mask=None, interpolate_pos_encoding=False): + return self.visual(image.type(self.dtype), image_mask, interpolate_pos_encoding) + + +@auto_docstring( + custom_intro=""" + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in *Attention is + all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz + Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + + .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762 + """ +) +class BridgeTowerTextModel(BridgeTowerPreTrainedModel): + config: BridgeTowerTextConfig + + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = BridgeTowerTextEmbeddings(config) + self.encoder = BridgeTowerTextEncoder(config) + + self.pooler = BridgeTowerPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + The bare BridgeTower Model transformer outputting BridgeTowerModelOutput object without any specific head on + """ +) +class BridgeTowerModel(BridgeTowerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + vision_config = config.vision_config + text_config = config.text_config + + if config.share_cross_modal_transformer_layers: + self.cross_modal_text_transform = nn.Linear(text_config.hidden_size, config.hidden_size) + self.cross_modal_image_transform = nn.Linear(vision_config.hidden_size, config.hidden_size) + else: + self.cross_modal_text_transform = nn.ModuleList( + [nn.Linear(text_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)] + ) + self.cross_modal_image_transform = nn.ModuleList( + [nn.Linear(vision_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)] + ) + + self.token_type_embeddings = nn.Embedding(2, config.hidden_size) + + self.vision_model = BridgeTowerVisionModel(vision_config) + + self.text_model = BridgeTowerTextModel(text_config) + + if not vision_config.share_layernorm and config.init_layernorm_from_vision_encoder: + for ln in self.vision_model.visual.cross_modal_ln_separate: + ln.weight.data = self.vision_model.visual.ln_post.weight.data + ln.bias.data = self.vision_model.visual.ln_post.bias.data + + self.cross_modal_image_layers = nn.ModuleList( + [BridgeTowerBertCrossLayer(text_config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) + self.cross_modal_text_layers = nn.ModuleList( + [BridgeTowerBertCrossLayer(text_config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) + + # Class token => Linear => Tanh + self.cross_modal_image_pooler = BridgeTowerPooler(config) + self.cross_modal_text_pooler = BridgeTowerPooler(config) + + # Initialize BridgeTower Components + self.cross_modal_text_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.cross_modal_image_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.share_link_tower_layers: + self.cross_modal_text_link_tower = BridgeTowerLinkTower(config) + self.cross_modal_image_link_tower = BridgeTowerLinkTower(config) + else: + self.cross_modal_text_link_tower = nn.ModuleList( + [BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)] + ) + self.cross_modal_image_link_tower = nn.ModuleList( + [BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)] + ) + + self.post_init() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + image_token_type_idx: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> Union[tuple[torch.Tensor], BridgeTowerModelOutput]: + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*): + Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `pixel_values` into patch embeddings. + image_token_type_idx (`int`, *optional*): + - The token type ids for images. + output_hidden_states (`bool`, *optional*): + If set to `True`, hidden states are returned as a list containing the hidden states of text, image, and + cross-modal components respectively. i.e. `(hidden_states_text, hidden_states_image, + hidden_states_cross_modal)` where each element is a list of the hidden states of the corresponding + modality. `hidden_states_txt/img` are a list of tensors corresponding to unimodal hidden states and + `hidden_states_cross_modal` is a list of tuples containing `cross_modal_text_hidden_states` and + `cross_modal_image_hidden_states` of each brdige layer. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels are currently not supported. + + Examples: + + ```python + >>> from transformers import BridgeTowerProcessor, BridgeTowerModel + >>> from PIL import Image + >>> import requests + + >>> # prepare image and text + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "hello world" + >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base") + >>> model = BridgeTowerModel.from_pretrained("BridgeTower/bridgetower-base") + + >>> inputs = processor(image, text, return_tensors="pt") + >>> outputs = model(**inputs) + >>> outputs.keys() + odict_keys(['text_features', 'image_features', 'pooler_output']) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + all_hidden_states_text = () if output_hidden_states else None + all_hidden_states_image = () if output_hidden_states else None + all_hidden_states_cross = () if output_hidden_states else None + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if inputs_embeds is not None and input_ids is None: + raise NotImplementedError( + "BridgeTowerModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead." + ) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + image_token_type_idx = image_token_type_idx if image_token_type_idx else 1 + input_shape = input_ids.size() + text_embeds = self.text_model.embeddings(input_ids=input_ids) + + if output_hidden_states: + all_hidden_states_text += (text_embeds,) + + if attention_mask is None: + attention_mask = torch.ones(input_shape, dtype=torch.long, device=input_ids.device) + extend_text_masks = self.text_model.get_extended_attention_mask(attention_mask, input_shape).to( + input_ids.device + ) + + # The split_index determines how many layers of the uni-modal encoder are applied before the cross-modal encoder + split_index = len(self.text_model.encoder.layer) - self.config.num_hidden_layers + 1 + + # Run the first 'split_index' layers of the textual encoder + for layer in self.text_model.encoder.layer[:split_index]: + text_embeds = layer(text_embeds, extend_text_masks)[0] + + if output_hidden_states: + all_hidden_states_text += (text_embeds,) + + if image_embeds is None: + image_embeds = self.vision_model.visual.forward_pre( + pixel_values.type(self.vision_model.dtype), interpolate_pos_encoding=interpolate_pos_encoding + ) + else: + # Permute as BridgeTowerResidualAttention has batch_first=True + image_embeds = image_embeds.permute(1, 0, 2) + + if output_hidden_states: + all_hidden_states_image += (image_embeds,) + + # Run the first 'split_index' layers of the visual encoder + for block in self.vision_model.visual.transformer.resblocks[:split_index]: + image_embeds = block(image_embeds) + if output_hidden_states: + all_hidden_states_image += (image_embeds,) + + image_embeds_with_ln = self.vision_model.visual.forward_post(image_embeds.type(self.vision_model.dtype)) + + # first layer is a special case because we don't have the output from the cross-encoder yet + cross_modal_text = self.cross_modal_text_transform(text_embeds) + + text_token_type_embeddings = self.token_type_embeddings( + torch.zeros(1, dtype=torch.long, device=input_ids.device) + ).expand_as(cross_modal_text) + + cross_modal_text = self.cross_modal_text_layernorm(cross_modal_text + text_token_type_embeddings) + + image_embeds_with_ln = self.cross_modal_image_transform(image_embeds_with_ln) + image_token_type_embeddings = self.token_type_embeddings( + torch.full((1,), image_token_type_idx, dtype=torch.long, device=input_ids.device) + ).expand_as(image_embeds_with_ln) + + image_embeds_with_ln = image_embeds_with_ln + image_token_type_embeddings + cross_modal_image = self.cross_modal_image_layernorm(image_embeds_with_ln) + + pixel_mask = torch.ones( + (cross_modal_image.size(0), cross_modal_image.size(1)), + dtype=torch.long, + device=input_ids.device, + ) + extend_image_masks = self.text_model.get_extended_attention_mask(pixel_mask, pixel_mask.size()).to( + input_ids.device + ) + + layer_outputs_text = self.cross_modal_text_layers[0]( + cross_modal_text, + cross_modal_image, + attention_mask=extend_text_masks, + encoder_attention_mask=extend_image_masks, + output_attentions=output_attentions, + ) + cross_text_features = layer_outputs_text[0] + + layer_outputs_image = self.cross_modal_image_layers[0]( + cross_modal_image, + cross_modal_text, + attention_mask=extend_image_masks, + encoder_attention_mask=extend_text_masks, + output_attentions=output_attentions, + ) + cross_image_features = layer_outputs_image[0] + + if output_hidden_states: + all_hidden_states_cross += ((cross_text_features, cross_image_features),) + + if output_attentions: + all_self_attentions += ((layer_outputs_text[1], layer_outputs_image[1]),) + + link_layer_index = 0 + + # Each of the top 6 layers of the visual and textual encoders ([split_index:]) is connected to each layer of + # the cross-modal encoder via bridge layers, which brings bottom-up alignment and fusion to the cross-modal encoder. + for i in range(split_index, len(self.text_model.encoder.layer)): + text_embeds = self.text_model.encoder.layer[i](text_embeds, extend_text_masks)[0] + image_embeds = self.vision_model.visual.transformer.resblocks[i](image_embeds).type( + self.vision_model.dtype + ) + image_embeds_with_ln = ( + self.cross_modal_image_transform(self.vision_model.visual.forward_post(image_embeds)) + + image_token_type_embeddings + ) + + text_link_tower = self.cross_modal_text_link_tower[link_layer_index] + image_link_tower = self.cross_modal_image_link_tower[link_layer_index] + + # Bridge layers for textual and visual encoders + cross_text_features_ = text_link_tower( + self.cross_modal_text_transform(text_embeds) + text_token_type_embeddings, + cross_text_features, + extend_text_masks, + ) + cross_image_features_ = image_link_tower(image_embeds_with_ln, cross_image_features, extend_image_masks) + + # Cross-modal encoder via bridge layers of textual and visual encoders + layer_outputs_text = self.cross_modal_text_layers[link_layer_index + 1]( + cross_text_features_, + cross_image_features_, + attention_mask=extend_text_masks, + encoder_attention_mask=extend_image_masks, + output_attentions=output_attentions, + ) + cross_text_features = layer_outputs_text[0] + + layer_outputs_image = self.cross_modal_image_layers[link_layer_index + 1]( + cross_image_features_, + cross_text_features_, + attention_mask=extend_image_masks, + encoder_attention_mask=extend_text_masks, + output_attentions=output_attentions, + ) + cross_image_features = layer_outputs_image[0] + + link_layer_index += 1 + + if output_hidden_states: + all_hidden_states_text += (text_embeds,) + all_hidden_states_image += (image_embeds,) + all_hidden_states_cross += ((cross_text_features, cross_image_features),) + + if output_attentions: + all_self_attentions += ((layer_outputs_text[1], layer_outputs_image[1]),) + + # Concatenate the cls token of the text and image features to get the final represtation + text_features, image_features = cross_text_features, cross_image_features + cls_features = self.get_cls_features(text_features, image_features) + + if output_hidden_states: + all_hidden_states = (all_hidden_states_text, all_hidden_states_image, all_hidden_states_cross) + + if not return_dict: + return tuple( + v + for v in [text_features, image_features, cls_features, all_hidden_states, all_self_attentions] + if v is not None + ) + + return BridgeTowerModelOutput( + text_features=text_features, + image_features=image_features, + pooler_output=cls_features, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def get_cls_features(self, text_features, image_features): + cls_features_text = self.cross_modal_text_pooler(text_features) + cls_features_image = self.cross_modal_image_pooler(image_features) + return torch.cat([cls_features_text, cls_features_image], dim=-1) + + +# Copied from transformers.models.vilt.modeling_vilt.ViltPredictionHeadTransform with Vilt->BridgeTower +class BridgeTowerPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BridgeTowerMLMHead(nn.Module): + def __init__(self, config, weight=None): + super().__init__() + self.config = config + self.transform = BridgeTowerPredictionHeadTransform(config) + self.decoder = nn.Linear(config.hidden_size, config.text_config.vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.text_config.vocab_size)) + if weight is not None: + self.decoder.weight = weight + + def forward(self, x): + mlm_score = self.transform(x) + mlm_score = self.decoder(mlm_score) + self.bias + return mlm_score + + +class BridgeTowerITMHead(nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.fc = nn.Linear(hidden_size, 2) + + def forward(self, x): + itm_score = self.fc(x) + return itm_score + + +@auto_docstring( + custom_intro=""" + BridgeTower Model with a language modeling head on top as done during pretraining. + """ +) +class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel): + _tied_weights_keys = ["mlm_score.decoder.weight"] + + def __init__(self, config): + super().__init__(config) + + self.bridgetower = BridgeTowerModel(config) + self.mlm_score = BridgeTowerMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.mlm_score.decoder + + def set_output_embeddings(self, new_embeddings): + self.mlm_score.decoder = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[MaskedLMOutput, tuple[torch.FloatTensor]]: + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*): + Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `pixel_values` into patch embeddings. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Examples: + + ```python + >>> from transformers import BridgeTowerProcessor, BridgeTowerForMaskedLM + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000360943.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + >>> text = "a looking out of the window" + + >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm") + >>> model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm") + + >>> # prepare inputs + >>> encoding = processor(image, text, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**encoding) + + >>> results = processor.decode(outputs.logits.argmax(dim=-1).squeeze(0).tolist()) + + >>> print(results) + .a cat looking out of the window. + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.bridgetower( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + pixel_values=pixel_values, + pixel_mask=pixel_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + image_embeds=image_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + mlm_logits = self.mlm_score(outputs.text_features if return_dict else outputs[0]) + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + + labels = labels.to(mlm_logits.device) + masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1)) + + if not return_dict: + output = tuple(mlm_logits) + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=mlm_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + BridgeTower Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the + [CLS] token) for image-to-text matching. + """ +) +class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bridgetower = BridgeTowerModel(config) + + self.itm_score = BridgeTowerITMHead(config.hidden_size * 2) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[SequenceClassifierOutput, tuple[torch.FloatTensor]]: + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*): + Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `pixel_values` into patch embeddings. + labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match. + The pairs with 0 will be skipped for calculation. + + Examples: + + ```python + >>> from transformers import BridgeTowerProcessor, BridgeTowerForImageAndTextRetrieval + >>> import requests + >>> from PIL import Image + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"] + + >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm") + >>> model = BridgeTowerForImageAndTextRetrieval.from_pretrained("BridgeTower/bridgetower-base-itm-mlm") + + >>> # forward pass + >>> scores = dict() + >>> for text in texts: + ... # prepare inputs + ... encoding = processor(image, text, return_tensors="pt") + ... outputs = model(**encoding) + ... scores[text] = outputs.logits[0, 1].item() + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bridgetower( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + pixel_values=pixel_values, + pixel_mask=pixel_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + image_embeds=image_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooler_output = outputs.pooler_output if return_dict else outputs[2] + + logits = self.itm_score(pooler_output) + + itm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + itm_loss = loss_fct(logits, labels) + + if not return_dict: + output = tuple(logits) + return ((itm_loss,) + output) if itm_loss is not None else output + + return SequenceClassifierOutput( + loss=itm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BridgeTowerContrastiveHead(nn.Module): + def __init__(self, hidden_size, embed_size): + super().__init__() + self.fc = nn.Linear(hidden_size, embed_size) + + def forward(self, x): + x = self.fc(x) + return x + + +@auto_docstring( + custom_intro=""" + BridgeTower Model with a image-text contrastive head on top computing image-text contrastive loss. + """ +) +class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bridgetower = BridgeTowerModel(config) + + self.itc_text_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size) + self.itc_image_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size) + self.itc_cross_modal_head = BridgeTowerContrastiveHead(config.hidden_size * 2, config.contrastive_hidden_size) + + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = True, + return_dict: Optional[bool] = None, + return_loss: Optional[bool] = None, + ) -> Union[BridgeTowerContrastiveOutput, tuple[torch.FloatTensor]]: + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*): + Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `pixel_values` into patch embeddings. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + + Examples: + + ```python + >>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning + >>> import requests + >>> from PIL import Image + >>> import torch + + >>> image_urls = [ + ... "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg", + ... "http://images.cocodataset.org/val2017/000000039769.jpg", + ... ] + >>> texts = ["two dogs in a car", "two cats sleeping on a couch"] + >>> images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls] + + >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") + >>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") + + >>> inputs = processor(images, texts, padding=True, return_tensors="pt") + >>> loss = model(**inputs, return_loss=True).loss + + >>> inputs = processor(images, texts[::-1], padding=True, return_tensors="pt") + >>> loss_swapped = model(**inputs, return_loss=True).loss + + >>> print("Loss", round(loss.item(), 4)) + Loss 0.0019 + + >>> print("Loss with swapped images", round(loss_swapped.item(), 4)) + Loss with swapped images 2.126 + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bridgetower( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + pixel_values=pixel_values, + pixel_mask=pixel_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + image_embeds=image_embeds, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + pooler_output = outputs.pooler_output if return_dict else outputs[2] + hidden_states_txt, hidden_states_img, hidden_states_cross_modal = ( + outputs.hidden_states if return_dict else outputs[3] + ) + + text_embeds = hidden_states_txt[-1] + image_embeds = hidden_states_img[-1] + + image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(image_embeds) + image_token_type_embeddings = self.bridgetower.token_type_embeddings( + torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device) + ).expand_as(image_embeds_with_ln) + + image_embeds = self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) + image_token_type_embeddings + + # normalized features + text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2) + image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2).to( + device=text_embeds.device + ) + cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2).to( + device=text_embeds.device + ) + + logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2) + + logit_scale = self.logit_scale.exp().to(device=text_embeds.device) + logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale + logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale + + itc_loss = None + + if return_loss: + labels = torch.arange(len(logits), device=logits.device) + text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels) + text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels) + image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels) + itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0 + + if not return_dict: + output = (logits, text_embeds, image_embeds, cross_embeds) + outputs[3:] + return ((itc_loss,) + output) if itc_loss is not None else output + + return BridgeTowerContrastiveOutput( + loss=itc_loss, + logits=logits, + text_embeds=text_embeds, + image_embeds=image_embeds, + cross_embeds=cross_embeds, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "BridgeTowerForContrastiveLearning", + "BridgeTowerForImageAndTextRetrieval", + "BridgeTowerForMaskedLM", + "BridgeTowerModel", + "BridgeTowerPreTrainedModel", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bridgetower/processing_bridgetower.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bridgetower/processing_bridgetower.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7059c4c5a5dfef87ca20117d60a4b6e9fb5f72 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/bridgetower/processing_bridgetower.py @@ -0,0 +1,73 @@ +# coding=utf-8 +# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for BridgeTower. +""" + +from typing import Optional + +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin + + +class BridgeTowerImagesKwargs(ImagesKwargs): + size_divisor: Optional[int] + + +class BridgeTowerProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: BridgeTowerImagesKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": { + "do_normalize": True, + "do_center_crop": True, + }, + } + + +class BridgeTowerProcessor(ProcessorMixin): + r""" + Constructs a BridgeTower processor which wraps a Roberta tokenizer and BridgeTower image processor into a single + processor. + + [`BridgeTowerProcessor`] offers all the functionalities of [`BridgeTowerImageProcessor`] and + [`RobertaTokenizerFast`]. See the docstring of [`~BridgeTowerProcessor.__call__`] and + [`~BridgeTowerProcessor.decode`] for more information. + + Args: + image_processor (`BridgeTowerImageProcessor`): + An instance of [`BridgeTowerImageProcessor`]. The image processor is a required input. + tokenizer (`RobertaTokenizerFast`): + An instance of ['RobertaTokenizerFast`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "BridgeTowerImageProcessor" + tokenizer_class = ("RobertaTokenizer", "RobertaTokenizerFast") + valid_processor_kwargs = BridgeTowerProcessorKwargs + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + +__all__ = ["BridgeTowerProcessor"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..908d32abdd13d38fb6d97cd974b52c3856078194 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/configuration_chinese_clip.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/configuration_chinese_clip.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..299a389cbcc248b6c55514dd2c6bf9009e740bae Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/configuration_chinese_clip.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/feature_extraction_chinese_clip.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/feature_extraction_chinese_clip.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5878a5036aec26854e318393d5157120ea90337 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/feature_extraction_chinese_clip.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/image_processing_chinese_clip.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/image_processing_chinese_clip.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e732286b333b334572a019689007be6a3c159411 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/image_processing_chinese_clip.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/image_processing_chinese_clip_fast.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/image_processing_chinese_clip_fast.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad3f08505b101f27ec0f4a2a1acb806902e65ac1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/image_processing_chinese_clip_fast.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/modeling_chinese_clip.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/modeling_chinese_clip.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f58a3bf8a2685f02a5221cd4c23bb7bee3df9dcf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/modeling_chinese_clip.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/processing_chinese_clip.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/processing_chinese_clip.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c7de361b83c0cc20f89931c7d235b9dec97912e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/chinese_clip/__pycache__/processing_chinese_clip.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/dinov2/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc316957eac509573bf44785209d0729ea13bb6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/dinov2/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dinov2 import * + from .modeling_dinov2 import * + from .modeling_flax_dinov2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/dinov2/configuration_dinov2.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/dinov2/configuration_dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..55fa0539a23bf0c02d0079ce41f0c5228b88c904 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/dinov2/configuration_dinov2.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DINOv2 model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class Dinov2Config(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Dinov2Model`]. It is used to instantiate an + Dinov2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Dinov2 + [google/dinov2-base-patch16-224](https://huggingface.co/google/dinov2-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + mlp_ratio (`int`, *optional*, defaults to 4): + Ratio of the hidden size of the MLPs relative to the `hidden_size`. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + layerscale_value (`float`, *optional*, defaults to 1.0): + Initial value to use for layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_swiglu_ffn (`bool`, *optional*, defaults to `False`): + Whether to use the SwiGLU feedforward neural network. + out_features (`list[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`list[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + apply_layernorm (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization to the feature maps in case the model is used as backbone. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in + case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size, + seq_len, hidden_size)`. + use_mask_token (`bool`, *optional*, defaults to `True`): + Whether to use mask_token in embeddings. + + Example: + + ```python + >>> from transformers import Dinov2Config, Dinov2Model + + >>> # Initializing a Dinov2 dinov2-base-patch16-224 style configuration + >>> configuration = Dinov2Config() + + >>> # Initializing a model (with random weights) from the dinov2-base-patch16-224 style configuration + >>> model = Dinov2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "dinov2" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + mlp_ratio=4, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + image_size=224, + patch_size=14, + num_channels=3, + qkv_bias=True, + layerscale_value=1.0, + drop_path_rate=0.0, + use_swiglu_ffn=False, + out_features=None, + out_indices=None, + apply_layernorm=True, + reshape_hidden_states=True, + use_mask_token=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.layerscale_value = layerscale_value + self.drop_path_rate = drop_path_rate + self.use_swiglu_ffn = use_swiglu_ffn + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + self.apply_layernorm = apply_layernorm + self.reshape_hidden_states = reshape_hidden_states + self.use_mask_token = use_mask_token + + +class Dinov2OnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["Dinov2Config", "Dinov2OnnxConfig"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/dinov2/modeling_dinov2.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/dinov2/modeling_dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..b01678b7f2f9d13827012caf4808d5015f8a303b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/dinov2/modeling_dinov2.py @@ -0,0 +1,684 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DINOv2 model.""" + +import collections.abc +from typing import Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import TransformersKwargs, auto_docstring, logging, torch_int +from ...utils.backbone_utils import BackboneMixin +from ...utils.generic import can_return_tuple, check_model_inputs +from .configuration_dinov2 import Dinov2Config + + +logger = logging.get_logger(__name__) + + +class Dinov2Embeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + if config.use_mask_token: + self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.patch_embeddings = Dinov2PatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.use_mask_token = config.use_mask_token + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + target_dtype = patch_pos_embed.dtype + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.to(torch.float32), + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ).to(dtype=target_dtype) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + if bool_masked_pos is not None and self.use_mask_token: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Dinov2PatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2 +class Dinov2SelfAttention(nn.Module): + def __init__(self, config: Dinov2Config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + def forward( + self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.shape[0] + new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size + + key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2) + query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + return context_layer, attention_probs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2 +class Dinov2SelfOutput(nn.Module): + """ + The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: Dinov2Config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2 +class Dinov2Attention(nn.Module): + def __init__(self, config: Dinov2Config): + super().__init__() + self.attention = Dinov2SelfAttention(config) + self.output = Dinov2SelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: set[int]): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + self_attn_output, _ = self.attention(hidden_states, head_mask) + output = self.output(self_attn_output, hidden_states) + return output + + +class Dinov2LayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath +class Dinov2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +class Dinov2MLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class Dinov2SwiGLUFFN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +class Dinov2Layer(GradientCheckpointingLayer): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = Dinov2Attention(config) + self.layer_scale1 = Dinov2LayerScale(config) + self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.use_swiglu_ffn: + self.mlp = Dinov2SwiGLUFFN(config) + else: + self.mlp = Dinov2MLP(config) + self.layer_scale2 = Dinov2LayerScale(config) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states_norm = self.norm1(hidden_states) + self_attention_output = self.attention(hidden_states_norm, head_mask) + self_attention_output = self.layer_scale1(self_attention_output) + + # first residual connection + hidden_states = self.drop_path(self_attention_output) + hidden_states + + # in Dinov2, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + return layer_output + + +class Dinov2Encoder(nn.Module): + def __init__(self, config: Dinov2Config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([Dinov2Layer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False + ) -> BaseModelOutput: + all_hidden_states = [hidden_states] if output_hidden_states else None + for i, layer_module in enumerate(self.layer): + layer_head_mask = head_mask[i] if head_mask is not None else None + hidden_states = layer_module(hidden_states, layer_head_mask) + if all_hidden_states: + all_hidden_states.append(hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=tuple(all_hidden_states) if all_hidden_states else None, + ) + + +@auto_docstring +class Dinov2PreTrainedModel(PreTrainedModel): + config: Dinov2Config + base_model_prefix = "dinov2" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["Dinov2Layer"] + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + _can_record_outputs = { + "attentions": Dinov2SelfAttention, + } + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Dinov2Embeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + if self.config.use_mask_token: + module.mask_token.data.zero_() + elif isinstance(module, Dinov2LayerScale): + module.lambda1.data.fill_(self.config.layerscale_value) + + +@auto_docstring +class Dinov2Model(Dinov2PreTrainedModel): + def __init__(self, config: Dinov2Config): + super().__init__(config) + self.config = config + + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @check_model_inputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> BaseModelOutputWithPooling: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for + pre-training. + """ + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + + encoder_outputs: BaseModelOutput = self.encoder( + embedding_output, head_mask=head_mask, output_hidden_states=output_hidden_states + ) + sequence_output = encoder_outputs.last_hidden_state + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """ +) +class Dinov2ForImageClassification(Dinov2PreTrainedModel): + def __init__(self, config: Dinov2Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.dinov2 = Dinov2Model(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> ImageClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs: BaseModelOutputWithPooling = self.dinov2(pixel_values, head_mask=head_mask, **kwargs) + + sequence_output = outputs.last_hidden_state # batch_size, sequence_length, hidden_size + cls_token = sequence_output[:, 0] + patch_tokens = sequence_output[:, 1:] + + linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) + logits = self.classifier(linear_input) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config, **kwargs) + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Dinov2 backbone, to be used with frameworks like DETR and MaskFormer. + """ +) +class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + @check_model_inputs + @auto_docstring + def forward( + self, pixel_values: torch.Tensor, output_hidden_states: Optional[bool] = None, **kwargs + ) -> BackboneOutput: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + + embedding_output = self.embeddings(pixel_values) + output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True) + hidden_states = output.hidden_states + + feature_maps = [] + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, 1:] + # this was actually a bug in the original implementation that we copied here, + # cause normally the order is height, width + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps.append(hidden_state) + + return BackboneOutput( + feature_maps=tuple(feature_maps), + hidden_states=hidden_states if output_hidden_states else None, + ) + + +__all__ = ["Dinov2ForImageClassification", "Dinov2Model", "Dinov2PreTrainedModel", "Dinov2Backbone"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/dinov2/modeling_flax_dinov2.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/dinov2/modeling_flax_dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ea2eaa3ebc5ac5db192dde220dfe114ef38235 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/dinov2/modeling_flax_dinov2.py @@ -0,0 +1,801 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax DINOv2 model.""" + +import collections.abc +import math +from typing import Optional + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict + +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward +from .configuration_dinov2 import Dinov2Config + + +DINOV2_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`Dinov2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +DINOV2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`Dinov2ImageProcessor.__call__`] + for details. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxDinov2PatchEmbeddings(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + image_size = self.config.image_size + patch_size = self.config.patch_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + + self.num_patches = num_patches + self.num_channels = self.config.num_channels + self.projection = nn.Conv( + self.config.hidden_size, + kernel_size=patch_size, + strides=patch_size, + padding="VALID", + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + ) + + # Copied from transformers.models.vit.modeling_flax_vit.FlaxViTPatchEmbeddings.__call__ + def __call__(self, pixel_values): + num_channels = pixel_values.shape[-1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.projection(pixel_values) + batch_size, _, _, channels = embeddings.shape + return jnp.reshape(embeddings, (batch_size, -1, channels)) + + +class FlaxDinov2Embeddings(nn.Module): + """Construct the CLS token, position and patch embeddings.""" + + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.cls_token = self.param( + "cls_token", + jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"), + (1, 1, self.config.hidden_size), + ) + if self.config.use_mask_token: + self.mask_token = self.param( + "mask_token", + jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"), + (1, self.config.hidden_size), + ) + self.patch_embeddings = FlaxDinov2PatchEmbeddings(self.config, dtype=self.dtype) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = self.param( + "position_embeddings", + jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"), + (1, num_patches + 1, self.config.hidden_size), + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def interpolate_pos_encoding(self, config, hidden_states, height, width, position_embeddings): + num_patches = hidden_states.shape[1] - 1 + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = hidden_states.shape[-1] + + h = height // config.patch_size + w = width // config.patch_size + height, width = h + 0.1, w + 0.1 + + patch_pos_embed = patch_pos_embed.reshape( + (1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + ) + patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 3, 1, 2)) + target_dtype = patch_pos_embed.dtype + new_height_ratio = jnp.float32(height / math.sqrt(num_positions)) + new_width_ratio = jnp.float32(width / math.sqrt(num_positions)) + + scale = jnp.array([new_height_ratio, new_width_ratio], dtype=jnp.float32) + translation = jnp.array([0.0, 0.0], dtype=jnp.float32) + + patch_pos_embed = jax.image.scale_and_translate( + patch_pos_embed.astype(jnp.float32), + shape=(patch_pos_embed.shape[0], patch_pos_embed.shape[1], h, w), + spatial_dims=(2, 3), + scale=scale, + translation=translation, + method="bicubic", + antialias=False, + ) + patch_pos_embed = patch_pos_embed.astype(target_dtype) + patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 2, 3, 1)).reshape((position_embeddings.shape[0], -1, dim)) + patch_pos_embed_expanded = jnp.tile(patch_pos_embed, (hidden_states.shape[0], 1, 1)) + class_pos_embed_expanded = jnp.tile(class_pos_embed, (hidden_states.shape[0], 1, 1)) + + return jnp.concatenate((class_pos_embed_expanded, patch_pos_embed_expanded), axis=1) + + def __call__(self, pixel_values, deterministic=True): + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embeddings.projection.dtype + height, width = pixel_values.shape[1], pixel_values.shape[2] + + embeddings = self.patch_embeddings(pixel_values.astype(target_dtype)) + + cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size)) + embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) + + embeddings = embeddings + self.interpolate_pos_encoding( + self.config, embeddings, height, width, self.position_embeddings + ) + + embeddings = self.dropout(embeddings, deterministic=deterministic) + return embeddings + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTSelfAttention with ViT->Dinov2 +class FlaxDinov2SelfAttention(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:" + " {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" + ), + use_bias=self.config.qkv_bias, + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" + ), + use_bias=self.config.qkv_bias, + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" + ), + use_bias=self.config.qkv_bias, + ) + + def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTSelfOutput with ViT->Dinov2 +class FlaxDinov2SelfOutput(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTAttention with ViT->Dinov2 +class FlaxDinov2Attention(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.attention = FlaxDinov2SelfAttention(self.config, dtype=self.dtype) + self.output = FlaxDinov2SelfOutput(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False): + attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +def ones_with_scale(key, shape, scale, dtype=jnp.float32): + return jnp.ones(shape, dtype) * scale + + +class FlaxDinov2LayerScale(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.lambda1 = self.config.layerscale_value * self.param( + "lambda1", + jax.nn.initializers.ones, + (self.config.hidden_size,), + ) + self.lambda1 = self.lambda1 * self.config.layerscale_value + + def __call__(self, hidden_states): + return self.lambda1 * hidden_states + + +# Copied from transformers.models.beit.modeling_flax_beit.FlaxBeitDropPath with Beit -> Dinov2 +class FlaxDinov2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + rate: float + + @nn.module.compact + def __call__(self, inputs, deterministic: Optional[bool] = True): + if self.rate == 0.0: + return inputs + keep_prob = 1.0 - self.rate + if deterministic: + return inputs + else: + shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + rng = self.make_rng("droppath") + random_tensor = keep_prob + jax.random.uniform(rng, shape=shape, dtype=inputs.dtype) + binary_tensor = jnp.floor(random_tensor) + output = inputs / keep_prob * binary_tensor + return output + + +class FlaxDinov2MLP(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.fc1 = nn.Dense( + self.config.hidden_size * self.config.mlp_ratio, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + self.fc2 = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + if isinstance(self.config.hidden_act, str): + self.act = ACT2FN[self.config.hidden_act] + else: + self.act = self.config.hidden_act + + def __call__(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class FlaxDinov2SwiGLUFFN(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + hidden_features = int(self.config.hidden_size * self.config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Dense( + 2 * hidden_features, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + self.weights_out = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + hidden_states = self.weights_in(hidden_states) + x1, x2 = jnp.split(hidden_states, 2, axis=-1) + hidden = nn.silu(x1) * x2 + return self.weights_out(hidden) + + +class FlaxDinov2Layer(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.norm1 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.attention = FlaxDinov2Attention(self.config, dtype=self.dtype) + self.layer_scale1 = FlaxDinov2LayerScale(self.config, dtype=self.dtype) + self.drop_path = FlaxDinov2DropPath(self.config.drop_path_rate) + self.norm2 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + if self.config.use_swiglu_ffn: + self.mlp = FlaxDinov2SwiGLUFFN(self.config, dtype=self.dtype) + else: + self.mlp = FlaxDinov2MLP(self.config, dtype=self.dtype) + + self.layer_scale2 = FlaxDinov2LayerScale(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): + self_attention_outputs = self.attention( + self.norm1(hidden_states), # in Dinov2, layernorm is applied before self-attention + deterministic=deterministic, + output_attentions=output_attentions, + ) + + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + + outputs = self_attention_outputs[1:] + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Dinov2, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTLayerCollection with ViT->Dinov2 +class FlaxDinov2LayerCollection(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxDinov2Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states,) + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTEncoder with ViT->Dinov2 +class FlaxDinov2Encoder(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxDinov2LayerCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class FlaxDinov2PreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Dinov2Config + base_model_prefix = "dinov2" + main_input_name = "pixel_values" + module_class: nn.Module = None + + def __init__( + self, + config: Dinov2Config, + input_shape=None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, config.num_channels) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + params_rng, dropout_rng = jax.random.split(rng) + dropout_rng, droppath_rng = jax.random.split(dropout_rng) + rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng} + + random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + pixel_values, + params: Optional[dict] = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + dropout_rng, droppath_rng = jax.random.split(dropout_rng) + rngs["dropout"] = dropout_rng + rngs["droppath"] = droppath_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(pixel_values, dtype=jnp.float32), + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + +class FlaxDinov2Module(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embeddings = FlaxDinov2Embeddings(self.config, dtype=self.dtype) + self.encoder = FlaxDinov2Encoder(self.config, dtype=self.dtype) + self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__( + self, + pixel_values, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + hidden_states = self.embeddings(pixel_values, deterministic=deterministic) + + encoder_outputs = self.encoder( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return FlaxBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The bare Dinov2 Model transformer outputting raw hidden-states without any specific head on top.", + DINOV2_START_DOCSTRING, +) +class FlaxDinov2Model(FlaxDinov2PreTrainedModel): + module_class = FlaxDinov2Module + + +FLAX_VISION_MODEL_DOCSTRING = """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, FlaxDinov2Model + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") + >>> model = FlaxDinov2Model.from_pretrained("facebook/dinov2-base") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +overwrite_call_docstring(FlaxDinov2Model, FLAX_VISION_MODEL_DOCSTRING) +append_replace_return_docstrings( + FlaxDinov2Model, output_type=FlaxBaseModelOutputWithPooling, config_class=Dinov2Config +) + + +class FlaxDinov2ForImageClassificationModule(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dinov2 = FlaxDinov2Module(config=self.config, dtype=self.dtype) + self.classifier = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + ) + + def __call__( + self, + pixel_values=None, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.dinov2( + pixel_values, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + cls_token = hidden_states[:, 0] + patch_tokens = hidden_states[:, 1:] + linear_input = jnp.concatenate([cls_token, patch_tokens.mean(axis=1)], axis=-1) + + logits = self.classifier(linear_input) + + if not return_dict: + output = (logits,) + outputs[2:] + return output + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + DINOV2_START_DOCSTRING, +) +class FlaxDinov2ForImageClassification(FlaxDinov2PreTrainedModel): + module_class = FlaxDinov2ForImageClassificationModule + + +FLAX_VISION_CLASSIFICATION_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoImageProcessor, FlaxDinov2ForImageClassification + >>> from PIL import Image + >>> import jax + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer") + >>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer", from_pt=True) + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) + >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) + ``` +""" + +overwrite_call_docstring(FlaxDinov2ForImageClassification, FLAX_VISION_CLASSIFICATION_DOCSTRING) +append_replace_return_docstrings( + FlaxDinov2ForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=Dinov2Config +) + + +__all__ = ["FlaxDinov2ForImageClassification", "FlaxDinov2Model", "FlaxDinov2PreTrainedModel"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/edgetam/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/edgetam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c1a55fc5bcd2fe4b083ffbac0633ad09076565 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/edgetam/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_edgetam import * + from .modeling_edgetam import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/edgetam/configuration_edgetam.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/edgetam/configuration_edgetam.py new file mode 100644 index 0000000000000000000000000000000000000000..07ccee36e93209cff330ae037bc7324ac686e3c3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/edgetam/configuration_edgetam.py @@ -0,0 +1,332 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/edgetam/modular_edgetam.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_edgetam.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ..auto import CONFIG_MAPPING, AutoConfig + + +class EdgeTamVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EdgeTamVisionModel`]. It is used to instantiate a SAM + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + defaults will yield a similar configuration to that of SAM 2.1 Hiera-tiny + [facebook/EdgeTAM](https://huggingface.co/facebook/EdgeTAM) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`Union[dict, "PretrainedConfig"]`, *optional*): + Configuration for the vision backbone. This is used to instantiate the backbone using + `AutoModel.from_config`. + backbone_channel_list (`List[int]`, *optional*, defaults to `[384, 192, 96, 48]`): + The list of channel dimensions for the backbone. + backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[256, 256], [128, 128], [64, 64]]`): + The spatial sizes of the feature maps from the backbone. + fpn_hidden_size (`int`, *optional*, defaults to 256): + The hidden dimension of the FPN. + fpn_kernel_size (`int`, *optional*, defaults to 1): + The kernel size for the convolutions in the neck. + fpn_stride (`int`, *optional*, defaults to 1): + The stride for the convolutions in the neck. + fpn_padding (`int`, *optional*, defaults to 0): + The padding for the convolutions in the neck. + fpn_top_down_levels (`List[int]`, *optional*, defaults to `[2, 3]`): + The levels for the top-down FPN connections. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of feature levels from the FPN to use. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the neck. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon for the layer normalization. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + """ + + base_config_key = "vision_config" + model_type = "edgetam_vision_model" + sub_configs = { + "backbone_config": AutoConfig, + } + + def __init__( + self, + backbone_config=None, + backbone_channel_list=None, + backbone_feature_sizes=None, + fpn_hidden_size=256, + fpn_kernel_size=1, + fpn_stride=1, + fpn_padding=0, + fpn_top_down_levels=None, + num_feature_levels=3, + hidden_act="gelu", + layer_norm_eps=1e-6, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + backbone_channel_list = [384, 192, 96, 48] if backbone_channel_list is None else backbone_channel_list + backbone_feature_sizes = ( + [[256, 256], [128, 128], [64, 64]] if backbone_feature_sizes is None else backbone_feature_sizes + ) + fpn_top_down_levels = [2, 3] if fpn_top_down_levels is None else fpn_top_down_levels + + if isinstance(backbone_config, dict): + backbone_config["model_type"] = backbone_config.get("model_type", "timm_wrapper") + backbone_config = CONFIG_MAPPING[backbone_config["model_type"]](**backbone_config) + elif isinstance(backbone_config, AutoConfig): + backbone_config = backbone_config + elif backbone_config is None: + backbone_config = AutoConfig.from_pretrained( + "timm/repvit_m1.dist_in1k", + model_args={"in_chans": 3, "features_only": True, "out_indices": [0, 1, 2, 3]}, + ) + + self.backbone_config = backbone_config + + # Neck + self.backbone_channel_list = backbone_channel_list + self.backbone_feature_sizes = backbone_feature_sizes + self.fpn_hidden_size = fpn_hidden_size + self.fpn_kernel_size = fpn_kernel_size + self.fpn_stride = fpn_stride + self.fpn_padding = fpn_padding + self.fpn_top_down_levels = fpn_top_down_levels + self.num_feature_levels = num_feature_levels + + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + + +class EdgeTamPromptEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EdgeTamPromptEncoder`]. The [`EdgeTamPromptEncoder`] + module is used to encode the input 2D points and bounding boxes. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + image_size (`int`, *optional*, defaults to 1024): + The expected output resolution of the image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + mask_input_channels (`int`, *optional*, defaults to 16): + The number of channels to be fed to the `MaskDecoder` module. + num_point_embeddings (`int`, *optional*, defaults to 4): + The number of point embeddings to be used. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the encoder and pooler. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + scale (`float`, *optional*, defaults to 1): + The scale factor for the prompt encoder. + """ + + base_config_key = "prompt_encoder_config" + + def __init__( + self, + hidden_size=256, + image_size=1024, + patch_size=16, + mask_input_channels=16, + num_point_embeddings=4, + hidden_act="gelu", + layer_norm_eps=1e-6, + scale=1, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.image_size = image_size + self.patch_size = patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.scale = scale + + +class EdgeTamMaskDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EdgeTamMaskDecoder`]. It is used to instantiate a EDGETAM + memory encoder according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the EDGETAM mask decoder. + mlp_dim (`int`, *optional*, defaults to 2048): + The dimension of the MLP in the two-way transformer. + num_hidden_layers (`int`, *optional*, defaults to 2): + The number of hidden layers in the two-way transformer. + num_attention_heads (`int`, *optional*, defaults to 8): + The number of attention heads in the two-way transformer. + attention_downsample_rate (`int`, *optional*, defaults to 2): + The downsample rate for the attention layers. + num_multimask_outputs (`int`, *optional*, defaults to 3): + The number of multimask outputs. + iou_head_depth (`int`, *optional*, defaults to 3): + The depth of the IoU head. + iou_head_hidden_dim (`int`, *optional*, defaults to 256): + The hidden dimension of the IoU head. + dynamic_multimask_via_stability (`bool`, *optional*, defaults to `True`): + Whether to use dynamic multimask via stability. + dynamic_multimask_stability_delta (`float`, *optional*, defaults to 0.05): + The stability delta for the dynamic multimask. + dynamic_multimask_stability_thresh (`float`, *optional*, defaults to 0.98): + The stability threshold for the dynamic multimask. + + """ + + base_config_key = "mask_decoder_config" + + def __init__( + self, + hidden_size=256, + hidden_act="gelu", + mlp_dim=2048, + num_hidden_layers=2, + num_attention_heads=8, + attention_downsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=256, + dynamic_multimask_via_stability=True, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_multimask_outputs = num_multimask_outputs + self.hidden_act = hidden_act + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + # TwoWayTransformer configuration + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.mlp_dim = mlp_dim + self.attention_downsample_rate = attention_downsample_rate + + +class EdgeTamConfig(PretrainedConfig): + r""" + [`EdgeTamConfig`] is the configuration class to store the configuration of a [`EdgeTamModel`]. It is used to instantiate a + EDGETAM model according to the specified arguments, defining the memory attention, memory encoder, and image encoder + configs. Instantiating a configuration defaults will yield a similar configuration to that of the SAM 2.1 Hiera-tiny + [facebook/edgetam.1-hiera-tiny](https://huggingface.co/facebook/edgetam.1-hiera-tiny) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (Union[`dict`, `EdgeTamVisionConfig`], *optional*): + Dictionary of configuration options used to initialize [`EdgeTamVisionConfig`]. + prompt_encoder_config (Union[`dict`, `EdgeTamPromptEncoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`EdgeTamPromptEncoderConfig`]. + mask_decoder_config (Union[`dict`, `EdgeTamMaskDecoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`EdgeTamMaskDecoderConfig`]. + initializer_range (`float`, *optional*, defaults to 0.02): + Standard deviation for parameter initialization. + + Example: + + ```python + >>> from transformers import ( + ... EdgeTamVisionConfig, + ... EdgeTamPromptEncoderConfig, + ... EdgeTamMaskDecoderConfig, + ... EdgeTamModel, + ... ) + + >>> # Initializing a EdgeTamConfig with `"facebook/edgetam.1_hiera_tiny"` style configuration + >>> configuration = EdgeTamconfig() + + >>> # Initializing a EdgeTamModel (with random weights) from the `"facebook/edgetam.1_hiera_tiny"` style configuration + >>> model = EdgeTamModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a EdgeTamConfig from a EdgeTamVisionConfig, EdgeTamPromptEncoderConfig, and EdgeTamMaskDecoderConfig + + >>> # Initializing EDGETAM vision encoder, memory attention, and memory encoder configurations + >>> vision_config = EdgeTamVisionConfig() + >>> prompt_encoder_config = EdgeTamPromptEncoderConfig() + >>> mask_decoder_config = EdgeTamMaskDecoderConfig() + + >>> config = EdgeTamConfig(vision_config, prompt_encoder_config, mask_decoder_config) + ```""" + + model_type = "edgetam" + sub_configs = { + "vision_config": AutoConfig, + "prompt_encoder_config": EdgeTamPromptEncoderConfig, + "mask_decoder_config": EdgeTamMaskDecoderConfig, + } + + def __init__( + self, + vision_config=None, + prompt_encoder_config=None, + mask_decoder_config=None, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + vision_config = vision_config if vision_config is not None else {} + prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {} + mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {} + + if isinstance(vision_config, dict): + vision_config["model_type"] = vision_config.get("model_type", "edgetam_vision_model") + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + if isinstance(prompt_encoder_config, EdgeTamPromptEncoderConfig): + prompt_encoder_config = prompt_encoder_config.to_dict() + if isinstance(mask_decoder_config, EdgeTamMaskDecoderConfig): + mask_decoder_config = mask_decoder_config.to_dict() + + self.vision_config = vision_config + self.prompt_encoder_config = EdgeTamPromptEncoderConfig(**prompt_encoder_config) + self.mask_decoder_config = EdgeTamMaskDecoderConfig(**mask_decoder_config) + + self.initializer_range = initializer_range + + +__all__ = ["EdgeTamConfig", "EdgeTamVisionConfig", "EdgeTamPromptEncoderConfig", "EdgeTamMaskDecoderConfig"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/edgetam/modeling_edgetam.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/edgetam/modeling_edgetam.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e3ee6009cf31d642af72a40f966d35cf11e367 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/edgetam/modeling_edgetam.py @@ -0,0 +1,1252 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/edgetam/modular_edgetam.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_edgetam.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import compile_compatible_method_lru_cache +from ...utils import ModelOutput, auto_docstring +from ..auto import AutoModel +from .configuration_edgetam import ( + EdgeTamConfig, + EdgeTamMaskDecoderConfig, + EdgeTamPromptEncoderConfig, + EdgeTamVisionConfig, +) + + +# fix this in modular +if True: + from transformers.models.timm_wrapper.modeling_timm_wrapper import TimmWrapperModel + + +class EdgeTamLayerNorm(nn.LayerNorm): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(normalized_shape, eps=eps, **kwargs) + if data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {data_format}") + self.data_format = data_format + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """ + Args: + features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) + """ + if self.data_format == "channels_first": + features = features.permute(0, 2, 3, 1) + features = super().forward(features) + features = features.permute(0, 3, 1, 2) + else: + features = super().forward(features) + return features + + +@dataclass +@auto_docstring(custom_intro="Base class for the vision encoder's outputs.") +class EdgeTamVisionEncoderOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + fpn_hidden_states (`tuple(torch.FloatTensor)`): + Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck. + fpn_position_encoding (`tuple(torch.FloatTensor)`): + Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the + model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + fpn_hidden_states: Optional[torch.FloatTensor] = None + fpn_position_encoding: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class EdgeTamAttention(nn.Module): + """ + EDGETAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__() + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + self.config = config + self.hidden_size = config.hidden_size + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.internal_dim // config.num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.o_proj = nn.Linear(self.internal_dim, self.hidden_size) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_similarity: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Input projections + batch_size, point_batch_size = query.shape[:2] + new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim) + + query = self.q_proj(query).view(*new_shape).transpose(1, 2) + key = self.k_proj(key).view(*new_shape).transpose(1, 2) + value = self.v_proj(value).view(*new_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=attention_similarity, + dropout=0.0, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape( + batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class EdgeTamTwoWayAttentionBlock(nn.Module): + def __init__(self, config: EdgeTamMaskDecoderConfig, skip_first_layer_pe: bool = False): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`EdgeTamMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__() + self.self_attn = EdgeTamAttention(config, downsample_rate=1) + self.layer_norm1 = nn.LayerNorm(config.hidden_size) + + self.cross_attn_token_to_image = EdgeTamAttention(config) + self.layer_norm2 = nn.LayerNorm(config.hidden_size) + + self.mlp = EdgeTamFeedForward( + config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers + ) + self.layer_norm3 = nn.LayerNorm(config.hidden_size) + + self.layer_norm4 = nn.LayerNorm(config.hidden_size) + self.cross_attn_image_to_token = EdgeTamAttention(config) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, + queries: Tensor, + keys: Tensor, + query_point_embedding: Tensor, + key_point_embedding: Tensor, + attention_similarity: Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + # Self attention block + if self.skip_first_layer_pe: + queries, _ = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out, _ = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_token_to_image( + query=query, key=key, value=keys, attention_similarity=attention_similarity + ) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + return queries, keys, attn_out + + +class EdgeTamFeedForward(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: str = "relu", + sigmoid_output: bool = False, + ): + super().__init__() + self.num_layers = num_layers + self.activation = ACT2FN[activation] + self.proj_in = nn.Linear(input_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = F.sigmoid(hidden_states) + return hidden_states + + +@auto_docstring +class EdgeTamPreTrainedModel(PreTrainedModel): + config_class = EdgeTamConfig + base_model_prefix = "edgetam" + main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn_2 = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (nn.LayerNorm, EdgeTamLayerNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + if isinstance(module, EdgeTamModel): + if module.no_memory_embedding is not None: + module.no_memory_embedding.data.zero_() + + +# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding +class EdgeTamSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None + ): + super().__init__() + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = 2 * math.pi if scale is None else scale + + @compile_compatible_method_lru_cache(maxsize=1) + def forward( + self, + shape: torch.Size, + device: Union[torch.device, str], + dtype: torch.dtype, + mask: Optional[Tensor] = None, + ) -> Tensor: + if mask is None: + mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool) + not_mask = (~mask).to(dtype) + y_embed = not_mask.cumsum(1) + x_embed = not_mask.cumsum(2) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class EdgeTamVisionNeck(nn.Module): + def __init__(self, config: EdgeTamVisionConfig): + super().__init__() + self.config = config + + self.position_encoding = EdgeTamSinePositionEmbedding( + num_pos_feats=config.fpn_hidden_size // 2, normalize=True + ) + self.convs = nn.ModuleList() + for in_channels in config.backbone_channel_list: + self.convs.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=config.fpn_hidden_size, + kernel_size=config.fpn_kernel_size, + stride=config.fpn_stride, + padding=config.fpn_padding, + ), + ) + self.fpn_top_down_levels = config.fpn_top_down_levels + + def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]: + fpn_hidden_states = () + fpn_position_encoding = () + + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + lateral_features = hidden_states[i].permute(0, 3, 1, 2) + lateral_features = self.convs[n - i](lateral_features) + if i not in self.fpn_top_down_levels or i == n: + prev_features = lateral_features + else: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode="nearest", + align_corners=None, + antialias=False, + ).to(lateral_features.dtype) + prev_features = lateral_features + top_down_features + + prev_position_encoding = self.position_encoding( + prev_features.shape, prev_features.device, prev_features.dtype + ).to(prev_features.dtype) + + fpn_hidden_states += (prev_features,) + fpn_position_encoding += (prev_position_encoding,) + + return fpn_hidden_states, fpn_position_encoding + + +@auto_docstring( + custom_intro=""" + The vision model from EdgeTAM without any head or projection on top. + """ +) +class EdgeTamVisionModel(EdgeTamPreTrainedModel): + config_class = EdgeTamVisionConfig + main_input_name = "pixel_values" + _can_record_outputs = {"hidden_states": TimmWrapperModel, "attentions": TimmWrapperModel} + + def __init__(self, config: EdgeTamVisionConfig): + super().__init__(config) + self.config = config + + self.backbone = AutoModel.from_config(config.backbone_config) + + self.neck = EdgeTamVisionNeck(config) + self.num_feature_levels = config.num_feature_levels + + self.post_init() + + @check_model_inputs + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, EdgeTamVisionEncoderOutput]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Forward through backbone + backbone_output = self.backbone(pixel_values) + intermediate_hidden_states = backbone_output.last_hidden_state + intermediate_hidden_states = [hidden_state.permute(0, 2, 3, 1) for hidden_state in intermediate_hidden_states] + + fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states) + # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution + fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1] + fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1] + + return EdgeTamVisionEncoderOutput( + last_hidden_state=intermediate_hidden_states[-1], + fpn_hidden_states=fpn_hidden_states, + fpn_position_encoding=fpn_position_encoding, + ) + + +@dataclass +@auto_docstring(custom_intro="Base class for the EdgeTam model's output.") +class EdgeTamImageSegmentationOutput(ModelOutput): + r""" + iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`): + The Intersection over Union (IoU) scores of the predicted masks. + pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`): + The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed + by the processor to be brought to the original image size. + object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`): + Logits for the object score, indicating if an object is present. + image_embeddings (`tuple(torch.FloatTensor)`): + The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each + tensor has shape `(batch_size, channels, height, width)`. + vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. + Hidden-states of the vision model at the output of each stage. + vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the vision model. + mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the mask decoder. + """ + + iou_scores: Optional[torch.FloatTensor] = None + pred_masks: Optional[torch.FloatTensor] = None + object_score_logits: Optional[torch.FloatTensor] = None + image_embeddings: tuple[torch.FloatTensor, ...] = None + vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class EdgeTamPositionalEmbedding(nn.Module): + def __init__(self, config: EdgeTamPromptEncoderConfig): + super().__init__() + self.scale = config.scale + positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2)) + self.register_buffer("positional_embedding", positional_embedding) + + def forward(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = input_coords.clone() + + if input_shape is not None: + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] + coordinates.to(torch.float32) + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = coordinates.to(self.positional_embedding.dtype) + coordinates = coordinates @ self.positional_embedding + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1) + + +class EdgeTamMaskEmbedding(nn.Module): + def __init__(self, config: EdgeTamPromptEncoderConfig): + super().__init__() + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2) + self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1) + self.layer_norm1 = EdgeTamLayerNorm( + self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.layer_norm2 = EdgeTamLayerNorm( + self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first" + ) + + def forward(self, masks): + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + return dense_embeddings + + +class EdgeTamPromptEncoder(nn.Module): + def __init__(self, config: EdgeTamPromptEncoderConfig): + super().__init__() + self.shared_embedding = EdgeTamPositionalEmbedding(config) + self.mask_embed = EdgeTamMaskEmbedding(config) + self.no_mask_embed = nn.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) + self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size) + self.input_image_size = config.image_size + + self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size) + self.hidden_size = config.hidden_size + self.not_a_point_embed = nn.Embedding(1, config.hidden_size) + + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0) + labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + # torch.where and expanding the labels tensor is required by the ONNX export + point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) + + # This is required for the ONNX export. The dtype, device need to be explicitly + # specified as otherwise torch.onnx.export interprets as double + point_embedding = torch.where( + labels[..., None] != -10, + point_embedding, + torch.zeros_like(point_embedding), + ) + + # Add point embeddings for labels >= 0 + point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1) + + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes += 0.5 # Shift to center of pixel + coords = boxes.view(*boxes.shape[:2], 2, 2) + # add padding point for consistency with the original implementation + coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0) + corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size)) + corner_embedding[:, :, 0, :] += self.point_embed.weight[2] + corner_embedding[:, :, 1, :] += self.point_embed.weight[3] + corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :]) + return corner_embedding + + def forward( + self, + input_points: Optional[tuple[torch.Tensor, torch.Tensor]], + input_labels: Optional[torch.Tensor], + input_boxes: Optional[torch.Tensor], + input_masks: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`torch.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`torch.Tensor`, *optional*): + boxes to embed + masks (`torch.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + batch_size = 1 + if input_points is not None: + batch_size = input_points.shape[0] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = point_embeddings + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class EdgeTamTwoWayTransformer(nn.Module): + def __init__(self, config: EdgeTamMaskDecoderConfig): + super().__init__() + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = nn.ModuleList() + + for i in range(self.num_hidden_layers): + self.layers.append(EdgeTamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) + + self.final_attn_token_to_image = EdgeTamAttention(config) + self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) + + def forward( + self, + point_embeddings: Tensor, + image_embeddings: Tensor, + image_positional_embeddings: Tensor, + attention_similarity: Tensor, + target_embedding=None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutput]: + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + if target_embedding is not None: + queries += target_embedding + + queries, keys, _ = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + attention_similarity=attention_similarity, + **kwargs, + ) + # Apply the final attention layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys + + +class EdgeTamMaskDecoder(nn.Module): + def __init__(self, config: EdgeTamMaskDecoderConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = nn.Embedding(1, self.hidden_size) + self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size) + + self.transformer = EdgeTamTwoWayTransformer(config) + + # should we create a new class for this? + self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + self.upscale_layer_norm = EdgeTamLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.activation = nn.GELU() + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [EdgeTamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] + self.output_hypernetworks_mlps = nn.ModuleList(mlps_list) + self.iou_prediction_head = EdgeTamFeedForward( + self.hidden_size, + config.iou_head_hidden_dim, + self.num_mask_tokens, + config.iou_head_depth, + sigmoid_output=True, + ) + + self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1) + self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1) + + self.obj_score_token = nn.Embedding(1, self.hidden_size) + self.pred_obj_score_head = EdgeTamFeedForward(self.hidden_size, self.hidden_size, 1, 3) + + self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_positional_embeddings: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + high_resolution_features: list[torch.Tensor], + attention_similarity: Optional[torch.Tensor] = None, + target_embedding: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Args: + image_embeddings (`torch.Tensor`): + The embeddings from the image encoder. + image_positional_embeddings (`torch.Tensor`): + Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (`torch.Tensor`): + The embeddings of the points and boxes. + dense_prompt_embeddings (`torch.Tensor`): + The embeddings of the mask inputs. + multimask_output (`bool`): + Whether to return multiple masks or a single mask. + high_resolution_features (`list[torch.Tensor]`, *optional*): + The high-resolution features from the vision encoder. + attention_similarity (`torch.Tensor`, *optional*): + The attention similarity tensor. + target_embedding (`torch.Tensor`, *optional*): + The target embedding. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + # Concatenate output tokens + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if sparse_prompt_embeddings.shape[0] != 0: + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + # Expand per-image data in batch direction to be per-mask + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + # Run the transformer + point_embeddings, image_embeddings = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + **kwargs, + ) + iou_token_out = point_embeddings[:, :, 1, :] + mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + image_embeddings = image_embeddings.transpose(2, 3).view( + batch_size * point_batch_size, num_channels, height, width + ) + + feat_s0, feat_s1 = high_resolution_features + feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0) + feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0) + upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1 + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0) + + hyper_in_list: list[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = torch.stack(hyper_in_list, dim=2) + + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width) + masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :]) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + elif self.dynamic_multimask_via_stability and not self.training: + mask_slice = slice(0, 1) + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape + + return masks, iou_pred, sam_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, :, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, :, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P] + best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + best_scores_inds_expanded = best_scores_inds_expanded.expand( + -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1) + ) + best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W] + best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1] + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, :, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, :, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out + + +@auto_docstring( + custom_intro=""" + Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and + input points and labels, boxes, or masks. + """ +) +class EdgeTamModel(EdgeTamPreTrainedModel): + _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + # need to be ignored, as it's a buffer and will not be correctly detected as tied weight + _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] + _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamTwoWayAttentionBlock, index=2)} + _keys_to_ignore_on_load_unexpected = [ + r"^memory_.*", + r"^mask_downsample.*", + r"spatial_perceiver.*", + r"^object_pointer_proj.*", + r"^temporal_positional_encoding_projection_layer.*", + "no_memory_positional_encoding", + "no_object_pointer", + "occlusion_spatial_embedding_parameter", + ] + + def __init__(self, config: EdgeTamConfig): + super().__init__(config) + self.shared_image_embedding = EdgeTamPositionalEmbedding(config.prompt_encoder_config) + self.vision_encoder = AutoModel.from_config(config.vision_config) + self.prompt_encoder = EdgeTamPromptEncoder(config.prompt_encoder_config) + # The module using it is not a PreTrainedModel subclass so we need this + config.mask_decoder_config._attn_implementation = config._attn_implementation + self.mask_decoder = EdgeTamMaskDecoder(config.mask_decoder_config) + + self.num_feature_levels = config.vision_config.num_feature_levels + self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes + # a single token to indicate no memory embedding from previous frames + self.hidden_dim = config.vision_config.fpn_hidden_size + self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + + self.post_init() + + def _tie_weights(self): + self.prompt_encoder.shared_embedding.positional_embedding.data = ( + self.shared_image_embedding.positional_embedding.data + ) + + def get_image_wide_positional_embeddings(self) -> torch.Tensor: + size = self.prompt_encoder.image_embedding_size + target_device = self.shared_image_embedding.positional_embedding.device + target_dtype = self.shared_image_embedding.positional_embedding.dtype + grid = torch.ones(size, device=target_device, dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size[0] + x_embed = x_embed / size[1] + + positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) + return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width + + @torch.no_grad() + def get_image_embeddings( + self, + pixel_values: torch.FloatTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> list[torch.Tensor]: + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + """ + batch_size = pixel_values.shape[0] + feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + return image_embeddings + + @torch.no_grad() + def get_prompt_embeddings( + self, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + @check_model_inputs + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> EdgeTamImageSegmentationOutput: + r""" + input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `forward` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoModel, AutoProcessor + + >>> model = AutoModel.from_pretrained("danelcsb/edgetam.1_hiera_tiny") + >>> processor = AutoProcessor.from_pretrained("danelcsb/edgetam.1_hiera_tiny") + + >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + >>> input_points = [[[400, 650]]] # 2D location of a window on the car + >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") + + >>> # Get segmentation mask + >>> outputs = model(**inputs) + + >>> # Postprocess masks + >>> masks = processor.post_process_masks( + ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] + ... ) + ``` + """ + if not ((pixel_values is None) ^ (image_embeddings is None)): + raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.") + if input_points is not None and input_boxes is not None: + if input_points.shape[1] != input_boxes.shape[1]: + raise ValueError( + f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}." + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features( + pixel_values, + **kwargs, + ) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + if input_points is not None and input_labels is None: + input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) + + if input_points is None and input_boxes is None: + # If no points are provide, pad with an empty point (with label -1) + input_points = torch.zeros( + batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device + ) + input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device) + + if input_masks is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size: + input_masks = F.interpolate( + input_masks.float(), + size=self.prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ).to(input_masks.dtype) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder( + image_embeddings=image_embeddings[-1], + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + high_resolution_features=image_embeddings[:-1], + attention_similarity=attention_similarity, + target_embedding=target_embedding, + **kwargs, + ) + + return EdgeTamImageSegmentationOutput( + iou_scores=iou_scores, + pred_masks=low_res_multimasks, + object_score_logits=object_score_logits, + image_embeddings=image_embeddings, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + ) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[ + list[torch.Tensor], + list[torch.Tensor], + Optional[tuple[torch.FloatTensor, ...]], + Optional[tuple[torch.FloatTensor, ...]], + ]: + r""" + Extract and preprocess image features using the vision encoder. + + Args: + pixel_values (`torch.FloatTensor`): + Input pixel values of shape `(batch_size, num_channels, height, width)`. + + Returns: + `tuple`: A tuple containing: + - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels. + - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level. + - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder. + - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder. + """ + vision_outputs: EdgeTamVisionEncoderOutput = self.vision_encoder( + pixel_values, + **kwargs, + ) + + feature_maps = vision_outputs.fpn_hidden_states + feature_maps_position_embeddings = vision_outputs.fpn_position_encoding + + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + feature_maps = list(feature_maps) + feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0]) + feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1]) + + # flatten NxCxHxW to HWxNxC + feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps] + feature_maps_position_embeddings = [ + feature_map_position_embedding.flatten(2).permute(2, 0, 1) + for feature_map_position_embedding in feature_maps_position_embeddings + ] + + return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions + + +__all__ = ["EdgeTamModel", "EdgeTamVisionModel", "EdgeTamPreTrainedModel"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/edgetam/modular_edgetam.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/edgetam/modular_edgetam.py new file mode 100644 index 0000000000000000000000000000000000000000..e26d58d96b812e56e0befdcef05de612c5cdce41 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/edgetam/modular_edgetam.py @@ -0,0 +1,261 @@ +# coding=utf-8 +# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SAM 2 model.""" + +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from transformers.models.sam2.configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig +from transformers.models.sam2.modeling_sam2 import ( + Sam2Attention, + Sam2FeedForward, + Sam2LayerNorm, + Sam2Model, + Sam2PreTrainedModel, + Sam2TwoWayAttentionBlock, + Sam2VisionEncoderOutput, + Sam2VisionModel, +) +from transformers.utils.generic import TransformersKwargs, check_model_inputs + +from ...configuration_utils import PretrainedConfig +from ...processing_utils import Unpack +from ...utils import ( + auto_docstring, +) +from ..auto import CONFIG_MAPPING, AutoConfig + + +# fix this in modular +if True: + from transformers.models.timm_wrapper.modeling_timm_wrapper import TimmWrapperModel + + +class EdgeTamVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EdgeTamVisionModel`]. It is used to instantiate a SAM + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + defaults will yield a similar configuration to that of SAM 2.1 Hiera-tiny + [facebook/EdgeTAM](https://huggingface.co/facebook/EdgeTAM) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`Union[dict, "PretrainedConfig"]`, *optional*): + Configuration for the vision backbone. This is used to instantiate the backbone using + `AutoModel.from_config`. + backbone_channel_list (`List[int]`, *optional*, defaults to `[384, 192, 96, 48]`): + The list of channel dimensions for the backbone. + backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[256, 256], [128, 128], [64, 64]]`): + The spatial sizes of the feature maps from the backbone. + fpn_hidden_size (`int`, *optional*, defaults to 256): + The hidden dimension of the FPN. + fpn_kernel_size (`int`, *optional*, defaults to 1): + The kernel size for the convolutions in the neck. + fpn_stride (`int`, *optional*, defaults to 1): + The stride for the convolutions in the neck. + fpn_padding (`int`, *optional*, defaults to 0): + The padding for the convolutions in the neck. + fpn_top_down_levels (`List[int]`, *optional*, defaults to `[2, 3]`): + The levels for the top-down FPN connections. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of feature levels from the FPN to use. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the neck. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon for the layer normalization. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + """ + + base_config_key = "vision_config" + model_type = "edgetam_vision_model" + sub_configs = { + "backbone_config": AutoConfig, + } + + def __init__( + self, + backbone_config=None, + backbone_channel_list=None, + backbone_feature_sizes=None, + fpn_hidden_size=256, + fpn_kernel_size=1, + fpn_stride=1, + fpn_padding=0, + fpn_top_down_levels=None, + num_feature_levels=3, + hidden_act="gelu", + layer_norm_eps=1e-6, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + backbone_channel_list = [384, 192, 96, 48] if backbone_channel_list is None else backbone_channel_list + backbone_feature_sizes = ( + [[256, 256], [128, 128], [64, 64]] if backbone_feature_sizes is None else backbone_feature_sizes + ) + fpn_top_down_levels = [2, 3] if fpn_top_down_levels is None else fpn_top_down_levels + + if isinstance(backbone_config, dict): + backbone_config["model_type"] = backbone_config.get("model_type", "timm_wrapper") + backbone_config = CONFIG_MAPPING[backbone_config["model_type"]](**backbone_config) + elif isinstance(backbone_config, AutoConfig): + backbone_config = backbone_config + elif backbone_config is None: + backbone_config = AutoConfig.from_pretrained( + "timm/repvit_m1.dist_in1k", + model_args={"in_chans": 3, "features_only": True, "out_indices": [0, 1, 2, 3]}, + ) + + self.backbone_config = backbone_config + + # Neck + self.backbone_channel_list = backbone_channel_list + self.backbone_feature_sizes = backbone_feature_sizes + self.fpn_hidden_size = fpn_hidden_size + self.fpn_kernel_size = fpn_kernel_size + self.fpn_stride = fpn_stride + self.fpn_padding = fpn_padding + self.fpn_top_down_levels = fpn_top_down_levels + self.num_feature_levels = num_feature_levels + + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + + +class EdgeTamPromptEncoderConfig(Sam2PromptEncoderConfig): + pass + + +class EdgeTamMaskDecoderConfig(Sam2MaskDecoderConfig): + pass + + +class EdgeTamConfig(Sam2Config): + pass + + +class EdgeTamLayerNorm(Sam2LayerNorm): + pass + + +class EdgeTamVisionEncoderOutput(Sam2VisionEncoderOutput): + pass + + +class EdgeTamAttention(Sam2Attention): + pass + + +class EdgeTamTwoWayAttentionBlock(Sam2TwoWayAttentionBlock): + pass + + +class EdgeTamFeedForward(Sam2FeedForward): + pass + + +@auto_docstring +class EdgeTamPreTrainedModel(Sam2PreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (nn.LayerNorm, EdgeTamLayerNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + if isinstance(module, EdgeTamModel): + if module.no_memory_embedding is not None: + module.no_memory_embedding.data.zero_() + + +@auto_docstring( + custom_intro=""" + The vision model from EdgeTAM without any head or projection on top. + """ +) +class EdgeTamVisionModel(Sam2VisionModel): + config_class = EdgeTamVisionConfig + main_input_name = "pixel_values" + _can_record_outputs = {"hidden_states": TimmWrapperModel, "attentions": TimmWrapperModel} + + def get_input_embeddings(self): + raise NotImplementedError("Can't get input embeddings from timm wrapper model") + + @check_model_inputs + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, EdgeTamVisionEncoderOutput]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Forward through backbone + backbone_output = self.backbone(pixel_values) + intermediate_hidden_states = backbone_output.last_hidden_state + intermediate_hidden_states = [hidden_state.permute(0, 2, 3, 1) for hidden_state in intermediate_hidden_states] + + fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states) + # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution + fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1] + fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1] + + return EdgeTamVisionEncoderOutput( + last_hidden_state=intermediate_hidden_states[-1], + fpn_hidden_states=fpn_hidden_states, + fpn_position_encoding=fpn_position_encoding, + ) + + +class EdgeTamModel(Sam2Model): + _keys_to_ignore_on_load_unexpected = [ + r"^memory_.*", + r"^mask_downsample.*", + r"spatial_perceiver.*", + r"^object_pointer_proj.*", + r"^temporal_positional_encoding_projection_layer.*", + "no_memory_positional_encoding", + "no_object_pointer", + "occlusion_spatial_embedding_parameter", + ] + + def get_input_embeddings(self): + raise NotImplementedError("Can't get input embeddings from timm wrapper model") + + +__all__ = [ + "EdgeTamModel", + "EdgeTamVisionModel", + "EdgeTamPreTrainedModel", + "EdgeTamConfig", + "EdgeTamVisionConfig", + "EdgeTamPromptEncoderConfig", + "EdgeTamMaskDecoderConfig", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d8555f58d1866451c38abb5559ef5bef9545f0b0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_emu3 import * + from .image_processing_emu3 import * + from .modeling_emu3 import * + from .processing_emu3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/configuration_emu3.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/configuration_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..79407bfbf0d73cfac40297ccfa472f5e451f352d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/configuration_emu3.py @@ -0,0 +1,328 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class Emu3VQVAEConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Emu3VQVAE`]. It is used to instantiate an VQ-VAE + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a configuration to the VQ model presented in Emu3 paper. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + codebook_size (`int`, *optional*, defaults to 32768): + Codebook size of the VQ model. + embed_dim (`int`, *optional*, defaults to 4): + Dimension of the quantized vector in codebook. + latent_channels (`int`, *optional*, defaults to 4): + Dimension of the output channel of encoder and the input channel of decoder + double_latent (`bool`, *optional*, defaults to `False`): + Whether double the output dim of the encoder. + in_channels (`int`, *optional*, defaults to 3): + Input channel of encoder. + out_channels (`int`, *optional*, defaults to 3): + Output channel of decoder. + temporal_downsample_factor (`int`, *optional*, defaults to 4): + Temporal downsample factor. + base_channels (`int`, *optional*, defaults to 256): + Basic channel number of the intermediate blocks. + channel_multiplier (`list[int]`, *optional*, defaults to `[1, 2, 2, 4]`): + Channel scaling factor of the intermediate blocks. + num_res_blocks (`int`, *optional*, defaults to 2): + Residual block number in each stage. + attn_resolutions (`list[int]`, *optional*, defaults to `[3]`): + Stage indices to apply attention. + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations in the attention layer. + num_attention_heads (`int`, *optional*, defaults to 1): + Number of attention heads for each attention layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Emu3VQVAE, Emu3VQVAEConfig + + >>> # Initializing a video VQ model of Emu3 configuration + >>> configuration = Emu3VQVAEConfig() + + >>> # Initializing a model from the Emu3 VQ model style configuration + >>> model = Emu3VQVAE(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "emu3_vqgan" + base_config_key = "vq_config" + + def __init__( + self, + codebook_size: int = 32768, + embed_dim: int = 4, + latent_channels: int = 4, + double_latent: bool = False, + in_channels: int = 3, + out_channels: int = 3, + temporal_downsample_factor: int = 4, + base_channels: int = 256, + channel_multiplier: list[int] = [1, 2, 2, 4], + num_res_blocks: int = 2, + attn_resolutions: list[int] = [3], + hidden_size: int = 1024, + num_attention_heads: int = 1, + attention_dropout: float = 0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.codebook_size = codebook_size + self.embed_dim = embed_dim + self.latent_channels = latent_channels + self.double_latent = double_latent + self.in_channels = in_channels + self.out_channels = out_channels + self.temporal_downsample_factor = temporal_downsample_factor + self.base_channels = base_channels + self.channel_multiplier = channel_multiplier + self.num_res_blocks = num_res_blocks + self.attn_resolutions = attn_resolutions + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.attention_dropout = attention_dropout + + +class Emu3TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Emu3TextModel`]. It is used to instantiate a + emu3 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the + [Emu3-community/Emu3-Chat-hf](https://huggingface.co/Emu3-community/Emu3-Chat-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 184622): + Vocabulary size of the Emu3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Emu3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 9216): + The maximum sequence length that this model might ever be used with. Emu supports up to 9216 tokens, + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 151643): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 151849): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 151850): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + + ```python + >>> from transformers import Emu3Model, Emu3Config + + >>> # Initializing a Emu3-community/Emu3-Chat-hf style configuration + >>> configuration = Emu3Config() + + >>> # Initializing a model from the Emu3-community/Emu3-Chat-hf style configuration + >>> model = Emu3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "emu3_text_model" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size: int = 184622, + hidden_size: int = 4096, + intermediate_size: int = 14336, + num_hidden_layers: int = 32, + num_attention_heads: int = 32, + num_key_value_heads: Optional[int] = 8, + hidden_act: str = "silu", + max_position_embeddings: int = 9216, + rms_norm_eps: float = 1e-5, + use_cache: bool = True, + pad_token_id: int = 151643, + bos_token_id: int = 151849, + eos_token_id: int = 151850, + tie_word_embeddings: bool = False, + rope_theta: float = 1000000.0, + rope_scaling: Optional[dict[str, Any]] = None, + mlp_bias=False, + attention_bias=False, + attention_dropout: float = 0.1, + initializer_range: float = 0.02, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.mlp_bias = mlp_bias + self.attention_bias = attention_bias + self.initializer_range = initializer_range + rope_config_validation(self) + + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Emu3Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Emu3Model`]. It is used to instantiate a + emu3 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the + [Emu3-community/Emu3-Chat-hf](https://huggingface.co/Emu3-community/Emu3-Chat-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vq_config (`Union[Dict, Emu3VQVAEConfig]`, *optional*): + Emu3VQVAEConfig instance containing the configuration for the VQ-VAE model. + text_config (`Union[Dict, Emu3TextConfig]``, *optional*): + Emu3TextConfig instance containing the configuration for the language model. + vocabulary_map (`dict`, *optional*): + A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs. + """ + + model_type = "emu3" + keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = {"text_config": Emu3TextConfig, "vq_config": Emu3VQVAEConfig} + + def __init__( + self, + vq_config: Union[dict, Emu3VQVAEConfig] = None, + text_config: Union[dict, Emu3TextConfig] = None, + vocabulary_map: Optional[dict[int, int]] = None, + **kwargs, + ): + if vq_config is None: + vq_config = Emu3VQVAEConfig() + elif isinstance(vq_config, dict): + vq_config = Emu3VQVAEConfig(**vq_config) + + if text_config is None: + text_config = Emu3TextConfig() + elif isinstance(text_config, dict): + text_config = Emu3TextConfig(**text_config) + + self.vq_config = vq_config + self.text_config = text_config + self.vocabulary_map = vocabulary_map + self.image_token_id = vocabulary_map.get("") if vocabulary_map is not None else None + + super().__init__(**kwargs) + + +__all__ = ["Emu3Config", "Emu3TextConfig", "Emu3VQVAEConfig"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/image_processing_emu3.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/image_processing_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..50ce82e01de8ae5f17b009e40b1c28071e7537fb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/image_processing_emu3.py @@ -0,0 +1,529 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Iterable +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + make_nested_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +if is_vision_available(): + from PIL import Image + +logger = logging.get_logger(__name__) + + +def smart_resize( + height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280 +): + """Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + + """ + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +class Emu3ImageProcessor(BaseImageProcessor): + r""" + Constructs a Emu3 image processor that dynamically resizes images based on the original images. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `list[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + min_pixels (`int`, *optional*, defaults to `512 * 512`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `1024 * 1024`): + The max pixels of the image to resize the image. + spatial_factor (`int`, *optional*, defaults to 8): + The spatial downsample factor the image will be downsampled in feature extracting phase + """ + + model_input_names = ["pixel_values", "image_sizes"] + + def __init__( + self, + do_resize: bool = True, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + do_pad: bool = True, + min_pixels: int = 512 * 512, + max_pixels: int = 1024 * 1024, + spatial_factor: int = 8, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.spatial_factor = spatial_factor + self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} + self.do_convert_rgb = do_convert_rgb + + def _preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: Optional[bool] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. + vision_info (`list[Dict]`, *optional*): + Optional list of dictionaries containing additional information about vision inputs. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + images = make_flat_list_of_images(images) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + height, width = get_image_size(images[0], channel_dim=input_data_format) + resized_height, resized_width = height, width + processed_images = [] + for image in images: + if do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.spatial_factor, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + image = resize( + image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + processed_images.append(image) + + images = np.array(processed_images) + return images + + def _pad_for_batching( + self, + pixel_values: list[np.ndarray], + image_sizes: list[list[int]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. + + Args: + pixel_values (`list[np.ndarray]`): + An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`) + image_sizes (`list[list[int]]`): + A list of sizes for each image in `pixel_values` in (height, width) format. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + + Returns: + list[`np.ndarray`]: The padded images. + """ + + max_shape = ( + max(size[0] for size in image_sizes), + max(size[1] for size in image_sizes), + ) + pixel_values = [ + pad( + image, + padding=((0, max_shape[0] - size[0]), (0, max_shape[1] - size[1])), + data_format=data_format, + input_data_format=input_data_format, + ) + for image, size in zip(pixel_values, image_sizes) + ] + return pixel_values + + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: Optional[bool] = None, + do_pad: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_pad = do_pad if do_pad is not None else self.do_pad + + if images is not None: + images = self.fetch_images(images) + images = make_nested_list_of_images(images) + + if images is not None and not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + pixel_values = [] + for image in images: + if image: + image = self._preprocess( + image, + do_resize=do_resize, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + ) + pixel_values.extend(image) + + image_sizes = [image.shape[-2:] for image in pixel_values] + if do_pad: + pixel_values = self._pad_for_batching(pixel_values, image_sizes) + pixel_values = np.array(pixel_values) + + return BatchFeature( + data={"pixel_values": pixel_values, "image_sizes": image_sizes}, tensor_type=return_tensors + ) + + def postprocess( + self, + images: ImageInput, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Union[str, TensorType] = "PIL.Image.Image", + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Postprocess an image or batch of images tensor. Postprocess is the reverse process of preprocess. + The parameters should be same as in preprocess. + Args: + images (`ImageInput`): + Image to postprocess. Expects a single or batch of images with pixel values ranging from -1 to 1. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = 1.0 / self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_flat_list_of_images(images) + if isinstance(images[0], Image.Image): + return images if len(images) > 1 else images[0] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + pixel_values = [] + for image in images: + image = to_numpy_array(image) + if do_normalize: + image = self.unnormalize( + image=image, image_mean=image_mean, image_std=image_std, input_data_format=input_data_format + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + image = image.clip(0, 255).astype(np.uint8) + + if do_normalize and do_rescale and return_tensors == "PIL.Image.Image": + image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format) + pixel_values.append(Image.fromarray(image)) + else: + pixel_values.extend(image) + + data = {"pixel_values": pixel_values} + return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None + + return BatchFeature(data=data, tensor_type=return_tensors) + + def unnormalize( + self, + image: np.ndarray, + image_mean: Union[float, Iterable[float]], + image_std: Union[float, Iterable[float]], + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Unnormalizes `image` using the mean and standard deviation specified by `mean` and `std`. + image = (image * image_std) + image_mean + Args: + image (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`): + Batch of pixel values to postprocess. + image_mean (`float` or `Iterable[float]`): + The mean to use for unnormalization. + image_std (`float` or `Iterable[float]`): + The standard deviation to use for unnormalization. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + num_channels = 3 + + if isinstance(image_mean, Iterable): + if len(image_mean) != num_channels: + raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(image_mean)}") + else: + image_mean = [image_mean] * num_channels + + if isinstance(image_std, Iterable): + if len(image_std) != num_channels: + raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(image_std)}") + else: + image_std = [image_std] * num_channels + + rev_image_mean = tuple(-mean / std for mean, std in zip(image_mean, image_std)) + rev_image_std = tuple(1 / std for std in image_std) + image = self.normalize( + image=image, mean=rev_image_mean, std=rev_image_std, input_data_format=input_data_format + ) + return image + + +__all__ = ["Emu3ImageProcessor"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/modeling_emu3.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/modeling_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..5e791d1042f6b5316b34a67f30477bdd0b83d825 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/modeling_emu3.py @@ -0,0 +1,1638 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/emu3/modular_emu3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_emu3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import cached_property +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Emu3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Emu3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class Emu3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Emu3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Emu3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Emu3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Emu3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Emu3Attention(config=config, layer_idx=layer_idx) + + self.mlp = Emu3MLP(config) + self.input_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.dropout = nn.Dropout(config.attention_dropout) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class Emu3VQVAEVectorQuantizer(nn.Module): + """ + A module for vector quantization using learned embedding vectors. + + This module implements the quantization process similar to te one described in + the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous + input vectors into discrete codebook vectors, which are learned during training. + Current implementation improves over previous ones by avoiding costly matrix multiplications + and allowing for post-hoc remapping of indices. + """ + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + self.embedding = nn.Embedding(config.codebook_size, config.embed_dim) + self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size) + + def forward(self, hidden_state: torch.Tensor): + batch_size, temporal, channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 1, 3, 4, 2).contiguous() + hidden_state_flattened = hidden_state.view(-1, channels) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + hidden_state_sum = torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + embedding_sum = torch.sum(self.embedding.weight**2, dim=1) + + # "bd,dn->bn", + distances = 2 * torch.matmul(hidden_state_flattened, self.embedding.weight.transpose(0, 1)) + distances = hidden_state_sum + embedding_sum - distances + + min_encoding_indices = torch.argmin(distances, dim=1) + min_encoding_indices = min_encoding_indices.view(batch_size, temporal, height, width) + return min_encoding_indices + + +class Emu3VQVAEEncoderConvDownsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, hidden_states): + # no asymmetric padding in torch conv, must do it ourselves + hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAEEncoderConvUpsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states): + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAEConv3d(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: tuple[int], + stride: tuple[int], + ): + super().__init__() + + padding_sizes = [one_kernel - one_stride for one_kernel, one_stride in zip(kernel_size[1:], stride[1:])] + self.padding = () + for pad_size in padding_sizes[::-1]: + self.padding += (pad_size // 2 + pad_size % 2, pad_size // 2) + self.padding += (2, 0) + + self.conv = nn.Conv3d( + in_channel, + out_channel, + kernel_size, + stride=stride, + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = F.pad(hidden_states, self.padding) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAESpatialNorm(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm( + num_channels=out_channels, + num_groups=32, + eps=1e-6, + affine=True, + ) + + self.conv_y = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.conv_b = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + quant_states = F.interpolate(quant_states, size=hidden_states.shape[-2:], mode="nearest") + hidden_states = self.norm_layer(hidden_states) + hidden_states = hidden_states * self.conv_y(quant_states) + self.conv_b(quant_states) + return hidden_states + + +class Emu3VQVAETemporalUpsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + batch_size, channels, temporal, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 1, 3, 4, 2).contiguous().view(batch_size, -1, temporal) + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = hidden_states.view(batch_size, channels, height, width, -1).permute(0, 1, 4, 2, 3).contiguous() + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalDownsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(4, 3, 3), + stride=(2, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.BatchNorm3d(in_channels) + self.conv1 = Emu3VQVAEConv3d( + in_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + self.norm2 = nn.BatchNorm3d(out_channels) + self.conv2 = Emu3VQVAEConv3d( + out_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + quant_channels: Optional[int] = None, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.quant_channels = quant_channels + + if quant_channels is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.norm1 = Emu3VQVAESpatialNorm(quant_channels, in_channels) + self.norm2 = Emu3VQVAESpatialNorm(quant_channels, out_channels) + + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_channels: Optional[torch.Tensor] = None): + norm_args = () if self.quant_channels is None else (quant_channels,) + + residual = hidden_states + hidden_states = self.norm1(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEAttentionBlock(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + # for compatibility with the attention interface + self.num_key_value_groups = 1 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class Emu3VQVAEGroupNorm(nn.GroupNorm): + """ + Same as the torch GroupNorm with the only difference that this ones accepts + an optional kwarg `quant_states` which is not used. This class makes it easier to + use SpatialNorm or GroupNorm without conditionals + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, input, quant_states=None): + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + + +class Emu3VQVAEMiddleBlock(nn.Module): + def __init__(self, config, in_channels, quant_channels=None): + super().__init__() + + self.block_1 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + self.attn_1 = Emu3VQVAEAttentionBlock(config) + if quant_channels is None: + self.attn_norm = Emu3VQVAEGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.attn_norm = Emu3VQVAESpatialNorm(quant_channels, in_channels) + + self.block_2 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: Optional[torch.FloatTensor] = None): + hidden_states = self.block_1(hidden_states, quant_states) + residual = hidden_states + hidden_states = self.attn_norm(hidden_states, quant_states) + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = self.attn_1(hidden_states)[0] + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + hidden_states = self.block_2(hidden_states, quant_states) + return hidden_states + + +class Emu3VQVAEDownBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + channel_multiplier = config.channel_multiplier + + in_channel_multiplier = (1,) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + ) + ) + block_in = block_out + if config.attn_resolutions is not None and i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)) + + down = nn.Module() + down.block = block + down.attn = attn + down.attn_norms = attn_norms + if i_level != self.num_resolutions - 1: + down.downsample = Emu3VQVAEEncoderConvDownsample(block_in) + self.down.append(down) + + def forward(self, hidden_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.down): + for i_block in range(self.num_res_blocks): + hidden_states = blocks.block[i_block](hidden_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + + if i_level != self.num_resolutions - 1: + hidden_states = blocks.downsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEUpBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_out = config.base_channels * config.channel_multiplier[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + quant_channels=quant_channels, + ) + ) + block_in = block_out + if i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(Emu3VQVAESpatialNorm(quant_channels, block_in)) + + up = nn.Module() + up.block = block + up.attn = attn + up.attn_norms = attn_norms + if i_level != 0: + up.upsample = Emu3VQVAEEncoderConvUpsample(block_in) + + self.up.insert(0, up) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.up[::-1]): + for i_block in range(self.num_res_blocks + 1): + hidden_states = blocks.block[i_block](hidden_states, quant_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states, quant_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + if i_level != len(self.up) - 1: + hidden_states = blocks.upsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + base_channels = config.base_channels + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + out_channels = 2 * latent_channels if double_latent else latent_channels + block_in = base_channels * channel_multiplier[-1] + + self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) + self.down_block = Emu3VQVAEDownBlock(config) + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + temporal_down_blocks = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + self.time_res_stack = nn.ModuleList() + + for i in range(temporal_down_blocks): + conv = Emu3VQVAETemporalDownsample(out_channels, out_channels) + self.time_conv.append(conv) + + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=out_channels, + out_channels=out_channels, + ) + self.time_res_stack.append(time_res_conv) + + def forward(self, pixel_values: torch.LongTensor): + temporal_dim = pixel_values.shape[1] + pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:]) + + # downsampling & middle + hidden_states = self.conv_in(pixel_values) + hidden_states = self.down_block(hidden_states) + hidden_states = self.middle_block(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + hidden_states = hidden_states.reshape(-1, temporal_dim, *hidden_states.shape[1:]) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for conv in self.time_conv: + hidden_states = conv(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + + for layer in self.time_res_stack: + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + return hidden_states + + +class Emu3VQVAEDecoder(nn.Module): + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + self.time_res_stack = nn.ModuleList() + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=config.latent_channels, out_channels=config.latent_channels + ) + self.time_res_stack.append(time_res_conv) + + temp_upsample_block_num = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + for i in range(temp_upsample_block_num): + conv = Emu3VQVAETemporalUpsample(config.latent_channels, config.latent_channels) + self.time_conv.append(conv) + + self.conv_in = nn.Conv2d( + config.latent_channels, + block_in, + kernel_size=3, + stride=1, + padding=1, + ) + + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in, quant_channels=quant_channels) + self.up_block = Emu3VQVAEUpBlock(config) + + block_in = config.base_channels * config.channel_multiplier[0] + self.norm_out = Emu3VQVAESpatialNorm(quant_channels, block_in) + self.conv_out = nn.Conv2d( + block_in, + config.out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + hidden_quant_states = torch.cat((hidden_states, quant_states), dim=0) + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for layer in self.time_res_stack: + hidden_quant_states = layer(hidden_quant_states) + + for layer in self.time_conv: + hidden_quant_states = layer(hidden_quant_states) + hidden_quant_states *= torch.sigmoid(hidden_quant_states) + + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + hidden_states, quant_states = torch.chunk(hidden_quant_states, 2, dim=0) + hidden_states = hidden_states.reshape(-1, *hidden_states.shape[2:]) + quant_states = quant_states.reshape(-1, *quant_states.shape[2:]) + + hidden_states = self.conv_in(hidden_states) + + # middle & upsampling + hidden_states = self.middle_block(hidden_states, quant_states) + hidden_states = self.up_block(hidden_states, quant_states) + + hidden_states = self.norm_out(hidden_states, quant_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +@auto_docstring( + custom_intro=""" + The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens. + This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from + [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv + Taigman](https://huggingface.co/papers/2203.13131). + """ +) +class Emu3VQVAE(PreTrainedModel): + config: Emu3VQVAEConfig + base_model_prefix = "emuvideovq" + main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + _no_split_modules = [ + "Emu3VQVAETemporalResnetBlock", + "Emu3VQVAEAttentionBlock", + "Emu3VQVAEResnetBlock", + "Emu3VQVAEVectorQuantizer", + ] + + def _init_weights(self, module): + if isinstance(module, (nn.Conv2d, nn.Conv3d)): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1.0) + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_() + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__(config) + + self.config = config + + self.encoder = Emu3VQVAEEncoder(config) + self.decoder = Emu3VQVAEDecoder(config) + self.quantize = Emu3VQVAEVectorQuantizer(config) + self.vision_spatial_factor = 2 ** (len(config.channel_multiplier) - 1) + + self.quant_conv = Emu3VQVAEConv3d( + config.latent_channels, config.embed_dim, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.post_quant_conv = Emu3VQVAEConv3d( + config.embed_dim, config.latent_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.spatial_scale_factor = 2 ** (len(config.channel_multiplier) - 1) + self.eval() # Emu3's VQ model is frozen + + self.post_init() + + def encode(self, pixel_values: torch.Tensor, image_sizes: torch.Tensor): + is_image = pixel_values.ndim == 4 + if is_image: + temporal = self.config.temporal_downsample_factor + batch_size, channels, height, width = pixel_values.shape + pixel_values = pixel_values.unsqueeze(1).repeat(1, temporal, 1, 1, 1) + else: + batch_size, temporal, channels, height, width = pixel_values.shape + + hidden_states = self.encoder(pixel_values) + + # b t c h w -> b c t h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + hidden_states = self.quant_conv(hidden_states) + + # b c t h w -> b t c h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + codes = self.quantize(hidden_states) + + image_tokens = codes.squeeze(1) if is_image else codes + + image_tokens = [ + single_image[: int(size[0] / self.vision_spatial_factor), : int(size[1] / self.vision_spatial_factor)] + for single_image, size in zip(image_tokens, image_sizes) + ] + + return image_tokens + + def decode(self, hidden_states: torch.Tensor): + is_image = hidden_states.ndim == 3 + if is_image: + hidden_states = hidden_states.unsqueeze(1) + + batch_size, temporal, height, width = hidden_states.shape + quant = self.quantize.embedding(hidden_states.flatten()) + + channels = quant.shape[-1] + quant = quant.view(batch_size, temporal, height, width, channels).permute(0, 4, 1, 2, 3).contiguous() + post_quant = self.post_quant_conv(quant) + + quant = quant.permute(0, 2, 1, 3, 4) + post_quant = post_quant.permute(0, 2, 1, 3, 4) + + video = self.decoder(post_quant, quant) + video = video.reshape( + batch_size, + temporal * self.config.temporal_downsample_factor, + self.config.out_channels, + height * self.spatial_scale_factor, + width * self.spatial_scale_factor, + ) + return video[:, 0] if is_image else video + + +class Emu3ImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + """ + + def __init__(self, vocab_map): + self.vocab_map = vocab_map + self.eol_token_id = vocab_map.get("<|extra_200|>") + self.image_token_id = vocab_map.get("") + + @cached_property + def image_tokens(self): + return sorted([val for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def image_tokens_str(self): + return sorted([name for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def img2bpe(self): + return {int(token[-8:-2]): self.vocab_map[token] for token in self.image_tokens_str} + + @cached_property + def bpe2img(self): + return {v: k for k, v in self.img2bpe.items()} + + @cached_property + def bpe2img_mapping_tensor(self): + mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int) + for k, v in self.bpe2img.items(): + mapping[k] = v + return mapping + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: list[torch.Tensor]) -> torch.Tensor: + device = img_batch.device + eol_row = torch.ones((img_batch.shape[0], 1), dtype=torch.int) * self.eol_token_id + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + img_tokens = torch.cat([img_tokens, eol_row], dim=-1) + return img_tokens.to(device) + + def convert_bpe2img(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_batch = img_batch[..., :-1] # remove last row of EOL tokens + img_tokens = self.bpe2img_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + +@auto_docstring +class Emu3PreTrainedModel(PreTrainedModel): + config: Emu3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [ + "Emu3DecoderLayer", + ] + _skip_keys_device_placement = ["past_key_values", "causal_mask"] + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_param_buffer_assignment = False + _supports_flex_attn = True + _supports_attention_backend = True + + +class Emu3RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Emu3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Emu3TextModel(Emu3PreTrainedModel): + _can_record_outputs = { + "hidden_states": Emu3DecoderLayer, + "attentions": Emu3Attention, + } + + def __init__(self, config: Emu3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Emu3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config: Emu3TextConfig + + def __init__(self, config): + super().__init__(config) + self.model = Emu3TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForCausalLM.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf") + + >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Emu3Model(Emu3PreTrainedModel): + _checkpoint_conversion_mapping = {"text_model.model": "text_model"} + + def __init__(self, config): + super().__init__(config) + self.text_model = Emu3TextModel._from_config(config.text_config) + self.vqmodel = Emu3VQVAE(config.vq_config) + self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.text_model = decoder + + def get_decoder(self): + return self.text_model + + def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. + """ + image_tokens_list = self.vqmodel.encode(pixel_values, image_sizes) + bpe_tokens_list = [self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in image_tokens_list] + bpe_tokens = torch.cat(bpe_tokens_list) + return bpe_tokens + + def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module and embeds + them with text embeddings layer + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. + """ + image_tokens = self.get_image_tokens(pixel_values, image_sizes) + split_sizes = [ + (height // self.vqmodel.vision_spatial_factor) * (width // self.vqmodel.vision_spatial_factor + 1) + for height, width in image_sizes + ] + image_features = self.get_input_embeddings()(image_tokens) + image_features = torch.split(image_features, split_sizes) + return image_features + + @torch.no_grad + def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int): + """ + Decodes generated image tokens from language model to continuous pixel values + with VQGAN module via upsampling. + + Args: + image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`): + The tensors corresponding to the input images. + height (`int`): + Height of the generated image before upsampling. + width (`int`): + Width of the generated image before upsampling. + """ + sequences = image_tokens[:, :-3].view(-1, height, width + 1) + image_tokens = self.vocabulary_mapping.convert_bpe2img(sequences) + image = self.vqmodel.decode(image_tokens) + return image + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_sizes) + image_embeds = torch.cat(image_embeds, dim=0) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.text_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return outputs + + +class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): + base_model_prefix = "" + _tied_weights_keys = ["lm_head.weight"] + _checkpoint_conversion_mapping = { + "^text_model.model": "model.text_model", + "^vqmodel": "model.vqmodel", + "^text_model.lm_head": "lm_head", + } + + def __init__(self, config): + super().__init__(config) + self.model = Emu3Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + # Make modules available through conditional class for BC + @property + def text_model(self): + return self.model.text_model + + @property + def vqmodel(self): + return self.model.vqmodel + + @property + def vocabulary_mapping(self): + return self.model.vocabulary_mapping + + def decode_image_tokens(self, **kwargs): + return self.model.decode_image_tokens(**kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForConditionalGeneration.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf") + + >>> conversation = [ + ... { + ... "role": "system", + ... "content": [ + ... {"type": "text", "text": "You are a helpful assistant."}, + ... ], + ... }, + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image"}, + ... {"type": "text", "text": "Please describe the image."}, + ... ], + ... }, + ... ] + + >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + >>> image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw) + + >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + use_cache=use_cache, + **kwargs, + ) + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + + return model_inputs + + +__all__ = [ + "Emu3ForConditionalGeneration", + "Emu3ForCausalLM", + "Emu3TextModel", + "Emu3PreTrainedModel", + "Emu3VQVAE", + "Emu3Model", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/modular_emu3.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/modular_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..32599727b24c12dadc5dd237b63bcebcaecd1316 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/modular_emu3.py @@ -0,0 +1,1222 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import cached_property +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg +from ..chameleon.modeling_chameleon import ( + ChameleonPreTrainedModel, + ChameleonVQVAEEncoderConvDownsample, +) +from ..llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, TransformersKwargs +from ..siglip.modeling_siglip import SiglipAttention +from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig + + +logger = logging.get_logger(__name__) + + +class Emu3Attention(LlamaAttention): + pass + + +# Has extra dropout which no other model in the library has +class Emu3DecoderLayer(LlamaDecoderLayer): + def __init__(self, config: Emu3Config, layer_idx: int): + super().__init__(config, layer_idx) + self.dropout = nn.Dropout(config.attention_dropout) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class Emu3VQVAEVectorQuantizer(nn.Module): + """ + A module for vector quantization using learned embedding vectors. + + This module implements the quantization process similar to te one described in + the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous + input vectors into discrete codebook vectors, which are learned during training. + Current implementation improves over previous ones by avoiding costly matrix multiplications + and allowing for post-hoc remapping of indices. + """ + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + self.embedding = nn.Embedding(config.codebook_size, config.embed_dim) + self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size) + + def forward(self, hidden_state: torch.Tensor): + batch_size, temporal, channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 1, 3, 4, 2).contiguous() + hidden_state_flattened = hidden_state.view(-1, channels) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + hidden_state_sum = torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + embedding_sum = torch.sum(self.embedding.weight**2, dim=1) + + # "bd,dn->bn", + distances = 2 * torch.matmul(hidden_state_flattened, self.embedding.weight.transpose(0, 1)) + distances = hidden_state_sum + embedding_sum - distances + + min_encoding_indices = torch.argmin(distances, dim=1) + min_encoding_indices = min_encoding_indices.view(batch_size, temporal, height, width) + return min_encoding_indices + + +class Emu3VQVAEEncoderConvDownsample(ChameleonVQVAEEncoderConvDownsample): + pass + + +class Emu3VQVAEEncoderConvUpsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states): + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAEConv3d(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: tuple[int], + stride: tuple[int], + ): + super().__init__() + + padding_sizes = [one_kernel - one_stride for one_kernel, one_stride in zip(kernel_size[1:], stride[1:])] + self.padding = () + for pad_size in padding_sizes[::-1]: + self.padding += (pad_size // 2 + pad_size % 2, pad_size // 2) + self.padding += (2, 0) + + self.conv = nn.Conv3d( + in_channel, + out_channel, + kernel_size, + stride=stride, + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = F.pad(hidden_states, self.padding) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAESpatialNorm(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm( + num_channels=out_channels, + num_groups=32, + eps=1e-6, + affine=True, + ) + + self.conv_y = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.conv_b = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + quant_states = F.interpolate(quant_states, size=hidden_states.shape[-2:], mode="nearest") + hidden_states = self.norm_layer(hidden_states) + hidden_states = hidden_states * self.conv_y(quant_states) + self.conv_b(quant_states) + return hidden_states + + +class Emu3VQVAETemporalUpsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + batch_size, channels, temporal, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 1, 3, 4, 2).contiguous().view(batch_size, -1, temporal) + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = hidden_states.view(batch_size, channels, height, width, -1).permute(0, 1, 4, 2, 3).contiguous() + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalDownsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(4, 3, 3), + stride=(2, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.BatchNorm3d(in_channels) + self.conv1 = Emu3VQVAEConv3d( + in_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + self.norm2 = nn.BatchNorm3d(out_channels) + self.conv2 = Emu3VQVAEConv3d( + out_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + quant_channels: Optional[int] = None, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.quant_channels = quant_channels + + if quant_channels is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.norm1 = Emu3VQVAESpatialNorm(quant_channels, in_channels) + self.norm2 = Emu3VQVAESpatialNorm(quant_channels, out_channels) + + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_channels: Optional[torch.Tensor] = None): + norm_args = () if self.quant_channels is None else (quant_channels,) + + residual = hidden_states + hidden_states = self.norm1(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEAttentionBlock(SiglipAttention): + def __init__(self, config: Emu3VQVAEConfig): + super().__init__(config) + + # for compatibility with the attention interface + self.num_key_value_groups = 1 + + +class Emu3VQVAEGroupNorm(nn.GroupNorm): + """ + Same as the torch GroupNorm with the only difference that this ones accepts + an optional kwarg `quant_states` which is not used. This class makes it easier to + use SpatialNorm or GroupNorm without conditionals + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, input, quant_states=None): + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + + +class Emu3VQVAEMiddleBlock(nn.Module): + def __init__(self, config, in_channels, quant_channels=None): + super().__init__() + + self.block_1 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + self.attn_1 = Emu3VQVAEAttentionBlock(config) + if quant_channels is None: + self.attn_norm = Emu3VQVAEGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.attn_norm = Emu3VQVAESpatialNorm(quant_channels, in_channels) + + self.block_2 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: Optional[torch.FloatTensor] = None): + hidden_states = self.block_1(hidden_states, quant_states) + residual = hidden_states + hidden_states = self.attn_norm(hidden_states, quant_states) + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = self.attn_1(hidden_states)[0] + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + hidden_states = self.block_2(hidden_states, quant_states) + return hidden_states + + +class Emu3VQVAEDownBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + channel_multiplier = config.channel_multiplier + + in_channel_multiplier = (1,) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + ) + ) + block_in = block_out + if config.attn_resolutions is not None and i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)) + + down = nn.Module() + down.block = block + down.attn = attn + down.attn_norms = attn_norms + if i_level != self.num_resolutions - 1: + down.downsample = Emu3VQVAEEncoderConvDownsample(block_in) + self.down.append(down) + + def forward(self, hidden_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.down): + for i_block in range(self.num_res_blocks): + hidden_states = blocks.block[i_block](hidden_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + + if i_level != self.num_resolutions - 1: + hidden_states = blocks.downsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEUpBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_out = config.base_channels * config.channel_multiplier[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + quant_channels=quant_channels, + ) + ) + block_in = block_out + if i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(Emu3VQVAESpatialNorm(quant_channels, block_in)) + + up = nn.Module() + up.block = block + up.attn = attn + up.attn_norms = attn_norms + if i_level != 0: + up.upsample = Emu3VQVAEEncoderConvUpsample(block_in) + + self.up.insert(0, up) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.up[::-1]): + for i_block in range(self.num_res_blocks + 1): + hidden_states = blocks.block[i_block](hidden_states, quant_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states, quant_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + if i_level != len(self.up) - 1: + hidden_states = blocks.upsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + base_channels = config.base_channels + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + out_channels = 2 * latent_channels if double_latent else latent_channels + block_in = base_channels * channel_multiplier[-1] + + self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) + self.down_block = Emu3VQVAEDownBlock(config) + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + temporal_down_blocks = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + self.time_res_stack = nn.ModuleList() + + for i in range(temporal_down_blocks): + conv = Emu3VQVAETemporalDownsample(out_channels, out_channels) + self.time_conv.append(conv) + + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=out_channels, + out_channels=out_channels, + ) + self.time_res_stack.append(time_res_conv) + + def forward(self, pixel_values: torch.LongTensor): + temporal_dim = pixel_values.shape[1] + pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:]) + + # downsampling & middle + hidden_states = self.conv_in(pixel_values) + hidden_states = self.down_block(hidden_states) + hidden_states = self.middle_block(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + hidden_states = hidden_states.reshape(-1, temporal_dim, *hidden_states.shape[1:]) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for conv in self.time_conv: + hidden_states = conv(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + + for layer in self.time_res_stack: + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + return hidden_states + + +class Emu3VQVAEDecoder(nn.Module): + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + self.time_res_stack = nn.ModuleList() + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=config.latent_channels, out_channels=config.latent_channels + ) + self.time_res_stack.append(time_res_conv) + + temp_upsample_block_num = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + for i in range(temp_upsample_block_num): + conv = Emu3VQVAETemporalUpsample(config.latent_channels, config.latent_channels) + self.time_conv.append(conv) + + self.conv_in = nn.Conv2d( + config.latent_channels, + block_in, + kernel_size=3, + stride=1, + padding=1, + ) + + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in, quant_channels=quant_channels) + self.up_block = Emu3VQVAEUpBlock(config) + + block_in = config.base_channels * config.channel_multiplier[0] + self.norm_out = Emu3VQVAESpatialNorm(quant_channels, block_in) + self.conv_out = nn.Conv2d( + block_in, + config.out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + hidden_quant_states = torch.cat((hidden_states, quant_states), dim=0) + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for layer in self.time_res_stack: + hidden_quant_states = layer(hidden_quant_states) + + for layer in self.time_conv: + hidden_quant_states = layer(hidden_quant_states) + hidden_quant_states *= torch.sigmoid(hidden_quant_states) + + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + hidden_states, quant_states = torch.chunk(hidden_quant_states, 2, dim=0) + hidden_states = hidden_states.reshape(-1, *hidden_states.shape[2:]) + quant_states = quant_states.reshape(-1, *quant_states.shape[2:]) + + hidden_states = self.conv_in(hidden_states) + + # middle & upsampling + hidden_states = self.middle_block(hidden_states, quant_states) + hidden_states = self.up_block(hidden_states, quant_states) + + hidden_states = self.norm_out(hidden_states, quant_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +@auto_docstring( + custom_intro=""" + The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens. + This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from + [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv + Taigman](https://huggingface.co/papers/2203.13131). + """ +) +class Emu3VQVAE(PreTrainedModel): + config: Emu3VQVAEConfig + base_model_prefix = "emuvideovq" + main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + _no_split_modules = [ + "Emu3VQVAETemporalResnetBlock", + "Emu3VQVAEAttentionBlock", + "Emu3VQVAEResnetBlock", + "Emu3VQVAEVectorQuantizer", + ] + + def _init_weights(self, module): + if isinstance(module, (nn.Conv2d, nn.Conv3d)): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1.0) + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_() + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__(config) + + self.config = config + + self.encoder = Emu3VQVAEEncoder(config) + self.decoder = Emu3VQVAEDecoder(config) + self.quantize = Emu3VQVAEVectorQuantizer(config) + self.vision_spatial_factor = 2 ** (len(config.channel_multiplier) - 1) + + self.quant_conv = Emu3VQVAEConv3d( + config.latent_channels, config.embed_dim, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.post_quant_conv = Emu3VQVAEConv3d( + config.embed_dim, config.latent_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.spatial_scale_factor = 2 ** (len(config.channel_multiplier) - 1) + self.eval() # Emu3's VQ model is frozen + + self.post_init() + + def encode(self, pixel_values: torch.Tensor, image_sizes: torch.Tensor): + is_image = pixel_values.ndim == 4 + if is_image: + temporal = self.config.temporal_downsample_factor + batch_size, channels, height, width = pixel_values.shape + pixel_values = pixel_values.unsqueeze(1).repeat(1, temporal, 1, 1, 1) + else: + batch_size, temporal, channels, height, width = pixel_values.shape + + hidden_states = self.encoder(pixel_values) + + # b t c h w -> b c t h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + hidden_states = self.quant_conv(hidden_states) + + # b c t h w -> b t c h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + codes = self.quantize(hidden_states) + + image_tokens = codes.squeeze(1) if is_image else codes + + image_tokens = [ + single_image[: int(size[0] / self.vision_spatial_factor), : int(size[1] / self.vision_spatial_factor)] + for single_image, size in zip(image_tokens, image_sizes) + ] + + return image_tokens + + def decode(self, hidden_states: torch.Tensor): + is_image = hidden_states.ndim == 3 + if is_image: + hidden_states = hidden_states.unsqueeze(1) + + batch_size, temporal, height, width = hidden_states.shape + quant = self.quantize.embedding(hidden_states.flatten()) + + channels = quant.shape[-1] + quant = quant.view(batch_size, temporal, height, width, channels).permute(0, 4, 1, 2, 3).contiguous() + post_quant = self.post_quant_conv(quant) + + quant = quant.permute(0, 2, 1, 3, 4) + post_quant = post_quant.permute(0, 2, 1, 3, 4) + + video = self.decoder(post_quant, quant) + video = video.reshape( + batch_size, + temporal * self.config.temporal_downsample_factor, + self.config.out_channels, + height * self.spatial_scale_factor, + width * self.spatial_scale_factor, + ) + return video[:, 0] if is_image else video + + +class Emu3ImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + """ + + def __init__(self, vocab_map): + self.vocab_map = vocab_map + self.eol_token_id = vocab_map.get("<|extra_200|>") + self.image_token_id = vocab_map.get("") + + @cached_property + def image_tokens(self): + return sorted([val for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def image_tokens_str(self): + return sorted([name for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def img2bpe(self): + return {int(token[-8:-2]): self.vocab_map[token] for token in self.image_tokens_str} + + @cached_property + def bpe2img(self): + return {v: k for k, v in self.img2bpe.items()} + + @cached_property + def bpe2img_mapping_tensor(self): + mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int) + for k, v in self.bpe2img.items(): + mapping[k] = v + return mapping + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: list[torch.Tensor]) -> torch.Tensor: + device = img_batch.device + eol_row = torch.ones((img_batch.shape[0], 1), dtype=torch.int) * self.eol_token_id + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + img_tokens = torch.cat([img_tokens, eol_row], dim=-1) + return img_tokens.to(device) + + def convert_bpe2img(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_batch = img_batch[..., :-1] # remove last row of EOL tokens + img_tokens = self.bpe2img_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + +class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE): + _no_split_modules = [ + "Emu3DecoderLayer", + ] + _supports_flex_attn = True + _supports_attention_backend = True + + +class Emu3TextModel(LlamaModel, Emu3PreTrainedModel): + _can_record_outputs = { + "hidden_states": Emu3DecoderLayer, + "attentions": Emu3Attention, + } + + def __init__(self, config: Emu3Config): + super().__init__(config) + self.layers = nn.ModuleList( + [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + +class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin): + config: Emu3TextConfig + + def __init__(self, config): + super().__init__(config) + self.model = Emu3TextModel(config) + + def forward(**super_kwargs): + r""" + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForCausalLM.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf") + + >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + super().forward() + + +class Emu3Model(Emu3PreTrainedModel): + _checkpoint_conversion_mapping = {"text_model.model": "text_model"} + + def __init__(self, config): + super().__init__(config) + self.text_model = Emu3TextModel._from_config(config.text_config) + self.vqmodel = Emu3VQVAE(config.vq_config) + self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.text_model = decoder + + def get_decoder(self): + return self.text_model + + def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. + """ + image_tokens_list = self.vqmodel.encode(pixel_values, image_sizes) + bpe_tokens_list = [self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in image_tokens_list] + bpe_tokens = torch.cat(bpe_tokens_list) + return bpe_tokens + + def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module and embeds + them with text embeddings layer + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. + """ + image_tokens = self.get_image_tokens(pixel_values, image_sizes) + split_sizes = [ + (height // self.vqmodel.vision_spatial_factor) * (width // self.vqmodel.vision_spatial_factor + 1) + for height, width in image_sizes + ] + image_features = self.get_input_embeddings()(image_tokens) + image_features = torch.split(image_features, split_sizes) + return image_features + + @torch.no_grad + def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int): + """ + Decodes generated image tokens from language model to continuous pixel values + with VQGAN module via upsampling. + + Args: + image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`): + The tensors corresponding to the input images. + height (`int`): + Height of the generated image before upsampling. + width (`int`): + Width of the generated image before upsampling. + """ + sequences = image_tokens[:, :-3].view(-1, height, width + 1) + image_tokens = self.vocabulary_mapping.convert_bpe2img(sequences) + image = self.vqmodel.decode(image_tokens) + return image + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_sizes) + image_embeds = torch.cat(image_embeds, dim=0) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.text_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return outputs + + +class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): + base_model_prefix = "" + _tied_weights_keys = ["lm_head.weight"] + _checkpoint_conversion_mapping = { + "^text_model.model": "model.text_model", + "^vqmodel": "model.vqmodel", + "^text_model.lm_head": "lm_head", + } + + def __init__(self, config): + super().__init__(config) + self.model = Emu3Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + # Make modules available through conditional class for BC + @property + def text_model(self): + return self.model.text_model + + @property + def vqmodel(self): + return self.model.vqmodel + + @property + def vocabulary_mapping(self): + return self.model.vocabulary_mapping + + def decode_image_tokens(self, **kwargs): + return self.model.decode_image_tokens(**kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForConditionalGeneration.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf") + + >>> conversation = [ + ... { + ... "role": "system", + ... "content": [ + ... {"type": "text", "text": "You are a helpful assistant."}, + ... ], + ... }, + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image"}, + ... {"type": "text", "text": "Please describe the image."}, + ... ], + ... }, + ... ] + + >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + >>> image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw) + + >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + use_cache=use_cache, + **kwargs, + ) + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + + return model_inputs + + +__all__ = [ + "Emu3ForConditionalGeneration", + "Emu3ForCausalLM", + "Emu3TextModel", + "Emu3PreTrainedModel", + "Emu3VQVAE", + "Emu3Model", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/processing_emu3.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/processing_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..67ccab795733563359dbba53e17591cab2878506 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/emu3/processing_emu3.py @@ -0,0 +1,248 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import is_vision_available + + +if is_vision_available(): + from .image_processing_emu3 import smart_resize + + +class Emu3TextKwargs(TextKwargs, total=False): + return_for_image_generation: bool + + +class Emu3ImagesKwargs(ImagesKwargs, total=False): + ratio: str + image_area: int + + +class Emu3ProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: Emu3TextKwargs + images_kwargs: Emu3ImagesKwargs + _defaults = { + "text_kwargs": { + "return_for_image_generation": False, + "return_mm_token_type_ids": False, + }, + "images_kwargs": { + "ratio": "1:1", + "image_area": 518400, + }, + } + + +class Emu3Processor(ProcessorMixin): + r""" + Constructs a Emu3 processor which wraps a Emu3 image processor and a GPT2 tokenizer into a single + processor. + + [`Emu3Processor`] offers all the functionalities of [`Emu3ImageProcessor`] and [`GPT2TokenizerFast`]. + See the [`~Emu3Processor.__call__`] and [`~Emu3Processor.decode`] for more information. + + Args: + image_processor ([`Emu3ImageProcessor`]): + The image processor is a required input. + tokenizer ([`Emu3TokenizerFast`]): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast") + image_processor_class = "Emu3ImageProcessor" + + def __init__( + self, + image_processor, + tokenizer, + chat_template=None, + **kwargs, + ): + self.image_token = tokenizer.image_token # image_token as placeholder to be replaced by vq-vae tokens + self.image_token_id = tokenizer.image_token_id + self.image_start_token = tokenizer.boi_token # "<|image start|>" fixed tokens for start and end of image + self.image_end_token = tokenizer.eoi_token # "<|image end|>" + self.fake_token_around_image = tokenizer.image_wrapper_token # "<|image token|>" every image starts with it + self.eof_token = tokenizer.eof_token # "<|extra_201|>" + self.bos_token = tokenizer.bos_token + self.downsample_ratio = 8 + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Emu3ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Emu3TokenizerFast's [`~Emu3TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + # check if images and text inputs are reversed for BC + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise TypeError("Invalid input text. Please provide a string, or a list of strings") + + output_kwargs = self._merge_kwargs( + Emu3ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + return_for_image_generation = output_kwargs["text_kwargs"].pop("return_for_image_generation", False) + ratio = output_kwargs["images_kwargs"].pop("ratio", None) + image_area = output_kwargs["images_kwargs"].pop("image_area", None) + + if return_for_image_generation and images is not None: + raise ValueError("You should not provide `images` when `return_for_image_generation=True`") + + if not return_for_image_generation and text is None and images is None: + raise ValueError("You must provide either text or images when `return_for_image_generation=False`") + + image_features = {} + image_start_tokens = f"{self.image_start_token}" + image_end_tokens = f"{self.eof_token}{self.image_end_token}" + + # generate text from image + text input, so we add placeholders for image tokens + if not return_for_image_generation and images is not None: + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + image_sizes = iter(image_features.image_sizes) + + prompt_strings = [] + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + height, width = image_size + height = height // self.downsample_ratio + width = width // self.downsample_ratio + image_seq_length = height * (width + 1) # +1 for extra row when converting to BPE in modeling code + + image_placeholder = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}{'' * image_seq_length}{image_end_tokens}" + sample = sample.replace(self.image_token, image_placeholder, 1) + sample = f"{self.bos_token}{sample}" # add BOS because GPT tokenizer doesn't add it + prompt_strings.append(sample) + text = [sample.replace("", self.image_token) for sample in prompt_strings] + + # generate image from text input, so we add begin-of-image tokens from where image generation starts + elif return_for_image_generation: + height, width = self.calculate_generate_size(ratio, image_area, self.downsample_ratio) + image_prompt = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}" + text = [f"{self.bos_token}{sample}{image_prompt}" for sample in text] + image_features["image_sizes"] = [[height, width]] * len(text) + + # else just generate from text-only input, and we do no special treatment for text + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None) + self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_features}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + num_image_tokens = [] + for height, width in image_sizes: + height, width = smart_resize( + height, + width, + self.image_processor.spatial_factor, + self.image_processor.min_pixels, + self.image_processor.max_pixels, + ) + height = height // self.downsample_ratio + width = width // self.downsample_ratio + image_seq_length = height * (width + 1) # +1 for extra row when converting to BPE in modeling code + num_image_tokens.append(image_seq_length) + + num_image_patches = [1] * len(image_sizes) + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + + def calculate_generate_size(self, ratio, image_area, spatial_factor): + width, height = map(int, ratio.split(":")) + current_area = width * height + target_ratio = (image_area / current_area) ** 0.5 + + token_height = int(round(height * target_ratio / spatial_factor)) + token_width = int(round(width * target_ratio / spatial_factor)) + return token_height, token_width + + def postprocess(self, images: ImageInput, **kwargs): + return self.image_processor.postprocess(images, **kwargs) + + +__all__ = ["Emu3Processor"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65fb1ca5edef4398753377796fc2609d0f17e1e8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_gemma import * + from .modeling_flax_gemma import * + from .modeling_gemma import * + from .tokenization_gemma import * + from .tokenization_gemma_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/configuration_gemma.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/configuration_gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..363af5c3ffc4ccfe856763d332fe6fbce5886bae --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/configuration_gemma.py @@ -0,0 +1,160 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig + + +class GemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma-7B. + e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GemmaModel`] + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The legacy activation function. It is overwritten by the `hidden_activation`. + hidden_activation (`str` or `function`, *optional*): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import GemmaModel, GemmaConfig + >>> # Initializing a Gemma gemma-7b style configuration + >>> configuration = GemmaConfig() + >>> # Initializing a model from the gemma-7b style configuration + >>> model = GemmaModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=256000, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + hidden_activation=None, + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.hidden_activation = hidden_activation + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["GemmaConfig"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/modeling_flax_gemma.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/modeling_flax_gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..0addcd7dde7a01f3bd11316596e41e903db59438 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/modeling_flax_gemma.py @@ -0,0 +1,777 @@ +# coding=utf-8 +# Copyright 2024 Google Inc., and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax Gemma model.""" + +from typing import Optional + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_gemma import GemmaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "GemmaConfig" +_CHECKPOINT_FOR_DOC = "google/gemma-2b" +_REAL_CHECKPOINT_FOR_DOC = "openlm-research/open_llama_3b_v2" + +GEMMA_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`GemmaConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16`, or + `jax.numpy.bfloat16`. + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +GEMMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def create_sinusoidal_positions(num_pos, dim): + inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2)[: (dim // 2)] / dim)) + freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") + + emb = np.concatenate((freqs, freqs), axis=-1) + out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1) + return jnp.array(out[:, :, :num_pos]) + + +# Copied from transformers.models.llama.modeling_flax_llama.rotate_half +def rotate_half(tensor): + """Rotates half the hidden dims of the input.""" + rotate_half_tensor = jnp.concatenate( + (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1 + ) + return rotate_half_tensor + + +# Copied from transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(tensor, sin_pos, cos_pos): + return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos) + + +class FlaxGemmaRMSNorm(nn.Module): + config: GemmaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.epsilon = self.config.rms_norm_eps + self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size) + + def __call__(self, hidden_states): + variance = jnp.asarray(hidden_states, dtype=jnp.float32) + variance = jnp.power(variance, 2) + variance = variance.mean(-1, keepdims=True) + # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt` + hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon) + + return (1 + self.weight) * jnp.asarray(hidden_states, dtype=self.dtype) + + +# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding with Llama->Gemma +class FlaxGemmaRotaryEmbedding(nn.Module): + config: GemmaConfig + dtype: jnp.dtype = jnp.float32 + + # Ignore copy + def setup(self): + head_dim = self.config.head_dim + self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim) + + def __call__(self, key, query, position_ids): + sincos = self.sincos[position_ids] + sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1) + + key = apply_rotary_pos_emb(key, sin_pos, cos_pos) + query = apply_rotary_pos_emb(query, sin_pos, cos_pos) + + key = jnp.asarray(key, dtype=self.dtype) + query = jnp.asarray(query, dtype=self.dtype) + + return key, query + + +class FlaxGemmaAttention(nn.Module): + config: GemmaConfig + dtype: jnp.dtype = jnp.float32 + causal: bool = True + is_cross_attention: bool = False + + def setup(self): + config = self.config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.attention_softmax_in_fp32 = self.dtype is not jnp.float32 + + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + kernel = jax.nn.initializers.normal(self.config.initializer_range) + self.q_proj = nn.Dense( + self.num_heads * self.head_dim, use_bias=config.attention_bias, dtype=self.dtype, kernel_init=kernel + ) + self.k_proj = nn.Dense( + self.num_key_value_heads * self.head_dim, + use_bias=config.attention_bias, + dtype=self.dtype, + kernel_init=kernel, + ) + self.v_proj = nn.Dense( + self.num_key_value_heads * self.head_dim, + use_bias=config.attention_bias, + dtype=self.dtype, + kernel_init=kernel, + ) + self.o_proj = nn.Dense(self.embed_dim, use_bias=config.attention_bias, dtype=self.dtype, kernel_init=kernel) + + self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") + self.rotary_emb = FlaxGemmaRotaryEmbedding(config, dtype=self.dtype) + + def _split_heads(self, hidden_states, num_heads): + return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads * self.head_dim,)) + + @nn.compact + # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + attention_mask, + position_ids, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + ): + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_heads) + key = self._split_heads(key, self.num_key_value_heads) + value = self._split_heads(value, self.num_key_value_heads) + + key, query = self.rotary_emb(key, query, position_ids) + + query_length, key_length = query.shape[1], key.shape[1] + + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + + batch_size = hidden_states.shape[0] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + + dropout_rng = None + if not deterministic and self.config.attention_dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.has_variable("cache", "cached_key") or init_cache: + key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) + + # transform boolean mask into float mask + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + + key = jnp.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = jnp.repeat(value, repeats=self.num_key_value_groups, axis=2) + + # usual dot product attention + attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype + attn_weights = dot_product_attention_weights( + query, + key, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_dropout, + deterministic=deterministic, + dtype=attention_dtype, + ) + + if self.attention_softmax_in_fp32: + attn_weights = attn_weights.astype(self.dtype) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) + attn_output = self._merge_heads(attn_output) + attn_output = self.o_proj(attn_output) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxGemmaMLP(nn.Module): + config: GemmaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + embed_dim = self.config.hidden_size + inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim + + kernel_init = jax.nn.initializers.normal(self.config.initializer_range) + if self.config.hidden_activation is None: + logger.warning_once( + "Gemma's activation function should be approximate GeLU and not exact GeLU. " + "Changing the activation function to `gelu_pytorch_tanh`." + f"if you want to use the legacy `{self.config.hidden_act}`, " + f"edit the `model.config` to set `hidden_activation={self.config.hidden_act}` " + " instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details." + ) + hidden_activation = "gelu_pytorch_tanh" + else: + hidden_activation = self.config.hidden_activation + self.act = ACT2FN[hidden_activation] + + self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) + self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) + self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) + + def __call__(self, hidden_states): + up_proj_states = self.up_proj(hidden_states) + gate_states = self.act(self.gate_proj(hidden_states)) + + hidden_states = self.down_proj(up_proj_states * gate_states) + return hidden_states + + +# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaDecoderLayer with Llama->Gemma +class FlaxGemmaDecoderLayer(nn.Module): + config: GemmaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.input_layernorm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype) + self.self_attn = FlaxGemmaAttention(self.config, dtype=self.dtype) + self.post_attention_layernorm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype) + self.mlp = FlaxGemmaMLP(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + ): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + outputs = self.self_attn( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + ) + # residual connection + attn_output = outputs[0] + hidden_states = residual + attn_output + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + hidden_states + + return (hidden_states,) + outputs[1:] + + +# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Gemma, GPT_NEO->GEMMA, transformer->model +class FlaxGemmaPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GemmaConfig + base_model_prefix = "model" + module_class: nn.Module = None + + def __init__( + self, + config: GemmaConfig, + input_shape: tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length)) + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + params: Optional[dict] = None, + past_key_values: Optional[dict] = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + batch_size, sequence_length = input_ids.shape + + if position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") + + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGemmaAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + False, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection with Llama->Gemma +class FlaxGemmaLayerCollection(nn.Module): + config: GemmaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.blocks = [ + FlaxGemmaDecoderLayer(self.config, dtype=self.dtype, name=str(i)) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for block in self.blocks: + if output_hidden_states: + all_hidden_states += (hidden_states,) + layer_outputs = block( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + # this contains possible `None` values - `FlaxGemmaModule` will filter them out + outputs = (hidden_states, all_hidden_states, all_attentions) + + return outputs + + +# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModule with Llama->Gemma +class FlaxGemmaModule(nn.Module): + config: GemmaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.hidden_size = self.config.hidden_size + embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) + self.embed_tokens = nn.Embed( + self.config.vocab_size, + self.hidden_size, + embedding_init=embedding_init, + dtype=self.dtype, + ) + self.layers = FlaxGemmaLayerCollection(self.config, dtype=self.dtype) + self.norm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype) + + # Ignore copy + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + deterministic=True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + input_embeds = self.embed_tokens(input_ids.astype("i4")) + + input_embeds = input_embeds * (self.config.hidden_size**0.5) + + outputs = self.layers( + input_embeds, + position_ids=position_ids, + attention_mask=attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states = outputs[1] + (hidden_states,) + outputs = (hidden_states, all_hidden_states) + outputs[2:] + else: + outputs = (hidden_states,) + outputs[1:] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=outputs[1], + attentions=outputs[-1], + ) + + +@add_start_docstrings( + "The bare Gemma Model transformer outputting raw hidden-states without any specific head on top.", + GEMMA_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModel with Llama->Gemma +class FlaxGemmaModel(FlaxGemmaPreTrainedModel): + module_class = FlaxGemmaModule + + +append_call_sample_docstring( + FlaxGemmaModel, + _CHECKPOINT_FOR_DOC, + FlaxBaseModelOutput, + _CONFIG_FOR_DOC, + real_checkpoint=_REAL_CHECKPOINT_FOR_DOC, +) + + +# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule with Llama->Gemma +class FlaxGemmaForCausalLMModule(nn.Module): + config: GemmaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.model = FlaxGemmaModule(self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + + # Ignore copy + def __call__( + self, + input_ids, + attention_mask=None, + position_ids=None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + outputs = self.model( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_kernel = self.model.variables["params"]["embed_tokens"]["embedding"].T + lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + + +@add_start_docstrings( + """ + The Gemma Model transformer with a language modeling head (linear layer) on top. + """, + GEMMA_START_DOCSTRING, +) +# Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Gemma +class FlaxGemmaForCausalLM(FlaxGemmaPreTrainedModel): + module_class = FlaxGemmaForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since Gemma uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxGemmaForCausalLM, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutput, + _CONFIG_FOR_DOC, + real_checkpoint=_REAL_CHECKPOINT_FOR_DOC, +) + + +__all__ = ["FlaxGemmaForCausalLM", "FlaxGemmaModel", "FlaxGemmaPreTrainedModel"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..04d27b309a405010b0b797a45fbda03cc45a4e9b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py @@ -0,0 +1,511 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_layers import ( + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_gemma import GemmaConfig + + +class GemmaRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class GemmaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class GemmaRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: GemmaConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class GemmaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GemmaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class GemmaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: GemmaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) + + self.mlp = GemmaMLP(config) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class GemmaPreTrainedModel(PreTrainedModel): + config: GemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GemmaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": GemmaDecoderLayer, + "attentions": GemmaAttention, + } + + def _init_weights(self, module): + super()._init_weights(module) + + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + if "RMSNorm" in module.__class__.__name__: + module.weight.data.zero_() + + +@auto_docstring +class GemmaModel(GemmaPreTrainedModel): + def __init__(self, config: GemmaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = GemmaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # normalized + # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +@auto_docstring +class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = GemmaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class GemmaForSequenceClassification(GenericForSequenceClassification, GemmaPreTrainedModel): + pass + + +class GemmaForTokenClassification(GenericForTokenClassification, GemmaPreTrainedModel): + pass + + +__all__ = [ + "GemmaModel", + "GemmaForCausalLM", + "GemmaForSequenceClassification", + "GemmaForTokenClassification", + "GemmaPreTrainedModel", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/modular_gemma.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/modular_gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f9c7dc40567353a81383929d1fca35b2cc8be9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/modular_gemma.py @@ -0,0 +1,492 @@ +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Any, Optional + +import sentencepiece as spm +import torch +from torch import nn + +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PretrainedConfig +from ...masking_utils import create_causal_mask +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import TransformersKwargs, logging +from ..llama.modeling_llama import ( + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaMLP, + LlamaModel, + LlamaPreTrainedModel, + LlamaRotaryEmbedding, +) +from ..llama.tokenization_llama import LlamaTokenizer + + +if TYPE_CHECKING: + from ...tokenization_utils_base import TextInput + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +SPIECE_UNDERLINE = "▁" + + +logger = logging.get_logger(__name__) + + +class GemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma-7B. + e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GemmaModel`] + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The legacy activation function. It is overwritten by the `hidden_activation`. + hidden_activation (`str` or `function`, *optional*): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import GemmaModel, GemmaConfig + >>> # Initializing a Gemma gemma-7b style configuration + >>> configuration = GemmaConfig() + >>> # Initializing a model from the gemma-7b style configuration + >>> model = GemmaModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=256000, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + hidden_activation=None, + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.hidden_activation = hidden_activation + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class GemmaTokenizer(LlamaTokenizer, PreTrainedTokenizer): + """ + Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is + no padding token in the original model. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + sp_model_kwargs (`dict[str, Any]`, `Optional`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + use_default_system_prompt (`bool`, *optional*, defaults to `False`): + Whether or not the default system prompt for Gemma should be used. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + """ + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + sp_model_kwargs: Optional[dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + use_default_system_prompt=False, + spaces_between_special_tokens=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.use_default_system_prompt = use_default_system_prompt + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + PreTrainedTokenizer.__init__( + self, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + def get_spm_processor(self): + raise AttributeError("Not needed for Gemma") + + def unk_token_length(self): + raise AttributeError("Not needed for Gemma") + + def tokenize(self, text: "TextInput", **kwargs) -> list[str]: + """ + Args: + text: TextInput + Simply calls PreTrainedTokenizer's method + """ + return PreTrainedTokenizer.tokenize(self, text, **kwargs) + + def _tokenize(self, text, **kwargs): + """ + Args: + text: TextInput + Returns a tokenized string. The Gemma tokenizer never adds a prefix space. + """ + return self.sp_model.encode(text, out_type=str) + + def _decode( + self, + token_ids: list[int], + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = False, + **kwargs, + ) -> str: + sub_texts = [] + current_sub_text = [] + for ids in token_ids: + if skip_special_tokens and ids in self.all_special_ids: + continue + if ids in self._added_tokens_decoder: + if current_sub_text: + sub_texts.append(self.sp_model.decode(current_sub_text)) + sub_texts.append(self._added_tokens_decoder[ids].content) + current_sub_text = [] + else: + current_sub_text.append(ids) + if current_sub_text: + sub_texts.append(self.sp_model.decode(current_sub_text)) + + if spaces_between_special_tokens: + sub_texts = " ".join(sub_texts) + else: + sub_texts = "".join(sub_texts) + + return sub_texts.replace(SPIECE_UNDERLINE, " ") + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self._added_tokens_encoder: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + +class GemmaRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class GemmaMLP(LlamaMLP): + def __init__(self, config): + super().__init__(config) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + +class GemmaRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class GemmaPreTrainedModel(LlamaPreTrainedModel): + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + if "RMSNorm" in module.__class__.__name__: + module.weight.data.zero_() + + +class GemmaModel(LlamaModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # normalized + # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class GemmaForCausalLM(LlamaForCausalLM): + def forward(**super_kwargs): + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + return super().forward(**super_kwargs) + + +class GemmaForSequenceClassification(LlamaForSequenceClassification): + pass + + +class GemmaForTokenClassification(LlamaForTokenClassification): + pass + + +__all__ = [ + "GemmaConfig", + "GemmaTokenizer", + "GemmaModel", + "GemmaForCausalLM", + "GemmaForSequenceClassification", + "GemmaForTokenClassification", + "GemmaPreTrainedModel", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/tokenization_gemma.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/tokenization_gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..3320968c2915dee8e4584470848d5480fb7f40c6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/tokenization_gemma.py @@ -0,0 +1,335 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Optional + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +if TYPE_CHECKING: + from ...tokenization_utils_base import TextInput + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +SPIECE_UNDERLINE = "▁" + + +@requires(backends=("sentencepiece",)) +class GemmaTokenizer(PreTrainedTokenizer): + """ + Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is + no padding token in the original model. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + sp_model_kwargs (`dict[str, Any]`, `Optional`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + use_default_system_prompt (`bool`, *optional*, defaults to `False`): + Whether or not the default system prompt for Gemma should be used. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + sp_model_kwargs: Optional[dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + use_default_system_prompt=False, + spaces_between_special_tokens=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.use_default_system_prompt = use_default_system_prompt + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__.update(d) + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def tokenize(self, text: "TextInput", **kwargs) -> list[str]: + """ + Args: + text: TextInput + Simply calls PreTrainedTokenizer's method + """ + return super().tokenize(text, **kwargs) + + def _tokenize(self, text, **kwargs): + """ + Args: + text: TextInput + Returns a tokenized string. The Gemma tokenizer never adds a prefix space. + """ + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self._added_tokens_encoder: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`list[int]`): + List of ids. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output + + def _decode( + self, + token_ids: list[int], + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = False, + **kwargs, + ) -> str: + sub_texts = [] + current_sub_text = [] + for ids in token_ids: + if skip_special_tokens and ids in self.all_special_ids: + continue + if ids in self._added_tokens_decoder: + if current_sub_text: + sub_texts.append(self.sp_model.decode(current_sub_text)) + sub_texts.append(self._added_tokens_decoder[ids].content) + current_sub_text = [] + else: + current_sub_text.append(ids) + if current_sub_text: + sub_texts.append(self.sp_model.decode(current_sub_text)) + + if spaces_between_special_tokens: + sub_texts = " ".join(sub_texts) + else: + sub_texts = "".join(sub_texts) + + return sub_texts.replace(SPIECE_UNDERLINE, " ") + + +__all__ = ["GemmaTokenizer"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/tokenization_gemma_fast.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/tokenization_gemma_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..9fc6e3d3593b9ff5403973af5a08d86820a0c2d3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma/tokenization_gemma_fast.py @@ -0,0 +1,195 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from shutil import copyfile +from typing import Optional + +from tokenizers import processors + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_gemma import GemmaTokenizer +else: + GemmaTokenizer = None + +logger = logging.get_logger(__name__) +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"} + + +class GemmaTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a Gemma tokenizer fast. Based on byte-level Byte-Pair-Encoding. + + This uses notably ByteFallback and no prefix space. Normalization is applied to replace `" "` with `"▁"` + + ```python + >>> from transformers import GemmaTokenizerFast + + >>> tokenizer = GemmaTokenizerFast.from_pretrained("hf-internal-testing/dummy-gemma") + >>> tokenizer.encode("Hello this is a test") + [2, 4521, 736, 603, 476, 2121] + ``` + + If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or + call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the + values of the first token and final token of an encoded sequence will not be correct). For more details, checkout + [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation. + + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`, *optional*): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that + contains the vocabulary necessary to instantiate a tokenizer. + tokenizer_file (`str`, *optional*): + [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that + contains everything needed to load the tokenizer. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `""`): + The padding token + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = GemmaTokenizer + padding_side = "left" + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + clean_up_tokenization_spaces=False, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + add_bos_token=True, + add_eos_token=False, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + **kwargs, + ) + self._add_bos_token = add_bos_token + self._add_eos_token = add_eos_token + self.update_post_processor() + self.vocab_file = vocab_file + + # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor + def update_post_processor(self): + """ + Updates the underlying post processor with the current `bos_token` and `eos_token`. + """ + bos = self.bos_token + bos_token_id = self.bos_token_id + if bos is None and self.add_bos_token: + raise ValueError("add_bos_token = True but bos_token = None") + + eos = self.eos_token + eos_token_id = self.eos_token_id + if eos is None and self.add_eos_token: + raise ValueError("add_eos_token = True but eos_token = None") + + single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}" + pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}" + + special_tokens = [] + if self.add_bos_token: + special_tokens.append((bos, bos_token_id)) + if self.add_eos_token: + special_tokens.append((eos, eos_token_id)) + self._tokenizer.post_processor = processors.TemplateProcessing( + single=single, pair=pair, special_tokens=special_tokens + ) + + @property + def add_eos_token(self): + return self._add_eos_token + + @property + def add_bos_token(self): + return self._add_bos_token + + @add_eos_token.setter + def add_eos_token(self, value): + self._add_eos_token = value + self.update_post_processor() + + @add_bos_token.setter + def add_bos_token(self, value): + self._add_bos_token = value + self.update_post_processor() + + # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.build_inputs_with_special_tokens + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + +__all__ = ["GemmaTokenizerFast"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma3n/__pycache__/configuration_gemma3n.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma3n/__pycache__/configuration_gemma3n.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ab62269af6bfe1f3001c012bff47f9a4a67aae7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma3n/__pycache__/configuration_gemma3n.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma3n/__pycache__/feature_extraction_gemma3n.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma3n/__pycache__/feature_extraction_gemma3n.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..617a9d060e6aa20fef118f0b74c647058d975a7c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma3n/__pycache__/feature_extraction_gemma3n.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma3n/__pycache__/processing_gemma3n.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma3n/__pycache__/processing_gemma3n.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..168c12e845065b5ec459bdfa53e3d37d2effd24c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gemma3n/__pycache__/processing_gemma3n.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gpt_neox/__pycache__/modeling_gpt_neox.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gpt_neox/__pycache__/modeling_gpt_neox.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bfbcfa567c3e5469fd6bd592142fac4ff85a4d6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gpt_neox/__pycache__/modeling_gpt_neox.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gpt_sw3/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gpt_sw3/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8de9469bff0f15782afdd599448eda8bcccd609f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gpt_sw3/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gpt_sw3/__pycache__/tokenization_gpt_sw3.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gpt_sw3/__pycache__/tokenization_gpt_sw3.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9bc238961bb017c2f53052d28d8f53d6e3d4da2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/gpt_sw3/__pycache__/tokenization_gpt_sw3.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/granitemoe/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/granitemoe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c85333f70b5ee330c9798da1ad624c44c65fc8d4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/granitemoe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_granitemoe import * + from .modeling_granitemoe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/granitemoe/configuration_granitemoe.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/granitemoe/configuration_granitemoe.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb8dbe16f7d322c1a6823dffebd3bf8a9d568b0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/granitemoe/configuration_granitemoe.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GraniteMoe model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class GraniteMoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GraniteMoeModel`]. It is used to instantiate an GraniteMoe + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the GraniteMoe-3B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the GraniteMoe model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GraniteMoeModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + embedding_multiplier (`float`, *optional*, defaults to 1.0): embedding multiplier + logits_scaling (`float`, *optional*, defaults to 1.0): divisor for output logits + residual_multiplier (`float`, *optional*, defaults to 1.0): residual multiplier + attention_multiplier (`float`, *optional*, defaults to 1.0): attention multiplier + num_local_experts (`int`, *optional*, defaults to 8): total number of experts + num_experts_per_tok (`int`, *optional*, defaults to 2): number of experts per token + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): router auxiliary loss coefficient + + ```python + >>> from transformers import GraniteMoeModel, GraniteMoeConfig + + >>> # Initializing a GraniteMoe granitemoe-3b style configuration + >>> configuration = GraniteMoeConfig() + + >>> # Initializing a model from the granitemoe-7b style configuration + >>> model = GraniteMoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "granitemoe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + embedding_multiplier=1.0, + logits_scaling=1.0, + residual_multiplier=1.0, + attention_multiplier=1.0, + num_local_experts=8, + num_experts_per_tok=2, + output_router_logits=False, + router_aux_loss_coef=0.001, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # this model has rope embedding type, hardcoded for BC + self.position_embedding_type = "rope" + + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.embedding_multiplier = embedding_multiplier + self.logits_scaling = logits_scaling + self.residual_multiplier = residual_multiplier + self.attention_multiplier = attention_multiplier + + self.num_local_experts = num_local_experts + self.num_experts_per_tok = num_experts_per_tok + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + rope_config_validation(self) + + +__all__ = ["GraniteMoeConfig"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/granitemoe/modeling_granitemoe.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/granitemoe/modeling_granitemoe.py new file mode 100644 index 0000000000000000000000000000000000000000..29c23a356509d828ed98160a3bc0f9ebf8118b21 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/granitemoe/modeling_granitemoe.py @@ -0,0 +1,999 @@ +# coding=utf-8 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_granitemoe import GraniteMoeConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + device_index = routing_weights.device.index if routing_weights.device.index is not None else 0 + rank = routing_weights.shape[1] * int(device_index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) + return overall_loss * num_experts + + +# Copied from transformers.models.granite.modeling_granite.GraniteRMSNorm with Granite->GraniteMoe +class GraniteMoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + GraniteMoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.granite.modeling_granite.GraniteRotaryEmbedding with Granite->GraniteMoe +class GraniteMoeRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: GraniteMoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.granite.modeling_granite.rotate_half with Granite->GraniteMoe +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.granite.modeling_granite.apply_rotary_pos_emb with Granite->GraniteMoe +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.jetmoe.modeling_jetmoe.JetMoeParallelExperts with JetMoe->GraniteMoe +class GraniteMoeParallelExperts(nn.Module): + def __init__(self, num_experts: int, input_size: int, output_size: int) -> None: + """ + Initialize the GraniteMoeParallelExperts module. + The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with + many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and + [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the + [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py) + used in vllm. + + Args: + num_experts (int): + Number of experts. + input_size (int): + Size of the input. + output_size (int): + Size of the output. + """ + super().__init__() + self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + + def forward(self, inputs, expert_size): + """ + Forward pass of the GraniteMoeParallelExperts module. + + Args: + inputs (Tensor): + Input tensor. + expert_size: + Expert size information. + + Returns: + Tensor: Output tensor. + """ + input_list = inputs.split(expert_size, dim=0) + output_list = [] + for i in range(self.num_experts): + output_list.append(F.linear(input_list[i], self.weight[i])) + results = torch.cat(output_list, dim=0) + return results + + +# Copied from transformers.models.jetmoe.modeling_jetmoe.JetMoeTopKGating with JetMoe->GraniteMoe +class GraniteMoeTopKGating(nn.Module): + def __init__(self, input_size: int, num_experts: int, top_k: int): + """ + Initialize the top-k gating mechanism. + Args: + input_size (`int`): + Size of the input. + num_experts (`int`): + Number of experts. + top_k (`int`): + Number of top experts to select. + """ + super().__init__() + + self.num_experts = num_experts + self.input_size = input_size + self.top_k = top_k + + self.layer = nn.Linear(input_size, num_experts, bias=False) + + def forward(self, hidden_states): + # compute the top_k routing decision + logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts] + top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k] + top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k] + + # compute number of input given to each expert + zeros = torch.zeros( + [top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device + ) # [num_tokens, num_experts] + gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts] + expert_size = gates.long().sum(0) # [num_experts,] + # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`) + # (and `DataDependentOutputException`) + expert_size = expert_size.tolist() + + # sort and group input tokens according to expert assignment + top_k_experts = top_k_indices.flatten() # [num_tokens * top_k] + _, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k] + batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k] + + # gather the gate values for grouped input tokens + top_k_gates = top_k_gates.flatten() # [num_tokens * top_k] + batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k] + + return index_sorted_experts, batch_index, batch_gates, expert_size, logits + + +class GraniteMoeMoE(nn.Module): + """ + A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. + + Args: + config: + Configuration object with model hyperparameters. + """ + + def __init__(self, config: GraniteMoeConfig): + super().__init__() + + self.input_size = config.hidden_size + self.hidden_size = config.intermediate_size + self.activation = ACT2FN[config.hidden_act] + self.input_linear = GraniteMoeParallelExperts(config.num_local_experts, self.input_size, self.hidden_size * 2) + self.output_linear = GraniteMoeParallelExperts(config.num_local_experts, self.hidden_size, self.input_size) + + self.router = GraniteMoeTopKGating( + input_size=self.input_size, + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + ) + + def forward(self, layer_input): + """ + Forward pass of the mixture of experts layer. + + Args: + layer_input (Tensor): + Input tensor. + + Returns: + Tensor: + Output tensor. + Tensor: + Router logits. + """ + bsz, length, emb_size = layer_input.size() + layer_input = layer_input.reshape(-1, emb_size) + _, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input) + + expert_inputs = layer_input[batch_index] + hidden_states = self.input_linear(expert_inputs, expert_size) + chunked_hidden_states = hidden_states.chunk(2, dim=-1) + hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1] + expert_outputs = self.output_linear(hidden_states, expert_size) + + expert_outputs = expert_outputs * batch_gates[:, None] + + zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device) + layer_output = zeros.index_add(0, batch_index, expert_outputs) + layer_output = layer_output.view(bsz, length, self.input_size) + return layer_output, router_logits + + +# Copied from transformers.models.granite.modeling_granite.repeat_kv with Granite->GraniteMoe +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# copied from transformers.models.granite.modeling_granite.GraniteAttention with Granite->GraniteMoe +# no longer copied after attention refactors +class GraniteMoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GraniteMoeConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + + self.scaling = config.attention_multiplier + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings if position_embeddings is not None else (None, None) + if position_embeddings is not None: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class GraniteMoeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: GraniteMoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = GraniteMoeAttention(config=config, layer_idx=layer_idx) + if config.num_local_experts > 0: + self.block_sparse_moe = GraniteMoeMoE(config) + self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.residual_multiplier = config.residual_multiplier + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + output_router_logits: Optional[bool] = False, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states * self.residual_multiplier + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + + hidden_states = residual + hidden_states * self.residual_multiplier + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +@auto_docstring +class GraniteMoePreTrainedModel(PreTrainedModel): + config: GraniteMoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GraniteMoeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, GraniteMoeParallelExperts): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + +@auto_docstring +class GraniteMoeModel(GraniteMoePreTrainedModel): + def __init__(self, config: GraniteMoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [GraniteMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + self.embedding_multiplier = config.embedding_multiplier + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + self.position_embedding_type = config.position_embedding_type + self.rotary_emb = GraniteMoeRotaryEmbedding(config) if self.position_embedding_type == "rope" else None + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds * self.embedding_multiplier + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + position_embeddings = None + # create position embeddings to be shared across the decoder layers + if self.rotary_emb is not None: + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + output_router_logits=output_router_logits, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: GraniteMoeConfig): + super().__init__(config) + self.model = GraniteMoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[tuple, MoeCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GraniteMoeForCausalLM + + >>> model = GraniteMoeForCausalLM.from_pretrained("ibm/PowerMoE-3b") + >>> tokenizer = AutoTokenizer.from_pretrained("ibm/PowerMoE-3b") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + # Only compute necessary logits + hidden_states = outputs[0] + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = logits / self.config.logits_scaling + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Flatten the tokens + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +__all__ = ["GraniteMoeForCausalLM", "GraniteMoeModel", "GraniteMoePreTrainedModel"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/informer/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/informer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd7a901eab235ecd14d2127a9eb78022c720c0f3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/informer/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_informer import * + from .modeling_informer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/informer/configuration_informer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/informer/configuration_informer.py new file mode 100644 index 0000000000000000000000000000000000000000..f62417358c82ce0412336f86f2c9d32d1c019de9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/informer/configuration_informer.py @@ -0,0 +1,250 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Informer model configuration""" + +from typing import Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class InformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`InformerModel`]. It is used to instantiate an + Informer model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Informer + [huggingface/informer-tourism-monthly](https://huggingface.co/huggingface/informer-tourism-monthly) architecture. + + Configuration objects inherit from [`PretrainedConfig`] can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + prediction_length (`int`): + The prediction length for the decoder. In other words, the prediction horizon of the model. This value is + typically dictated by the dataset and we recommend to set it appropriately. + context_length (`int`, *optional*, defaults to `prediction_length`): + The context length for the encoder. If `None`, the context length will be the same as the + `prediction_length`. + distribution_output (`string`, *optional*, defaults to `"student_t"`): + The distribution emission head for the model. Could be either "student_t", "normal" or "negative_binomial". + loss (`string`, *optional*, defaults to `"nll"`): + The loss function for the model corresponding to the `distribution_output` head. For parametric + distributions it is the negative log likelihood (nll) - which currently is the only supported one. + input_size (`int`, *optional*, defaults to 1): + The size of the target variable which by default is 1 for univariate targets. Would be > 1 in case of + multivariate targets. + scaling (`string` or `bool`, *optional* defaults to `"mean"`): + Whether to scale the input targets via "mean" scaler, "std" scaler or no scaler if `None`. If `True`, the + scaler is set to "mean". + lags_sequence (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 5, 6, 7]`): + The lags of the input time series as covariates often dictated by the frequency of the data. Default is + `[1, 2, 3, 4, 5, 6, 7]` but we recommend to change it based on the dataset appropriately. + num_time_features (`int`, *optional*, defaults to 0): + The number of time features in the input time series. + num_dynamic_real_features (`int`, *optional*, defaults to 0): + The number of dynamic real valued features. + num_static_categorical_features (`int`, *optional*, defaults to 0): + The number of static categorical features. + num_static_real_features (`int`, *optional*, defaults to 0): + The number of static real valued features. + cardinality (`list[int]`, *optional*): + The cardinality (number of different values) for each of the static categorical features. Should be a list + of integers, having the same length as `num_static_categorical_features`. Cannot be `None` if + `num_static_categorical_features` is > 0. + embedding_dimension (`list[int]`, *optional*): + The dimension of the embedding for each of the static categorical features. Should be a list of integers, + having the same length as `num_static_categorical_features`. Cannot be `None` if + `num_static_categorical_features` is > 0. + d_model (`int`, *optional*, defaults to 64): + Dimensionality of the transformer layers. + encoder_layers (`int`, *optional*, defaults to 2): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 2): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 2): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 2): + Number of attention heads for each attention layer in the Transformer decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 32): + Dimension of the "intermediate" (often named feed-forward) layer in encoder. + decoder_ffn_dim (`int`, *optional*, defaults to 32): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and decoder. If string, `"gelu"` and + `"relu"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the encoder, and decoder. + encoder_layerdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for the attention and fully connected layers for each encoder layer. + decoder_layerdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for the attention and fully connected layers for each decoder layer. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability used between the two layers of the feed-forward networks. + num_parallel_samples (`int`, *optional*, defaults to 100): + The number of samples to generate in parallel for each time step of inference. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated normal weight initialization distribution. + use_cache (`bool`, *optional*, defaults to `True`): + Whether to use the past key/values attentions (if applicable to the model) to speed up decoding. + attention_type (`str`, *optional*, defaults to "prob"): + Attention used in encoder. This can be set to "prob" (Informer's ProbAttention) or "full" (vanilla + transformer's canonical self-attention). + sampling_factor (`int`, *optional*, defaults to 5): + ProbSparse sampling factor (only makes affect when `attention_type`="prob"). It is used to control the + reduced query matrix (Q_reduce) input length. + distil (`bool`, *optional*, defaults to `True`): + Whether to use distilling in encoder. + + Example: + + ```python + >>> from transformers import InformerConfig, InformerModel + + >>> # Initializing an Informer configuration with 12 time steps for prediction + >>> configuration = InformerConfig(prediction_length=12) + + >>> # Randomly initializing a model (with random weights) from the configuration + >>> model = InformerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "informer" + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + "num_hidden_layers": "encoder_layers", + "initializer_range": "init_std", + } + + def __init__( + self, + prediction_length: Optional[int] = None, + context_length: Optional[int] = None, + distribution_output: str = "student_t", + loss: str = "nll", + input_size: int = 1, + lags_sequence: Optional[list[int]] = None, + scaling: Optional[Union[str, bool]] = "mean", + num_dynamic_real_features: int = 0, + num_static_real_features: int = 0, + num_static_categorical_features: int = 0, + num_time_features: int = 0, + cardinality: Optional[list[int]] = None, + embedding_dimension: Optional[list[int]] = None, + d_model: int = 64, + encoder_ffn_dim: int = 32, + decoder_ffn_dim: int = 32, + encoder_attention_heads: int = 2, + decoder_attention_heads: int = 2, + encoder_layers: int = 2, + decoder_layers: int = 2, + is_encoder_decoder: bool = True, + activation_function: str = "gelu", + dropout: float = 0.05, + encoder_layerdrop: float = 0.1, + decoder_layerdrop: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + num_parallel_samples: int = 100, + init_std: float = 0.02, + use_cache=True, + # Informer arguments + attention_type: str = "prob", + sampling_factor: int = 5, + distil: bool = True, + **kwargs, + ): + # time series specific configuration + self.prediction_length = prediction_length + self.context_length = context_length or prediction_length + self.distribution_output = distribution_output + self.loss = loss + self.input_size = input_size + self.num_time_features = num_time_features + self.lags_sequence = lags_sequence if lags_sequence is not None else [1, 2, 3, 4, 5, 6, 7] + self.scaling = scaling + self.num_dynamic_real_features = num_dynamic_real_features + self.num_static_real_features = num_static_real_features + self.num_static_categorical_features = num_static_categorical_features + + # set cardinality + if cardinality and num_static_categorical_features > 0: + if len(cardinality) != num_static_categorical_features: + raise ValueError( + "The cardinality should be a list of the same length as `num_static_categorical_features`" + ) + self.cardinality = cardinality + else: + self.cardinality = [0] + + # set embedding_dimension + if embedding_dimension and num_static_categorical_features > 0: + if len(embedding_dimension) != num_static_categorical_features: + raise ValueError( + "The embedding dimension should be a list of the same length as `num_static_categorical_features`" + ) + self.embedding_dimension = embedding_dimension + else: + self.embedding_dimension = [min(50, (cat + 1) // 2) for cat in self.cardinality] + + self.num_parallel_samples = num_parallel_samples + + # Transformer architecture configuration + self.feature_size = input_size * len(self.lags_sequence) + self._number_of_features + self.d_model = d_model + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + + self.activation_function = activation_function + self.init_std = init_std + + self.use_cache = use_cache + + # Informer + self.attention_type = attention_type + self.sampling_factor = sampling_factor + self.distil = distil + + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + @property + def _number_of_features(self) -> int: + return ( + sum(self.embedding_dimension) + + self.num_dynamic_real_features + + self.num_time_features + + self.num_static_real_features + + self.input_size * 2 # the log1p(abs(loc)) and log(scale) features + ) + + +__all__ = ["InformerConfig"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/informer/modeling_informer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/informer/modeling_informer.py new file mode 100644 index 0000000000000000000000000000000000000000..b0fc4964a9aea30e23019dbfcfb602c6b4946558 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/informer/modeling_informer.py @@ -0,0 +1,2160 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/informer/modular_informer.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_informer.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2023 Amazon and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import numpy as np +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + SampleTSPredictionOutput, + Seq2SeqTSModelOutput, + Seq2SeqTSPredictionOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput +from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_informer import InformerConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +class InformerFeatureEmbedder(nn.Module): + """ + Embed a sequence of categorical features. + + Args: + cardinalities (`list[int]`): + List of cardinalities of the categorical features. + embedding_dims (`list[int]`): + List of embedding dimensions of the categorical features. + """ + + def __init__(self, cardinalities: list[int], embedding_dims: list[int]) -> None: + super().__init__() + + self.num_features = len(cardinalities) + self.embedders = nn.ModuleList([nn.Embedding(c, d) for c, d in zip(cardinalities, embedding_dims)]) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + if self.num_features > 1: + # we slice the last dimension, giving an array of length + # self.num_features with shape (N,T) or (N) + cat_feature_slices = torch.chunk(features, self.num_features, dim=-1) + else: + cat_feature_slices = [features] + + return torch.cat( + [ + embed(cat_feature_slice.squeeze(-1)) + for embed, cat_feature_slice in zip(self.embedders, cat_feature_slices) + ], + dim=-1, + ) + + +class InformerStdScaler(nn.Module): + """ + Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by + subtracting from the mean and dividing by the standard deviation. + """ + + def __init__(self, config: InformerConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5 + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim) + denominator = denominator.clamp_min(1.0) + loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator + + variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator + scale = torch.sqrt(variance + self.minimum_scale) + return (data - loc) / scale, loc, scale + + +class InformerMeanScaler(nn.Module): + """ + Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data + accordingly. + """ + + def __init__(self, config: InformerConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 + self.default_scale = config.default_scale if hasattr(config, "default_scale") else None + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True) + num_observed = observed_indicator.sum(self.dim, keepdim=True) + + scale = ts_sum / torch.clamp(num_observed, min=1) + + # If `default_scale` is provided, we use it, otherwise we use the scale + # of the batch. + if self.default_scale is None: + batch_sum = ts_sum.sum(dim=0) + batch_observations = torch.clamp(num_observed.sum(0), min=1) + default_scale = torch.squeeze(batch_sum / batch_observations) + else: + default_scale = self.default_scale * torch.ones_like(scale) + + # apply default scale where there are no observations + scale = torch.where(num_observed > 0, scale, default_scale) + + # ensure the scale is at least `self.minimum_scale` + scale = torch.clamp(scale, min=self.minimum_scale) + scaled_data = data / scale + + if not self.keepdim: + scale = scale.squeeze(dim=self.dim) + + return scaled_data, torch.zeros_like(scale), scale + + +class InformerNOPScaler(nn.Module): + """ + Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. + """ + + def __init__(self, config: InformerConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + + def forward( + self, data: torch.Tensor, observed_indicator: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + return data, loc, scale + + +class InformerSinusoidalPositionalEmbedding(nn.Embedding): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: + super().__init__(num_positions, embedding_dim) + + def _init_weight(self): + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + n_pos, dim = self.weight.shape + position_enc = np.array( + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] + ) + out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False) + sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 + out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + self.weight = nn.Parameter(out, requires_grad=False) + + @torch.no_grad() + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """`input_ids_shape` is expected to be [bsz x seqlen].""" + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) + + +class InformerValueEmbedding(nn.Module): + def __init__(self, feature_size, d_model): + super().__init__() + self.value_projection = nn.Linear(in_features=feature_size, out_features=d_model, bias=False) + + def forward(self, x): + return self.value_projection(x) + + +@auto_docstring +class InformerPreTrainedModel(PreTrainedModel): + config: InformerConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module: nn.Module): + super()._init_weights(module) + if isinstance(module, InformerSinusoidalPositionalEmbedding): + module._init_weight() + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class InformerAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[InformerConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class InformerProbSparseAttention(nn.Module): + """Probabilistic Attention mechanism to select the "active" + queries rather than the "lazy" queries and provides a sparse Transformer thus mitigating the quadratic compute and + memory requirements of vanilla attention""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + sampling_factor: int = 5, + bias: bool = True, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.factor = sampling_factor + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + key_states_time_length = key_states.size(1) # L_K + log_key_states_time_length = np.ceil(np.log1p(key_states_time_length)).astype("int").item() # log_L_K + + query_states_time_length = query_states.size(1) # L_Q + log_query_states_time_length = np.ceil(np.log1p(query_states_time_length)).astype("int").item() # log_L_Q + + u_part = min(self.factor * query_states_time_length * log_key_states_time_length, key_states_time_length) + u = min(self.factor * log_query_states_time_length, query_states_time_length) + + if key_states_time_length > 0: + index_sample = torch.randint(0, key_states_time_length, (u_part,)) + k_sample = key_states[:, index_sample, :] + else: + k_sample = key_states + + queries_keys_sample = torch.bmm(query_states, k_sample.transpose(1, 2)) # Q_K_sampled + + # find the Top_k query with sparsity measurement + if u > 0: + sparsity_measurement = queries_keys_sample.max(dim=-1)[0] - torch.div( + queries_keys_sample.sum(dim=-1), key_states_time_length + ) # M + top_u_sparsity_measurement = sparsity_measurement.topk(u, sorted=False)[1] # M_top + + # calculate q_reduce: query_states[:, top_u_sparsity_measurement] + dim_for_slice = torch.arange(query_states.size(0)).unsqueeze(-1) + q_reduce = query_states[dim_for_slice, top_u_sparsity_measurement] + else: + q_reduce = query_states + top_u_sparsity_measurement = None + + # Use q_reduce to calculate attention weights + attn_weights = torch.bmm(q_reduce, key_states.transpose(1, 2)) + + src_len = key_states.size(1) + if attn_weights.size() != (bsz * self.num_heads, u, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, u, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + prob_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, src_len).reshape( + bsz * self.num_heads, tgt_len, src_len + ) + + if top_u_sparsity_measurement is not None: + dim_for_slice = torch.arange(prob_mask.size(0)).unsqueeze(-1) + prob_mask = prob_mask[dim_for_slice, top_u_sparsity_measurement, :] + + attn_weights = attn_weights.view(bsz, self.num_heads, u, src_len) + prob_mask.view( + bsz, self.num_heads, u, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, u, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.bmm(attn_probs, value_states) + + # calculate context for updating the attn_output, based on: + # https://github.com/zhouhaoyi/Informer2020/blob/ac59c7447135473fb2aafeafe94395f884d5c7a5/models/attn.py#L74 + if self.is_decoder: + # cast to float32 before operation to avoid overflow + context = value_states.cumsum(dim=-2, dtype=torch.float32).to(value_states.dtype) + else: + v_mean_dim_time = value_states.mean(dim=-2) + context = ( + v_mean_dim_time.unsqueeze(dim=1) + .expand(bsz * self.num_heads, query_states_time_length, v_mean_dim_time.size(-1)) + .clone() + ) + + if top_u_sparsity_measurement is not None: + # update context: copy the attention output to the context at top_u_sparsity_measurement index + dim_for_slice = torch.arange(context.size(0)).unsqueeze(-1) + context[dim_for_slice, top_u_sparsity_measurement, :] = attn_output + attn_output = context + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py +class InformerConvLayer(GradientCheckpointingLayer): + def __init__(self, c_in): + super().__init__() + self.downConv = nn.Conv1d( + in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=1, + padding_mode="circular", + ) + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class InformerEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: InformerConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class InformerDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.d_model + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = InformerAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + is_decoder=True, + layer_idx=layer_idx, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_values (`Cache`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class InformerEncoder(InformerPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`InformerEncoderLayer`]. + + Args: + config: InformerConfig + """ + + def __init__(self, config: InformerConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + self.gradient_checkpointing = False + + if config.distil: + self.conv_layers = nn.ModuleList( + [InformerConvLayer(config.d_model) for _ in range(config.encoder_layers - 1)] + ) + self.conv_layers.append(None) + else: + self.conv_layers = [None] * config.encoder_layers + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.value_embedding(inputs_embeds) + embed_pos = self.embed_positions(inputs_embeds.size()) + + hidden_states = self.layernorm_embedding(hidden_states + embed_pos) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, (encoder_layer, conv_layer) in enumerate(zip(self.layers, self.conv_layers)): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + if conv_layer is not None: + output = conv_layer(layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class InformerDecoder(InformerPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a + [`InformerDecoderLayer`] + + Args: + config: InformerConfig + """ + + def __init__(self, config: InformerConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_shape = inputs_embeds.size()[:-1] + # initialize `past_key_values` + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device + ) + + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) + + hidden_states = self.value_embedding(inputs_embeds) + embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length) + hidden_states = self.layernorm_embedding(hidden_states + embed_pos) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring +class InformerModel(InformerPreTrainedModel): + def __init__(self, config: InformerConfig): + super().__init__(config) + + if config.scaling == "mean" or config.scaling is True: + self.scaler = InformerMeanScaler(config) + elif config.scaling == "std": + self.scaler = InformerStdScaler(config) + else: + self.scaler = InformerNOPScaler(config) + + if config.num_static_categorical_features > 0: + self.embedder = InformerFeatureEmbedder( + cardinalities=config.cardinality, + embedding_dims=config.embedding_dimension, + ) + + # transformer encoder-decoder and mask initializer + self.encoder = InformerEncoder(config) + self.decoder = InformerDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + @property + def _past_length(self) -> int: + return self.config.context_length + max(self.config.lags_sequence) + + def get_lagged_subsequences( + self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0 + ) -> torch.Tensor: + """ + Returns lagged subsequences of a given sequence. Returns a tensor of shape (N, S, C, I), + where S = subsequences_length and I = len(indices), containing lagged subsequences. Specifically, lagged[i, + j, :, k] = sequence[i, -indices[k]-S+j, :]. + + Args: + sequence: Tensor + The sequence from which lagged subsequences should be extracted. Shape: (N, T, C). + subsequences_length : int + Length of the subsequences to be extracted. + shift: int + Shift the lags by this amount back. + """ + sequence_length = sequence.shape[1] + indices = [lag - shift for lag in self.config.lags_sequence] + + if max(indices) + subsequences_length > sequence_length: + raise ValueError( + f"lags cannot go further than history length, found lag {max(indices)} " + f"while history length is only {sequence_length}" + ) + + lagged_values = [] + for lag_index in indices: + begin_index = -lag_index - subsequences_length + end_index = -lag_index if lag_index > 0 else None + lagged_values.append(sequence[:, begin_index:end_index, ...]) + return torch.stack(lagged_values, dim=-1) + + def create_network_inputs( + self, + past_values: torch.Tensor, + past_time_features: torch.Tensor, + static_categorical_features: Optional[torch.Tensor] = None, + static_real_features: Optional[torch.Tensor] = None, + past_observed_mask: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + future_time_features: Optional[torch.Tensor] = None, + ): + # time feature + time_feat = ( + torch.cat( + ( + past_time_features[:, self._past_length - self.config.context_length :, ...], + future_time_features, + ), + dim=1, + ) + if future_values is not None + else past_time_features[:, self._past_length - self.config.context_length :, ...] + ) + + # target + if past_observed_mask is None: + past_observed_mask = torch.ones_like(past_values) + + context = past_values[:, -self.config.context_length :] + observed_context = past_observed_mask[:, -self.config.context_length :] + _, loc, scale = self.scaler(context, observed_context) + + inputs = ( + (torch.cat((past_values, future_values), dim=1) - loc) / scale + if future_values is not None + else (past_values - loc) / scale + ) + + # static features + if loc.ndim == 3: + squeezed_loc = loc.squeeze(1) + squeezed_scale = scale.squeeze(1) + else: + squeezed_loc = loc + squeezed_scale = scale + log_abs_loc = squeezed_loc.abs().log1p() + log_scale = squeezed_scale.log() + static_feat = torch.cat((log_abs_loc, log_scale), dim=1) + + if static_real_features is not None: + static_feat = torch.cat((static_real_features, static_feat), dim=1) + if static_categorical_features is not None: + embedded_cat = self.embedder(static_categorical_features) + static_feat = torch.cat((embedded_cat, static_feat), dim=1) + expanded_static_feat = static_feat.unsqueeze(1).expand(-1, time_feat.shape[1], -1) + + # all features + features = torch.cat((expanded_static_feat, time_feat), dim=-1) + + # lagged features + subsequences_length = ( + self.config.context_length + self.config.prediction_length + if future_values is not None + else self.config.context_length + ) + lagged_sequence = self.get_lagged_subsequences(sequence=inputs, subsequences_length=subsequences_length) + lags_shape = lagged_sequence.shape + reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1) + + if reshaped_lagged_sequence.shape[1] != time_feat.shape[1]: + raise ValueError( + f"input length {reshaped_lagged_sequence.shape[1]} and time feature lengths {time_feat.shape[1]} does not match" + ) + + # transformer inputs + transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1) + + return transformer_inputs, loc, scale, static_feat + + def get_encoder(self): + return self.encoder + + @auto_docstring + def forward( + self, + past_values: torch.Tensor, + past_time_features: torch.Tensor, + past_observed_mask: torch.Tensor, + static_categorical_features: Optional[torch.Tensor] = None, + static_real_features: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + future_time_features: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Seq2SeqTSModelOutput, tuple]: + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerModel + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerModel.from_pretrained("huggingface/informer-tourism-monthly") + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> last_hidden_state = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_inputs, loc, scale, static_feat = self.create_network_inputs( + past_values=past_values, + past_time_features=past_time_features, + past_observed_mask=past_observed_mask, + static_categorical_features=static_categorical_features, + static_real_features=static_real_features, + future_values=future_values, + future_time_features=future_time_features, + ) + + if encoder_outputs is None: + enc_input = transformer_inputs[:, : self.config.context_length, ...] + encoder_outputs = self.encoder( + inputs_embeds=enc_input, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Avoid empty tensors and instead create a zeroes tensor which + # will be treated the same in torch, i.e. matmul with empty == all 0s + if self.config.context_length >= transformer_inputs.shape[1]: + bsz, _, dim = transformer_inputs.shape + dec_input = torch.zeros( + size=(bsz, 1, dim), device=transformer_inputs.device, dtype=transformer_inputs.dtype + ) + else: + dec_input = transformer_inputs[:, self.config.context_length :, ...] + + decoder_outputs = self.decoder( + inputs_embeds=dec_input, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + (loc, scale, static_feat) + + return Seq2SeqTSModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + loc=loc, + scale=scale, + static_features=static_feat, + ) + + +def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: + """ + Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, + meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. + + Args: + input_tensor (`torch.FloatTensor`): + Input tensor, of which the average must be computed. + weights (`torch.FloatTensor`, *optional*): + Weights tensor, of the same shape as `input_tensor`. + dim (`int`, *optional*): + The dim along which to average `input_tensor`. + + Returns: + `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. + """ + if weights is not None: + weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) + sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) + return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights + else: + return input_tensor.mean(dim=dim) + + +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + +@auto_docstring +class InformerForPrediction(InformerPreTrainedModel): + def __init__(self, config: InformerConfig): + super().__init__(config) + + self.model = InformerModel(config) + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.input_size) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.input_size) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.input_size) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.parameter_projection = self.distribution_output.get_parameter_projection(self.model.config.d_model) + self.target_shape = self.distribution_output.event_shape + + if config.loss == "nll": + self.loss = nll + else: + raise ValueError(f"Unknown loss function {config.loss}") + + # Initialize weights of distribution_output and apply final processing + self.post_init() + + def output_params(self, dec_output): + return self.parameter_projection(dec_output) + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + @torch.jit.ignore + def output_distribution(self, params, loc=None, scale=None, trailing_n=None) -> torch.distributions.Distribution: + sliced_params = params + if trailing_n is not None: + sliced_params = [p[:, -trailing_n:] for p in params] + return self.distribution_output.distribution(sliced_params, loc=loc, scale=scale) + + @auto_docstring + def forward( + self, + past_values: torch.Tensor, + past_time_features: torch.Tensor, + past_observed_mask: torch.Tensor, + static_categorical_features: Optional[torch.Tensor] = None, + static_real_features: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + future_time_features: Optional[torch.Tensor] = None, + future_observed_mask: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Seq2SeqTSModelOutput, tuple]: + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + future_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + This mask is used to filter out missing values for the final loss calculation. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerForPrediction + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerForPrediction.from_pretrained( + ... "huggingface/informer-tourism-monthly" + ... ) + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> loss = outputs.loss + >>> loss.backward() + + >>> # during inference, one only provides past values + >>> # as well as possible additional features + >>> # the model autoregressively generates future values + >>> outputs = model.generate( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> mean_prediction = outputs.sequences.mean(dim=1) + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if future_values is not None: + use_cache = False + + outputs = self.model( + past_values=past_values, + past_time_features=past_time_features, + past_observed_mask=past_observed_mask, + static_categorical_features=static_categorical_features, + static_real_features=static_real_features, + future_values=future_values, + future_time_features=future_time_features, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + use_cache=use_cache, + return_dict=return_dict, + cache_position=cache_position, + ) + + prediction_loss = None + params = None + if future_values is not None: + params = self.output_params(outputs[0]) # outputs.last_hidden_state + # loc is 3rd last and scale is 2nd last output + distribution = self.output_distribution(params, loc=outputs[-3], scale=outputs[-2]) + + loss = self.loss(distribution, future_values) + + if future_observed_mask is None: + future_observed_mask = torch.ones_like(future_values) + + if len(self.target_shape) == 0: + loss_weights = future_observed_mask + else: + loss_weights, _ = future_observed_mask.min(dim=-1, keepdim=False) + + prediction_loss = weighted_average(loss, weights=loss_weights) + + if not return_dict: + outputs = ((params,) + outputs[1:]) if params is not None else outputs[1:] + return ((prediction_loss,) + outputs) if prediction_loss is not None else outputs + + return Seq2SeqTSPredictionOutput( + loss=prediction_loss, + params=params, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + loc=outputs.loc, + scale=outputs.scale, + static_features=outputs.static_features, + ) + + @torch.no_grad() + def generate( + self, + past_values: torch.Tensor, + past_time_features: torch.Tensor, + future_time_features: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + static_categorical_features: Optional[torch.Tensor] = None, + static_real_features: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SampleTSPredictionOutput: + r""" + Greedily generate sequences of sample predictions from a model with a probability distribution head. + + Parameters: + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size + of this tensor must be larger than the `context_length` of the model, since the model will use the + larger size to construct lag features, i.e. additional values from the past which are added in order to + serve as "extra context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if + no `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length + of the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, + such as `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number + of variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things + like "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). + These could also be so-called "age" features, which basically help the model know "at which point in + life" a time-series is. Age features have small values for distant past time steps and increase + monotonically the more we approach the current time step. Holiday features are also a good example of + time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, + where the position encodings are learned from scratch internally as parameters of the model, the Time + Series Transformer requires to provide additional time features. The Time Series Transformer only + learns additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these + features must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to sampled + predictions. These could be things like "month of year", "day of the month", etc. encoded as vectors + (for instance as Fourier features). These could also be so-called "age" features, which basically help + the model know "at which point in life" a time-series is. Age features have small values for distant + past time steps and increase monotonically the more we approach the current time step. Holiday features + are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, + where the position encodings are learned from scratch internally as parameters of the model, the Time + Series Transformer requires to provide additional time features. The Time Series Transformer only + learns additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these + features must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to + the values of the time series. + + Static categorical features are features which have the same value for all time steps (static over + time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + + Return: + [`SampleTSPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of + samples, prediction_length)` or `(batch_size, number of samples, prediction_length, input_size)` for + multivariate predictions. + """ + outputs = self( + static_categorical_features=static_categorical_features, + static_real_features=static_real_features, + past_time_features=past_time_features, + past_values=past_values, + past_observed_mask=past_observed_mask, + future_time_features=future_time_features, + future_values=None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + use_cache=True, + ) + + decoder = self.model.get_decoder() + enc_last_hidden = outputs.encoder_last_hidden_state + loc = outputs.loc + scale = outputs.scale + static_feat = outputs.static_features + + num_parallel_samples = self.config.num_parallel_samples + repeated_loc = loc.repeat_interleave(repeats=num_parallel_samples, dim=0) + repeated_scale = scale.repeat_interleave(repeats=num_parallel_samples, dim=0) + + repeated_past_values = ( + past_values.repeat_interleave(repeats=num_parallel_samples, dim=0) - repeated_loc + ) / repeated_scale + + expanded_static_feat = static_feat.unsqueeze(1).expand(-1, future_time_features.shape[1], -1) + features = torch.cat((expanded_static_feat, future_time_features), dim=-1) + repeated_features = features.repeat_interleave(repeats=num_parallel_samples, dim=0) + + repeated_enc_last_hidden = enc_last_hidden.repeat_interleave(repeats=num_parallel_samples, dim=0) + + future_samples = [] + + # greedy decoding + for k in range(self.config.prediction_length): + lagged_sequence = self.model.get_lagged_subsequences( + sequence=repeated_past_values, + subsequences_length=1 + k, + shift=1, + ) + + lags_shape = lagged_sequence.shape + reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1) + + decoder_input = torch.cat((reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1) + + dec_output = decoder(inputs_embeds=decoder_input, encoder_hidden_states=repeated_enc_last_hidden) + dec_last_hidden = dec_output.last_hidden_state + + params = self.parameter_projection(dec_last_hidden[:, -1:]) + distr = self.output_distribution(params, loc=repeated_loc, scale=repeated_scale) + next_sample = distr.sample() + + repeated_past_values = torch.cat( + (repeated_past_values, (next_sample - repeated_loc) / repeated_scale), dim=1 + ) + future_samples.append(next_sample) + + concat_future_samples = torch.cat(future_samples, dim=1) + + return SampleTSPredictionOutput( + sequences=concat_future_samples.reshape( + (-1, num_parallel_samples, self.config.prediction_length) + self.target_shape, + ) + ) + + +__all__ = ["InformerForPrediction", "InformerModel", "InformerPreTrainedModel"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/informer/modular_informer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/informer/modular_informer.py new file mode 100644 index 0000000000000000000000000000000000000000..aa6e2ad30a9f858393bff04dd5b3f797574bfacf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/informer/modular_informer.py @@ -0,0 +1,981 @@ +# coding=utf-8 +# Copyright 2023 Amazon and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Informer model.""" + +from typing import Optional, Union + +import numpy as np +import torch +from torch import nn + +from ...cache_utils import Cache, EncoderDecoderCache +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, +) +from ...utils.deprecation import deprecate_kwarg +from ..bart.modeling_bart import BartAttention +from ..time_series_transformer.modeling_time_series_transformer import ( + TimeSeriesFeatureEmbedder, + TimeSeriesMeanScaler, + TimeSeriesNOPScaler, + TimeSeriesSinusoidalPositionalEmbedding, + TimeSeriesStdScaler, + TimeSeriesTransformerDecoder, + TimeSeriesTransformerDecoderLayer, + TimeSeriesTransformerEncoder, + TimeSeriesTransformerEncoderLayer, + TimeSeriesTransformerForPrediction, + TimeSeriesTransformerModel, + TimeSeriesValueEmbedding, +) +from .configuration_informer import InformerConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + +class InformerFeatureEmbedder(TimeSeriesFeatureEmbedder): + pass + + +class InformerStdScaler(TimeSeriesStdScaler): + pass + + +class InformerMeanScaler(TimeSeriesMeanScaler): + pass + + +class InformerNOPScaler(TimeSeriesNOPScaler): + pass + + +class InformerSinusoidalPositionalEmbedding(TimeSeriesSinusoidalPositionalEmbedding): + pass + + +class InformerValueEmbedding(TimeSeriesValueEmbedding): + pass + + +@auto_docstring +class InformerPreTrainedModel(PreTrainedModel): + config: InformerConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module: nn.Module): + super()._init_weights(module) + if isinstance(module, InformerSinusoidalPositionalEmbedding): + module._init_weight() + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +class InformerAttention(BartAttention): + pass + + +class InformerProbSparseAttention(nn.Module): + """Probabilistic Attention mechanism to select the "active" + queries rather than the "lazy" queries and provides a sparse Transformer thus mitigating the quadratic compute and + memory requirements of vanilla attention""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + sampling_factor: int = 5, + bias: bool = True, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.factor = sampling_factor + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + key_states_time_length = key_states.size(1) # L_K + log_key_states_time_length = np.ceil(np.log1p(key_states_time_length)).astype("int").item() # log_L_K + + query_states_time_length = query_states.size(1) # L_Q + log_query_states_time_length = np.ceil(np.log1p(query_states_time_length)).astype("int").item() # log_L_Q + + u_part = min(self.factor * query_states_time_length * log_key_states_time_length, key_states_time_length) + u = min(self.factor * log_query_states_time_length, query_states_time_length) + + if key_states_time_length > 0: + index_sample = torch.randint(0, key_states_time_length, (u_part,)) + k_sample = key_states[:, index_sample, :] + else: + k_sample = key_states + + queries_keys_sample = torch.bmm(query_states, k_sample.transpose(1, 2)) # Q_K_sampled + + # find the Top_k query with sparsity measurement + if u > 0: + sparsity_measurement = queries_keys_sample.max(dim=-1)[0] - torch.div( + queries_keys_sample.sum(dim=-1), key_states_time_length + ) # M + top_u_sparsity_measurement = sparsity_measurement.topk(u, sorted=False)[1] # M_top + + # calculate q_reduce: query_states[:, top_u_sparsity_measurement] + dim_for_slice = torch.arange(query_states.size(0)).unsqueeze(-1) + q_reduce = query_states[dim_for_slice, top_u_sparsity_measurement] + else: + q_reduce = query_states + top_u_sparsity_measurement = None + + # Use q_reduce to calculate attention weights + attn_weights = torch.bmm(q_reduce, key_states.transpose(1, 2)) + + src_len = key_states.size(1) + if attn_weights.size() != (bsz * self.num_heads, u, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, u, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + prob_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, src_len).reshape( + bsz * self.num_heads, tgt_len, src_len + ) + + if top_u_sparsity_measurement is not None: + dim_for_slice = torch.arange(prob_mask.size(0)).unsqueeze(-1) + prob_mask = prob_mask[dim_for_slice, top_u_sparsity_measurement, :] + + attn_weights = attn_weights.view(bsz, self.num_heads, u, src_len) + prob_mask.view( + bsz, self.num_heads, u, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, u, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.bmm(attn_probs, value_states) + + # calculate context for updating the attn_output, based on: + # https://github.com/zhouhaoyi/Informer2020/blob/ac59c7447135473fb2aafeafe94395f884d5c7a5/models/attn.py#L74 + if self.is_decoder: + # cast to float32 before operation to avoid overflow + context = value_states.cumsum(dim=-2, dtype=torch.float32).to(value_states.dtype) + else: + v_mean_dim_time = value_states.mean(dim=-2) + context = ( + v_mean_dim_time.unsqueeze(dim=1) + .expand(bsz * self.num_heads, query_states_time_length, v_mean_dim_time.size(-1)) + .clone() + ) + + if top_u_sparsity_measurement is not None: + # update context: copy the attention output to the context at top_u_sparsity_measurement index + dim_for_slice = torch.arange(context.size(0)).unsqueeze(-1) + context[dim_for_slice, top_u_sparsity_measurement, :] = attn_output + attn_output = context + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py +class InformerConvLayer(GradientCheckpointingLayer): + def __init__(self, c_in): + super().__init__() + self.downConv = nn.Conv1d( + in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=1, + padding_mode="circular", + ) + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class InformerEncoderLayer(TimeSeriesTransformerEncoderLayer): + def __init__(self, config: InformerConfig): + super().__init__(config) + + del self.self_attn + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + +class InformerDecoderLayer(TimeSeriesTransformerDecoderLayer): + def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None): + super().__init__(config) + + del self.self_attn + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + is_decoder=True, + layer_idx=layer_idx, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + + +class InformerEncoder(TimeSeriesTransformerEncoder): + def __init__(self, config: InformerConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + self.gradient_checkpointing = False + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + if config.distil: + self.conv_layers = nn.ModuleList( + [InformerConvLayer(config.d_model) for _ in range(config.encoder_layers - 1)] + ) + self.conv_layers.append(None) + else: + self.conv_layers = [None] * config.encoder_layers + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.value_embedding(inputs_embeds) + embed_pos = self.embed_positions(inputs_embeds.size()) + + hidden_states = self.layernorm_embedding(hidden_states + embed_pos) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, (encoder_layer, conv_layer) in enumerate(zip(self.layers, self.conv_layers)): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + if conv_layer is not None: + output = conv_layer(layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class InformerDecoder(TimeSeriesTransformerDecoder): + def __init__(self, config: InformerConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + +class InformerModel(TimeSeriesTransformerModel): + def __init__(self, config: InformerConfig): + PreTrainedModel.__init__(self, config) + + if config.scaling == "mean" or config.scaling is True: + self.scaler = InformerMeanScaler(config) + elif config.scaling == "std": + self.scaler = InformerStdScaler(config) + else: + self.scaler = InformerNOPScaler(config) + + if config.num_static_categorical_features > 0: + self.embedder = InformerFeatureEmbedder( + cardinalities=config.cardinality, + embedding_dims=config.embedding_dimension, + ) + + # transformer encoder-decoder and mask initializer + self.encoder = InformerEncoder(config) + self.decoder = InformerDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward(self, **super_kwargs): + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerModel + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerModel.from_pretrained("huggingface/informer-tourism-monthly") + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> last_hidden_state = outputs.last_hidden_state + ```""" + super().forward(**super_kwargs) + + +class InformerForPrediction(TimeSeriesTransformerForPrediction): + def __init__(self, config: InformerConfig): + PreTrainedModel.__init__(self, config) + + self.model = InformerModel(config) + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.input_size) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.input_size) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.input_size) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.parameter_projection = self.distribution_output.get_parameter_projection(self.model.config.d_model) + self.target_shape = self.distribution_output.event_shape + + if config.loss == "nll": + self.loss = nll + else: + raise ValueError(f"Unknown loss function {config.loss}") + + # Initialize weights of distribution_output and apply final processing + self.post_init() + + @auto_docstring + def forward(self, **super_kwargs): + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + future_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + This mask is used to filter out missing values for the final loss calculation. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerForPrediction + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerForPrediction.from_pretrained( + ... "huggingface/informer-tourism-monthly" + ... ) + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> loss = outputs.loss + >>> loss.backward() + + >>> # during inference, one only provides past values + >>> # as well as possible additional features + >>> # the model autoregressively generates future values + >>> outputs = model.generate( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> mean_prediction = outputs.sequences.mean(dim=1) + ```""" + super().forward(**super_kwargs) + + +__all__ = ["InformerForPrediction", "InformerModel", "InformerPreTrainedModel"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mamba/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mamba/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8948cf5baa3035ca0f5a9235349f68aea8c594c0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mamba/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e803c5d8e69be941067d04fb5c0e45acb578b143 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df1227092224ef6d4e9cb5fce2361e02b72381f7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3b5c95ac8c618512ee56013a90ed7b79715be48 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/configuration_mllama.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/configuration_mllama.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f4f5f63f52fce13d3c1ba766fef248a3f2eff25 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/configuration_mllama.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/image_processing_mllama.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/image_processing_mllama.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..939215afbbb7e09e214025771c46b90b996e7e49 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/image_processing_mllama.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/modeling_mllama.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/modeling_mllama.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4489923572da3e6ed73b5955f4059733910cf3b6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/modeling_mllama.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/processing_mllama.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/processing_mllama.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdbf823a42f6e8debcb55b68c462ec4553cc9df1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mllama/__pycache__/processing_mllama.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mpt/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mpt/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a86af2944a4ee7050d949817f934beb5c3f15c4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mpt/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mpt/__pycache__/configuration_mpt.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mpt/__pycache__/configuration_mpt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5097d1f730f58697204810cb9e49f0ba6b0a469 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mpt/__pycache__/configuration_mpt.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mpt/__pycache__/modeling_mpt.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mpt/__pycache__/modeling_mpt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11e8ec5a67acaf157b95d42d0ae046866e84c711 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/mpt/__pycache__/modeling_mpt.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/speech_encoder_decoder/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/speech_encoder_decoder/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c185e77fc322c7fc20ae3cd01f19a7e8c60a2c7d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/speech_encoder_decoder/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/speech_encoder_decoder/__pycache__/configuration_speech_encoder_decoder.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/speech_encoder_decoder/__pycache__/configuration_speech_encoder_decoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64ec7d29ba9870ed7f46e847585a2c0e5a4e2cb3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/speech_encoder_decoder/__pycache__/configuration_speech_encoder_decoder.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/speech_encoder_decoder/__pycache__/modeling_flax_speech_encoder_decoder.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/speech_encoder_decoder/__pycache__/modeling_flax_speech_encoder_decoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..031d1b4cfc456d14f73441638a3827a67f256663 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/speech_encoder_decoder/__pycache__/modeling_flax_speech_encoder_decoder.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/speech_encoder_decoder/__pycache__/modeling_speech_encoder_decoder.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/speech_encoder_decoder/__pycache__/modeling_speech_encoder_decoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35d73e452ff24ee8260c2203aeb1dd5a20d7ba12 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/speech_encoder_decoder/__pycache__/modeling_speech_encoder_decoder.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/t5gemma/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/t5gemma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa8099e267826725cd4e6abcbb5098e4f37cc96e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/t5gemma/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_encdecgemma2 import * + from .modeling_encdecgemma2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/t5gemma/configuration_t5gemma.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/t5gemma/configuration_t5gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..2085cc8aa5178692ac21bfb0d0715fc681621601 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/t5gemma/configuration_t5gemma.py @@ -0,0 +1,327 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/t5gemma/modular_t5gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_t5gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +from ...configuration_utils import PretrainedConfig, layer_type_validation + + +class T5GemmaModuleConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5GemmaModuleModel`]. It is used to instantiate an T5GemmaModule + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the T5GemmaModule-7B. + e.g. [google/t5_gemma_module-7b](https://huggingface.co/google/t5_gemma_module-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the T5GemmaModule model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`T5GemmaModuleModel`] + hidden_size (`int`, *optional*, defaults to 2304): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 9216): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): + scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): + in T5GemmaModule, every other layer uses sliding window attention. This is the size of the sliding window. + layer_types (`list`, *optional*): + Attention pattern for each layer. + final_logit_softcapping (`float`, *optional*, defaults to 30.0): + scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): + scaling factor when applying tanh softcapping on the attention scores. + + ```python + >>> from transformers import T5GemmaModuleModel, T5GemmaModuleConfig + >>> # Initializing a T5GemmaModule t5_gemma_module-7b style configuration + >>> configuration = T5GemmaModuleConfig() + >>> # Initializing a model from the t5_gemma_module-7b style configuration + >>> model = T5GemmaModuleModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "t5_gemma_module" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=256000, + hidden_size=2304, + intermediate_size=9216, + num_hidden_layers=26, + num_attention_heads=8, + num_key_value_heads=4, + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + query_pre_attn_scalar=256, + sliding_window=4096, + layer_types=None, + final_logit_softcapping=30.0, + attn_logit_softcapping=50.0, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.query_pre_attn_scalar = query_pre_attn_scalar + self.sliding_window = sliding_window + self.final_logit_softcapping = final_logit_softcapping + self.attn_logit_softcapping = attn_logit_softcapping + self.layer_types = layer_types + + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + +class T5GemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5GemmaModel`]. It is used to instantiate an T5Gemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to a hypothetical balanced Gemma2 encoder-decoder model. + e.g. [google/t5gemma-2b-2b-prefixlm-it](https://huggingface.co/google/t5gemma-2b-2b-prefixlm-it) + ```python + >>> from transformers import T5GemmaConfig, T5GemmaModel + >>> t5gemma_config = T5GemmaConfig.from_pretrained("google/t5gemma-2b-2b-prefixlm-it") + >>> model = T5GemmaModel(t5gemma_config) + ``` + Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the + documentation from [PretrainedConfig] for more information. + Args: + encoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*): + Configuration for the encoder. + decoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*): + Configuration for the decoder. + is_encoder_decoder (bool, optional, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + dropout_rate (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers (following T5). + classifier_dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier (following T5). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for attention. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether tie input and output embeddings. + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the T5Gemma model (the same as Gemma 2). + kwargs (additional keyword arguments, optional, *optional*): + Will be passed to the PretrainedConfig base class. + """ + + model_type = "t5gemma" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + # encoder + "encoder.layers.*.self_attn.q_proj": "colwise", + "encoder.layers.*.self_attn.k_proj": "colwise", + "encoder.layers.*.self_attn.v_proj": "colwise", + "encoder.layers.*.self_attn.o_proj": "rowwise", + "encoder.layers.*.mlp.gate_proj": "colwise", + "encoder.layers.*.mlp.up_proj": "colwise", + "encoder.layers.*.mlp.down_proj": "rowwise", + # decoder + "decoder.layers.*.self_attn.q_proj": "colwise", + "decoder.layers.*.self_attn.k_proj": "colwise", + "decoder.layers.*.self_attn.v_proj": "colwise", + "decoder.layers.*.self_attn.o_proj": "rowwise", + "decoder.layers.*.cross_attn.q_proj": "colwise", + "decoder.layers.*.cross_attn.k_proj": "colwise", + "decoder.layers.*.cross_attn.v_proj": "colwise", + "decoder.layers.*.cross_attn.o_proj": "rowwise", + "decoder.layers.*.mlp.gate_proj": "colwise", + "decoder.layers.*.mlp.up_proj": "colwise", + "decoder.layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + # encoder + "encoder.embed_tokens": (["input_ids"], ["inputs_embeds"]), + "encoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "encoder.norm": (["hidden_states"], ["hidden_states"]), + # decoder + "decoder.embed_tokens": (["input_ids"], ["inputs_embeds"]), + "decoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "decoder.norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + encoder: Optional[Union[T5GemmaModuleConfig, dict[Any, Any]]] = None, + decoder: Optional[Union[T5GemmaModuleConfig, dict[Any, Any]]] = None, + is_encoder_decoder: bool = True, + dropout_rate: float = 0.0, + classifier_dropout_rate: float = 0.0, + attention_dropout: float = 0.0, + tie_word_embeddings: bool = True, + vocab_size: int = 256000, + **kwargs, + ): + if isinstance(encoder, dict): + encoder = T5GemmaModuleConfig(**encoder) + elif encoder is None: + encoder = T5GemmaModuleConfig() + else: + assert isinstance(encoder, T5GemmaModuleConfig), f"{type(encoder)} is not supported." + + if isinstance(decoder, dict): + decoder = T5GemmaModuleConfig(**decoder) + elif decoder is None: + decoder = encoder + else: + assert isinstance(decoder, T5GemmaModuleConfig), f"{type(decoder)} is not supported." + + encoder = T5GemmaModuleConfig(**encoder.to_dict()) + decoder = T5GemmaModuleConfig(**decoder.to_dict()) + + encoder.is_decoder = False + encoder.dropout_rate = dropout_rate + encoder.attention_dropout = attention_dropout + self.encoder = encoder + + decoder.is_decoder = True + decoder.use_cache = True + decoder.dropout_rate = dropout_rate + decoder.attention_dropout = attention_dropout + decoder.cross_attention_hidden_size = encoder.hidden_size + self.decoder = decoder + + for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id"]: + if special_token_key not in kwargs: + kwargs[special_token_key] = getattr(decoder, special_token_key) + + super().__init__(**kwargs) + + self.is_encoder_decoder = is_encoder_decoder + self.use_cache = kwargs.get("use_cache", decoder.use_cache) + self.initializer_range = kwargs.get("initializer_range", decoder.initializer_range) + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.classifier_dropout_rate = classifier_dropout_rate + self.tie_word_embeddings = tie_word_embeddings + + # Used in pipeline generation. + self.vocab_size = vocab_size + + def __setattr__(self, key, value): + shared_attr_with_submodules = [ + "output_hidden_states", + "output_attentions", + "_attn_implementation", + "dropout_rate", + "attention_dropout", + "vocab_size", + ] + + if key in shared_attr_with_submodules: + setattr(self.encoder, key, value) + setattr(self.decoder, key, value) + super().__setattr__(key, value) + + +__all__ = ["T5GemmaConfig", "T5GemmaModuleConfig"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/t5gemma/modeling_t5gemma.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/t5gemma/modeling_t5gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..b6be86e9cdd70b5640604b6496096d7419707fcf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/t5gemma/modeling_t5gemma.py @@ -0,0 +1,1385 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/t5gemma/modular_t5gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_t5gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from .configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig + + +logger = logging.get_logger(__name__) + + +class T5GemmaRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst T5Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class T5GemmaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, x): + hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + hidden_states = self.dropout(hidden_states) + down_proj = self.down_proj(hidden_states) + return down_proj + + +class T5GemmaRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class T5GemmaSelfAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + # Required by flash attention: encoder selfattention is non-causal + self.is_causal = config.is_decoder + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + softcap=self.attn_logit_softcapping, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class T5GemmaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + self.is_causal = False + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + + self.k_proj = nn.Linear( + config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + + if config.cross_attention_hidden_size is None: + raise ValueError("Cross-attention needs cross_attention_hidden_size to be specified.") + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + encoder_hidden_states: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if encoder_hidden_states is None: + raise ValueError("Encoder hidden state is required for cross attention.") + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if past_key_values is not None: + is_updated = past_key_values.is_updated.get(self.layer_idx) + curr_past_key_value = past_key_values.cross_attention_cache + + if past_key_values is None or not is_updated: + encoder_input_shape = encoder_hidden_states.shape[:-1] + encoder_hidden_shape = (*encoder_input_shape, -1, self.head_dim) + key_states = self.k_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + + if past_key_values is not None: + key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) + past_key_values.is_updated[self.layer_idx] = True + else: + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=None, + softcap=self.attn_logit_softcapping, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class T5GemmaEncoderLayer(GradientCheckpointingLayer): + """Encoder sub-layer.""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.config = config + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + + self.self_attn = T5GemmaSelfAttention( + config=config, + layer_idx=layer_idx, + ) + self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.mlp = T5GemmaMLP(config) + self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.FloatTensor,]: + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class T5GemmaDecoderLayer(T5GemmaEncoderLayer): + """Decoder sub-layer: an extra cross-attention layer.""" + + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + self.cross_attn = T5GemmaCrossAttention(config=config, layer_idx=layer_idx) + self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_cross_attn_layernorm(hidden_states) + hidden_states, _ = self.cross_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + hidden_states = self.post_cross_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class T5GemmaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, hidden_size: int, num_labels: int, classifier_dropout_rate: float = 0.0): + super().__init__() + self.dropout = nn.Dropout(p=classifier_dropout_rate) + self.out_proj = nn.Linear(hidden_size, num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class T5GemmaLMHead(nn.Module): + """Head for language modeling (generation) tasks.""" + + def __init__(self, hidden_size: int, vocab_size: int, bias: bool = False): + super().__init__() + self.out_proj = nn.Linear(hidden_size, vocab_size, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.out_proj(hidden_states) + return logits + + +class T5GemmaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: T5GemmaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + softcap=self.attn_logit_softcapping, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@auto_docstring +class T5GemmaPreTrainedModel(PreTrainedModel): + config: T5GemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["T5GemmaEncoderLayer", "T5GemmaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": T5GemmaDecoderLayer, + "attentions": T5GemmaAttention, + } + + def _init_weights(self, module): + # TODO: support initialization for encoders and decoders separately(?) + super()._init_weights(module) + std = self.config.initializer_range + if isinstance(module, T5GemmaClassificationHead): + scale = module.out_proj.weight.shape[0] ** -0.5 + module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + elif isinstance(module, T5GemmaLMHead): + if not self.config.tie_word_embeddings: + scale = module.out_proj.weight.shape[0] ** -0.5 + module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif "RMSNorm" in module.__class__.__name__: + module.weight.data.zero_() + + def _shift_right(self, input_ids): + """ + Shifts input_ids to the right, prepends the decoder_start_token_id, and handles + pad_token_id replacement for labels that were -100. + This is a common preparation step for decoder inputs in sequence-to-sequence models. + """ + decoder_start_token_id = self.config.decoder.bos_token_id + pad_token_id = self.config.decoder.pad_token_id + + if decoder_start_token_id is None: + raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ") + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.decoder.pad_token_id has to be defined.") + + # Is this T5 specific? + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def bidirectional_mask_function(attention_mask: Optional[torch.Tensor]) -> Callable: + """ + This creates bidirectional attention mask. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + if attention_mask is None: + return torch.ones((), dtype=torch.bool) + return attention_mask[batch_idx, kv_idx].to(torch.bool) + + return inner_mask + + +def sliding_window_bidirectional_mask_function(sliding_window: int) -> Callable: + """ + This creates bidirectional attention mask with sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + return (q_idx - sliding_window < kv_idx) & (kv_idx < q_idx + sliding_window) + + return inner_mask + + +def make_default_2d_attention_mask( + token_ids: Optional[torch.LongTensor], + hidden_states: torch.Tensor, + pad_token_id: Optional[int], +) -> torch.Tensor: + """Construct the default attention mask.""" + if token_ids is not None: + if pad_token_id is None: + raise ValueError("`pad_token_id` is required for padding information.") + attention_mask = (token_ids != pad_token_id).to(hidden_states.device, torch.long) + else: + attention_mask = torch.ones( + (hidden_states.shape[0], hidden_states.shape[1]), device=hidden_states.device, dtype=torch.long + ) + return attention_mask + + +class T5GemmaEncoder(T5GemmaPreTrainedModel): + _can_record_outputs = { + "attentions": T5GemmaSelfAttention, + "hidden_states": T5GemmaEncoderLayer, + } + + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = T5GemmaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + self.layers = nn.ModuleList( + [T5GemmaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # As we want to pass `past_key_values=None` explicitly everywhere, we need to pop them from kwargs if present + kwargs.pop("past_key_values", None) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if attention_mask is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": None, + "position_ids": position_ids, + } + self_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(attention_mask), + ), + "sliding_attention": create_sliding_window_causal_mask( + **mask_kwargs, + or_mask_function=sliding_window_bidirectional_mask_function(self.config.sliding_window), + and_mask_function=bidirectional_mask_function(attention_mask), + ), + } + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer_module( + hidden_states, + position_embeddings, + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutput( + last_hidden_state=hidden_states, + ) + + +class T5GemmaDecoder(T5GemmaEncoder): + _can_record_outputs = { + "attentions": OutputRecorder(T5GemmaSelfAttention, index=1), + "cross_attentions": OutputRecorder(T5GemmaCrossAttention, index=1), + "hidden_states": T5GemmaDecoderLayer, + } + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList( + [T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + self.post_init() + + @check_model_inputs + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPastAndCrossAttentions: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states` must be given in decoder") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if not self.training and use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if attention_mask is None and past_key_values is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None, + "position_ids": position_ids, + } + self_attn_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": encoder_hidden_states, + "attention_mask": encoder_attention_mask, + "cache_position": cache_position, + "past_key_values": None, + "position_ids": None, + } + cross_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(encoder_attention_mask), + ), + } + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer_module( + hidden_states, + position_embeddings, + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + past_key_values, + use_cache, + cache_position, + encoder_hidden_states, + cross_attn_mask_mapping["full_attention"], + **kwargs, + ) + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class T5GemmaModel(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig): + super().__init__(config) + + if not config.is_encoder_decoder: + raise ValueError("T5GemmaModel only support encoder-decoder modeling. Use `T5GemmaEncoderModel` instead.") + + self.encoder = T5GemmaEncoder(config.encoder) + self.decoder = T5GemmaDecoder(config.decoder) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Seq2SeqModelOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + """ + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + encoder_hidden_states = encoder_outputs.last_hidden_state + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states + if kwargs.get("output_hidden_states", False) + else (decoder_outputs.last_hidden_state,), + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class T5GemmaEncoderModel(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig): + super().__init__(config) + + if config.is_encoder_decoder: + raise ValueError("T5GemmaEncoderModel only supports encoder-only model. Use `T5GemmaModel` instead.") + + self.encoder = T5GemmaEncoder(config.encoder) + self.post_init() + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + return encoder_outputs + + +class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["model.decoder.embed_tokens.weight", "lm_head.out_proj.weight"] + _tp_plan = {"lm_head.out_proj": "colwise_rep"} + _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])} + + def __init__(self, config: T5GemmaConfig): + config.is_encoder_decoder = True + super().__init__(config) + + self.model = T5GemmaModel(config) + self.vocab_size = config.decoder.vocab_size + self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size) + self.loss_type = "ForMaskedLM" + + self.post_init() + + def set_output_embeddings(self, new_embeddings): + self.lm_head.out_proj = new_embeddings + + def get_output_embeddings(self): + return self.lm_head.out_proj + + def _tie_weights(self): + # Decoder input and output embeddings are tied. + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.lm_head.out_proj, self.get_decoder().get_input_embeddings()) + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + decoder_outputs: Seq2SeqModelOutput = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = decoder_outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + decoder_config = self.get_decoder().config + if decoder_config.final_logit_softcapping is not None: + logits = logits / decoder_config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * decoder_config.final_logit_softcapping + + loss = None + if labels is not None: + # Input has right-shifted so we directly perform masked lm loss + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.decoder_hidden_states, + decoder_attentions=decoder_outputs.decoder_attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state, + encoder_hidden_states=decoder_outputs.encoder_hidden_states, + encoder_attentions=decoder_outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + +@auto_docstring +class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None): + r""" + is_encoder_decoder (`Optional`, *optional*): + Whether use encoder_decoder for sequence classification. When set to False, only encoder is used. + """ + if is_encoder_decoder is not None: + config.is_encoder_decoder = is_encoder_decoder + super().__init__(config) + self.num_labels = config.num_labels + + if config.is_encoder_decoder: + self.model = T5GemmaModel(config) + else: + self.model = T5GemmaEncoderModel(config) + + hidden_size = config.encoder.hidden_size + if config.is_encoder_decoder: + hidden_size = config.decoder.hidden_size + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> SequenceClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None): + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode." + ) + + # Following T5, we automatically creates decoder_input_ids from input_ids if no decoder_input_ids are provided + if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None): + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + if self.config.is_encoder_decoder: + outputs: Seq2SeqModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + else: + outputs: BaseModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.hidden_states + attentions = outputs.attentions + + logits = self.score(last_hidden_state) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + + if self.config.is_encoder_decoder: + last_non_pad_token += 1 # due to the right shift. + last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +@auto_docstring +class T5GemmaForTokenClassification(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None): + r""" + is_encoder_decoder (`Optional`, *optional*): + Whether use encoder_decoder for token classification. When set to False, only encoder is used. + """ + if is_encoder_decoder is not None: + config.is_encoder_decoder = is_encoder_decoder + super().__init__(config) + self.num_labels = config.num_labels + + if config.is_encoder_decoder: + self.model = T5GemmaModel(config) + else: + self.model = T5GemmaEncoderModel(config) + + hidden_size = config.encoder.hidden_size + if config.is_encoder_decoder: + hidden_size = config.decoder.hidden_size + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None): + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode." + ) + + if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None): + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + if self.config.is_encoder_decoder: + outputs: Seq2SeqModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + else: + outputs: BaseModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.hidden_states + attentions = outputs.attentions + + logits = self.score(last_hidden_state) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = [ + "T5GemmaForConditionalGeneration", + "T5GemmaModel", + "T5GemmaEncoderModel", + "T5GemmaPreTrainedModel", + "T5GemmaForSequenceClassification", + "T5GemmaForTokenClassification", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/t5gemma/modular_t5gemma.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/t5gemma/modular_t5gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..a7b11f9a4f3d0696954e4dd5118b7ccf3adce1a3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/t5gemma/modular_t5gemma.py @@ -0,0 +1,1239 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn + +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...configuration_utils import PretrainedConfig +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import OutputRecorder, check_model_inputs +from ..gemma2.configuration_gemma2 import Gemma2Config +from ..gemma2.modeling_gemma2 import ( + Gemma2Attention, + Gemma2MLP, + Gemma2PreTrainedModel, + Gemma2RMSNorm, + Gemma2RotaryEmbedding, + create_causal_mask, + create_sliding_window_causal_mask, + eager_attention_forward, +) + + +_CHECKPOINT_FOR_DOC = "google/t5gemma-2b-2b-prefixlm-it" + + +logger = logging.get_logger(__name__) + + +class T5GemmaModuleConfig(Gemma2Config): + pass + + +class T5GemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5GemmaModel`]. It is used to instantiate an T5Gemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to a hypothetical balanced Gemma2 encoder-decoder model. + e.g. [google/t5gemma-2b-2b-prefixlm-it](https://huggingface.co/google/t5gemma-2b-2b-prefixlm-it) + ```python + >>> from transformers import T5GemmaConfig, T5GemmaModel + >>> t5gemma_config = T5GemmaConfig.from_pretrained("google/t5gemma-2b-2b-prefixlm-it") + >>> model = T5GemmaModel(t5gemma_config) + ``` + Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the + documentation from [PretrainedConfig] for more information. + Args: + encoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*): + Configuration for the encoder. + decoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*): + Configuration for the decoder. + is_encoder_decoder (bool, optional, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + dropout_rate (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers (following T5). + classifier_dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier (following T5). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for attention. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether tie input and output embeddings. + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the T5Gemma model (the same as Gemma 2). + kwargs (additional keyword arguments, optional, *optional*): + Will be passed to the PretrainedConfig base class. + """ + + model_type = "t5gemma" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + # encoder + "encoder.layers.*.self_attn.q_proj": "colwise", + "encoder.layers.*.self_attn.k_proj": "colwise", + "encoder.layers.*.self_attn.v_proj": "colwise", + "encoder.layers.*.self_attn.o_proj": "rowwise", + "encoder.layers.*.mlp.gate_proj": "colwise", + "encoder.layers.*.mlp.up_proj": "colwise", + "encoder.layers.*.mlp.down_proj": "rowwise", + # decoder + "decoder.layers.*.self_attn.q_proj": "colwise", + "decoder.layers.*.self_attn.k_proj": "colwise", + "decoder.layers.*.self_attn.v_proj": "colwise", + "decoder.layers.*.self_attn.o_proj": "rowwise", + "decoder.layers.*.cross_attn.q_proj": "colwise", + "decoder.layers.*.cross_attn.k_proj": "colwise", + "decoder.layers.*.cross_attn.v_proj": "colwise", + "decoder.layers.*.cross_attn.o_proj": "rowwise", + "decoder.layers.*.mlp.gate_proj": "colwise", + "decoder.layers.*.mlp.up_proj": "colwise", + "decoder.layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + # encoder + "encoder.embed_tokens": (["input_ids"], ["inputs_embeds"]), + "encoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "encoder.norm": (["hidden_states"], ["hidden_states"]), + # decoder + "decoder.embed_tokens": (["input_ids"], ["inputs_embeds"]), + "decoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "decoder.norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + encoder: Optional[Union[T5GemmaModuleConfig, dict[Any, Any]]] = None, + decoder: Optional[Union[T5GemmaModuleConfig, dict[Any, Any]]] = None, + is_encoder_decoder: bool = True, + dropout_rate: float = 0.0, + classifier_dropout_rate: float = 0.0, + attention_dropout: float = 0.0, + tie_word_embeddings: bool = True, + vocab_size: int = 256000, + **kwargs, + ): + if isinstance(encoder, dict): + encoder = T5GemmaModuleConfig(**encoder) + elif encoder is None: + encoder = T5GemmaModuleConfig() + else: + assert isinstance(encoder, T5GemmaModuleConfig), f"{type(encoder)} is not supported." + + if isinstance(decoder, dict): + decoder = T5GemmaModuleConfig(**decoder) + elif decoder is None: + decoder = encoder + else: + assert isinstance(decoder, T5GemmaModuleConfig), f"{type(decoder)} is not supported." + + encoder = T5GemmaModuleConfig(**encoder.to_dict()) + decoder = T5GemmaModuleConfig(**decoder.to_dict()) + + encoder.is_decoder = False + encoder.dropout_rate = dropout_rate + encoder.attention_dropout = attention_dropout + self.encoder = encoder + + decoder.is_decoder = True + decoder.use_cache = True + decoder.dropout_rate = dropout_rate + decoder.attention_dropout = attention_dropout + decoder.cross_attention_hidden_size = encoder.hidden_size + self.decoder = decoder + + for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id"]: + if special_token_key not in kwargs: + kwargs[special_token_key] = getattr(decoder, special_token_key) + + super().__init__(**kwargs) + + self.is_encoder_decoder = is_encoder_decoder + self.use_cache = kwargs.get("use_cache", decoder.use_cache) + self.initializer_range = kwargs.get("initializer_range", decoder.initializer_range) + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.classifier_dropout_rate = classifier_dropout_rate + self.tie_word_embeddings = tie_word_embeddings + + # Used in pipeline generation. + self.vocab_size = vocab_size + + def __setattr__(self, key, value): + shared_attr_with_submodules = [ + "output_hidden_states", + "output_attentions", + "_attn_implementation", + "dropout_rate", + "attention_dropout", + "vocab_size", + ] + + if key in shared_attr_with_submodules: + setattr(self.encoder, key, value) + setattr(self.decoder, key, value) + super().__setattr__(key, value) + + +class T5GemmaRMSNorm(Gemma2RMSNorm): + pass + + +class T5GemmaMLP(Gemma2MLP): + def __init__(self, config): + super().__init__(config) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, x): + hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + hidden_states = self.dropout(hidden_states) + down_proj = self.down_proj(hidden_states) + return down_proj + + +class T5GemmaRotaryEmbedding(Gemma2RotaryEmbedding): + def __init__(self, config, device=None): + super().__init__(config, device) + + +class T5GemmaSelfAttention(Gemma2Attention): + def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): + super().__init__(config, layer_idx) + # Required by flash attention: encoder selfattention is non-causal + self.is_causal = config.is_decoder + + +class T5GemmaCrossAttention(Gemma2Attention): + def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): + super().__init__(config, layer_idx) + del self.sliding_window + self.is_causal = False + + if config.cross_attention_hidden_size is None: + raise ValueError("Cross-attention needs cross_attention_hidden_size to be specified.") + + self.k_proj = nn.Linear( + config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + encoder_hidden_states: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if encoder_hidden_states is None: + raise ValueError("Encoder hidden state is required for cross attention.") + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if past_key_values is not None: + is_updated = past_key_values.is_updated.get(self.layer_idx) + curr_past_key_value = past_key_values.cross_attention_cache + + if past_key_values is None or not is_updated: + encoder_input_shape = encoder_hidden_states.shape[:-1] + encoder_hidden_shape = (*encoder_input_shape, -1, self.head_dim) + key_states = self.k_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + + if past_key_values is not None: + key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) + past_key_values.is_updated[self.layer_idx] = True + else: + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=None, + softcap=self.attn_logit_softcapping, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +def bidirectional_mask_function(attention_mask: Optional[torch.Tensor]) -> Callable: + """ + This creates bidirectional attention mask. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + if attention_mask is None: + return torch.ones((), dtype=torch.bool) + return attention_mask[batch_idx, kv_idx].to(torch.bool) + + return inner_mask + + +def sliding_window_bidirectional_mask_function(sliding_window: int) -> Callable: + """ + This creates bidirectional attention mask with sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + return (q_idx - sliding_window < kv_idx) & (kv_idx < q_idx + sliding_window) + + return inner_mask + + +class T5GemmaEncoderLayer(GradientCheckpointingLayer): + """Encoder sub-layer.""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.config = config + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + + self.self_attn = T5GemmaSelfAttention( + config=config, + layer_idx=layer_idx, + ) + self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.mlp = T5GemmaMLP(config) + self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.FloatTensor,]: + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class T5GemmaDecoderLayer(T5GemmaEncoderLayer): + """Decoder sub-layer: an extra cross-attention layer.""" + + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + self.cross_attn = T5GemmaCrossAttention(config=config, layer_idx=layer_idx) + self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_cross_attn_layernorm(hidden_states) + hidden_states, _ = self.cross_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + hidden_states = self.post_cross_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class T5GemmaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, hidden_size: int, num_labels: int, classifier_dropout_rate: float = 0.0): + super().__init__() + self.dropout = nn.Dropout(p=classifier_dropout_rate) + self.out_proj = nn.Linear(hidden_size, num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class T5GemmaLMHead(nn.Module): + """Head for language modeling (generation) tasks.""" + + def __init__(self, hidden_size: int, vocab_size: int, bias: bool = False): + super().__init__() + self.out_proj = nn.Linear(hidden_size, vocab_size, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.out_proj(hidden_states) + return logits + + +@auto_docstring +class T5GemmaPreTrainedModel(Gemma2PreTrainedModel): + config: T5GemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["T5GemmaEncoderLayer", "T5GemmaDecoderLayer"] + + def _init_weights(self, module): + # TODO: support initialization for encoders and decoders separately(?) + PreTrainedModel._init_weights(self, module) + std = self.config.initializer_range + if isinstance(module, T5GemmaClassificationHead): + scale = module.out_proj.weight.shape[0] ** -0.5 + module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + elif isinstance(module, T5GemmaLMHead): + if not self.config.tie_word_embeddings: + scale = module.out_proj.weight.shape[0] ** -0.5 + module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif "RMSNorm" in module.__class__.__name__: + module.weight.data.zero_() + + def _shift_right(self, input_ids): + """ + Shifts input_ids to the right, prepends the decoder_start_token_id, and handles + pad_token_id replacement for labels that were -100. + This is a common preparation step for decoder inputs in sequence-to-sequence models. + """ + decoder_start_token_id = self.config.decoder.bos_token_id + pad_token_id = self.config.decoder.pad_token_id + + if decoder_start_token_id is None: + raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ") + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.decoder.pad_token_id has to be defined.") + + # Is this T5 specific? + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def make_default_2d_attention_mask( + token_ids: Optional[torch.LongTensor], + hidden_states: torch.Tensor, + pad_token_id: Optional[int], +) -> torch.Tensor: + """Construct the default attention mask.""" + if token_ids is not None: + if pad_token_id is None: + raise ValueError("`pad_token_id` is required for padding information.") + attention_mask = (token_ids != pad_token_id).to(hidden_states.device, torch.long) + else: + attention_mask = torch.ones( + (hidden_states.shape[0], hidden_states.shape[1]), device=hidden_states.device, dtype=torch.long + ) + return attention_mask + + +class T5GemmaEncoder(T5GemmaPreTrainedModel): + _can_record_outputs = { + "attentions": T5GemmaSelfAttention, + "hidden_states": T5GemmaEncoderLayer, + } + + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = T5GemmaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + self.layers = nn.ModuleList( + [T5GemmaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # As we want to pass `past_key_values=None` explicitly everywhere, we need to pop them from kwargs if present + kwargs.pop("past_key_values", None) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if attention_mask is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": None, + "position_ids": position_ids, + } + self_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(attention_mask), + ), + "sliding_attention": create_sliding_window_causal_mask( + **mask_kwargs, + or_mask_function=sliding_window_bidirectional_mask_function(self.config.sliding_window), + and_mask_function=bidirectional_mask_function(attention_mask), + ), + } + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer_module( + hidden_states, + position_embeddings, + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutput( + last_hidden_state=hidden_states, + ) + + +class T5GemmaDecoder(T5GemmaEncoder): + _can_record_outputs = { + "attentions": OutputRecorder(T5GemmaSelfAttention, index=1), + "cross_attentions": OutputRecorder(T5GemmaCrossAttention, index=1), + "hidden_states": T5GemmaDecoderLayer, + } + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList( + [T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + self.post_init() + + @check_model_inputs + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPastAndCrossAttentions: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states` must be given in decoder") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if not self.training and use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if attention_mask is None and past_key_values is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None, + "position_ids": position_ids, + } + self_attn_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": encoder_hidden_states, + "attention_mask": encoder_attention_mask, + "cache_position": cache_position, + "past_key_values": None, + "position_ids": None, + } + cross_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(encoder_attention_mask), + ), + } + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer_module( + hidden_states, + position_embeddings, + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + past_key_values, + use_cache, + cache_position, + encoder_hidden_states, + cross_attn_mask_mapping["full_attention"], + **kwargs, + ) + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class T5GemmaModel(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig): + super().__init__(config) + + if not config.is_encoder_decoder: + raise ValueError("T5GemmaModel only support encoder-decoder modeling. Use `T5GemmaEncoderModel` instead.") + + self.encoder = T5GemmaEncoder(config.encoder) + self.decoder = T5GemmaDecoder(config.decoder) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Seq2SeqModelOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + """ + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + encoder_hidden_states = encoder_outputs.last_hidden_state + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states + if kwargs.get("output_hidden_states", False) + else (decoder_outputs.last_hidden_state,), + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class T5GemmaEncoderModel(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig): + super().__init__(config) + + if config.is_encoder_decoder: + raise ValueError("T5GemmaEncoderModel only supports encoder-only model. Use `T5GemmaModel` instead.") + + self.encoder = T5GemmaEncoder(config.encoder) + self.post_init() + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + return encoder_outputs + + +class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["model.decoder.embed_tokens.weight", "lm_head.out_proj.weight"] + _tp_plan = {"lm_head.out_proj": "colwise_rep"} + _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])} + + def __init__(self, config: T5GemmaConfig): + config.is_encoder_decoder = True + super().__init__(config) + + self.model = T5GemmaModel(config) + self.vocab_size = config.decoder.vocab_size + self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size) + self.loss_type = "ForMaskedLM" + + self.post_init() + + def set_output_embeddings(self, new_embeddings): + self.lm_head.out_proj = new_embeddings + + def get_output_embeddings(self): + return self.lm_head.out_proj + + def _tie_weights(self): + # Decoder input and output embeddings are tied. + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.lm_head.out_proj, self.get_decoder().get_input_embeddings()) + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + decoder_outputs: Seq2SeqModelOutput = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = decoder_outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + decoder_config = self.get_decoder().config + if decoder_config.final_logit_softcapping is not None: + logits = logits / decoder_config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * decoder_config.final_logit_softcapping + + loss = None + if labels is not None: + # Input has right-shifted so we directly perform masked lm loss + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.decoder_hidden_states, + decoder_attentions=decoder_outputs.decoder_attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state, + encoder_hidden_states=decoder_outputs.encoder_hidden_states, + encoder_attentions=decoder_outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + +@auto_docstring +class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None): + r""" + is_encoder_decoder (`Optional`, *optional*): + Whether use encoder_decoder for sequence classification. When set to False, only encoder is used. + """ + if is_encoder_decoder is not None: + config.is_encoder_decoder = is_encoder_decoder + super().__init__(config) + self.num_labels = config.num_labels + + if config.is_encoder_decoder: + self.model = T5GemmaModel(config) + else: + self.model = T5GemmaEncoderModel(config) + + hidden_size = config.encoder.hidden_size + if config.is_encoder_decoder: + hidden_size = config.decoder.hidden_size + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> SequenceClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None): + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode." + ) + + # Following T5, we automatically creates decoder_input_ids from input_ids if no decoder_input_ids are provided + if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None): + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + if self.config.is_encoder_decoder: + outputs: Seq2SeqModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + else: + outputs: BaseModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.hidden_states + attentions = outputs.attentions + + logits = self.score(last_hidden_state) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + + if self.config.is_encoder_decoder: + last_non_pad_token += 1 # due to the right shift. + last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +@auto_docstring +class T5GemmaForTokenClassification(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None): + r""" + is_encoder_decoder (`Optional`, *optional*): + Whether use encoder_decoder for token classification. When set to False, only encoder is used. + """ + if is_encoder_decoder is not None: + config.is_encoder_decoder = is_encoder_decoder + super().__init__(config) + self.num_labels = config.num_labels + + if config.is_encoder_decoder: + self.model = T5GemmaModel(config) + else: + self.model = T5GemmaEncoderModel(config) + + hidden_size = config.encoder.hidden_size + if config.is_encoder_decoder: + hidden_size = config.decoder.hidden_size + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None): + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode." + ) + + if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None): + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + if self.config.is_encoder_decoder: + outputs: Seq2SeqModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + else: + outputs: BaseModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.hidden_states + attentions = outputs.attentions + + logits = self.score(last_hidden_state) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = [ + "T5GemmaConfig", + "T5GemmaModuleConfig", + "T5GemmaForConditionalGeneration", + "T5GemmaModel", + "T5GemmaEncoderModel", + "T5GemmaPreTrainedModel", + "T5GemmaForSequenceClassification", + "T5GemmaForTokenClassification", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/timesformer/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/timesformer/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b19384859126d6d2a3d4f3cc6ab845e1d9e6da2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/timesformer/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/timesformer/__pycache__/configuration_timesformer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/timesformer/__pycache__/configuration_timesformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8299823459e507f871003fa05e04f40c62a8fc5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/timesformer/__pycache__/configuration_timesformer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/timesformer/__pycache__/modeling_timesformer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/timesformer/__pycache__/modeling_timesformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00bba1b4924126845f9633e37eef68982d4e7ee4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/timesformer/__pycache__/modeling_timesformer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/upernet/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/upernet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dee32c4758f0e4c54406473ec7146caa5514c413 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/upernet/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_upernet import * + from .modeling_upernet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/upernet/configuration_upernet.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/upernet/configuration_upernet.py new file mode 100644 index 0000000000000000000000000000000000000000..d116b22fcfba681d5a1b25d52a597ab6b28f9362 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/upernet/configuration_upernet.py @@ -0,0 +1,148 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""UperNet model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto.configuration_auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class UperNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`UperNetForSemanticSegmentation`]. It is used to + instantiate an UperNet model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the UperNet + [openmmlab/upernet-convnext-tiny](https://huggingface.co/openmmlab/upernet-convnext-tiny) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + hidden_size (`int`, *optional*, defaults to 512): + The number of hidden units in the convolutional layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + pool_scales (`tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`): + Pooling scales used in Pooling Pyramid Module applied on the last feature map. + use_auxiliary_head (`bool`, *optional*, defaults to `True`): + Whether to use an auxiliary head during training. + auxiliary_loss_weight (`float`, *optional*, defaults to 0.4): + Weight of the cross-entropy loss of the auxiliary head. + auxiliary_channels (`int`, *optional*, defaults to 256): + Number of channels to use in the auxiliary head. + auxiliary_num_convs (`int`, *optional*, defaults to 1): + Number of convolutional layers to use in the auxiliary head. + auxiliary_concat_input (`bool`, *optional*, defaults to `False`): + Whether to concatenate the output of the auxiliary head with the input before the classification layer. + loss_ignore_index (`int`, *optional*, defaults to 255): + The index that is ignored by the loss function. + + Examples: + + ```python + >>> from transformers import UperNetConfig, UperNetForSemanticSegmentation + + >>> # Initializing a configuration + >>> configuration = UperNetConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = UperNetForSemanticSegmentation(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "upernet" + + def __init__( + self, + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + hidden_size=512, + initializer_range=0.02, + pool_scales=[1, 2, 3, 6], + use_auxiliary_head=True, + auxiliary_loss_weight=0.4, + auxiliary_in_channels=None, + auxiliary_channels=256, + auxiliary_num_convs=1, + auxiliary_concat_input=False, + loss_ignore_index=255, + **kwargs, + ): + super().__init__(**kwargs) + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") + backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage1", "stage2", "stage3", "stage4"]) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.pool_scales = pool_scales + self.use_auxiliary_head = use_auxiliary_head + self.auxiliary_loss_weight = auxiliary_loss_weight + self.auxiliary_in_channels = auxiliary_in_channels + self.auxiliary_channels = auxiliary_channels + self.auxiliary_num_convs = auxiliary_num_convs + self.auxiliary_concat_input = auxiliary_concat_input + self.loss_ignore_index = loss_ignore_index + + @property + def sub_configs(self): + return ( + {"backbone_config": type(self.backbone_config)} + if getattr(self, "backbone_config", None) is not None + else {} + ) + + +__all__ = ["UperNetConfig"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/upernet/modeling_upernet.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/upernet/modeling_upernet.py new file mode 100644 index 0000000000000000000000000000000000000000..36dc90c30adbed46231881d3a62582316e8e46f0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/upernet/modeling_upernet.py @@ -0,0 +1,388 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch UperNet model. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.""" + +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...modeling_outputs import SemanticSegmenterOutput +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring +from ...utils.backbone_utils import load_backbone +from .configuration_upernet import UperNetConfig + + +class UperNetConvModule(nn.Module): + """ + A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution + layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, tuple[int, int]], + padding: Union[int, tuple[int, int], str] = 0, + bias: bool = False, + dilation: Union[int, tuple[int, int]] = 1, + ) -> None: + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=bias, + dilation=dilation, + ) + self.batch_norm = nn.BatchNorm2d(out_channels) + self.activation = nn.ReLU() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.conv(input) + output = self.batch_norm(output) + output = self.activation(output) + + return output + + +class UperNetPyramidPoolingBlock(nn.Module): + def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None: + super().__init__() + self.layers = [ + nn.AdaptiveAvgPool2d(pool_scale), + UperNetConvModule(in_channels, channels, kernel_size=1), + ] + for i, layer in enumerate(self.layers): + self.add_module(str(i), layer) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class UperNetPyramidPoolingModule(nn.Module): + """ + Pyramid Pooling Module (PPM) used in PSPNet. + + Args: + pool_scales (`tuple[int]`): + Pooling scales used in Pooling Pyramid Module. + in_channels (`int`): + Input channels. + channels (`int`): + Channels after modules, before conv_seg. + align_corners (`bool`): + align_corners argument of F.interpolate. + """ + + def __init__(self, pool_scales: tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None: + super().__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + self.blocks = [] + for i, pool_scale in enumerate(pool_scales): + block = UperNetPyramidPoolingBlock(pool_scale=pool_scale, in_channels=in_channels, channels=channels) + self.blocks.append(block) + self.add_module(str(i), block) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + ppm_outs = [] + for ppm in self.blocks: + ppm_out = ppm(x) + upsampled_ppm_out = nn.functional.interpolate( + ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners + ) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +class UperNetHead(nn.Module): + """ + Unified Perceptual Parsing for Scene Understanding. This head is the implementation of + [UPerNet](https://huggingface.co/papers/1807.10221). + """ + + def __init__(self, config, in_channels): + super().__init__() + + self.config = config + self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) + self.in_channels = in_channels + self.channels = config.hidden_size + self.align_corners = False + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + # PSP Module + self.psp_modules = UperNetPyramidPoolingModule( + self.pool_scales, + self.in_channels[-1], + self.channels, + align_corners=self.align_corners, + ) + self.bottleneck = UperNetConvModule( + self.in_channels[-1] + len(self.pool_scales) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = UperNetConvModule(in_channels, self.channels, kernel_size=1) + fpn_conv = UperNetConvModule(self.channels, self.channels, kernel_size=3, padding=1) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = UperNetConvModule( + len(self.in_channels) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + + def psp_forward(self, inputs): + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + # build laterals + laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] + + laterals.append(self.psp_forward(encoder_hidden_states)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate( + laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners + ) + + # build outputs + fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = nn.functional.interpolate( + fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners + ) + fpn_outs = torch.cat(fpn_outs, dim=1) + output = self.fpn_bottleneck(fpn_outs) + output = self.classifier(output) + + return output + + +class UperNetFCNHead(nn.Module): + """ + Fully Convolution Networks for Semantic Segmentation. This head is the implementation of + [FCNNet](https://huggingface.co/papers/1411.4038>). + + Args: + config: + Configuration. + in_channels (int): + Number of input channels. + kernel_size (int): + The kernel size for convs in the head. Default: 3. + dilation (int): + The dilation rate for convs in the head. Default: 1. + """ + + def __init__( + self, config, in_channels, in_index: int = 2, kernel_size: int = 3, dilation: Union[int, tuple[int, int]] = 1 + ) -> None: + super().__init__() + + self.config = config + self.in_channels = ( + in_channels[in_index] if config.auxiliary_in_channels is None else config.auxiliary_in_channels + ) + self.channels = config.auxiliary_channels + self.num_convs = config.auxiliary_num_convs + self.concat_input = config.auxiliary_concat_input + self.in_index = in_index + + conv_padding = (kernel_size // 2) * dilation + convs = [] + convs.append( + UperNetConvModule( + self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation + ) + ) + for i in range(self.num_convs - 1): + convs.append( + UperNetConvModule( + self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation + ) + ) + if self.num_convs == 0: + self.convs = nn.Identity() + else: + self.convs = nn.Sequential(*convs) + if self.concat_input: + self.conv_cat = UperNetConvModule( + self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + # just take the relevant feature maps + hidden_states = encoder_hidden_states[self.in_index] + output = self.convs(hidden_states) + if self.concat_input: + output = self.conv_cat(torch.cat([hidden_states, output], dim=1)) + output = self.classifier(output) + return output + + +@auto_docstring +class UperNetPreTrainedModel(PreTrainedModel): + config: UperNetConfig + main_input_name = "pixel_values" + _no_split_modules = [] + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.BatchNorm2d): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +@auto_docstring( + custom_intro=""" + UperNet framework leveraging any vision backbone e.g. for ADE20k, CityScapes. + """ +) +class UperNetForSemanticSegmentation(UperNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.backbone = load_backbone(config) + + # Semantic segmentation head(s) + self.decode_head = UperNetHead(config, in_channels=self.backbone.channels) + self.auxiliary_head = ( + UperNetFCNHead(config, in_channels=self.backbone.channels) if config.use_auxiliary_head else None + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SemanticSegmenterOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Examples: + ```python + >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation + >>> from PIL import Image + >>> from huggingface_hub import hf_hub_download + + >>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-tiny") + >>> model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-tiny") + + >>> filepath = hf_hub_download( + ... repo_id="hf-internal-testing/fixtures_ade20k", filename="ADE_val_00000001.jpg", repo_type="dataset" + ... ) + >>> image = Image.open(filepath).convert("RGB") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> logits = outputs.logits # shape (batch_size, num_labels, height, width) + >>> list(logits.shape) + [1, 150, 512, 512] + ```""" + if labels is not None and self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + features = outputs.feature_maps + + logits = self.decode_head(features) + logits = nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False) + + auxiliary_logits = None + if self.auxiliary_head is not None: + auxiliary_logits = self.auxiliary_head(features) + auxiliary_logits = nn.functional.interpolate( + auxiliary_logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False + ) + + loss = None + if labels is not None: + # compute weighted loss + loss_fct = CrossEntropyLoss(ignore_index=self.config.loss_ignore_index) + loss = loss_fct(logits, labels) + if auxiliary_logits is not None: + auxiliary_loss = loss_fct(auxiliary_logits, labels) + loss += self.config.auxiliary_loss_weight * auxiliary_loss + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["UperNetForSemanticSegmentation", "UperNetPreTrainedModel"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/video_llava/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/video_llava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5f5971a167997ffc89de76d5cc5ba3d8216b4b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/video_llava/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_video_llava import * + from .image_processing_video_llava import * + from .modeling_video_llava import * + from .processing_video_llava import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/video_llava/configuration_video_llava.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/video_llava/configuration_video_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..561d2e3a9254932ec795ca31b9d51e9fc6ee0d1d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/video_llava/configuration_video_llava.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VideoLlava model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class VideoLlavaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VideoLlavaForConditionalGeneration`]. It is used to instantiate an + VideoLlava model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the like LanguageBind/Video-LLaVA-7B-hf. + + e.g. [LanguageBind/Video-LLaVA-7B-hf](https://huggingface.co/LanguageBind/Video-LLaVA-7B-hf) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`VideoLlavaVisionConfig`, *optional*): + Custom vision config or dict. Defaults to `CLIPVisionConfig` if not indicated. + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. + Defaults to `LlamaConfig` if not indicated. + image_token_index (`int`, *optional*, defaults to 32000): + The image token index to encode the image prompt. + video_token_index (`int`, *optional*, defaults to 32001): + The video token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the CLIP backbone. + Can be either "full" to select all features or "default" to select features without `CLS`. + vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -2): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + image_seq_length (`int`, *optional*, defaults to 256): + Sequence length of one image embedding. + video_seq_length (`int`, *optional*, defaults to 2056): + Sequence length of one video embedding. + multimodal_projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. + + Example: + + ```python + >>> from transformers import VideoLlavaForConditionalGeneration, VideoLlavaConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a CLIP-vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> # Initializing a VideoLlava video_llava-1.5-7b style configuration + >>> configuration = VideoLlavaConfig(vision_config, text_config) + + >>> # Initializing a model from the video_llava-1.5-7b style configuration + >>> model = VideoLlavaForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "video_llava" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + } + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=32000, + video_token_index=32001, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_seq_length=256, + video_seq_length=2056, + multimodal_projector_bias=True, + **kwargs, + ): + self.image_token_index = image_token_index + self.video_token_index = video_token_index + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.image_seq_length = image_seq_length + self.video_seq_length = video_seq_length + self.multimodal_projector_bias = multimodal_projector_bias + + self.vision_config = vision_config + + if isinstance(self.vision_config, dict): + if "model_type" not in vision_config: + vision_config["model_type"] = "clip_vision_model" + logger.warning("Key=`model_type` not found in vision config, setting it to `clip_vision_model`") + self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + self.vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=224, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + if isinstance(text_config, dict): + if "model_type" not in text_config: + text_config["model_type"] = "llama" + logger.warning("Key=`model_type` not found in text config, setting it to `llama`") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + super().__init__(**kwargs) + + +__all__ = ["VideoLlavaConfig"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/video_llava/image_processing_video_llava.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/video_llava/image_processing_video_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..1ed8f911af8e8bbf2f7921f8d560db4ece9f0e76 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/video_llava/image_processing_video_llava.py @@ -0,0 +1,391 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Video-LLaVA.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, logging +from ...video_utils import VideoInput, make_batched_videos + + +logger = logging.get_logger(__name__) + + +class VideoLlavaImageProcessor(BaseImageProcessor): + r""" + Constructs a CLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Optional[dict[str, int]] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: Optional[list[ImageInput]] = None, + videos: Optional[list[VideoInput]] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`, *optional*): + List of images to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + videos (`VideoInput`, *optional*): + List of videos to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If + passing in videos with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + if images is not None: + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if images is not None and not valid_images(images): + raise ValueError( + "Invalid input type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + data = {} + if videos is not None: + logger.warning( + "`VideoLlavaImageProcessor` works only with image inputs and doesn't process videos anymore. " + "This is a deprecated behavior and will be removed in v5.0. " + "Your videos should be forwarded to `VideoLlavaVideoProcessor`. " + ) + videos = make_batched_videos(videos) + pixel_values_videos = [ + [ + self._preprocess_image( + image=frame, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_convert_rgb=do_convert_rgb, + data_format=data_format, + input_data_format=input_data_format, + ) + for frame in video + ] + for video in videos + ] + data["pixel_values_videos"] = pixel_values_videos + + if images is not None: + pixel_values_images = [ + self._preprocess_image( + image=image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_convert_rgb=do_convert_rgb, + data_format=data_format, + input_data_format=input_data_format, + ) + for image in images + ] + data["pixel_values_images"] = pixel_values_images + + encoded_outputs = BatchFeature(data, tensor_type=return_tensors) + + return encoded_outputs + + def _preprocess_image( + self, + image: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[int] = None, + do_convert_rgb: Optional[bool] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + image = convert_to_rgb(image) + + # All transformations expect numpy arrays. + image = to_numpy_array(image) + + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images/video frames. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(image) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + return image + + +__all__ = ["VideoLlavaImageProcessor"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/video_llava/modeling_video_llava.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/video_llava/modeling_video_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..2db424455087ea69e9abaed8177e80c617889f35 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/models/video_llava/modeling_video_llava.py @@ -0,0 +1,718 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch VideoLlava model.""" + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ..auto import AutoModel +from .configuration_video_llava import VideoLlavaConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for VideoLlava base model outputs. + """ +) +class VideoLlavaModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for VideoLlava causal language model (or autoregressive) outputs. + """ +) +class VideoLlavaCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None + + +# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->VideoLlava +class VideoLlavaMultiModalProjector(nn.Module): + def __init__(self, config: VideoLlavaConfig): + super().__init__() + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@auto_docstring +class VideoLlavaPreTrainedModel(PreTrainedModel): + config: VideoLlavaConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _no_split_modules = ["VideoLlavaVisionAttention"] + _skip_keys_device_placement = "past_key_values" + + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +@auto_docstring( + custom_intro=""" + The VideoLlava model which consists of a vision backbone and a language model without language modeling head. + """, +) +class VideoLlavaModel(VideoLlavaPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + + def __init__(self, config: VideoLlavaConfig): + super().__init__(config) + self.video_tower = AutoModel.from_config(config.vision_config) + self.image_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = VideoLlavaMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModel.from_config(config.text_config) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_image_features( + self, + pixel_values_images: torch.FloatTensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values_images (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, list[int]]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") + + image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True) + + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + image_outputs = image_outputs.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + image_outputs = image_outputs[:, 1:] + else: + hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + # For default; crop CLS from each hidden state in the hidden state pool + if vision_feature_select_strategy == "default": + hs_pool = [hs[:, 1:] for hs in hs_pool] + image_outputs = torch.cat(hs_pool, dim=-1) + + image_features = self.multi_modal_projector(image_outputs) + + return image_features + + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + ): + """ + Obtains video last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values_videos (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) + The tensors corresponding to the input videos. + vision_feature_layer (`Union[int, list[int]]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + Returns: + video_features (`torch.Tensor`): Video feature tensor of shape `(num_videos * num_frames, image_length, embed_dim)`). + frames (`int`): Number of frames the videos have. + """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape + + pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width) + video_outputs = self.video_tower(pixel_values, output_hidden_states=True) + + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + video_features = video_outputs.hidden_states[vision_feature_layer] + else: + hs_pool = [video_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + video_features = torch.cat(hs_pool, dim=-1) + + video_features = self.multi_modal_projector(video_features) + + return video_features, num_frames + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0] * image_features.shape[1]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0] * video_features.shape[1]}" + ) + + return special_image_mask, special_video_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values_images: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, VideoLlavaModelOutputWithPast]: + r""" + pixel_values_images (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`VideoLlavaImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`VideoLlavaImageProcessor`] for processing images). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values_images is not None: + image_features = self.get_image_features( + pixel_values_images, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + if pixel_values_videos is not None: + video_features, num_frames = self.get_video_features( + pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + _, special_video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return VideoLlavaModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values_images is not None else None, + video_hidden_states=video_features if pixel_values_videos is not None else None, + ) + + +@auto_docstring( + custom_intro=""" + The VideoLlava model which consists of a vision backbone and a language model. + """ +) +class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^image_tower": "model.image_tower", + "^video_tower": "model.video_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: VideoLlavaConfig): + super().__init__(config) + self.model = VideoLlavaModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features( + self, + pixel_values_images: torch.FloatTensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + ): + return self.model.get_image_features( + pixel_values_images=pixel_values_images, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def video_tower(self): + return self.model.video_tower + + @property + def image_tower(self): + return self.model.image_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values_images: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, VideoLlavaCausalLMOutputWithPast]: + r""" + pixel_values_images (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`VideoLlavaImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`VideoLlavaImageProcessor`] for processing images). + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> import numpy as np + >>> import av + >>> from huggingface_hub import hf_hub_download + >>> from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration + + + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`list[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + >>> model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf") + >>> processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf") + + >>> prompt = "USER: