diff --git a/License.txt b/License.txt new file mode 100644 index 0000000000000000000000000000000000000000..80f49c0857176a9aa8119d72f73f7784fe58c9a9 --- /dev/null +++ b/License.txt @@ -0,0 +1,335 @@ +Tencent is pleased to support the open source community by making Seed-X available. + +Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. + +Seed-X is licensed under the Apache License Version 2.0 except for the third-party components listed below. + + +Terms of the Apache License Version 2.0: +-------------------------------------------------------------------- +Apache License + +Version 2.0, January 2004 + +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of this License; and + +You must cause any modified files to carry prominent notices stating that You changed the files; and + +You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + +If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + + + +Other dependencies and licenses: + + +Open Source Software Licensed under the Apache License Version 2.0: +-------------------------------------------------------------------- +1. transformers +Copyright 2018- The Hugging Face team. All rights reserved. +Source code of this software can be obtained from: https://github.com/huggingface/transformers/blob/v4.30.2/ + +2. diffusers +Copyright 2023 The HuggingFace Team. All rights reserved. +Source code of this software can be obtained from: https://github.com/huggingface/diffusers/blob/v0.25.0/ + +A copy of Apache 2.0 has been included in this file. + + + +Open Source Software Licensed under the BSD 3-Clause License: +-------------------------------------------------------------------- +1. torchvision +Copyright (c) Soumith Chintala 2016, +All rights reserved. + +Terms of the BSD 3-Clause License: +-------------------------------------------------------------------- +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein: +-------------------------------------------------------------------- +1. numpy +Copyright (c) 2005-2021, NumPy Developers. +All rights reserved. + +A copy of the BSD 3-Clause License is included in this file. + +For the license of other third party components, please refer to the following URL: +https://github.com/numpy/numpy/blob/v1.20.1/LICENSES_bundled.txt + + + +Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein: +-------------------------------------------------------------------- +1. torch +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +A copy of the BSD 3-Clause License is included in this file. + +For the license of other third party components, please refer to the following URL: +https://github.com/pytorch/pytorch/blob/v2.0.1/NOTICE + + + +Open Source Software Licensed under the LLAMA 2 Community License: +-------------------------------------------------------------------- +1. Llama 2 +Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +Terms of the LLAMA 2 COMMUNITY LICENSE AGREEMENT: +-------------------------------------------------------------------- +LLAMA 2 COMMUNITY LICENSE AGREEMENT +Llama 2 Version Release Date: July 18, 2023 + +"Agreement" means the terms and conditions for use, reproduction, distribution and +modification of the Llama Materials set forth herein. + +"Documentation" means the specifications, manuals and documentation +accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and- +libraries/llama-downloads/. + +"Licensee" or "you" means you, or your employer or any other person or entity (if +you are entering into this Agreement on such person or entity's behalf), of the age +required under applicable laws, rules or regulations to provide legal consent and that +has legal authority to bind your employer or such other person or entity if you are +entering in this Agreement on their behalf. + +"Llama 2" means the foundational large language models and software and +algorithms, including machine-learning model code, trained model weights, +inference-enabling code, training-enabling code, fine-tuning enabling code and other +elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and- +libraries/llama-downloads/. + +"Llama Materials" means, collectively, Meta's proprietary Llama 2 and +Documentation (and any portion thereof) made available under this Agreement. + +"Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you +are an entity, your principal place of business is in the EEA or Switzerland) and Meta +Platforms, Inc. (if you are located outside of the EEA or Switzerland). + +By clicking "I Accept" below or by using or distributing any portion or element of the +Llama Materials, you agree to be bound by this Agreement. + +1. License Rights and Redistribution. + + a. Grant of Rights. You are granted a non-exclusive, worldwide, non- +transferable and royalty-free limited license under Meta's intellectual property or +other rights owned by Meta embodied in the Llama Materials to use, reproduce, +distribute, copy, create derivative works of, and make modifications to the Llama +Materials. + + b. Redistribution and Use. + + i. If you distribute or make the Llama Materials, or any derivative works +thereof, available to a third party, you shall provide a copy of this Agreement to such +third party. + ii. If you receive Llama Materials, or any derivative works thereof, from +a Licensee as part of an integrated end user product, then Section 2 of this +Agreement will not apply to you. + + iii. You must retain in all copies of the Llama Materials that you +distribute the following attribution notice within a "Notice" text file distributed as a +part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License, +Copyright (c) Meta Platforms, Inc. All Rights Reserved." + + iv. Your use of the Llama Materials must comply with applicable laws +and regulations (including trade compliance laws and regulations) and adhere to the +Acceptable Use Policy for the Llama Materials (available at +https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into +this Agreement. + + v. You will not use the Llama Materials or any output or results of the +Llama Materials to improve any other large language model (excluding Llama 2 or +derivative works thereof). + +2. Additional Commercial Terms. If, on the Llama 2 version release date, the +monthly active users of the products or services made available by or for Licensee, +or Licensee's affiliates, is greater than 700 million monthly active users in the +preceding calendar month, you must request a license from Meta, which Meta may +grant to you in its sole discretion, and you are not authorized to exercise any of the +rights under this Agreement unless or until Meta otherwise expressly grants you +such rights. + +3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE +LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE +PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY +WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR +FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE +FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING +THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR +USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS. + +4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE +LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, +NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS +AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, +CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN +IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF +ANY OF THE FOREGOING. + +5. Intellectual Property. + + a. No trademark licenses are granted under this Agreement, and in +connection with the Llama Materials, neither Meta nor Licensee may use any name +or mark owned by or associated with the other or any of its affiliates, except as +required for reasonable and customary use in describing and redistributing the +Llama Materials. + + b. Subject to Meta's ownership of Llama Materials and derivatives made by or +for Meta, with respect to any derivative works and modifications of the Llama +Materials that are made by you, as between you and Meta, you are and will be the +owner of such derivative works and modifications. + + c. If you institute litigation or other proceedings against Meta or any entity +(including a cross-claim or counterclaim in a lawsuit) alleging that the Llama +Materials or Llama 2 outputs or results, or any portion of any of the foregoing, +constitutes infringement of intellectual property or other rights owned or licensable +by you, then any licenses granted to you under this Agreement shall terminate as of +the date such litigation or claim is filed or instituted. You will indemnify and hold +harmless Meta from and against any claim by any third party arising out of or related +to your use or distribution of the Llama Materials. + +6. Term and Termination. The term of this Agreement will commence upon your +acceptance of this Agreement or access to the Llama Materials and will continue in +full force and effect until terminated in accordance with the terms and conditions +herein. Meta may terminate this Agreement if you are in breach of any term or +condition of this Agreement. Upon termination of this Agreement, you shall delete +and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the +termination of this Agreement. + +7. Governing Law and Jurisdiction. This Agreement will be governed and +construed under the laws of the State of California without regard to choice of law +principles, and the UN Convention on Contracts for the International Sale of Goods +does not apply to this Agreement. The courts of California shall have exclusive +jurisdiction of any dispute arising out of this Agreement. + + + +Open Source Software Licensed under the Tongyi Qianwen LICENSE AGREEMENT: +-------------------------------------------------------------------- +1. Qwen-VL +Copyright (c) Alibaba Cloud. All Rights Reserved. + + +Terms of the Tongyi Qianwen LICENSE AGREEMENT: +-------------------------------------------------------------------- +Tongyi Qianwen LICENSE AGREEMENT + +Tongyi Qianwen Release Date: August 23, 2023 + +By clicking to agree or by using or distributing any portion or element of the Tongyi Qianwen Materials, you will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately. + +1. Definitions + a. This Tongyi Qianwen LICENSE AGREEMENT (this "Agreement") shall mean the terms and conditions for use, reproduction, distribution and modification of the Materials as defined by this Agreement. + b. "We"(or "Us") shall mean Alibaba Cloud. + c. "You" (or "Your") shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Materials for any purpose and in any field of use. + d. "Third Parties" shall mean individuals or legal entities that are not under common control with Us or You. + e. "Tongyi Qianwen" shall mean the large language models (including Qwen-VL model and Qwen-VL-Chat model), and software and algorithms, consisting of trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Us. + f. "Materials" shall mean, collectively, Alibaba Cloud's proprietary Tongyi Qianwen and Documentation (and any portion thereof) made available under this Agreement. + g. "Source" form shall mean the preferred form for making modifications, including but not limited to model source code, documentation source, and configuration files. + h. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, + and conversions to other media types. + +2. Grant of Rights +You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Alibaba Cloud's intellectual property or other rights owned by Us embodied in the Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Materials. + +3. Redistribution +You may reproduce and distribute copies of the Materials or derivative works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + a. You shall give any other recipients of the Materials or derivative works a copy of this Agreement; + b. You shall cause any modified files to carry prominent notices stating that You changed the files; + c. You shall retain in all copies of the Materials that You distribute the following attribution notices within a "Notice" text file distributed as a part of such copies: "Tongyi Qianwen is licensed under the Tongyi Qianwen LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved."; and + d. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such derivative works as a whole, provided Your use, reproduction, and distribution of the work otherwise complies with the terms and conditions of this Agreement. + +4. Restrictions +If you are commercially using the Materials, and your product or service has more than 100 million monthly active users, You shall request a license from Us. You cannot exercise your rights under this Agreement without our express authorization. + +5. Rules of use + a. The Materials may be subject to export controls or restrictions in China, the United States or other countries or regions. You shall comply with applicable laws and regulations in your use of the Materials. + b. You can not use the Materials or any output therefrom to improve any other large language model (excluding Tongyi Qianwen or derivative works thereof). + +6. Intellectual Property + a. We retain ownership of all intellectual property rights in and to the Materials and derivatives made by or for Us. Conditioned upon compliance with the terms and conditions of this Agreement, with respect to any derivative works and modifications of the Materials that are made by you, you are and will be the owner of such derivative works and modifications. + b. No trademark license is granted to use the trade names, trademarks, service marks, or product names of Us, except as required to fulfill notice requirements under this Agreement or as required for reasonable and customary use in describing and redistributing the Materials. + c. If you commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any entity alleging that the Materials or any output therefrom, or any part of the foregoing, infringe any intellectual property or other right owned or licensable by you, then all licences granted to you under this Agreement shall terminate as of the date such lawsuit or other proceeding is commenced or brought. + +7. Disclaimer of Warranty and Limitation of Liability + + a. We are not obligated to support, update, provide training for, or develop any further version of the Tongyi Qianwen Materials or to grant any license thereto. + b. THE MATERIALS ARE PROVIDED "AS IS" WITHOUT ANY EXPRESS OR IMPLIED WARRANTY OF ANY KIND INCLUDING WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR FITNESS FOR A PARTICULAR PURPOSE. WE MAKE NO WARRANTY AND ASSUME NO RESPONSIBILITY FOR THE SAFETY OR STABILITY OF THE MATERIALS AND ANY OUTPUT THEREFROM. + c. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MATERIALS OR ANY OUTPUT OF IT, NO MATTER HOW IT’S CAUSED. + d. You will defend, indemnify and hold harmless Us from and against any claim by any third party arising out of or related to your use or distribution of the Materials. + +8. Survival and Termination. + a. The term of this Agreement shall commence upon your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. + b. We may terminate this Agreement if you breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, you must delete and cease use of the Materials. Sections 7 and 9 shall survive the termination of this Agreement. + +9. Governing Law and Jurisdiction. + a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. + b. The People's Courts in Hangzhou City shall have exclusive jurisdiction over any dispute arising out of this Agreement. diff --git a/configs/.DS_Store b/configs/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..1739b5bbd55ff60b238650150f8057aab2a71a55 Binary files /dev/null and b/configs/.DS_Store differ diff --git a/configs/clm_models/.DS_Store b/configs/clm_models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..23b0f8fba4b24c073efa630c38d8398c2002ee40 Binary files /dev/null and b/configs/clm_models/.DS_Store differ diff --git a/configs/clm_models/agent_seed_x_i.yaml b/configs/clm_models/agent_seed_x_i.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ec9d0b57093cfdf62b09f5202490f1d71224738f --- /dev/null +++ b/configs/clm_models/agent_seed_x_i.yaml @@ -0,0 +1,23 @@ +_target_: src.models.mllm.seed_x.ContinuousLVLM.from_pretrained +input_resampler: + _target_: src.models.tokenizer.qwen_visual.Resampler + grid_size: 8 + embed_dim: 5120 + num_heads: 32 + kv_dim: 4096 + +output_resampler: + _target_: src.models.tokenizer.qwen_visual.Resampler + grid_size: 8 + embed_dim: 4096 + num_heads: 32 + kv_dim: 5120 + +add_patch_pos: True +vit_down: True +mse: True + +lm_loss_scale: 1.0 +rec_loss_scale: 6.0 + +pretrained_model_path: https://huggingface.co/AILab-CVC/SEED-X-17B/blob/main/seed_x_i/agent/pytorch_model.bin diff --git a/configs/clm_models/llm_seed_x_i.yaml b/configs/clm_models/llm_seed_x_i.yaml new file mode 100644 index 0000000000000000000000000000000000000000..09f7f72c41e7585179472eddb93c274bbf2d1abf --- /dev/null +++ b/configs/clm_models/llm_seed_x_i.yaml @@ -0,0 +1,3 @@ +_target_: src.models.mllm.modeling_llama_xformer.LlamaForCausalLM.from_pretrained +pretrained_model_name_or_path: https://huggingface.co/AILab-CVC/SEED-X-17B/tree/main/seed_x_i/llm +low_cpu_mem_usage: True diff --git a/configs/discrete_model/.DS_Store b/configs/discrete_model/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/configs/discrete_model/.DS_Store differ diff --git a/configs/discrete_model/discrete_identity.yaml b/configs/discrete_model/discrete_identity.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17f1d02cc8d0db208f6bc9b4bca0a852507f61ee --- /dev/null +++ b/configs/discrete_model/discrete_identity.yaml @@ -0,0 +1 @@ +_target_: src.models.tokenizer.discrete_models.DiscreteModleIdentity diff --git a/configs/processer/.DS_Store b/configs/processer/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/configs/processer/.DS_Store differ diff --git a/configs/processer/qwen_448_transform.yaml b/configs/processer/qwen_448_transform.yaml new file mode 100644 index 0000000000000000000000000000000000000000..68954ee0e81fbd3f40f43979e0d317c60f5891dd --- /dev/null +++ b/configs/processer/qwen_448_transform.yaml @@ -0,0 +1,4 @@ +_target_: src.processer.transforms.get_transform +type: clip +image_size: 448 +keep_ratio: False diff --git a/configs/sdxl_adapter/.DS_Store b/configs/sdxl_adapter/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/configs/sdxl_adapter/.DS_Store differ diff --git a/configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_full_with_latent_image_pretrain_no_normalize.yaml b/configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_full_with_latent_image_pretrain_no_normalize.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e758eac2429f9aa51007c155cd360d1702555680 --- /dev/null +++ b/configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_full_with_latent_image_pretrain_no_normalize.yaml @@ -0,0 +1,20 @@ +_target_: src.models.detokenizer.adapter_modules.SDXLAdapterWithLatentImage.from_pretrained + +resampler: + _target_: src.models.detokenizer.resampler.ResamplerXLV2 + dim: 1024 + depth: 4 + dim_head: 64 + heads: 16 + num_queries: 64 + embedding_dim: 4096 + output1_dim: 768 + output2_dim: 1280 + ff_mult: 4 + normalize: False + +full_ft: True +set_trainable_late: False + +vit_down: True +pretrained_model_path: pretrained/seed_detokenizer/second_stage/pytorch_model.bin diff --git a/configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml b/configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8205d43efee1203f53395ed04dae8171bc0a1271 --- /dev/null +++ b/configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml @@ -0,0 +1,18 @@ +_target_: src.models.detokenizer.adapter_modules.SDXLAdapter.from_pretrained + +resampler: + _target_: src.models.detokenizer.resampler.ResamplerXLV2 + dim: 1024 + depth: 4 + dim_head: 64 + heads: 16 + num_queries: 64 + embedding_dim: 4096 + output1_dim: 768 + output2_dim: 1280 + ff_mult: 4 + normalize: False + +vit_down: True + +pretrained_model_path: https://huggingface.co/AILab-CVC/SEED-X-17B/blob/main/seed_detokenizer/first_stage/pytorch_model.bin \ No newline at end of file diff --git a/configs/tokenizer/.DS_Store b/configs/tokenizer/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/configs/tokenizer/.DS_Store differ diff --git a/configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml b/configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7b5ea7970a632d5106b573c209d0cdd11aa557ff --- /dev/null +++ b/configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml @@ -0,0 +1,2 @@ +_target_: transformers.LlamaTokenizer.from_pretrained +pretrained_model_name_or_path: https://huggingface.co/AILab-CVC/SEED-X-17B/tree/main/cvlm_llama2_tokenizer_100img_and_224loc_addpatch diff --git a/configs/visual_encoder/.DS_Store b/configs/visual_encoder/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/configs/visual_encoder/.DS_Store differ diff --git a/configs/visual_encoder/qwen_vitg_448.yaml b/configs/visual_encoder/qwen_vitg_448.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3b497ff6030445e67b100c853afb1559877731ea --- /dev/null +++ b/configs/visual_encoder/qwen_vitg_448.yaml @@ -0,0 +1,11 @@ +_target_: src.models.tokenizer.qwen_visual.VisionTransformerWithAttnPool.from_pretrained +heads: 16 +image_size: 448 +image_start_id": 151857 +layers: 48 +mlp_ratio: 4.9231 +output_dim: 4096 +patch_size: 14 +width: 1664 + +pretrained_model_path: pretrained/QwenViT/qwen_vit_G.pt diff --git a/pretrained/QwenViT/qwen_vit_G.pt b/pretrained/QwenViT/qwen_vit_G.pt new file mode 100644 index 0000000000000000000000000000000000000000..810c0207e1578ff6ddc7c23efeae64c08a973148 --- /dev/null +++ b/pretrained/QwenViT/qwen_vit_G.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d951083fc79b07bdb84be61944eb263b8e14572fe2dc4fa80b0447f83064463c +size 3871440281 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b36f32c979bdcfd18c9a77a290908d0b6e89bb3a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +torch==2.0.1 +hydra-core +transformers==4.30.2 +diffusers==0.25.0 +sentencepiece +opencv-python +deepspeed +pyrootutils +xformers>=0.0.20 +accelerate +transformers_stream_generator \ No newline at end of file diff --git a/seed_x/arrow.jpg b/seed_x/arrow.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e7c63baa122927ae73bd5c84e46c041c8ae19e51 Binary files /dev/null and b/seed_x/arrow.jpg differ diff --git a/seed_x/bank.png b/seed_x/bank.png new file mode 100644 index 0000000000000000000000000000000000000000..a0088aa5476c6f20d231bac3d20f1d7821b9e7a6 Binary files /dev/null and b/seed_x/bank.png differ diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..925c0c9f6b6c063f2ead6f8d7689e2841eb7ceb7 Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/demo/__pycache__/conversation.cpython-311.pyc b/src/demo/__pycache__/conversation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f3f9688d6bbccaaadd370863c712fcc20334c00 Binary files /dev/null and b/src/demo/__pycache__/conversation.cpython-311.pyc differ diff --git a/src/demo/__pycache__/conversation.cpython-38.pyc b/src/demo/__pycache__/conversation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e5e0bd3395399b2283444afc756cf224336d1d8 Binary files /dev/null and b/src/demo/__pycache__/conversation.cpython-38.pyc differ diff --git a/src/demo/__pycache__/utils.cpython-311.pyc b/src/demo/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a15a24ff9fec316c5ff80712ac0bda48625661e3 Binary files /dev/null and b/src/demo/__pycache__/utils.cpython-311.pyc differ diff --git a/src/demo/__pycache__/utils.cpython-38.pyc b/src/demo/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96950b8767543d992615da6ab2dc223c4cf1f11b Binary files /dev/null and b/src/demo/__pycache__/utils.cpython-38.pyc differ diff --git a/src/demo/configs/agent_13b_anyres_out_64_pretrain_merged.yaml b/src/demo/configs/agent_13b_anyres_out_64_pretrain_merged.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5d8015fe96aee7251459d9afed49d29e113d0fb3 --- /dev/null +++ b/src/demo/configs/agent_13b_anyres_out_64_pretrain_merged.yaml @@ -0,0 +1,29 @@ +_target_: src.models_clm.models.ContinuousLVLM.from_pretrained +input_resampler: + _target_: src.models.qwen_visual.Resampler + grid_size: 8 + embed_dim: 5120 + num_heads: 32 + kv_dim: 4096 + +output_resampler: + _target_: src.models.qwen_visual.Resampler + grid_size: 8 + embed_dim: 4096 + num_heads: 32 + kv_dim: 5120 + +add_patch_pos: True +vit_down: True +mse: True + +lm_loss_scale: 1.0 +rec_loss_scale: 6.0 + +#pretrained_model_path: /chat_sh/share_300719895/user/jinguozhu/codes/work_dirs/sft_exp_new_acc4/checkpoint-2000/pytorch_model.bin +#pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/03_27_any_res_sft_from_merged_10k/checkpoint-9000/pytorch_model.bin +#pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/03_27_any_res_sft_from_merged_10k/checkpoint-8000-merged/agent/pytorch_model.bin +#pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_09_any_res_sft_editing_from_merged_10k_32a100_new_data/checkpoint-6000-merged/agent/pytorch_model.bin +#pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_16_any_res_sft_editing_from_merged_H800_23k_16_gpu_2_new/checkpoint-6000-merged/agent/pytorch_model.bin +#pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_16_any_res_sft_com_gen_from_merged_H800_23k/checkpoint-15000-merged/agent/pytorch_model.bin +pretrained_model_path: /group/40034/yuyingge/SEED_X_inference/pretrained/seed_x_i/agent/pytorch_model.bin \ No newline at end of file diff --git a/src/demo/configs/agent_13b_in100_out64_rs5_merged_pretrain.yaml b/src/demo/configs/agent_13b_in100_out64_rs5_merged_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..34de8fcd78965eed8a0fc18877e114a27d28d361 --- /dev/null +++ b/src/demo/configs/agent_13b_in100_out64_rs5_merged_pretrain.yaml @@ -0,0 +1,22 @@ +_target_: src.models_clm.models.ContinuousLVLM.from_pretrained +input_resampler: + _target_: src.models.qwen_visual.Resampler + grid_size: 10 + embed_dim: 5120 + num_heads: 32 + kv_dim: 4096 + +output_resampler: + _target_: src.models.qwen_visual.Resampler + grid_size: 16 + embed_dim: 4096 + num_heads: 32 + kv_dim: 5120 + +lm_loss_scale: 1.0 +rec_loss_scale: 5.0 + +# pretrained_model_path: /chat_sh/share_300719895/user/sijiezhao/Program/2023/DiscreteLearning/train_output_clm_sh/1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro/checkpoint-27000/pytorch_model.bin +# pretrained_model_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_40k/ckpt-10000-merged/agent/pytorch_model.bin +# pretrained_model_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_40k/ckpt-5000-merged/agent/pytorch_model.bin +pretrained_model_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1211_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_grounding_27k/ckpt-4000-merged/agent/pytorch_model.bin \ No newline at end of file diff --git a/src/demo/configs/llama2chat13b_merged_100imgtokens.yaml b/src/demo/configs/llama2chat13b_merged_100imgtokens.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e1062c3c8f9de28068fff2b15057b6e1dcb4c856 --- /dev/null +++ b/src/demo/configs/llama2chat13b_merged_100imgtokens.yaml @@ -0,0 +1,12 @@ + +_target_: src.models_clm.modeling_llama_xformer.LlamaForCausalLM.from_pretrained +# _target_: transformers.LlamaForCausalLM.from_pretrained +# pretrained_model_name_or_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_40k/ckpt-10000-merged/llm +# pretrained_model_name_or_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_40k/ckpt-5000-merged/llm +#pretrained_model_name_or_path: /chat_sh/share_300719895/user/jinguozhu/codes/work_dirs/pretraining_anyres_newexp_v2/checkpoint-10000-merged/llm +#pretrained_model_name_or_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/03_27_any_res_sft_from_merged_10k/checkpoint-8000-merged/llm +#pretrained_model_name_or_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_09_any_res_sft_editing_from_merged_10k_32a100_new_data/checkpoint-6000-merged/llm +#pretrained_model_name_or_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_16_any_res_sft_editing_from_merged_H800_23k_16_gpu_2_new/checkpoint-6000-merged/llm +#pretrained_model_name_or_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_16_any_res_sft_com_gen_from_merged_H800_23k/checkpoint-15000-merged/llm +pretrained_model_name_or_path: /group/40034/yuyingge/SEED_X_inference/pretrained/seed_x_i/llm +low_cpu_mem_usage: True diff --git a/src/demo/conversation.py b/src/demo/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a5ccf5584e0527ee916aa8b2bd5325492349d9 --- /dev/null +++ b/src/demo/conversation.py @@ -0,0 +1,182 @@ +import dataclasses +from enum import auto, Enum +from typing import List, Tuple + +import io +import base64 +import os +from PIL import Image +import copy + +IMG_FLAG = '' + + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + MPT = auto() + PLAIN = auto() + LLAMA_2 = auto() + + +def decode_image(encoded_image: str) -> Image: + decoded_bytes = base64.b64decode(encoded_image.encode('utf-8')) + buffer = io.BytesIO(decoded_bytes) + image = Image.open(buffer) + return image + + +def encode_image(image: Image.Image, format: str = 'PNG') -> str: + with io.BytesIO() as buffer: + image.save(buffer, format=format) + encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8') + return encoded_image + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[dict] # multi-turn -> user & assistant -> {'images': [PIL.Image,], 'text': str} + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + messages = copy.deepcopy(self.messages) + if self.sep_style == SeparatorStyle.SINGLE: + if self.system is None or self.system == '': + text = '' + else: + text = self.system + self.sep + images = [] + for message in messages: + text += message['role'] + ": " + message['message']['text'] + self.sep + for image_path in message['message']['images']: + image = Image.open(image_path).resize((256, 256)) + image_base64 = encode_image(image) + images.append(image_base64) + + text += self.roles[1] + ":" + elif self.sep_style == SeparatorStyle.LLAMA_2: + b_token = "[INST] " + e_token = " [/INST]" + if self.system is None or self.system == '': + text = '' + else: + text = f"<>\n{self.system}\n<>\n\n" + images = [] + for idx, message in enumerate(messages): + # text += message['role'] + ": " + message['message']['text'] + self.sep + if idx % 2 == 0: + text += b_token + message['message']['text'] + e_token + self.sep + else: + text += message['message']['text'] + self.sep + + for image_path in message['message']['images']: + image = Image.open(image_path) + image_base64 = encode_image(image) + images.append(image_base64) + else: + raise NotImplementedError + + return {'text': text, 'images': images} + + # def update_image_ids(self, images_ids): + # image_count = 0 + # for message in self.messages: + # for idx in range(len(message['message']['images_ids'])): + # if message['message']["images_ids"][idx] is None: + # message['message']["images_ids"][idx] = images_ids[image_count] + # image_count += 1 + + # assert len(images_ids) == image_count, print(len(images_ids), image_count) + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + dialog = [] + for i, single_turn in enumerate(self.messages[self.offset:]): + single_turn = single_turn['message'] + text_list = single_turn['text'].split(IMG_FLAG) + assert len(text_list) == len(single_turn['images']) + 1, print(text_list, len(single_turn['images'])) + message = '' + for image_idx in range(len(single_turn['images'])): + # image = single_turn['images'][image_idx] + # image_base64 = encode_image(image) + # image_str = f'user upload image' + image_path = single_turn['images'][image_idx] + if image_path == '': + message += text_list[image_idx] + '' + else: + message += text_list[image_idx] + f'![](file={image_path})' + message += text_list[-1] + + if i % 2 == 0: + dialog.append([message, None]) + else: + dialog[-1][-1] = message + + return dialog + + def copy(self): + return Conversation(system=self.system, + roles=self.roles, + messages=copy.deepcopy(self.messages), + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version) + + def dict(self): + messages = copy.deepcopy(self.messages) + for message in messages: + for i in range(len(message['message']['images'])): + message['message']['images'][i] = os.path.basename(message['message']['images'][i]) + return { + "system": self.system, + "roles": self.roles, + "messages": messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +conv_seed_vicuna = Conversation( + system="", + roles=("USER", "ASSISTANT"), + version="v2", + messages=[], + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep='\n', +) + +conv_seed_vicuna_system = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. ", + roles=("USER", "ASSISTANT"), + version="v2", + messages=[], + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep='\n', +) + +conv_seed_llama2 = Conversation( + system="", + roles=("[INST]", "[/INST]"), + version="v2", + messages=[], + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep='\n', +) diff --git a/src/demo/seed_llama_flask.py b/src/demo/seed_llama_flask.py new file mode 100644 index 0000000000000000000000000000000000000000..db8aa1bf187156ee593c1f46de64579374dc2463 --- /dev/null +++ b/src/demo/seed_llama_flask.py @@ -0,0 +1,379 @@ +import hydra +import pyrootutils +import torch +import re +import time +from omegaconf import OmegaConf +from flask import Flask, request +from typing import Optional +import transformers +from dataclasses import dataclass, field +import io +import base64 +from PIL import Image +import numpy as np +import cv2 +from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler + + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from src.data.any_res import process_anyres_image + +BOI_TOKEN = '' +BOP_TOKEN = '' +EOI_TOKEN = '' +EOP_TOKEN = '' +IMG_TOKEN = '' + +IMG_FLAG = '' +num_img_in_tokens = 64 +num_img_out_tokens = 64 + +resolution_grids = ['1x1', '1x2', '1x3', '1x4', '1x5', '1x6', '1x10', '2x1', '3x1', '4x1', '5x1', '6x1', '10x1', '2x2', '2x3', '3x2', '2x4', '4x2'] +base_resolution = 448 + +app = Flask(__name__) + + +def decode_image(encoded_image: str) -> Image: + decoded_bytes = base64.b64decode(encoded_image.encode('utf-8')) + buffer = io.BytesIO(decoded_bytes) + image = Image.open(buffer) + return image + + +def encode_image(image: Image.Image, format: str = 'PNG') -> str: + with io.BytesIO() as buffer: + image.save(buffer, format=format) + encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8') + return encoded_image + + +@dataclass +class Arguments: + image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"}) + tokenizer: Optional[str] = field(default=None, metadata={"help": "config path of tokenizer used to initialize tokenizer"}) + llm: Optional[str] = field(default=None, metadata={"help": "config path of llm"}) + visual_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"}) + sd_adapter: Optional[str] = field(default=None, metadata={"help": "config path of sd adapter"}) + agent: Optional[str] = field(default=None, metadata={"help": "config path of agent model"}) + diffusion_path: Optional[str] = field(default=None, metadata={"help": "diffusion model path"}) + has_bbox: Optional[bool] = field(default=False, metadata={"help": "visualize the box"}) + + port: Optional[str] = field(default=80, metadata={"help": "network port"}) + llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"}) + vit_sd_device: Optional[str] = field(default='cuda:0', metadata={"help": "sd and vit device"}) + dtype: Optional[str] = field(default='fp16', metadata={"help": "mix percision"}) + + multi_resolution: Optional[bool] = field(default=False, metadata={"help": "multi resolution"}) + + +parser = transformers.HfArgumentParser(Arguments) +args, = parser.parse_args_into_dataclasses() + +def extract_box(output_str): + boxes = re.findall('(.*?)', output_str) + if len(boxes) >0: + bboxes = [[int(num) for num in re.findall('', box)] for box in boxes] + else: + bboxes = None + + return bboxes + + +def visualize_bbox(image, bboxes): + img_width, img_height = image.size + image = np.array(image) + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + for bbox in bboxes: + x_center, y_center, box_width, box_height = bbox + + x_center = x_center / 224 * img_width + y_center = y_center / 224 * img_height + + box_width = box_width /224 * img_width + box_height = box_height / 224 * img_height + + x1 = int(x_center - box_width / 2) + y1 = int(y_center - box_height / 2) + x2 = int(x_center + box_width / 2) + y2 = int(y_center + box_height / 2) + + cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 4) + + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = Image.fromarray(image) + + + return image + + + + +class LLMService: + + def __init__(self, args) -> None: + + self.llm_device = args.llm_device + self.vit_sd_device = args.vit_sd_device + + dtype = args.dtype + if dtype == 'fp16': + self.dtype = torch.float16 + elif dtype == 'bf16': + self.dtype = torch.bfloat16 + else: + raise ValueError + + image_transform_cfg = OmegaConf.load(args.image_transform) + self.image_transform = hydra.utils.instantiate(image_transform_cfg) + + tokenizer_cfg = OmegaConf.load(args.tokenizer) + self.tokenizer = hydra.utils.instantiate(tokenizer_cfg) + + visual_encoder_cfg = OmegaConf.load(args.visual_encoder) + self.visual_encoder = hydra.utils.instantiate(visual_encoder_cfg) + self.visual_encoder.eval().to(self.vit_sd_device, dtype=self.dtype) + print('Init visual encoder done') + + llm_cfg = OmegaConf.load(args.llm) + llm = hydra.utils.instantiate(llm_cfg, torch_dtype=self.dtype) + print('Init llm done.') + + agent_cfg = OmegaConf.load(args.agent) + self.agent = hydra.utils.instantiate(agent_cfg, llm=llm) + + self.agent.eval().to(self.llm_device, dtype=self.dtype) + print('Init agent mdoel Done') + + noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.diffusion_path, subfolder="scheduler") + + vae = AutoencoderKL.from_pretrained(args.diffusion_path, subfolder="vae").to(self.vit_sd_device, dtype=self.dtype) + + unet = UNet2DConditionModel.from_pretrained(args.diffusion_path, subfolder="unet").to(dtype=self.dtype) + + sd_adapter_cfg = OmegaConf.load(args.sd_adapter) + + self.sd_adapter = hydra.utils.instantiate(sd_adapter_cfg, unet=unet).eval().to(dtype=self.dtype) + + self.sd_adapter.init_pipe(vae=vae, + scheduler=noise_scheduler, + visual_encoder=self.visual_encoder.to("cpu"), + image_transform=self.image_transform, + discrete_model=None, + dtype=self.dtype, + device="cpu") + + print('Init sd adapter pipe done.') + + self.visual_encoder.to(self.vit_sd_device, dtype=self.dtype) + + self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0] + self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0] + + self.bop_token_id = self.tokenizer.encode(BOP_TOKEN, add_special_tokens=False)[0] + self.eop_token_id = self.tokenizer.encode(EOP_TOKEN, add_special_tokens=False)[0] + + self.multi_resolution = args.multi_resolution + if self.multi_resolution: + self.base_resolution = base_resolution + grid_pinpoints = [] + for scale in resolution_grids: + s1, s2 = scale.split('x') + grid_pinpoints.append([int(s1)*base_resolution, int(s2)*base_resolution]) + self.grid_pinpoints = grid_pinpoints + + +service = LLMService(args) + + +@app.route('/generate', methods=['GET', 'POST']) +def generate(): + with torch.no_grad(): + request_info = request.get_json() + + text_list = request_info['text'].split(IMG_FLAG) + image_list = request_info['images'] + max_new_tokens = request_info.get('max_new_tokens', 256) + top_p = 0.5 + force_boi = request_info.get('force_boi', False) + force_bbox = request_info.get('force_bbox', False) + + assert len(text_list) == len(image_list) + 1 + + image_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN + + input_images = [] + if len(image_list) > 0: + image_tensor_list = [] + embeds_cmp_mask = [] + embeds_gen_mask = [] + + if service.multi_resolution: + patch_pos = [] + image_patch_length = [] + image_size_list = [] + + for idx, image_item in enumerate(image_list): + if isinstance(image_item, str): + image = decode_image(image_item) + print('after decode image size:', image.size) + input_images.append(image) + + if service.multi_resolution: + image_size_list.append(image.size) + print('image size:', image.size) + image_tensor, patch_pos_tensor = process_anyres_image(image, service.image_transform, service.grid_pinpoints, service.base_resolution) + image_tensor_list.append(image_tensor) + patch_pos.append(patch_pos_tensor) + image_patch_length.append(image_tensor.shape[0]) + print('image_patch_length', image_patch_length) + embeds_cmp_mask.extend([True]*image_tensor.shape[0]) + embeds_gen_mask.extend([False]*image_tensor.shape[0]) + + else: + image_tensor = service.image_transform(image) + + image_tensor_list.append(image_tensor) + embeds_cmp_mask.append(True) + embeds_gen_mask.append(False) + else: + raise ValueError + + if service.multi_resolution: + pixel_values = torch.cat(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype) + patch_position = torch.cat(patch_pos, dim=0) + + image_tokens_list = [] + for patch_length in image_patch_length: + image_tokens = '' + for _ in range(patch_length-1): + image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN + image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN + image_tokens_list.append(image_tokens) + else: + pixel_values = torch.stack(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype) + + image_embeds = service.visual_encoder(pixel_values) + image_embeds = image_embeds.to(service.llm_device) + + embeds_cmp_mask = torch.tensor(embeds_cmp_mask, dtype=torch.bool).to(service.llm_device) + embeds_gen_mask = torch.tensor(embeds_gen_mask, dtype=torch.bool).to(service.llm_device) + + else: + image_embeds = None + patch_position = 0 + embeds_cmp_mask = None + embeds_gen_mask = None + + if service.multi_resolution: + input_text = '' + for i, c in enumerate(text_list[:-1]): + input_text += c + image_tokens_list[i] + input_text += text_list[-1] + + else: + input_text = image_tokens.join(text_list) + + if force_boi: + input_text = input_text + BOI_TOKEN + + if force_bbox: + input_text = input_text + '[[ ' + print('input_text:', input_text) + input_ids = service.tokenizer.encode(input_text, add_special_tokens=False) + input_ids = [service.tokenizer.bos_token_id] + input_ids + + input_ids = torch.tensor(input_ids).to(service.llm_device, dtype=torch.long) + ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) + ids_gen_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) + + if service.multi_resolution: + boi_indices = torch.where(torch.logical_or(input_ids == service.boi_token_id, input_ids == service.bop_token_id))[0].tolist() + eoi_indices = torch.where(torch.logical_or(input_ids == service.eoi_token_id, input_ids == service.eop_token_id))[0].tolist() + + else: + + boi_indices = torch.where(input_ids == service.boi_token_id)[0].tolist() + eoi_indices = torch.where(input_ids == service.eoi_token_id)[0].tolist() + + for boi_idx, eoi_idx in zip(boi_indices, eoi_indices): + ids_cmp_mask[boi_idx + 1:eoi_idx] = True + + input_ids = input_ids.unsqueeze(0) + ids_cmp_mask = ids_cmp_mask.unsqueeze(0) + ids_gen_mask = ids_gen_mask.unsqueeze(0) + + error_msg = [] + + if service.multi_resolution: + output = service.agent.generate( + tokenizer=service.tokenizer, + input_ids=input_ids, + image_embeds=image_embeds, + patch_positions=patch_position, + embeds_cmp_mask=embeds_cmp_mask, + ids_cmp_mask=ids_cmp_mask, + num_img_gen_tokens=num_img_out_tokens, + max_new_tokens=max_new_tokens, + dtype=service.dtype, + device=service.llm_device, + top_p=top_p, + ) + else: + output = service.agent.generate( + tokenizer=service.tokenizer, + input_ids=input_ids, + image_embeds=image_embeds, + embeds_cmp_mask=embeds_cmp_mask, + ids_cmp_mask=ids_cmp_mask, + num_img_gen_tokens=num_img_out_tokens, + max_new_tokens=max_new_tokens, + dtype=service.dtype, + device=service.llm_device, + top_p=top_p, + ) + + gen_imgs_base64_list = [] + generated_text = output['text'] + generated_text = generated_text.replace(EOI_TOKEN, IMG_FLAG).replace(service.tokenizer.eos_token, '') + + if output['has_img_output']: + print('loading visual encoder and llm to CPU, and sd to GPU') + a = time.time() + service.agent = service.agent.to("cpu") + service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype) + print("Loading finished: ", time.time() - a) + + img_gen_feat = output['img_gen_feat'].to(service.vit_sd_device, dtype=service.dtype) + + for img_idx in range(output['num_gen_imgs']): + img_feat = img_gen_feat[img_idx:img_idx + 1] + generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0] + image_base64 = encode_image(generated_image) + gen_imgs_base64_list.append(image_base64) + + print('loading visual encoder and llm to GPU, and sd to CPU') + a = time.time() + service.sd_adapter = service.sd_adapter.to("cpu") + service.visual_encoder = service.visual_encoder.to(service.vit_sd_device, dtype=service.dtype) + service.agent = service.agent.to(service.vit_sd_device, dtype=service.dtype) + print("Loading finished: ", time.time() - a) + + if args.has_bbox: + bboxes = extract_box(generated_text) + + if bboxes is not None and len(input_images) > 0: + image_viz = visualize_bbox(input_images[0], bboxes) + image_base64 = encode_image(image_viz) + gen_imgs_base64_list.append(image_base64) + generated_text = re.sub(r'\[\[ .*?.*?\]\]', 'the green bounding box', generated_text) + generated_text += IMG_FLAG + print(input_text + generated_text) + + return {'text': generated_text, 'images': gen_imgs_base64_list, 'error_msg': error_msg} + + +if __name__ == '__main__': + app.run(host='0.0.0.0', port=args.port) diff --git a/src/demo/seed_llama_gradio.py b/src/demo/seed_llama_gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..2d55dfe059d14124e43c64dee77c0b397e107184 --- /dev/null +++ b/src/demo/seed_llama_gradio.py @@ -0,0 +1,465 @@ +import os +import numpy as np +import datetime +import json +from typing import Optional +import transformers +from dataclasses import dataclass, field +import io +import base64 +from PIL import Image +import gradio as gr +import time +import hashlib +import requests + +from utils import build_logger +from conversation import conv_seed_llama2 + +IMG_FLAG = '' +LOGDIR = 'log' + +logger = build_logger("gradio_seed_x", LOGDIR) +headers = {"User-Agent": "SEED-X Client"} + +no_change_btn = gr.Button.update() +enable_btn = gr.Button.update(interactive=True) +disable_btn = gr.Button.update(interactive=False) + + +@dataclass +class Arguments: + server_port: Optional[int] = field(default=7860, metadata={"help": "network port"}) + server_name: Optional[str] = field(default='0.0.0.0', metadata={"help": "network address"}) + request_address: Optional[str] = field(default='http://127.0.0.1:7890/generate', + metadata={"help": "request address"}) + + +parser = transformers.HfArgumentParser(Arguments) +args, = parser.parse_args_into_dataclasses() +conv_seed_llama = conv_seed_llama2 + + +def decode_image(encoded_image: str) -> Image: + decoded_bytes = base64.b64decode(encoded_image.encode('utf-8')) + buffer = io.BytesIO(decoded_bytes) + image = Image.open(buffer) + return image + + +def encode_image(image: Image.Image, format: str = 'PNG') -> str: + with io.BytesIO() as buffer: + image.save(buffer, format=format) + encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8') + return encoded_image + + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + + +def get_conv_image_dir(): + name = os.path.join(LOGDIR, 'images') + os.makedirs(name, exist_ok=True) + return name + + +def get_image_name(image, image_dir=None): + buffer = io.BytesIO() + image.save(buffer, format='PNG') + image_bytes = buffer.getvalue() + md5 = hashlib.md5(image_bytes).hexdigest() + + if image_dir is not None: + image_name = os.path.join(image_dir, md5 + '.png') + else: + image_name = md5 + '.png' + + return image_name + + +def resize_image_square(image, target_size=448): + resized_image = image.resize((target_size, target_size)) + return resized_image + + +def resize_image(image, max_size=512): + width, height = image.size + aspect_ratio = float(width) / float(height) + + if width > height: + new_width = max_size + new_height = int(new_width / aspect_ratio) + else: + new_height = max_size + new_width = int(new_height * aspect_ratio) + + resized_image = image.resize((new_width, new_height)) + return resized_image + + +def center_crop_image(image, max_aspect_ratio=1.5): + width, height = image.size + aspect_ratio = max(width, height) / min(width, height) + + if aspect_ratio >= max_aspect_ratio: + if width > height: + new_width = int(height * max_aspect_ratio) + left = (width - new_width) // 2 + right = (width + new_width) // 2 + top = 0 + bottom = height + else: + new_height = int(width * max_aspect_ratio) + left = 0 + right = width + top = (height - new_height) // 2 + bottom = (height + new_height) // 2 + + cropped_image = image.crop((left, top, right, bottom)) + return cropped_image + else: + return image + + +def vote_last_response(state, vote_type, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + + +def upvote_last_response(state, request: gr.Request): + logger.info(f"upvote. ip: {request.client.host}") + vote_last_response(state, "upvote", request) + return (disable_btn,) * 2 + + +def downvote_last_response(state, request: gr.Request): + logger.info(f"downvote. ip: {request.client.host}") + vote_last_response(state, "downvote", request) + return (disable_btn,) * 2 + + +def regenerate(dialog_state, request: gr.Request): + logger.info(f"regenerate. ip: {request.client.host}") + if dialog_state.messages[-1]['role'] == dialog_state.roles[1]: + dialog_state.messages.pop() + return ( + dialog_state, + dialog_state.to_gradio_chatbot(), + ) + (disable_btn,) * 4 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history. ip: {request.client.host}") + dialog_state = conv_seed_llama.copy() + input_state = init_input_state() + return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 + + +def init_input_state(): + return {'images': [], 'text': ''} + + +def add_text(dialog_state, input_state, text, request: gr.Request): + logger.info(f"add_text. ip: {request.client.host}.") + # if len(input_state['text']) == 0: + if text is None or len(text) == 0: + # dialog_state.skip_next = True + return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 + input_state['text'] += text + + + if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]: + dialog_state.messages[-1]['message'] = input_state + else: + dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state}) + print('add_text: ', dialog_state.to_gradio_chatbot()) + + return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 + + +def is_blank(image): + image_array = np.array(image) + unique_colors = np.unique(image_array) + print('unique_colors', len(unique_colors)) + return len(unique_colors) == 1 + + +def add_image(dialog_state, input_state, image, request: gr.Request): + logger.info(f"add_image. ip: {request.client.host}.") + if image is None: + return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 + + image = image.convert('RGB') + + print('image size:', image.size) + + image = center_crop_image(image, max_aspect_ratio=10) + + image_dir = get_conv_image_dir() + image_path = get_image_name(image=image, image_dir=image_dir) + if not os.path.exists(image_path): + image.save(image_path) + input_state['images'].append(image_path) + input_state['text'] += IMG_FLAG + + if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]: + dialog_state.messages[-1]['message'] = input_state + else: + dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state}) + + print('add_image:', dialog_state) + + return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 + + +def http_bot(dialog_state, input_state, max_new_tokens, max_turns, force_image_gen, force_bbox, + request: gr.Request): + logger.info(f"http_bot. ip: {request.client.host}") + print('input_state:', input_state) + + if len(dialog_state.messages) == 0 or dialog_state.messages[-1]['role'] != dialog_state.roles[0] or len( + dialog_state.messages[-1]['message']['text'].strip(' ?.;!/')) == 0: + return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 + + if len(dialog_state.messages) > max_turns * 2: + output_state = init_input_state() + output_state['text'] = 'Error: History exceeds maximum rounds, please clear history and restart.' + dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state}) + input_state = init_input_state() + return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 3 + (enable_btn,) + + prompt = dialog_state.get_prompt() + payload = { + 'text': prompt['text'], + 'max_new_tokens': int(max_new_tokens), + 'images': prompt['images'], + 'force_boi': force_image_gen, + 'force_bbox': force_bbox, + } + + print( + 'request: ', { + 'text': prompt['text'], + 'max_new_tokens': int(max_new_tokens), + }) + print('request_address', args.request_address) + response = requests.request(method="POST", url=args.request_address, headers=headers, json=payload) + results = response.json() + print('response: ', {'text': results['text'], 'error_msg': results['error_msg']}) + + output_state = init_input_state() + image_dir = get_conv_image_dir() + output_state['text'] = results['text'] + + for image_base64 in results['images']: + if image_base64 == '': + image_path = '' + else: + image = decode_image(image_base64) + image = image.convert('RGB') + image_path = get_image_name(image=image, image_dir=image_dir) + if not os.path.exists(image_path): + image.save(image_path) + output_state['images'].append(image_path) + + dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state}) + + vote_last_response(dialog_state, 'common', request) + input_state = init_input_state() + chatbot = update_error_msg(dialog_state.to_gradio_chatbot(), results['error_msg']) + return (dialog_state, input_state, chatbot) + (enable_btn,) * 4 + + +def update_error_msg(chatbot, error_msg): + if len(error_msg) > 0: + info = '\n-------------\nSome errors occurred during response, please clear history and restart.\n' + '\n'.join( + error_msg) + chatbot[-1][-1] = chatbot[-1][-1] + info + + return chatbot + + +def load_demo(request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}") + dialog_state = conv_seed_llama.copy() + input_state = init_input_state() + return dialog_state, input_state + + +title = (""" +# SEED-X-I +[[Paper]](https://arxiv.org/abs/2404.14396) [[Code]](https://github.com/AILab-CVC/SEED-X) + +Demo of a general instruction-tuned model SEED-X-I (17B) from the foundation model SEED-X. + +SEED-X-I can follow multimodal instruction (including images with **dynamic resolutions**) and make responses with **images, texts and bounding boxes** in multi-turn conversation. + +SEED-X-I **does not support image manipulation**. If you want to experience **SEED-X-Edit** for high-precision image editing, please refer to [[Inference Code]](https://github.com/AILab-CVC/SEED-X). + +Due to insufficient GPU memory, when generating images, we need to offload the LLM to the CPU and move the de-tokenizer to the CPU, which will **result in a long processing time**. If you want to experience the normal model inference speed, you can run [[Inference Code]](https://github.com/AILab-CVC/SEED-X) locally. + + +## Tips: +* Check out the conversation examples (at the bottom) for inspiration. + +* You can adjust "Max History Rounds" to try a conversation with up to five rounds. For more turns, you can download our checkpoints from GitHub and deploy them locally for inference. + +* Our demo supports a mix of images and texts as input. You can freely upload an image or enter text, and then click on "Add Image/Text". You can repeat the former step multiple times, and click on "Submit" for model inference at last. + +* You can click "Force Image Generation" to compel the model to produce images when necessary. For example, our model might struggle to generate images when there is an excessive amount of text-only context. + +* You can click "Force Bounding Box" to compel the model to produce bounding box for object detection. + +* SEED-X was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable. + +""") + +css = """ +img { + font-family: 'Helvetica'; + font-weight: 300; + line-height: 2; + text-align: center; + + width: auto; + height: auto; + display: block; + position: relative; +} + +img:before { + content: " "; + display: block; + + position: absolute; + top: -10px; + left: 0; + height: calc(100% + 10px); + width: 100%; + background-color: rgb(230, 230, 230); + border: 2px dotted rgb(200, 200, 200); + border-radius: 5px; +} + +img:after { + content: " "; + display: block; + font-size: 16px; + font-style: normal; + font-family: FontAwesome; + color: rgb(100, 100, 100); + + position: absolute; + top: 5px; + left: 0; + width: 100%; + text-align: center; +} + +""" + +if __name__ == '__main__': + + examples_mix = [ + ['seed_x/bank.png', 'Can I conntect with an advisor on Sunday?'], + ['seed_x/ground.png', + 'Is there anything in the image that can protect me from catching the flu virus when I go out? Show me the location.'], + ['seed_x/arrow.jpg', 'What is the object pointed by the red arrow?'], + ['seed_x/shanghai.png', 'Where was this image taken? Explain your answer.'], + ['seed_x/GPT4.png', 'How long does it take to make GPT-4 safer?'], + ['seed_x/twitter.png', + 'Please provide a comprehensive description of this image.'], + ] + + examples_text = [ + ['I want to build a two story cabin in the woods, with many commanding windows. Can you show me a picture?'], + ['Use your imagination to design a concept image for Artificial General Intelligence (AGI). Show me an image.'], + [ + 'Can you design an illustration for “The Three-Body Problem” to depict a scene from the novel? Show me a picture.'], + [ + 'My four year old son loves toy trains. Can you design a fancy birthday cake for him? Please generate a picture.'], + [ + 'Generate an image of a portrait of young nordic girl, age 25, freckled skin, neck tatoo, blue eyes 35mm lens, photography, ultra details.'], + ['Generate an impressionist painting of an astronaut in a jungle.'] + ] + with gr.Blocks(css=css) as demo: + gr.Markdown(title) + dialog_state = gr.State() + input_state = gr.State() + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(): + image = gr.Image(type='pil', label='input_image') + with gr.Row(): + text = gr.Textbox(lines=5, + show_label=False, + label='input_text', + elem_id='textbox', + placeholder="Enter text or add image, and press submit,").style(container=False) + with gr.Row(): + add_image_btn = gr.Button("Add Image") + add_text_btn = gr.Button("Add Text") + + submit_btn = gr.Button("Submit") + + with gr.Row(): + max_new_tokens = gr.Slider(minimum=64, + maximum=1024, + value=768, + step=64, + interactive=True, + label="Max Output Tokens") + max_turns = gr.Slider(minimum=1, maximum=9, value=3, step=1, interactive=True, + label="Max History Rounds") + force_img_gen = gr.Radio(choices=[True, False], value=False, label='Force Image Generation') + force_bbox = gr.Radio(choices=[True, False], value=False, label='Force Bounding Box') + + with gr.Column(scale=7): + chatbot = gr.Chatbot(elem_id='chatbot', label="SEED-X-I").style(height=700) + with gr.Row(): + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + + with gr.Row(): + with gr.Column(scale=0.7): + gr.Examples(examples=examples_mix, label='Input examples', inputs=[image, text]) + with gr.Column(scale=0.3): + gr.Examples(examples=examples_text, label='Input examples', inputs=[text]) + + # Register listeners + btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn] + upvote_btn.click(upvote_last_response, [dialog_state], [upvote_btn, downvote_btn]) + downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn]) + + regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then( + http_bot, [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox], + [dialog_state, input_state, chatbot] + btn_list) + add_image_btn.click(add_image, [dialog_state, input_state, image], + [dialog_state, input_state, image, chatbot] + btn_list) + + add_text_btn.click(add_text, [dialog_state, input_state, text], + [dialog_state, input_state, text, chatbot] + btn_list) + + submit_btn.click( + add_image, [dialog_state, input_state, image], [dialog_state, input_state, image, chatbot] + btn_list).then( + add_text, [dialog_state, input_state, text], + [dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then( + http_bot, + [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox], + [dialog_state, input_state, chatbot] + btn_list) + clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list) + + demo.load(load_demo, None, [dialog_state, input_state]) + + demo.launch(server_name=args.server_name, server_port=args.server_port, enable_queue=True) diff --git a/src/demo/utils.py b/src/demo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c69f0942ba28b6bb64f0eb43f0e248a96e12df88 --- /dev/null +++ b/src/demo/utils.py @@ -0,0 +1,83 @@ +import datetime +import logging +import logging.handlers +import os +import sys + +handler = None + + +def build_logger(logger_name, logger_dir): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(logger_dir, exist_ok=True) + filename = os.path.join(logger_dir, logger_name + '.log') + handler = logging.handlers.TimedRotatingFileHandler(filename, when='D', utc=True) + handler.setFormatter(formatter) + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = '' + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = '' + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == '\n': + self.logger.log(self.log_level, line.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != '': + self.logger.log(self.log_level, self.linebuf.rstrip()) + self.linebuf = '' diff --git a/src/inference/.DS_Store b/src/inference/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/src/inference/.DS_Store differ diff --git a/src/inference/__pycache__/any_res.cpython-311.pyc b/src/inference/__pycache__/any_res.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7787d734649c4e58eaec99eb14ef898d3796b7d3 Binary files /dev/null and b/src/inference/__pycache__/any_res.cpython-311.pyc differ diff --git a/src/inference/__pycache__/any_res.cpython-38.pyc b/src/inference/__pycache__/any_res.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13ba06212a52fdc2676316db59144d9de812cccc Binary files /dev/null and b/src/inference/__pycache__/any_res.cpython-38.pyc differ diff --git a/src/inference/any_res.py b/src/inference/any_res.py new file mode 100644 index 0000000000000000000000000000000000000000..8587dfe13f361d57dd665e64640307ceb12d4ad7 --- /dev/null +++ b/src/inference/any_res.py @@ -0,0 +1,257 @@ +import base64 +import torch +import math +import ast +from PIL import Image +from io import BytesIO + + +def select_best_resolution(original_size, possible_resolutions): + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float('inf') + + for width, height in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + +def select_best_resolution_v2(original_size, possible_resolutions): + """ + Selects the best resolution from a list of possible resolutions based on the original size and aspect ratio. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + original_width, original_height = original_size + original_aspect_ratio = original_height / original_width + original_area = original_width * original_height + best_fit = None + min_aspect_ratio_diff = float('inf') + min_area_ratio = float('inf') + + for width, height in possible_resolutions: + aspect_ratio = height / width + area = width * height + aspect_ratio_diff = max(aspect_ratio, original_aspect_ratio) / min(aspect_ratio, original_aspect_ratio) + area_ratio = max(area, original_area) / min(area, original_area) + + if aspect_ratio_diff < min_aspect_ratio_diff or (aspect_ratio_diff == min_aspect_ratio_diff and area_ratio < min_area_ratio): + min_aspect_ratio_diff = aspect_ratio_diff + min_area_ratio = area_ratio + best_fit = (width, height) + + return best_fit + + +def resize_and_pad_image(image, target_resolution, keep_ratio=False): + """ + Resize and pad an image to a target resolution + + Args: + image (PIL.Image.Image): The input image. + target_resolution (tuple): The target resolution (width, height) of the image. + + Returns: + PIL.Image.Image: The resized and padded image. + """ + original_width, original_height = image.size + target_width, target_height = target_resolution + + if keep_ratio: + # maintaining aspect ratio + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + # Resize the image + resized_image = image.resize((new_width, new_height)) + + new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + new_image.paste(resized_image, (paste_x, paste_y)) + else: + # not maintaining aspect ratio + new_image = image.resize((target_width, target_height)) + + return new_image + + +def divide_to_patches(image, patch_size): + """ + Divides an image into patches of a specified size. + + Args: + image (PIL.Image.Image): The input image. + patch_size (int): The size of each patch. + + Returns: + list: A list of PIL.Image.Image objects representing the patches. + """ + patches = [] + width, height = image.size + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + box = (j, i, j + patch_size, i + patch_size) + patch = image.crop(box) + patches.append(patch) + + return patches + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (tuple): The size of the input image in the format (width, height). + grid_pinpoints (str): A string representation of a list of possible resolutions. + patch_size (int): The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + width1, height1 = select_best_resolution(image_size, possible_resolutions) + width2, height2 = select_best_resolution_v2(image_size, possible_resolutions) + if width1*height1 > width2*height2: + width, height = width2, height2 + else: + width, height = width1, height1 + return width // patch_size, height // patch_size + + +def process_anyres_image(image, image_transform, grid_pinpoints, base_image_size): + """ + Process an image with variable resolutions. + + Args: + image (PIL.Image.Image): The input image to be processed. + image_transform: The image processor object. + grid_pinpoints (str): A string representation of a list of possible resolutions. + + Returns: + torch.Tensor: A tensor containing the processed image patches. + """ + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + # best_resolution = select_best_resolution(image.size, possible_resolutions) + width1, height1 = select_best_resolution(image.size, possible_resolutions) + width2, height2 = select_best_resolution_v2(image.size, possible_resolutions) + if width1*height1 > width2*height2: + width, height = width2, height2 + else: + width, height = width1, height1 + best_resolution = [width, height] + + image_padded = resize_and_pad_image(image, best_resolution) + + patches = divide_to_patches(image_padded, base_image_size) + + image_original_resize = image.resize((base_image_size, base_image_size)) + + image_patches = patches + [image_original_resize] # add the original image as the last patch + image_patches = [image_transform(image_patch) + for image_patch in image_patches] + + patch_grid = (best_resolution[0]//base_image_size, best_resolution[1]//base_image_size) + x_index = (torch.arange(patch_grid[0]).repeat(patch_grid[1], 1) + 0.5)/patch_grid[0] + y_index = (torch.arange(patch_grid[1]).unsqueeze(1).repeat(1, patch_grid[0]) + 0.5)/patch_grid[1] + patch_pos = torch.stack([x_index, y_index], dim=-1).flatten(0, 1) # h*w, 2 + + origin_pos = torch.tensor([[0.5, 0.5]]) + patch_pos = torch.cat([patch_pos, origin_pos], dim=0) # h*w+1, 2 + + return torch.stack(image_patches, dim=0), patch_pos + + +def load_image_from_base64(image): + return Image.open(BytesIO(base64.b64decode(image))) + + +def anyres_data_collate(batch, tokenizer, dataset_name=None): + results = {} + keys = batch[0].keys() + + for key in keys: + cur = [batch[i][key] for i in range(len(batch)) if batch[i][key] is not None] + if len(cur) == 0: + results[key] = None + elif isinstance(cur[0], torch.Tensor): + if key in ['embeds_gen_mask', 'embeds_cmp_mask', 'images', 'images_patch_length', 'patch_position', 'image_size']: + results[key] = torch.cat(cur, dim=0) + else: + if key in ['input_ids']: + results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=tokenizer.pad_token_id) + elif key in ['attention_mask']: + results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=0) + elif key in ['labels']: + results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=-100) + elif key in ['ids_gen_mask', 'ids_cmp_mask']: + results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=False) + + else: + results[key] = torch.stack(cur, dim=0) + else: + results[key] = cur + + results['dataset_name'] = dataset_name + + return results + + +def anyres_data_collate_old(batch, dataset_name=None): + results = {} + keys = batch[0].keys() + + for key in keys: + cur = [batch[i][key] for i in range(len(batch)) if batch[i][key] is not None] + if len(cur) == 0: + results[key] = None + elif isinstance(cur[0], torch.Tensor): + if key in ['embeds_gen_mask', 'embeds_cmp_mask', 'images', 'images_patch_length', 'patch_position', 'image_size']: + results[key] = torch.cat(cur, dim=0) + else: + results[key] = torch.stack(cur, dim=0) + else: + results[key] = cur + + results['dataset_name'] = dataset_name + + return results diff --git a/src/inference/eval_img2edit_seed_x.py b/src/inference/eval_img2edit_seed_x.py new file mode 100644 index 0000000000000000000000000000000000000000..bed3f63e965869d8dc9773d6fa10c2bb3661ff07 --- /dev/null +++ b/src/inference/eval_img2edit_seed_x.py @@ -0,0 +1,155 @@ +import hydra +import torch +import os +import re +import pyrootutils +from PIL import Image +from omegaconf import OmegaConf +from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler, Transformer2DModel +from any_res import process_anyres_image + +pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True) + +BOI_TOKEN = '' +BOP_TOKEN = '' +EOI_TOKEN = '' +EOP_TOKEN = '' +IMG_TOKEN = '' + +resolution_grids = ['1x1'] +base_resolution = 448 + +device = 'cuda:0' +device1 = 'cuda:1' +dtype = torch.float16 +dtype_str = 'fp16' +num_img_in_tokens = 64 +num_img_out_tokens = 64 +instruction_prompt = '[INST] {instruction} [/INST]\n' + +save_dir = 'vis' +os.makedirs(save_dir, exist_ok=True) + +tokenizer_cfg_path = 'configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml' +image_transform_cfg_path = 'configs/processer/qwen_448_transform.yaml' +visual_encoder_cfg_path = 'configs/visual_encoder/qwen_vitg_448.yaml' +llm_cfg_path = 'configs/clm_models/llm_seed_x_edit.yaml' +agent_cfg_path = 'configs/clm_models/agent_seed_x_edit.yaml' +adapter_cfg_path = 'configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_full_with_latent_image_pretrain_no_normalize.yaml' +discrete_model_cfg_path = 'configs/discrete_model/discrete_identity.yaml' + +diffusion_model_path = 'pretrained/stable-diffusion-xl-base-1.0' + +tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path) +tokenizer = hydra.utils.instantiate(tokenizer_cfg) + +image_transform_cfg = OmegaConf.load(image_transform_cfg_path) +image_transform = hydra.utils.instantiate(image_transform_cfg) + +visual_encoder_cfg = OmegaConf.load(visual_encoder_cfg_path) +visual_encoder = hydra.utils.instantiate(visual_encoder_cfg) +visual_encoder.eval().to(device1, dtype=dtype) +print('Init visual encoder done') + +llm_cfg = OmegaConf.load(llm_cfg_path) +llm = hydra.utils.instantiate(llm_cfg, torch_dtype=dtype) +print('Init llm done.') + +agent_model_cfg = OmegaConf.load(agent_cfg_path) +agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm) + +agent_model.eval().to(device, dtype=dtype) +print('Init agent mdoel Done') + +noise_scheduler = EulerDiscreteScheduler.from_pretrained(diffusion_model_path, subfolder="scheduler") +print('init vae') +vae = AutoencoderKL.from_pretrained(diffusion_model_path, subfolder="vae").to(device1, dtype=dtype) +print('init unet') +unet = UNet2DConditionModel.from_pretrained(diffusion_model_path, subfolder="unet").to(device1, dtype=dtype) + +adapter_cfg = OmegaConf.load(adapter_cfg_path) +adapter = hydra.utils.instantiate(adapter_cfg, unet=unet).to(device1, dtype=dtype).eval() + +discrete_model_cfg = OmegaConf.load(discrete_model_cfg_path) +discrete_model = hydra.utils.instantiate(discrete_model_cfg).to(device1).eval() +print('Init adapter done') + +adapter.init_pipe(vae=vae, + scheduler=noise_scheduler, + visual_encoder=visual_encoder, + image_transform=image_transform, + dtype=dtype, + device=device1) + +print('Init adapter pipe done') +boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0] +eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0] + +bop_token_id = tokenizer.encode(BOP_TOKEN, add_special_tokens=False)[0] +eop_token_id = tokenizer.encode(EOP_TOKEN, add_special_tokens=False)[0] + +grid_pinpoints = [] +for scale in resolution_grids: + s1, s2 = scale.split('x') + grid_pinpoints.append([int(s1)*base_resolution, int(s2)*base_resolution]) +grid_pinpoints = grid_pinpoints + + +image_path = 'demo_images/car.jpg' +instruction = 'Make it under the sunset' + +image = Image.open(image_path).convert('RGB') +source_image = image.resize((1024, 1024)) + +image_tensor, patch_pos_tensor = process_anyres_image(image, image_transform, grid_pinpoints, base_resolution) +embeds_cmp_mask = torch.tensor([True]*image_tensor.shape[0]).to(device, dtype=torch.bool) + +patch_pos = [patch_pos_tensor] +patch_position = torch.cat(patch_pos, dim=0) + +image_tensor = image_tensor.to(device1, dtype=dtype) + +patch_length = image_tensor.shape[0] +image_tokens = '' +for _ in range(patch_length-1): + image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN +image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN + +prompt = instruction_prompt.format_map({'instruction': image_tokens + instruction}) + +input_ids = tokenizer.encode(prompt, add_special_tokens=False) +input_ids = [tokenizer.bos_token_id] + input_ids + +input_ids = torch.tensor(input_ids).to(device, dtype=torch.long) + +ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool) + +boi_indices = torch.where(torch.logical_or(input_ids == boi_token_id, input_ids == bop_token_id))[0].tolist() +eoi_indices = torch.where(torch.logical_or(input_ids == eoi_token_id, input_ids == eop_token_id))[0].tolist() + +for boi_idx, eoi_idx in zip(boi_indices, eoi_indices): + ids_cmp_mask[boi_idx + 1:eoi_idx] = True + +input_ids = input_ids.unsqueeze(0) +ids_cmp_mask = ids_cmp_mask.unsqueeze(0) + +with torch.no_grad(): + image_embeds = visual_encoder(image_tensor) + image_embeds = image_embeds.to(device) + output = agent_model.generate(tokenizer=tokenizer, + input_ids=input_ids, + image_embeds=image_embeds, + embeds_cmp_mask=embeds_cmp_mask, + patch_positions=patch_position, + ids_cmp_mask=ids_cmp_mask, + max_new_tokens=512, + num_img_gen_tokens=num_img_out_tokens) +text = re.sub('<[^>]*>', '', output['text']) +print(text) + +if output['has_img_output']: + images = adapter.generate(image_embeds=output['img_gen_feat'].to(device1), latent_image=source_image, num_inference_steps=50) + + save_path = os.path.join(save_dir, str(len(os.listdir(save_dir))) + '_' + instruction + '.jpg') + images[0].save(save_path) +torch.cuda.empty_cache() diff --git a/src/inference/eval_img2text_seed_x.py b/src/inference/eval_img2text_seed_x.py new file mode 100644 index 0000000000000000000000000000000000000000..36a47b798149c08f151d2170cb358388cebd439d --- /dev/null +++ b/src/inference/eval_img2text_seed_x.py @@ -0,0 +1,235 @@ +import hydra +import torch +import os +import pyrootutils +from PIL import Image +import re +import cv2 +import numpy as np +from omegaconf import OmegaConf +from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler +from any_res import process_anyres_image + + +pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True) + +def visualize_bbox(image, bboxes, save_path): + img_width, img_height = image.size + image = np.array(image) + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + for bbox in bboxes: + x_center, y_center, box_width, box_height = bbox + + x_center = x_center / 224 * img_width + y_center = y_center / 224 * img_height + + box_width = box_width /224 * img_width + box_height = box_height / 224 * img_height + + x1 = int(x_center - box_width / 2) + y1 = int(y_center - box_height / 2) + x2 = int(x_center + box_width / 2) + y2 = int(y_center + box_height / 2) + + cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2) + + cv2.imwrite(save_path, image) + + +def extract_box(output_str): + boxes = re.findall('(.*?)', output_str) + if len(boxes) >0: + bboxes = [[int(num) for num in re.findall('', box)] for box in boxes] + else: + bboxes = None + + return bboxes + + +BOI_TOKEN = '' +BOP_TOKEN = '' +EOI_TOKEN = '' +EOP_TOKEN = '' +IMG_TOKEN = '' + +instruction_prompt = '[INST] {instruction} [/INST]\n' + +resolution_grids = ['1x1', '1x2', '1x3', '2x1', '3x1', '1x4', '4x1', '2x2'] +base_resolution = 448 + +device = 'cuda:0' +device1 = 'cuda:1' +dtype = torch.float16 +dtype_str = 'fp16' +num_img_in_tokens = 64 +num_img_out_tokens = 64 + +tokenizer_cfg_path = 'configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml' +image_transform_cfg_path = 'configs/processer/qwen_448_transform.yaml' +visual_encoder_cfg_path = 'configs/visual_encoder/qwen_vitg_448.yaml' +llm_cfg_path = 'configs/clm_models/llm_seed_x_i.yaml' +agent_cfg_path = 'configs/clm_models/agent_seed_x_i.yaml' +adapter_cfg_path = 'configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml' +discrete_model_cfg_path = 'configs/discrete_model/discrete_identity.yaml' + +diffusion_model_path = 'pretrained/stable-diffusion-xl-base-1.0' + +tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path) +tokenizer = hydra.utils.instantiate(tokenizer_cfg) + +image_transform_cfg = OmegaConf.load(image_transform_cfg_path) +image_transform = hydra.utils.instantiate(image_transform_cfg) + +visual_encoder_cfg = OmegaConf.load(visual_encoder_cfg_path) +visual_encoder = hydra.utils.instantiate(visual_encoder_cfg) +visual_encoder.eval().to(device1, dtype=dtype) +print('Init visual encoder done') + +llm_cfg = OmegaConf.load(llm_cfg_path) +llm = hydra.utils.instantiate(llm_cfg, torch_dtype=dtype) +print('Init llm done.') + +agent_model_cfg = OmegaConf.load(agent_cfg_path) +agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm) + +agent_model.eval().to(device, dtype=dtype) +print('Init agent mdoel Done') + +noise_scheduler = EulerDiscreteScheduler.from_pretrained(diffusion_model_path, subfolder="scheduler") +print('init vae') +vae = AutoencoderKL.from_pretrained(diffusion_model_path, subfolder="vae").to(device1, dtype=dtype) +print('init unet') +unet = UNet2DConditionModel.from_pretrained(diffusion_model_path, subfolder="unet").to(device1, dtype=dtype) + +adapter_cfg = OmegaConf.load(adapter_cfg_path) +adapter = hydra.utils.instantiate(adapter_cfg, unet=unet).to(device1, dtype=dtype).eval() + +discrete_model_cfg = OmegaConf.load(discrete_model_cfg_path) +discrete_model = hydra.utils.instantiate(discrete_model_cfg).to(device1).eval() +print('Init adapter done') + +adapter.init_pipe(vae=vae, + scheduler=noise_scheduler, + visual_encoder=visual_encoder, + image_transform=image_transform, + discrete_model=discrete_model, + dtype=dtype, + device=device1) + +print('Init adapter pipe done') +boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0] +eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0] + +bop_token_id = tokenizer.encode(BOP_TOKEN, add_special_tokens=False)[0] +eop_token_id = tokenizer.encode(EOP_TOKEN, add_special_tokens=False)[0] + +grid_pinpoints = [] +for scale in resolution_grids: + s1, s2 = scale.split('x') + grid_pinpoints.append([int(s1)*base_resolution, int(s2)*base_resolution]) +grid_pinpoints = grid_pinpoints + +# image comprehension +image_path = 'demo_images/advisor.png' +image = Image.open(image_path).convert('RGB') +image_tensor, patch_pos_tensor = process_anyres_image(image, image_transform, grid_pinpoints, base_resolution) +embeds_cmp_mask = torch.tensor([True]*image_tensor.shape[0]).to(device, dtype=torch.bool) + +patch_pos = [patch_pos_tensor] +patch_position = torch.cat(patch_pos, dim=0) + +image_tensor = image_tensor.to(device1, dtype=dtype) + +patch_length = image_tensor.shape[0] +image_tokens = '' +for _ in range(patch_length-1): + image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN +image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN + +question = 'Can I conntect with an advisor on Sunday?' +prompt = instruction_prompt.format_map({'instruction': image_tokens + question}) + +input_ids = tokenizer.encode(prompt, add_special_tokens=False) +input_ids = [tokenizer.bos_token_id] + input_ids + +input_ids = torch.tensor(input_ids).to(device, dtype=torch.long) + +ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool) + +boi_indices = torch.where(torch.logical_or(input_ids == boi_token_id, input_ids == bop_token_id))[0].tolist() +eoi_indices = torch.where(torch.logical_or(input_ids == eoi_token_id, input_ids == eop_token_id))[0].tolist() + +for boi_idx, eoi_idx in zip(boi_indices, eoi_indices): + ids_cmp_mask[boi_idx + 1:eoi_idx] = True + +input_ids = input_ids.unsqueeze(0) +ids_cmp_mask = ids_cmp_mask.unsqueeze(0) + +with torch.no_grad(): + image_embeds = visual_encoder(image_tensor) + image_embeds = image_embeds.to(device) + output = agent_model.generate(tokenizer=tokenizer, + input_ids=input_ids, + image_embeds=image_embeds, + embeds_cmp_mask=embeds_cmp_mask, + patch_positions=patch_position, + ids_cmp_mask=ids_cmp_mask, + max_new_tokens=512, + num_img_gen_tokens=num_img_out_tokens) + +text = re.sub('<[^>]*>', '', output['text']) +print(text) + +# detection +image_path = 'demo_images/ground.png' +image = Image.open(image_path).convert('RGB') +image_tensor, patch_pos_tensor = process_anyres_image(image, image_transform, grid_pinpoints, base_resolution) +embeds_cmp_mask = torch.tensor([True]*image_tensor.shape[0]).to(device, dtype=torch.bool) + +patch_pos = [patch_pos_tensor] +patch_position = torch.cat(patch_pos, dim=0) + +image_tensor = image_tensor.to(device1, dtype=dtype) + +patch_length = image_tensor.shape[0] +image_tokens = '' +for _ in range(patch_length-1): + image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN +image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN + +question = 'Is there anything in the image that can protect me from catching the flu virus when I go out? Show me the location.' +prompt = instruction_prompt.format_map({'instruction': image_tokens + question}) + +input_ids = tokenizer.encode(prompt, add_special_tokens=False) +input_ids = [tokenizer.bos_token_id] + input_ids + +input_ids = torch.tensor(input_ids).to(device, dtype=torch.long) + +ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool) + +boi_indices = torch.where(torch.logical_or(input_ids == boi_token_id, input_ids == bop_token_id))[0].tolist() +eoi_indices = torch.where(torch.logical_or(input_ids == eoi_token_id, input_ids == eop_token_id))[0].tolist() + +for boi_idx, eoi_idx in zip(boi_indices, eoi_indices): + ids_cmp_mask[boi_idx + 1:eoi_idx] = True + +input_ids = input_ids.unsqueeze(0) +ids_cmp_mask = ids_cmp_mask.unsqueeze(0) + +with torch.no_grad(): + image_embeds = visual_encoder(image_tensor) + image_embeds = image_embeds.to(device) + output = agent_model.generate(tokenizer=tokenizer, + input_ids=input_ids, + image_embeds=image_embeds, + embeds_cmp_mask=embeds_cmp_mask, + patch_positions=patch_position, + ids_cmp_mask=ids_cmp_mask, + max_new_tokens=512, + num_img_gen_tokens=num_img_out_tokens) +print(output['text']) +bbox = extract_box(output['text']) +if bbox is not None: + save_path = 'vis/ground.png' + visualize_bbox(image, bbox, save_path) + \ No newline at end of file diff --git a/src/inference/eval_text2img_seed_x.py b/src/inference/eval_text2img_seed_x.py new file mode 100644 index 0000000000000000000000000000000000000000..0bed25df41b4bbc206201bdecdeaa785ece9fed3 --- /dev/null +++ b/src/inference/eval_text2img_seed_x.py @@ -0,0 +1,94 @@ +import hydra +import torch +import os +import pyrootutils +from PIL import Image +from omegaconf import OmegaConf +from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler + + +pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True) + +BOI_TOKEN = '' +EOI_TOKEN = '' +IMG_TOKEN = '' + +device = 'cuda:0' +device_2 = 'cuda:1' +dtype = torch.float16 +dtype_str = 'fp16' +num_img_in_tokens = 64 +num_img_out_tokens = 64 + +instruction_prompt = '[INST] Generate an image: {caption} [/INST]\n' + +tokenizer_cfg_path = 'configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml' +image_transform_cfg_path = 'configs/processer/qwen_448_transform.yaml' +visual_encoder_cfg_path = 'configs/visual_encoder/qwen_vitg_448.yaml' +llm_cfg_path = 'configs/clm_models/llm_seed_x_i.yaml' +agent_cfg_path = 'configs/clm_models/agent_seed_x_i.yaml' +adapter_cfg_path = 'configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml' +discrete_model_cfg_path = 'configs/discrete_model/discrete_identity.yaml' + +diffusion_model_path = 'pretrained/stable-diffusion-xl-base-1.0' + +save_dir = 'vis' +os.makedirs(save_dir, exist_ok=True) + +tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path) +tokenizer = hydra.utils.instantiate(tokenizer_cfg) + +image_transform_cfg = OmegaConf.load(image_transform_cfg_path) +image_transform = hydra.utils.instantiate(image_transform_cfg) + +visual_encoder_cfg = OmegaConf.load(visual_encoder_cfg_path) +visual_encoder = hydra.utils.instantiate(visual_encoder_cfg) +visual_encoder.eval().to(device_2, dtype=dtype) +print('Init visual encoder done') + +llm_cfg = OmegaConf.load(llm_cfg_path) +llm = hydra.utils.instantiate(llm_cfg, torch_dtype=dtype) +print('Init llm done.') + +agent_model_cfg = OmegaConf.load(agent_cfg_path) +agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm) + +agent_model.eval().to(device, dtype=dtype) +print('Init agent mdoel Done') + +noise_scheduler = EulerDiscreteScheduler.from_pretrained(diffusion_model_path, subfolder="scheduler") +print('init vae') +vae = AutoencoderKL.from_pretrained(diffusion_model_path, subfolder="vae").to(device_2, dtype=dtype) +print('init unet') +unet = UNet2DConditionModel.from_pretrained(diffusion_model_path, subfolder="unet").to(device_2, dtype=dtype) + +adapter_cfg = OmegaConf.load(adapter_cfg_path) +adapter = hydra.utils.instantiate(adapter_cfg, unet=unet).to(device_2, dtype=dtype).eval() + +discrete_model_cfg = OmegaConf.load(discrete_model_cfg_path) +discrete_model = hydra.utils.instantiate(discrete_model_cfg).to(device_2).eval() +print('Init adapter done') + +adapter.init_pipe(vae=vae, + scheduler=noise_scheduler, + visual_encoder=visual_encoder, + image_transform=image_transform, + discrete_model=discrete_model, + dtype=dtype, + device=device_2) + +print('Init adapter pipe done') + +caption = 'A cybernetic soldier, enhanced with advanced weapons systems and tactical analysis software, on a mission behind enemy lines.' +prompt = instruction_prompt.format_map({'caption': caption}) +prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) +input_ids = torch.tensor([tokenizer.bos_token_id] + prompt_ids).to(device, dtype=torch.long).unsqueeze(0) +output = agent_model.generate(tokenizer=tokenizer, input_ids=input_ids, num_img_gen_tokens=num_img_out_tokens) +print(output['has_img_output']) +print(output['text']) + +if output['has_img_output']: + images = adapter.generate(image_embeds=output['img_gen_feat'].to(device_2), num_inference_steps=50) + save_path = os.path.join(save_dir, caption.replace('.', '') + '.png') + images[0].save(save_path) +torch.cuda.empty_cache() diff --git a/src/models/detokenizer/__init__.py b/src/models/detokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/models/detokenizer/__init__.py @@ -0,0 +1 @@ + diff --git a/src/models/detokenizer/__pycache__/__init__.cpython-311.pyc b/src/models/detokenizer/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d637b8682a4681e8cd5b5a8e1f0e1ebbe2bbd46 Binary files /dev/null and b/src/models/detokenizer/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/models/detokenizer/__pycache__/__init__.cpython-38.pyc b/src/models/detokenizer/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d71264e396e90773ca1a5e2f6a8a3660ae8d4c80 Binary files /dev/null and b/src/models/detokenizer/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/models/detokenizer/__pycache__/adapter_modules.cpython-311.pyc b/src/models/detokenizer/__pycache__/adapter_modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44e3b0afb03e243028ab8cf9b6b178dd33a22271 Binary files /dev/null and b/src/models/detokenizer/__pycache__/adapter_modules.cpython-311.pyc differ diff --git a/src/models/detokenizer/__pycache__/adapter_modules.cpython-38.pyc b/src/models/detokenizer/__pycache__/adapter_modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b376b4468d9b8be3f98289ea5e23d8222f4408dd Binary files /dev/null and b/src/models/detokenizer/__pycache__/adapter_modules.cpython-38.pyc differ diff --git a/src/models/detokenizer/__pycache__/attention_processor.cpython-38.pyc b/src/models/detokenizer/__pycache__/attention_processor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e007224a6104997e14dbec2b07cdcfe8ea8ed428 Binary files /dev/null and b/src/models/detokenizer/__pycache__/attention_processor.cpython-38.pyc differ diff --git a/src/models/detokenizer/__pycache__/ipa_utils.cpython-38.pyc b/src/models/detokenizer/__pycache__/ipa_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc50adc1e52a09187ddfa3b35c3e01cf5c1a4f7f Binary files /dev/null and b/src/models/detokenizer/__pycache__/ipa_utils.cpython-38.pyc differ diff --git a/src/models/detokenizer/__pycache__/pipeline_stable_diffusion_t2i_edit.cpython-38.pyc b/src/models/detokenizer/__pycache__/pipeline_stable_diffusion_t2i_edit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c662330e686302c175a154498bac329b460dd396 Binary files /dev/null and b/src/models/detokenizer/__pycache__/pipeline_stable_diffusion_t2i_edit.cpython-38.pyc differ diff --git a/src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-311.pyc b/src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..654f9d206e7a068ec31a2ce1ef266b7929cf85b2 Binary files /dev/null and b/src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-311.pyc differ diff --git a/src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-38.pyc b/src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9a1b8646226b7d459ae352cf368135e1472f7d4 Binary files /dev/null and b/src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-38.pyc differ diff --git a/src/models/detokenizer/__pycache__/resampler.cpython-311.pyc b/src/models/detokenizer/__pycache__/resampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f6082979ff35efa5e133f8c42635e5bf55941a3 Binary files /dev/null and b/src/models/detokenizer/__pycache__/resampler.cpython-311.pyc differ diff --git a/src/models/detokenizer/__pycache__/resampler.cpython-38.pyc b/src/models/detokenizer/__pycache__/resampler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5744cc6e049d574dcfd90a66bb9d15f3f08c0df1 Binary files /dev/null and b/src/models/detokenizer/__pycache__/resampler.cpython-38.pyc differ diff --git a/src/models/detokenizer/adapter_modules.py b/src/models/detokenizer/adapter_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..6a46a3e33aaa3442ad2308a5f68d1ee993929e14 --- /dev/null +++ b/src/models/detokenizer/adapter_modules.py @@ -0,0 +1,288 @@ +import torch +import torch.nn as nn +import itertools +import torch.nn.functional as F +from typing import List +from diffusers import StableDiffusionXLPipeline +from PIL import Image +from .pipeline_stable_diffusion_xl_t2i_edit import StableDiffusionXLText2ImageAndEditPipeline + + +class SDXLAdapter(nn.Module): + + def __init__(self, unet, resampler, full_ft=False, vit_down=False) -> None: + super().__init__() + self.unet = unet + self.resampler = resampler + self.full_ft = full_ft + self.set_trainable_v2() + self.vit_down = vit_down + + def set_trainable_v2(self): + self.resampler.requires_grad_(True) + adapter_parameters = [] + if self.full_ft: + self.unet.requires_grad_(True) + adapter_parameters.extend(self.unet.parameters()) + else: + self.unet.requires_grad_(False) + for name, module in self.unet.named_modules(): + if name.endswith('to_k') or name.endswith('to_v'): + if module is not None: + adapter_parameters.extend(module.parameters()) + self.adapter_parameters = adapter_parameters + + + def params_to_opt(self): + return itertools.chain(self.resampler.parameters(), self.adapter_parameters) + + def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise, time_ids): + + image_embeds, pooled_image_embeds = self.resampler(image_embeds) + + unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_image_embeds} + + noise_pred = self.unet(noisy_latents, timesteps, image_embeds, added_cond_kwargs=unet_added_conditions).sample + + # if noise is not None: + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + # else: + # loss = torch.tensor(0.0, device=noisy_latents) + + return {'total_loss': loss, 'noise_pred': noise_pred} + + def encode_image_embeds(self, image_embeds): + image_embeds, pooled_image_embeds = self.resampler(image_embeds) + + return image_embeds, pooled_image_embeds + + @classmethod + def from_pretrained(cls, unet, resampler, pretrained_model_path=None, **kwargs): + model = cls(unet=unet, resampler=resampler, **kwargs) + if pretrained_model_path is not None: + ckpt = torch.load(pretrained_model_path, map_location='cpu') + missing, unexpected = model.load_state_dict(ckpt, strict=False) + print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) + return model + + def init_pipe(self, + vae, + scheduler, + visual_encoder, + image_transform, + discrete_model=None, + dtype=torch.float16, + device='cuda'): + self.device = device + self.dtype = dtype + sdxl_pipe = StableDiffusionXLPipeline(tokenizer=None, + tokenizer_2=None, + text_encoder=None, + text_encoder_2=None, + vae=vae, + unet=self.unet, + scheduler=scheduler) + + self.sdxl_pipe = sdxl_pipe #.to(self.device, dtype=self.dtype) + # print(sdxl_pipe.text_encoder_2, sdxl_pipe.text_encoder) + + self.visual_encoder = visual_encoder.to(self.device, dtype=self.dtype) + if discrete_model is not None: + self.discrete_model = discrete_model.to(self.device, dtype=self.dtype) + else: + self.discrete_model = None + self.image_transform = image_transform + + @torch.inference_mode() + def get_image_embeds(self, image_pil=None, image_tensor=None, image_embeds=None, return_negative=True, image_size=448): + assert int(image_pil is not None) + int(image_tensor is not None) + int(image_embeds is not None) == 1 + + if image_pil is not None: + image_tensor = self.image_transform(image_pil).unsqueeze(0).to(self.device, dtype=self.dtype) + + if image_tensor is not None: + if return_negative: + image_tensor_neg = torch.zeros_like(image_tensor) + image_tensor = torch.cat([image_tensor, image_tensor_neg], dim=0) + + image_embeds = self.visual_encoder(image_tensor) + elif return_negative: + image_tensor_neg = torch.zeros(1, 3, image_size, image_size).to(image_embeds.device, dtype=image_embeds.dtype) + image_embeds_neg = self.visual_encoder(image_tensor_neg) + if self.vit_down: + image_embeds_neg = image_embeds_neg.permute(0, 2, 1) # NLD -> NDL + image_embeds_neg = F.avg_pool1d(image_embeds_neg, kernel_size=4, stride=4) + image_embeds_neg = image_embeds_neg.permute(0, 2, 1) + image_embeds = torch.cat([image_embeds, image_embeds_neg], dim=0) + + if self.discrete_model is not None: + image_embeds = self.discrete_model.encode_image_embeds(image_embeds) + image_embeds, pooled_image_embeds = self.encode_image_embeds(image_embeds) + + if return_negative: + image_embeds, image_embeds_neg = image_embeds.chunk(2) + pooled_image_embeds, pooled_image_embeds_neg = pooled_image_embeds.chunk(2) + + else: + image_embeds_neg = None + pooled_image_embeds_neg = None + + return image_embeds, image_embeds_neg, pooled_image_embeds, pooled_image_embeds_neg + + def generate(self, + image_pil=None, + image_tensor=None, + image_embeds=None, + seed=42, + height=1024, + width=1024, + guidance_scale=7.5, + num_inference_steps=30, + input_image_size=448, + **kwargs): + if image_pil is not None: + assert isinstance(image_pil, Image.Image) + + image_prompt_embeds, uncond_image_prompt_embeds, pooled_image_prompt_embeds, pooled_uncond_image_prompt_embeds = self.get_image_embeds( + image_pil=image_pil, + image_tensor=image_tensor, + image_embeds=image_embeds, + return_negative=True, + image_size=input_image_size, + ) + # print(image_prompt_embeds.shape, pooled_image_prompt_embeds.shape) + generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + + images = self.sdxl_pipe( + prompt_embeds=image_prompt_embeds, + negative_prompt_embeds=uncond_image_prompt_embeds, + pooled_prompt_embeds=pooled_image_prompt_embeds, + negative_pooled_prompt_embeds=pooled_uncond_image_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + height=height, + width=width, + **kwargs, + ).images + + return images + + +class SDXLAdapterWithLatentImage(SDXLAdapter): + def __init__(self, unet, resampler, full_ft=False, set_trainable_late=False, vit_down=False) -> None: + nn.Module.__init__(self) + self.unet = unet + self.resampler = resampler + self.full_ft = full_ft + if not set_trainable_late: + self.set_trainable() + self.vit_down = vit_down + + + def set_trainable(self): + self.resampler.requires_grad_(True) + adapter_parameters = [] + + in_channels = 8 + out_channels = self.unet.conv_in.out_channels + self.unet.register_to_config(in_channels=in_channels) + self.unet.requires_grad_(False) + with torch.no_grad(): + new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, + self.unet.conv_in.padding) + + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) + self.unet.conv_in = new_conv_in + self.unet.conv_in.requires_grad_(True) + + if self.full_ft: + self.unet.requires_grad_(True) + adapter_parameters.extend(self.unet.parameters()) + else: + adapter_parameters.extend(self.unet.conv_in.parameters()) + for name, module in self.unet.named_modules(): + if name.endswith('to_k') or name.endswith('to_v'): + if module is not None: + adapter_parameters.extend(module.parameters()) + self.adapter_parameters = adapter_parameters + + @classmethod + def from_pretrained(cls, unet, resampler, pretrained_model_path=None, set_trainable_late=False, **kwargs): + model = cls(unet=unet, resampler=resampler, set_trainable_late=set_trainable_late, **kwargs) + if pretrained_model_path is not None: + ckpt = torch.load(pretrained_model_path, map_location='cpu') + missing, unexpected = model.load_state_dict(ckpt, strict=False) + print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) + if set_trainable_late: + model.set_trainable() + return model + + def init_pipe(self, + vae, + scheduler, + visual_encoder, + image_transform, + dtype=torch.float16, + device='cuda'): + self.device = device + self.dtype = dtype + + sdxl_pipe = StableDiffusionXLText2ImageAndEditPipeline( + tokenizer=None, + tokenizer_2=None, + text_encoder=None, + text_encoder_2=None, + vae=vae, + unet=self.unet, + scheduler=scheduler, + ) + + self.sdxl_pipe = sdxl_pipe + self.sdxl_pipe.to(device, dtype=dtype) + self.discrete_model = None + + self.visual_encoder = visual_encoder.to(self.device, dtype=self.dtype) + self.image_transform = image_transform + + def generate(self, + image_pil=None, + image_tensor=None, + image_embeds=None, + latent_image=None, + seed=42, + height=1024, + width=1024, + guidance_scale=7.5, + num_inference_steps=30, + input_image_size=448, + **kwargs): + if image_pil is not None: + assert isinstance(image_pil, Image.Image) + + image_prompt_embeds, uncond_image_prompt_embeds, pooled_image_prompt_embeds, pooled_uncond_image_prompt_embeds = self.get_image_embeds( + image_pil=image_pil, + image_tensor=image_tensor, + image_embeds=image_embeds, + return_negative=True, + image_size=input_image_size, + ) + # print(image_prompt_embeds.shape, pooled_image_prompt_embeds.shape) + generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + + images = self.sdxl_pipe( + image=latent_image, + prompt_embeds=image_prompt_embeds, + negative_prompt_embeds=uncond_image_prompt_embeds, + pooled_prompt_embeds=pooled_image_prompt_embeds, + negative_pooled_prompt_embeds=pooled_uncond_image_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + height=height, + width=width, + **kwargs, + ).images + return images + \ No newline at end of file diff --git a/src/models/detokenizer/pipeline_stable_diffusion_xl_t2i_edit.py b/src/models/detokenizer/pipeline_stable_diffusion_xl_t2i_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..8b81715e327bad11ab39f13031198fdd23a28070 --- /dev/null +++ b/src/models/detokenizer/pipeline_stable_diffusion_xl_t2i_edit.py @@ -0,0 +1,994 @@ +# Copyright 2023 Harutatsu Akiyama and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + +if is_invisible_watermark_available(): + from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLInstructPix2PixPipeline + >>> from diffusers.utils import load_image + + >>> resolution = 768 + >>> image = load_image( + ... "https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" + ... ).resize((resolution, resolution)) + >>> edit_instruction = "Turn sky into a cloudy one" + + >>> pipe = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( + ... "diffusers/sdxl-instructpix2pix-768", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> edited_image = pipe( + ... prompt=edit_instruction, + ... image=image, + ... height=resolution, + ... width=resolution, + ... guidance_scale=3.0, + ... image_guidance_scale=1.5, + ... num_inference_steps=30, + ... ).images[0] + >>> edited_image + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLText2ImageAndEditPipeline(DiffusionPipeline, TextualInversionLoaderMixin, FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin): + r""" + Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] + + as well as the following saving methods: + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config + of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ([self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1:-1]) + logger.warning("The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}") + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError(f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}.") + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt, negative_prompt_2] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`.") + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds_dtype = self.text_encoder_2.dtype if self.text_encoder_2 is not None else self.unet.dtype + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.check_inputs + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError(f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}.") + + if callback_on_step_end_tensor_inputs is not None and not all(k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError(f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two.") + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.") + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError(f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}.") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators.") + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_latents(self, + image, + batch_size, + num_images_per_prompt, + dtype, + device, + do_classifier_free_guidance, + generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError(f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}") + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + image_latents = image + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.upcast_vae() + image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators.") + + if isinstance(generator, list): + image_latents = [self.vae.encode(image[i:i + 1]).latent_dist.mode() for i in range(batch_size)] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image.float()).latent_dist.mode() + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning.") + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError(f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts.") + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + if image_latents.dtype != self.vae.dtype: + image_latents = image_latents.to(dtype=self.vae.dtype) + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = (self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 100, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + image_guidance_scale: float = 1.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): + The image(s) to modify with the pipeline. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + image_guidance_scale (`float`, *optional*, defaults to 1.5): + Image guidance scale is to push the generated image towards the inital image `image`. Image guidance + scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to + generate images that are closely linked to the source image `image`, usually at the expense of lower + image quality. This pipeline requires a value of at least `1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # if image is None: + # raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0 + # check if scheduler is in sigmas space + scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") + + # 3. Encode input prompt + text_encoder_lora_scale = (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + if image is not None: + # 4. Preprocess image + image = self.image_processor.preprocess(image).to(device) + # 6. Prepare Image latents + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + do_classifier_free_guidance, + generator, + ) + else: + image_latents = None + + # 7. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 8. Check that shapes of latents and image match the UNet channels + # num_channels_image = image_latents.shape[1] + # if num_channels_latents + num_channels_image != self.unet.config.in_channels: + # raise ValueError(f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + # f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + # f" `num_channels_image`: {num_channels_image} " + # f" = {num_channels_latents + num_channels_image}. Please verify the config of" + # " `pipeline.unet` or your `image` input.") + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if do_classifier_free_guidance: + # The extra concat similar to how it's done in SD InstructPix2Pix. + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds], dim=0) + add_text_embeds = torch.cat([add_text_embeds, negative_pooled_prompt_embeds, negative_pooled_prompt_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round(self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps))) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance. + # The latents are expanded 3 times because for pix2pix the guidance + # is applied for both the text and the input image. + latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + + # concat latents, image_latents in the channel dimension + scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + if image_latents is None: + image_latents = torch.zeros_like(scaled_latent_model_input) + scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + scaled_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # Hack: + # For karras style schedulers the model does classifer free guidance using the + # predicted_original_sample instead of the noise_pred. So we need to compute the + # predicted_original_sample here if we are using a karras style scheduler. + if scheduler_is_in_sigma_space: + step_index = (self.scheduler.timesteps == t).nonzero()[0].item() + sigma = self.scheduler.sigmas[step_index] + noise_pred = latent_model_input - sigma * noise_pred + + # perform guidance + if do_classifier_free_guidance: + noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = (noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_image) + + image_guidance_scale * (noise_pred_image - noise_pred_uncond)) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # Hack: + # For karras style schedulers the model does classifer free guidance using the + # predicted_original_sample instead of the noise_pred. But the scheduler.step function + # expects the noise_pred and computes the predicted_original_sample internally. So we + # need to overwrite the noise_pred here such that the value of the computed + # predicted_original_sample is correct. + if scheduler_is_in_sigma_space: + noise_pred = (noise_pred - latents) / (-sigma) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, ) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/src/models/detokenizer/resampler.py b/src/models/detokenizer/resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..82c61bdbd6ac5f8abe73ee3cbe791f30c8474e3f --- /dev/null +++ b/src/models/detokenizer/resampler.py @@ -0,0 +1,309 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class AttentionPool2d(nn.Module): + + def __init__(self, seq_len: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(seq_len + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x, return_all_tokens=False): + # x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = x.permute(1, 0, 2) # (N(HW)C) => (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward(query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + if return_all_tokens: + return x + else: + return x[0] + + +class Resampler(nn.Module): + + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.in_dim = dim + self.out_dim = output_dim + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + output_embeds = self.norm_out(latents) + + return output_embeds + + +class ResamplerXL(nn.Module): + + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output1_dim=768, + output2_dim=1280, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + # self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(dim) + + self.in_dim = dim + self.out_dim = output1_dim + output2_dim + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + self.unet_proj_1 = nn.Linear(self.in_dim, output1_dim) + self.unet_proj_2 = nn.Linear(self.in_dim, output2_dim) + self.unet_attnpool = AttentionPool2d(num_queries, self.in_dim, heads, output2_dim) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + hidden_embeds = self.norm_out(latents) + + encoder_hidden_1 = self.unet_proj_1(hidden_embeds) # [bs, 256, 768] + encoder_hidden_2 = self.unet_proj_2(hidden_embeds) # [bs, 256, 1280] + prompt_embeds = torch.cat([encoder_hidden_1, encoder_hidden_2], dim=-1) # [bs, 256, 2048] + pooled_prompt_embeds = self.unet_attnpool(hidden_embeds) # [bs, 1280] + + return prompt_embeds, pooled_prompt_embeds + + +class ResamplerXLV2(nn.Module): + + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output1_dim=768, + output2_dim=1280, + ff_mult=4, + normalize=True + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.normalize = normalize + self.proj_in = nn.Linear(embedding_dim, dim) + + # self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(dim) + + self.in_dim = dim + self.out_dim = output1_dim + output2_dim + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + self.unet_proj_1 = nn.Linear(self.in_dim, output1_dim) + self.unet_proj_2 = nn.Linear(self.in_dim, output2_dim) + self.unet_attnpool = AttentionPool2d(num_queries, self.in_dim, heads, output2_dim) + + def forward(self, x,pooled_text_embeds=None): + + latents = self.latents.repeat(x.size(0), 1, 1) + + if self.normalize: + x = F.normalize(x) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + hidden_embeds = self.norm_out(latents) + + encoder_hidden_1 = self.unet_proj_1(hidden_embeds) # [bs, 256, 768] + encoder_hidden_2 = self.unet_proj_2(hidden_embeds) # [bs, 256, 1280] + prompt_embeds = torch.cat([encoder_hidden_1, encoder_hidden_2], dim=-1) # [bs, 256, 2048] + pooled_prompt_embeds = self.unet_attnpool(hidden_embeds) # [bs, 1280] + + return prompt_embeds, pooled_prompt_embeds + +class ResamplerXLIdentity(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, pooled_text_embeds=None): + return x, pooled_text_embeds + + +if __name__ == '__main__': + image_proj_model = Resampler(dim=1024, + depth=4, + dim_head=64, + heads=12, + num_queries=1024, + embedding_dim=1024, + output_dim=1024, + ff_mult=4) + numel = 0 + for name, param in image_proj_model.named_parameters(): + numel += param.numel() + + print(f'Total params: {numel}') diff --git a/src/models/mllm/__init__.py b/src/models/mllm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/mllm/__pycache__/__init__.cpython-311.pyc b/src/models/mllm/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d54781c5b238a37059b42e95a1e9d25105b5b7d9 Binary files /dev/null and b/src/models/mllm/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/models/mllm/__pycache__/__init__.cpython-38.pyc b/src/models/mllm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..924ca2c81529592bd0759421ee25017ff9f30109 Binary files /dev/null and b/src/models/mllm/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/models/mllm/__pycache__/generation.cpython-311.pyc b/src/models/mllm/__pycache__/generation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa8d8b76dd556e6cdaddc8eee78722306df11d1b Binary files /dev/null and b/src/models/mllm/__pycache__/generation.cpython-311.pyc differ diff --git a/src/models/mllm/__pycache__/generation.cpython-38.pyc b/src/models/mllm/__pycache__/generation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fca1b66caf0fb1ddc564bade1d9fb0c02678806e Binary files /dev/null and b/src/models/mllm/__pycache__/generation.cpython-38.pyc differ diff --git a/src/models/mllm/__pycache__/modeling_llama_xformer.cpython-311.pyc b/src/models/mllm/__pycache__/modeling_llama_xformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29f01e452aa07a9fa0cd27929a7de3ec61c2f7be Binary files /dev/null and b/src/models/mllm/__pycache__/modeling_llama_xformer.cpython-311.pyc differ diff --git a/src/models/mllm/__pycache__/modeling_llama_xformer.cpython-38.pyc b/src/models/mllm/__pycache__/modeling_llama_xformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..816b45d7dd76a32c8f0ca326645968eac6799f1c Binary files /dev/null and b/src/models/mllm/__pycache__/modeling_llama_xformer.cpython-38.pyc differ diff --git a/src/models/mllm/__pycache__/models.cpython-38.pyc b/src/models/mllm/__pycache__/models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3983860dcf1c3fe5791ba1636c5cf382b786cd26 Binary files /dev/null and b/src/models/mllm/__pycache__/models.cpython-38.pyc differ diff --git a/src/models/mllm/__pycache__/peft_models.cpython-38.pyc b/src/models/mllm/__pycache__/peft_models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f54d498826dd5f6cabdc6465b9f2af208fe537c Binary files /dev/null and b/src/models/mllm/__pycache__/peft_models.cpython-38.pyc differ diff --git a/src/models/mllm/__pycache__/seed_x.cpython-311.pyc b/src/models/mllm/__pycache__/seed_x.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf277cf0ab23ec13f61ed439b685f3f656113c24 Binary files /dev/null and b/src/models/mllm/__pycache__/seed_x.cpython-311.pyc differ diff --git a/src/models/mllm/__pycache__/seed_x.cpython-38.pyc b/src/models/mllm/__pycache__/seed_x.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57120f6aa7e8734336859443c5f9f8d206ae61dd Binary files /dev/null and b/src/models/mllm/__pycache__/seed_x.cpython-38.pyc differ diff --git a/src/models/mllm/__pycache__/utils.cpython-311.pyc b/src/models/mllm/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24655500e8c33c6e94f04d20ed37219b07504d6c Binary files /dev/null and b/src/models/mllm/__pycache__/utils.cpython-311.pyc differ diff --git a/src/models/mllm/__pycache__/utils.cpython-38.pyc b/src/models/mllm/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afa8cbf9fc3c0cb77f4e681f3a4929b310c1f139 Binary files /dev/null and b/src/models/mllm/__pycache__/utils.cpython-38.pyc differ diff --git a/src/models/mllm/generation.py b/src/models/mllm/generation.py new file mode 100644 index 0000000000000000000000000000000000000000..0a828375ff93c03f5e2c004352daf95a38e9d71d --- /dev/null +++ b/src/models/mllm/generation.py @@ -0,0 +1,31 @@ +import torch +from transformers import LogitsProcessor + + +BOI_TOKEN = '' +EOI_TOKEN = '' +IMG_TOKEN = '' + +class AutoImageTokenGenerationProcessor(LogitsProcessor): + + def __init__(self, tokenizer, num_img_gen_tokens=64) -> None: + super().__init__() + # self.boi_token_id = tokenizer.encode(BOI_TOKEN)[0] + # self.eoi_token_id = tokenizer.encode(EOI_TOKEN)[0] + img_all_token_str = ''.join([BOI_TOKEN] + [IMG_TOKEN.format(int(item)) + for item in range(num_img_gen_tokens)] + [EOI_TOKEN]) + self.img_ids_list = tokenizer.encode(img_all_token_str, add_special_tokens=False) + + def __call__(self, input_ids, scores): + bz = input_ids.shape[0] + for i in range(bz): + cur_input_id = input_ids[i, -1].item() + if cur_input_id in self.img_ids_list[:-1]: + + output_id = self.img_ids_list[self.img_ids_list.index(cur_input_id) + 1] + scores[i, ..., output_id] = scores[i, ...].max() + 10. + else: + + scores[i, ..., torch.tensor(self.img_ids_list[1:]).to(dtype=torch.long)] = 0.0 + + return scores diff --git a/src/models/mllm/modeling_llama_xformer.py b/src/models/mllm/modeling_llama_xformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d975c99c14f5dde12f2b0e7fca9d775b4fa634fe --- /dev/null +++ b/src/models/mllm/modeling_llama_xformer.py @@ -0,0 +1,919 @@ +# coding=utf-8 +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" + +import torch +import torch.utils.checkpoint +import xformers.ops as xops +from torch import nn +from typing import List, Optional, Tuple, Union +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.llama.configuration_llama import LlamaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full( + (tgt_len, tgt_len), + torch.tensor(torch.finfo(dtype).min, device=device), + device=device, + ) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + +from transformers.models.llama.modeling_llama import LlamaRMSNorm + +class LlamaRotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base**(torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange( + self.max_seq_len_cached, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype, + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads}).") + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + if self.training: + attn_output = xops.memory_efficient_attention( + query_states, + key_states, + value_states, + attn_bias=xops.LowerTriangularMask(), + ) + else: + attn_output = xops.memory_efficient_attention( + query_states, + key_states, + value_states, + attn_bias=None if attention_mask.sum() == 0 else xops.LowerTriangularMask(), + ) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaDecoderLayer(nn.Module): + + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (self_attn_weights, ) + + if use_cache: + outputs += (present_key_value, ) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, + tgt_len=input_shape[-1]).to(inputs_embeds.device) + combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + # if input_ids is not None and inputs_embeds is not None: + # raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + # elif input_ids is not None: + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1], ) + + if output_attentions: + all_self_attns += (layer_outputs[1], ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + # check if nan or inf in hidden_states + if torch.isnan(hidden_states).any(): + print(f'nan in hidden_states') + elif torch.isinf(hidden_states).any(): + print(f'inf in hidden_states') + + logits = self.lm_head(hidden_states) + + # check if nan or inf in logits + if torch.isnan(logits).any(): + print(f'nan in logits') + elif torch.isinf(logits).any(): + print(f'inf in logits') + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + # fp16 + loss = loss_fct(shift_logits, shift_labels) + # loss = loss_fct(shift_logits.float(), shift_labels).type_as(shift_logits) + + # check if nan or inf in loss + if torch.isnan(loss).any(): + print(f'nan in loss') + elif torch.isinf(loss).any(): + print(f'inf in loss') + + if not return_dict: + output = (logits, ) + outputs[1:] + return (loss, ) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update({ + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + }) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past), ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits, ) + transformer_outputs[1:] + return ((loss, ) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +if __name__ == "__main__": + from transformers import LlamaTokenizer + + model = LlamaForCausalLM.from_pretrained("luodian/llama-7b-hf", device_map="auto") + tokenizer = LlamaTokenizer.from_pretrained("luodian/llama-7b-hf") + prompt = "Hey, are you consciours? Can you talk to me?" + inputs = tokenizer(prompt, return_tensors="pt") + generate_ids = model.generate(inputs.input_ids, max_length=30) + print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]) diff --git a/src/models/mllm/seed_x.py b/src/models/mllm/seed_x.py new file mode 100644 index 0000000000000000000000000000000000000000..cb8da6f01bd6752c3d0663e71111d49862b8f25c --- /dev/null +++ b/src/models/mllm/seed_x.py @@ -0,0 +1,238 @@ +import os +import torch +import torch.nn as nn +from torch.nn import functional as F +from transformers import LogitsProcessorList +from .generation import AutoImageTokenGenerationProcessor +from .utils import load_zero3_checkpoint + + +BOI_TOKEN = '' +EOI_TOKEN = '' +IMG_TOKEN = '' + + +def cosine_loss(rec, target): + target = target / target.norm(dim=-1, keepdim=True) + rec = rec / rec.norm(dim=-1, keepdim=True) + rec_loss = (1 - (target * rec).sum(-1)).mean() + return rec_loss + + +class ContinuousLVLM(nn.Module): + + def __init__(self, llm, input_resampler, output_resampler, lm_loss_scale=1.0, rec_loss_scale=1.0, add_patch_pos=False, vit_down=False, mse=False) -> None: + super().__init__() + self.llm = llm + self.input_resampler = input_resampler + self.output_resampler = output_resampler + self.lm_loss_scale = lm_loss_scale + self.rec_loss_scale = rec_loss_scale + self.add_patch_pos = add_patch_pos + + self.vit_down = vit_down + if self.vit_down: + self.pool_size = 4 + self.stride = 4 + + self.mse = mse + if self.mse: + self.mse_loss = torch.nn.MSELoss() + + self.add_patch_pos = add_patch_pos + if self.add_patch_pos: + patch_dim = self.input_resampler.embed_dim + self.patch_pos_embed = nn.Parameter((patch_dim**-0.5) * torch.randn(4, patch_dim)) + + + def forward(self, input_ids, attention_mask, labels, image_embeds, embeds_gen_mask, embeds_cmp_mask, ids_gen_mask, + ids_cmp_mask, patch_positions=None): + + input_embeds = self.llm.get_input_embeddings()(input_ids) # bz x seq_len x dim, 4 x 160 x 4096 + + bz, sq, dim = input_embeds.shape + + if image_embeds is not None: + image_embeds_cmp = image_embeds[embeds_cmp_mask] # num_imgs_in_batch x nq_in x dim_in, 4 x 64 x 4096 + if patch_positions is not None: + patch_positions = patch_positions[embeds_cmp_mask] + + + if image_embeds is not None and image_embeds_cmp.shape[0] > 0: + image_embeds_lm = self.input_resampler(image_embeds_cmp) # num_imgs_in_batch x nq x dim, 4 x 64 x 4096 + if self.add_patch_pos and patch_positions is not None: + # assert patch_positions is not None + patch_positions = patch_positions.to( + image_embeds_lm + ) + rel_pos_embed = torch.mm(torch.cat([patch_positions, 1-patch_positions], dim=-1)/2, self.patch_pos_embed).unsqueeze(1) + image_embeds_lm = image_embeds_lm + rel_pos_embed + has_image_cmp = True + else: + image_embeds_cmp_fake = torch.randn( 1 , self.output_resampler.num_queries, + self.output_resampler.embed_dim).to(input_embeds.device, dtype=input_embeds.dtype) + + # image_embeds = torch.randn(bz, self.output_resampler.num_queries, + # self.output_resampler.embed_dim).to(input_embeds.device, dtype=input_embeds.dtype) + image_embeds_lm = self.input_resampler(image_embeds_cmp_fake) + if self.add_patch_pos: + rel_pos_embed = self.patch_pos_embed.mean(0, keepdim=True).unsqueeze(1) # 1, 1, dim + image_embeds_lm = image_embeds_lm + rel_pos_embed + + has_image_cmp = False + + has_image_input = image_embeds is not None and embeds_cmp_mask.sum().item() > 0 + has_image_output = image_embeds is not None and embeds_gen_mask.sum().item() > 0 + + if has_image_input: + input_embeds[ids_cmp_mask] = image_embeds_lm.reshape(-1, dim) # eg, 128 x 4096 + # zero_loss = 0.0 + else: + input_embeds[:1, :self.input_resampler.num_queries, :] += 0.0 * image_embeds_lm[:1, :, :] + + output_lm = self.llm(attention_mask=attention_mask, + inputs_embeds=input_embeds, + labels=labels, + output_hidden_states=True, + return_dict=True) + lm_loss = output_lm['loss'] + + last_hidden_state = output_lm.hidden_states[-1] # 4 x 160 x 4096 + + if has_image_output: + target_embeds = image_embeds[embeds_gen_mask] # num_imgs_gen_target x nq_in x dim_in, 2 x 256 x 4096 + + if self.vit_down: + target_embeds = target_embeds.permute(0, 2, 1) # NLD -> NDL + target_embeds = F.avg_pool1d(target_embeds, kernel_size=self.pool_size, stride=self.stride) + target_embeds = target_embeds.permute(0, 2, 1) + + num_imgs_for_rec = target_embeds.shape[0] + output_image_embeds = last_hidden_state[ids_gen_mask].view(num_imgs_for_rec, -1, dim) # 128 x 4096 -> 2 x 64 x 4096 + + recon_image_embeds = self.output_resampler(output_image_embeds) # 2 x 256 x 4096 + + if self.mse: + # rec_loss = self.mse_loss(recon_image_embeds, target_embeds.detach()) + rec_loss = F.mse_loss(recon_image_embeds, target_embeds.detach()) # for zero3 compatibility + else: + rec_loss = cosine_loss(recon_image_embeds, target_embeds.detach()) + + else: + output_image_embeds = torch.randn(1, self.input_resampler.num_queries, + self.input_resampler.embed_dim).to(input_embeds.device, dtype=input_embeds.dtype) + 0.0 * last_hidden_state[0, :self.input_resampler.num_queries, :] + recon_image_embeds = self.output_resampler(output_image_embeds) + # target_embeds = torch.randn(1, self.output_resampler.num_queries, + # self.output_resampler.embed_dim).to(input_embeds.device, dtype=input_embeds.dtype) + # rec_loss = cosine_loss(recon_image_embeds, target_embeds.detach) * 0.0 + rec_loss = 0.0 * recon_image_embeds.sum() + + total_loss = self.lm_loss_scale * lm_loss + self.rec_loss_scale * rec_loss + + return {'total_loss': total_loss, 'lm_loss': lm_loss, 'rec_loss': rec_loss} + + def generate(self, + tokenizer, + prompt=None, + input_ids=None, + image_embeds=None, + embeds_cmp_mask=None, + ids_cmp_mask=None, + logits_processor=None, + num_img_gen_tokens=64, + temperature=0.7, + num_beams=1, + max_new_tokens=120, + top_p=0.5, + dtype=torch.float16, + device='cuda', + patch_positions=None): + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append( + AutoImageTokenGenerationProcessor(tokenizer=tokenizer, num_img_gen_tokens=num_img_gen_tokens)) + + if prompt is not None: + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + if isinstance(input_ids, list): + input_ids = torch.tensor(input_ids) + + input_ids = input_ids.to(device=device) + input_embeds = self.llm.get_input_embeddings()(input_ids) + bz, sq, dim = input_embeds.shape + + if image_embeds is not None: + assert embeds_cmp_mask is not None and ids_cmp_mask is not None + with torch.no_grad(): + image_embeds_lm = self.input_resampler(image_embeds) + if self.add_patch_pos: + assert patch_positions is not None + patch_positions = patch_positions.to( + image_embeds_lm + ) + rel_pos_embed = torch.mm(torch.cat([patch_positions, 1-patch_positions], dim=-1)/2, self.patch_pos_embed).unsqueeze(1) + image_embeds_lm = image_embeds_lm + rel_pos_embed + #print(input_embeds.shape, ids_cmp_mask.shape, image_embeds_lm.shape, embeds_cmp_mask.shape) + input_embeds[ids_cmp_mask] = image_embeds_lm[embeds_cmp_mask].view(-1, dim) + + generation_config = { + 'temperature': temperature, + 'num_beams': num_beams, + 'max_new_tokens': max_new_tokens, + 'top_p': top_p, + 'do_sample': False + } + + # generate_ids = self.llm.generate(input_ids=input_ids, **generation_config) + output = self.llm.generate(input_ids=input_ids, + inputs_embeds=input_embeds, + output_hidden_states=True, + return_dict_in_generate=True, + logits_processor=logits_processor, + **generation_config) + + generate_ids = output.sequences[0][input_ids.shape[1]:] + generate_id_list = generate_ids.tolist() + boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0] + eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0] + + last_hidden_states = torch.cat([hidden_state[-1] for hidden_state in output.hidden_states], + dim=1)[0, input_ids.shape[1]:, :] + + eoi_indices = torch.where(generate_ids == eoi_token_id)[0].tolist() + num_gen_imgs = len(eoi_indices) + text_mask = torch.ones_like(generate_ids, dtype=torch.bool) + has_img_output = num_gen_imgs > 0 + if has_img_output: + img_gen_feats = [] + for eoi_idx in eoi_indices: + img_gen_feats.append(last_hidden_states[eoi_idx - num_img_gen_tokens:eoi_idx]) + text_mask[eoi_idx - num_img_gen_tokens:eoi_idx] = False + + img_gen_feats = torch.stack(img_gen_feats) + img_gen_feat = self.output_resampler(img_gen_feats) + else: + img_gen_feat = None + + text_mask[generate_ids == boi_token_id] = False + generate_ids = generate_ids[text_mask] + generate_text = tokenizer.decode(generate_ids, skip_special_tokens=False) + + return { + 'text': generate_text, + 'has_img_output': has_img_output, + 'img_gen_feat': img_gen_feat, + 'num_gen_imgs': num_gen_imgs + } + + @classmethod + def from_pretrained(cls, llm, input_resampler, output_resampler, pretrained_model_path=None, **kwargs): + model = cls(llm=llm, input_resampler=input_resampler, output_resampler=output_resampler, **kwargs) + if os.environ.get('DEBUG_FLAG', 'False') == 'True': + return model + + if pretrained_model_path is not None: + ckpt = torch.load(pretrained_model_path, map_location='cpu') + load_zero3_checkpoint(model, ckpt) + return model diff --git a/src/models/mllm/utils.py b/src/models/mllm/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f50356b1c54878cb34ba095bd948782121946d98 --- /dev/null +++ b/src/models/mllm/utils.py @@ -0,0 +1,84 @@ +import deepspeed +from transformers import AutoConfig +from transformers.deepspeed import is_deepspeed_zero3_enabled +from torch import nn + + +def remove_mismatched_weights(model, pretrained_state_dict): + own_state = model.state_dict() + mismatch_keys = [] + + for name in list(pretrained_state_dict.keys()): + if name not in own_state or own_state[name].shape != pretrained_state_dict[name].shape: + mismatch_keys.append(name) + pretrained_state_dict.pop(name) + + return pretrained_state_dict, mismatch_keys + + +def load_zero3_checkpoint(module: nn.Module, state_dict, prefix="", error_msgs = [], top=True): + # check if zero3 + + zero3_enabled = is_deepspeed_zero3_enabled() + # print(f'zero3_enabled: {zero3_enabled}') + + if not is_deepspeed_zero3_enabled(): + + state_dict, mismatch_keys = remove_mismatched_weights(module, state_dict) + + + + info = module.load_state_dict(state_dict, strict=False) + + + if len(mismatch_keys) > 0: + print("shape mismatch keys: ", mismatch_keys) + + + if len(info.missing_keys) > 0: + print("missing keys: ", info.missing_keys) + + if len(info.unexpected_keys) > 0: + print("unexpected keys: ", info.unexpected_keys) + + else: + # error_msgs = [] + local_metadata = {} + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if len([key for key in state_dict if key.startswith(prefix)]) > 0: + + named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) + params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] + params_name = [k for k in state_dict.keys() if k in named_parameters] + ## named buffer for layers like batchnorm + named_buffers = dict(module.named_buffers(prefix=prefix[:-1], recurse=False)) + buffers_to_gather = [named_buffers[k] for k in state_dict.keys() if k in named_buffers] + + if len(params_to_gather) > 0 or len(buffers_to_gather)>0: + # if len(buffers_to_gather)>0: + # print("loading buffers") + with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): + # if torch.distributed.get_rank() == 0: + # if only rank0, then module's buffer will not be syncd + # for k, v in zip(params_name, params_to_gather): + # log the shape of the loaded weights + # print(f'loading {k} with shape {v.shape}') + module._load_from_state_dict(*args) + + + # if len (error_msgs) > 0: + # print(error_msgs) + + for name, child in module._modules.items(): + if child is not None: + load_zero3_checkpoint(child, state_dict, prefix + name + ".", top=False) + + if top: + if len(error_msgs) > 0: + print('loading zero3 model weights meets error messages!') + print(error_msgs) + else: + print('loading zero3 model weights success!') + \ No newline at end of file diff --git a/src/models/tokenizer/__init__.py b/src/models/tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/tokenizer/__pycache__/__init__.cpython-311.pyc b/src/models/tokenizer/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a05d801976870770064ae51d5516509de75a5750 Binary files /dev/null and b/src/models/tokenizer/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/models/tokenizer/__pycache__/__init__.cpython-38.pyc b/src/models/tokenizer/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5602298751884a1569f0c32a9c943336b6d6b0a7 Binary files /dev/null and b/src/models/tokenizer/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/models/tokenizer/__pycache__/discrete_models.cpython-311.pyc b/src/models/tokenizer/__pycache__/discrete_models.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a5223cb307208db2f3fc556cccb426fc0814cb1 Binary files /dev/null and b/src/models/tokenizer/__pycache__/discrete_models.cpython-311.pyc differ diff --git a/src/models/tokenizer/__pycache__/discrete_models.cpython-38.pyc b/src/models/tokenizer/__pycache__/discrete_models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..036fb0a60cb41a0b6635fbb4702732a13215326f Binary files /dev/null and b/src/models/tokenizer/__pycache__/discrete_models.cpython-38.pyc differ diff --git a/src/models/tokenizer/__pycache__/qwen_visual.cpython-311.pyc b/src/models/tokenizer/__pycache__/qwen_visual.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec9685843cfb3510eda7f1285bc7bbe3dac7fa7b Binary files /dev/null and b/src/models/tokenizer/__pycache__/qwen_visual.cpython-311.pyc differ diff --git a/src/models/tokenizer/__pycache__/qwen_visual.cpython-38.pyc b/src/models/tokenizer/__pycache__/qwen_visual.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..213dc91818a8a2cfaf540bddc31c71be7c29fbcc Binary files /dev/null and b/src/models/tokenizer/__pycache__/qwen_visual.cpython-38.pyc differ diff --git a/src/models/tokenizer/discrete_models.py b/src/models/tokenizer/discrete_models.py new file mode 100644 index 0000000000000000000000000000000000000000..7f62d7404912a90928f37f23f6bbe313dcc6069b --- /dev/null +++ b/src/models/tokenizer/discrete_models.py @@ -0,0 +1,17 @@ +import torch.nn as nn +import pyrootutils + + +pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True) + +class DiscreteModleIdentity(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.model = nn.Identity() + + def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None): + return + + def encode_image_embeds(self, image_embeds): + return self.model(image_embeds) diff --git a/src/models/tokenizer/qwen_visual.py b/src/models/tokenizer/qwen_visual.py new file mode 100644 index 0000000000000000000000000000000000000000..8a6ee399412e0f808dee1f79a13c6599ee3dd16f --- /dev/null +++ b/src/models/tokenizer/qwen_visual.py @@ -0,0 +1,538 @@ +# Tongyi Qianwen is licensed under the Tongyi Qianwen +# LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. +# All Rights Reserved. + +import os +import math +import requests +import torch +import numpy as np + +from collections import OrderedDict +from functools import partial +from PIL import Image +from typing import Callable, Optional, List + +from torch import nn +from torch.nn import functional as F +from torch.nn.init import trunc_normal_ +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from src.models.mllm.utils import load_zero3_checkpoint + + +def get_abs_pos(abs_pos, tgt_size): + # abs_pos: L, C + # tgt_size: M + # return: M, C + src_size = int(math.sqrt(abs_pos.size(0))) + tgt_size = int(math.sqrt(tgt_size)) + dtype = abs_pos.dtype + + if src_size != tgt_size: + return F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size, tgt_size), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) + else: + return abs_pos + + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class Resampler(nn.Module): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + + def __init__(self, grid_size, embed_dim, num_heads, kv_dim=None, norm_layer=nn.LayerNorm): + super().__init__() + self.num_queries = grid_size**2 + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.pos_embed = nn.Parameter(torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, + grid_size)).float()).requires_grad_(False) + + self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) + trunc_normal_(self.query, std=.02) + + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) + self.out_dim = kv_dim + else: + self.kv_proj = nn.Identity() + self.out_dim = embed_dim + + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.ln_q = norm_layer(embed_dim) + self.ln_kv = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x, attn_mask=None): + + pos_embed = get_abs_pos(self.pos_embed, x.size(1)) + + x = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(self._repeat(q, N) + self.pos_embed.unsqueeze(1), x + pos_embed.unsqueeze(1), x, attn_mask=attn_mask)[0] + return out.permute(1, 0, 2) + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class VisualAttention(nn.Module): + """self-attention layer class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, embed_dim, num_heads, bias=True, kdim=None, vdim=None): + super(VisualAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + + # Per attention head and per partition values. + assert embed_dim % num_heads == 0 + self.hidden_size_per_attention_head = embed_dim // num_heads + self.num_attention_heads_per_partition = num_heads + self.hidden_size_per_partition = embed_dim + + # Strided linear layer. + assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently' + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + + def forward(self, query, key, value, attn_mask=None): + # query/key/value: [sq, b, h] + sq, b, _ = query.size() + + assert query is key, 'Only Support Self-Attention Currently' + sk = sq + mixed_x_layer = self.in_proj(query) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + query_layer, key_layer, value_layer = mixed_x_layer.split(self.hidden_size_per_attention_head, dim=-1) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(sk, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + q_scaled = query_layer / self.norm_factor + if attn_mask is not None: + attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1)) + else: + attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) + attention_probs = attention_probs.softmax(dim=-1) + + value_layer = value_layer.view(sk, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(b, self.num_attention_heads_per_partition, sq, self.hidden_size_per_attention_head) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + output = self.out_proj(context_layer) + + return output + + +class VisualAttentionBlock(nn.Module): + + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + is_cross_attention: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.attn = VisualAttention(d_model, n_head) + self.mlp = nn.Sequential( + OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model))])) + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn(q_x, k_x, v_x, attn_mask=attn_mask) + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class TransformerBlock(nn.Module): + + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + + self.resblocks = nn.ModuleList( + [VisualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) for _ in range(layers)]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def get_cast_device(self) -> torch.device: + return self.resblocks[0].mlp.c_fc.weight.device + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + + gradient_checkpointing = True + for r in self.resblocks: + if gradient_checkpointing and self.training: + x = torch.utils.checkpoint.checkpoint(r, x, None, None, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + + # for r in self.resblocks: + # x = r(x, attn_mask=attn_mask) + # return x + + +class VisionTransformerWithAttnPool(nn.Module): + + def __init__(self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + patch_pos: bool = False, + **kwargs): + super().__init__() + image_height, image_width = self.image_size = (image_size, image_size) + patch_height, patch_width = self.patch_size = (patch_size, patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.output_dim = output_dim + + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + self.image_transform = transforms.Compose([ + transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width**-0.5 + self.positional_embedding = nn.Parameter(scale * torch.randn(256, width)) + + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.ln_pre = norm_layer(width) + self.transformer = TransformerBlock( + width, + layers, + heads, + mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.attn_pool = Resampler( + grid_size=int(math.sqrt(n_queries)), + embed_dim=output_dim, + num_heads=output_dim // 128, + kv_dim=width, + norm_layer=norm_layer, + ) + + self.patch_pos = patch_pos + if patch_pos: + # 4*dim for the 4 corners of the image + self.patch_pos_embed = nn.Parameter((output_dim**-0.5) * torch.randn(4, output_dim)) + + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter((output_dim**-0.5) * torch.randn(output_dim, output_dim)) + + def forward(self, x: torch.Tensor, patch_positions: Optional[torch.Tensor] = None): + x = x.to( + dtype=self.transformer.get_cast_dtype(), + device=self.transformer.get_cast_device(), + ) + # to patches + x = self.conv1(x) # shape = [*, width, grid, grid] + # shape = [*, width, grid ** 2] + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = x + get_abs_pos(self.positional_embedding, x.size(1)) + + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.attn_pool(x) + if self.patch_pos: + patch_positions = patch_positions.to( + dtype=self.transformer.get_cast_dtype(), + device=self.transformer.get_cast_device(), + ) + rel_posembed = torch.mm(torch.cat([patch_positions, 1 - patch_positions], dim=-1) / 2, self.patch_pos_embed).unsqueeze(1) + x = x + rel_posembed + x = self.ln_post(x) + x = x @ self.proj + + return x + + def encode(self, image_paths: List[str]): + images = [] + for image_path in image_paths: + if image_path.startswith("http://") or image_path.startswith("https://"): + image = Image.open(requests.get(image_path, stream=True).raw) + else: + image = Image.open(image_path) + image = image.convert("RGB") + images.append(self.image_transform(image)) + images = torch.stack(images, dim=0) + return self(images) + + @classmethod + def from_pretrained(cls, pretrained_model_path=None, **kawrgs): + if os.environ.get('DEBUG_FLAG', 'False') == 'True': + print('DEBUG_FLAG is set to True, return a random initialized model') + kawrgs.update( + { + "heads": 4, + "image_size": 448, + "layers": 1, + "mlp_ratio": 1, + "output_dim": 768, # llama input dim + "patch_size": 14, + "width": 768, + } + ) + return cls(**kawrgs) + model = cls(**kawrgs) + # with deepspeed.zero.Init(mem_efficient_linear=False, enabled=is_deepspeed_zero3_enabled() ): + # model = cls(**kawrgs) + if pretrained_model_path is not None: + print('Load ckpt of qwen visual encoder') + ckpt = torch.load(pretrained_model_path, map_location='cpu') + # missing, unexpected = model.load_state_dict(ckpt, strict=False) + load_zero3_checkpoint(model, ckpt) + # print('Load ckpt of qwen visual encoder') + # print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) + + + return model + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + **kwargs): + super().__init__() + image_height, image_width = self.image_size = (image_size, image_size) + patch_height, patch_width = self.patch_size = (patch_size, patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.output_dim = output_dim + + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + self.image_transform = transforms.Compose([ + transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width**-0.5 + self.positional_embedding = nn.Parameter(scale * torch.randn(256, width)) + + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.ln_pre = norm_layer(width) + self.transformer = TransformerBlock( + width, + layers, + heads, + mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + def forward(self, x: torch.Tensor): + x = x.to( + dtype=self.transformer.get_cast_dtype(), + device=self.transformer.get_cast_device(), + ) + # to patches + x = self.conv1(x) # shape = [*, width, grid, grid] + # shape = [*, width, grid ** 2] + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = x + get_abs_pos(self.positional_embedding, x.size(1)) + + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + return x + + def encode(self, image_paths: List[str]): + images = [] + for image_path in image_paths: + if image_path.startswith("http://") or image_path.startswith("https://"): + image = Image.open(requests.get(image_path, stream=True).raw) + else: + image = Image.open(image_path) + image = image.convert("RGB") + images.append(self.image_transform(image)) + images = torch.stack(images, dim=0) + return self(images) diff --git a/src/processer/__pycache__/transforms.cpython-311.pyc b/src/processer/__pycache__/transforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77338e74b5d26543cc16bebd95c7e09b27b80643 Binary files /dev/null and b/src/processer/__pycache__/transforms.cpython-311.pyc differ diff --git a/src/processer/__pycache__/transforms.cpython-38.pyc b/src/processer/__pycache__/transforms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77c9d729903ed2c1f9bc1862b17b6ad012469aa1 Binary files /dev/null and b/src/processer/__pycache__/transforms.cpython-38.pyc differ diff --git a/src/processer/tokenizer.py b/src/processer/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..cfffdb9816d8a8d1ac16d23b4a2d24a8c17ffd0d --- /dev/null +++ b/src/processer/tokenizer.py @@ -0,0 +1,8 @@ +from transformers import BertTokenizer + + +def bert_tokenizer(pretrained_model_name_or_path): + tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path, + truncation_side='right') + tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + return tokenizer diff --git a/src/processer/transforms.py b/src/processer/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..b3aa9ad17dc28293957ebf1fc4ab5ac127f1f645 --- /dev/null +++ b/src/processer/transforms.py @@ -0,0 +1,83 @@ +from torchvision import transforms +from PIL import Image + + +def get_transform(type='clip', keep_ratio=True, image_size=224): + if type == 'clip': + transform = [] + if keep_ratio: + transform.extend([ + transforms.Resize(image_size), + transforms.CenterCrop(image_size), + ]) + else: + transform.append(transforms.Resize((image_size, image_size))) + transform.extend([ + transforms.ToTensor(), + transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + ]) + + return transforms.Compose(transform) + elif type == 'clipa': + transform = [] + if keep_ratio: + transform.extend([ + transforms.Resize(image_size), + transforms.CenterCrop(image_size), + ]) + else: + transform.append(transforms.Resize((image_size, image_size))) + transform.extend([transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]) + + return transforms.Compose(transform) + elif type == 'clipb': + transform = [] + + if keep_ratio: + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), + background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), + background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + background_color = tuple(int(x * 255) for x in (0.48145466, 0.4578275, 0.40821073)) + + transform.append( + transforms.Lambda( + lambda img: expand2square(img, background_color))) + + transform.append(transforms.Resize((image_size, image_size))) + else: + transform.append(transforms.Resize((image_size, image_size))) + + transform.extend([ + transforms.ToTensor(), + transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)) + ]) + + return transforms.Compose(transform) + + elif type == 'sd': + transform = [] + if keep_ratio: + transform.extend([ + transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(image_size), + ]) + else: + transform.append(transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC)) + transform.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) + + return transforms.Compose(transform) + else: + raise NotImplementedError diff --git a/start.py b/start.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce492593f893189b588c13d377aa5c973537d60 --- /dev/null +++ b/start.py @@ -0,0 +1,17 @@ +import subprocess + +if __name__ == '__main__': + backend_comand = ['python3', 'src/demo/seed_llama_flask.py', '--image_transform', 'configs/processer/qwen_448_transform.yaml', \ + '--tokenizer', 'configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml', \ + '--llm', 'configs/clm_models/llm_seed_x_i.yaml', \ + '--visual_encoder', 'configs/visual_encoder/qwen_vitg_448.yaml', \ + '--sd_adapter', 'configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml', \ + '--agent', 'configs/clm_models/agent_seed_x_i.yaml', \ + '--diffusion_path', 'https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0', \ + '--port', '7890', '--llm_device', 'cuda:0', '--vit_sd_device', 'cuda:0', '--multi_resolution', 'True', '--has_bbox'] + + frontend_comand = ['python3', 'src/demo/seed_llama_gradio.py', '--server_port', '7860', '--request_address', 'http://127.0.0.1:7890/generate'] + + backend_proc = subprocess.Popen(backend_comand) + + frontend_proc = subprocess.Popen(frontend_comand) diff --git a/start.sh b/start.sh new file mode 100644 index 0000000000000000000000000000000000000000..549ac0e177bf15634c8908d28ef78e9d1ca6828a --- /dev/null +++ b/start.sh @@ -0,0 +1,18 @@ +nohup python3 src/demo/seed_llama_flask.py \ + --image_transform configs/processer/qwen_448_transform.yaml \ + --tokenizer configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml \ + --llm configs/clm_models/llm_seed_x_i.yaml \ + --visual_encoder configs/visual_encoder/qwen_vitg_448.yaml \ + --sd_adapter configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml \ + --agent configs/clm_models/agent_seed_x_i.yaml \ + --diffusion_path https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 \ + --port 7890 \ + --llm_device 'cuda:0' \ + --vit_sd_device 'cuda:0' \ + --multi_resolution True \ + --has_bbox + +python3 src/demo/seed_llama_gradio.py \ + --server_port 7860 \ + --server_name '0.0.0.0' \ + --request_address 'http://127.0.0.1:7890/generate'