diff --git a/elia/.gitattributes b/elia/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..c7d9f3332a950355d5a77d85000f05e6f45435ea --- /dev/null +++ b/elia/.gitattributes @@ -0,0 +1,34 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/elia/LICENSE b/elia/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f288702d2fa16d3cdf0035b15a9fcbc552cd88e7 --- /dev/null +++ b/elia/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/elia/README.md b/elia/README.md new file mode 100644 index 0000000000000000000000000000000000000000..869de8afb5569dca5c730028fe17228f54198bef --- /dev/null +++ b/elia/README.md @@ -0,0 +1,222 @@ +# LAVT: Language-Aware Vision Transformer for Referring Image Segmentation +Welcome to the official repository for the method presented in +"LAVT: Language-Aware Vision Transformer for Referring Image Segmentation." + + +![Pipeline Image](pipeline.jpg) + +Code in this repository is written using [PyTorch](https://pytorch.org/) and is organized in the following way (assuming the working directory is the root directory of this repository): +* `./lib` contains files implementing the main network. +* Inside `./lib`, `_utils.py` defines the highest-level model, which incorporates the backbone network +defined in `backbone.py` and the simple mask decoder defined in `mask_predictor.py`. +`segmentation.py` provides the model interface and initialization functions. +* `./bert` contains files migrated from [Hugging Face Transformers v3.0.2](https://huggingface.co/transformers/v3.0.2/quicktour.html), +which implement the BERT language model. +We used Transformers v3.0.2 during development but it had a bug that would appear when using `DistributedDataParallel`. +Therefore we maintain a copy of the relevant source files in this repository. +This way, the bug is fixed and code in this repository is self-contained. +* `./train.py` is invoked to train the model. +* `./test.py` is invoked to run inference on the evaluation subsets after training. +* `./refer` contains data pre-processing code and is also where data should be placed, including the images and all annotations. +It is cloned from [refer](https://github.com/lichengunc/refer). +* `./data/dataset_refer_bert.py` is where the dataset class is defined. +* `./utils.py` defines functions that track training statistics and setup +functions for `DistributedDataParallel`. + + +## Updates +**June 21st, 2022**. Uploaded the training logs and trained +model weights of lavt_one. + +**June 9th, 2022**. +Added a more efficient implementation of LAVT. +* To train this new model, specify `--model` as `lavt_one` +(and `lavt` is still valid for specifying the old model). +The rest of the configuration stays unchanged. +* The difference between this version and the previous one +is that the language model has been moved inside the overall model, +so that `DistributedDataParallel` needs to be applied only once. +Applying it twice (on the standalone language model and the main branch) +as done in the old implementation led to low GPU utility, +which prevented scaling up training speed with more GPUs. +We recommend training this model on 8 GPUs +(and same as before with batch size 32). + +## Setting Up +### Preliminaries +The code has been verified to work with PyTorch v1.7.1 and Python 3.7. +1. Clone this repository. +2. Change directory to root of this repository. +### Package Dependencies +1. Create a new Conda environment with Python 3.7 then activate it: +```shell +conda create -n lavt python==3.7 +conda activate lavt +``` + +2. Install PyTorch v1.7.1 with a CUDA version that works on your cluster/machine (CUDA 10.2 is used in this example): +```shell +conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch +``` + +3. Install the packages in `requirements.txt` via `pip`: +```shell +pip install -r requirements.txt +``` + +### Datasets +1. Follow instructions in the `./refer` directory to set up subdirectories +and download annotations. +This directory is a git clone (minus two data files that we do not need) +from the [refer](https://github.com/lichengunc/refer) public API. + +2. Download images from [COCO](https://cocodataset.org/#download). +Please use the first downloading link *2014 Train images [83K/13GB]*, and extract +the downloaded `train_2014.zip` file to `./refer/data/images/mscoco/images`. + +### The Initialization Weights for Training +1. Create the `./pretrained_weights` directory where we will be storing the weights. +```shell +mkdir ./pretrained_weights +``` +2. Download [pre-trained classification weights of +the Swin Transformer](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth), +and put the `pth` file in `./pretrained_weights`. +These weights are needed for training to initialize the model. + +### Trained Weights of LAVT for Testing +1. Create the `./checkpoints` directory where we will be storing the weights. +```shell +mkdir ./checkpoints +``` +2. Download LAVT model weights (which are stored on Google Drive) using links below and put them in `./checkpoints`. + +| [RefCOCO](https://drive.google.com/file/d/13D-OeEOijV8KTC3BkFP-gOJymc6DLwVT/view?usp=sharing) | [RefCOCO+](https://drive.google.com/file/d/1B8Q44ZWsc8Pva2xD_M-KFh7-LgzeH2-2/view?usp=sharing) | [G-Ref (UMD)](https://drive.google.com/file/d/1BjUnPVpALurkGl7RXXvQiAHhA-gQYKvK/view?usp=sharing) | [G-Ref (Google)](https://drive.google.com/file/d/1weiw5UjbPfo3tCBPfB8tu6xFXCUG16yS/view?usp=sharing) | +|---|---|---|---| + +3. Model weights and training logs of the new lavt_one implementation are below. + +| RefCOCO | RefCOCO+ | G-Ref (UMD) | G-Ref (Google) | +|:-----:|:-----:|:-----:|:-----:| +|[log](https://drive.google.com/file/d/1YIojIHqe3bxxsWOltifa2U9jH67hPHLM/view?usp=sharing) | [weights](https://drive.google.com/file/d/1xFMEXr6AGU97Ypj1yr8oo00uObbeIQvJ/view?usp=sharing)|[log](https://drive.google.com/file/d/1Z34T4gEnWlvcSUQya7txOuM0zdLK7MRT/view?usp=sharing) | [weights](https://drive.google.com/file/d/1HS8ZnGaiPJr-OmoUn4-4LVnVtD_zHY6w/view?usp=sharing)|[log](https://drive.google.com/file/d/14VAgahngOV8NA6noLZCqDoqaUrlW14v8/view?usp=sharing) | [weights](https://drive.google.com/file/d/14g8NzgZn6HzC6tP_bsQuWmh5LnOcovsE/view?usp=sharing)|[log](https://drive.google.com/file/d/1JBXfmlwemWSvs92Rky0TlHcVuuLpt4Da/view?usp=sharing) | [weights](https://drive.google.com/file/d/1IJeahFVLgKxu_BVmWacZs3oUzgTCeWcz/view?usp=sharing)| + +* The Prec@K, overall IoU and mean IoU numbers in the training logs will differ +from the final results obtained by running `test.py`, +because only one out of multiple annotated expressions is +randomly selected and evaluated for each object during training. +But these numbers give a good idea about the test performance. +The two should be fairly close. + + +## Training +We use `DistributedDataParallel` from PyTorch. +The released `lavt` weights were trained using 4 x 32G V100 cards (max mem on each card was about 26G). +The released `lavt_one` weights were trained using 8 x 32G V100 cards (max mem on each card was about 13G). +Using more cards was to accelerate training. +To run on 4 GPUs (with IDs 0, 1, 2, and 3) on a single node: +```shell +mkdir ./models + +mkdir ./models/refcoco +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcoco --model_id refcoco --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/refcoco/output + +mkdir ./models/refcoco+ +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcoco+ --model_id refcoco+ --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/refcoco+/output + +mkdir ./models/gref_umd +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcocog --splitBy umd --model_id gref_umd --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/gref_umd/output + +mkdir ./models/gref_google +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcocog --splitBy google --model_id gref_google --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/gref_google/output +``` +* *--model* is a pre-defined model name. Options include `lavt` and `lavt_one`. See [Updates](#updates). +* *--dataset* is the dataset name. One can choose from `refcoco`, `refcoco+`, and `refcocog`. +* *--splitBy* needs to be specified if and only if the dataset is G-Ref (which is also called RefCOCOg). +`umd` identifies the UMD partition and `google` identifies the Google partition. +* *--model_id* is the model name one should define oneself (*e.g.*, customize it to contain training/model configurations, dataset information, experiment IDs, *etc*.). +It is used in two ways: Training log will be saved as `./models/[args.model_id]/output` and the best checkpoint will be saved as `./checkpoints/model_best_[args.model_id].pth`. +* *--swin_type* specifies the version of the Swin Transformer. +One can choose from `tiny`, `small`, `base`, and `large`. The default is `base`. +* *--pretrained_swin_weights* specifies the path to pre-trained Swin Transformer weights used for model initialization. +* Note that currently we need to manually create the `./models/[args.model_id]` directory via `mkdir` before running `train.py`. +This is because we use `tee` to redirect `stdout` and `stderr` to `./models/[args.model_id]/output` for logging. +This is a nuisance and should be resolved in the future, *i.e.*, using a proper logger or a bash script for initiating training. + +## Testing +For RefCOCO/RefCOCO+, run one of +```shell +python test.py --model lavt --swin_type base --dataset refcoco --split val --resume ./checkpoints/refcoco.pth --workers 4 --ddp_trained_weights --window12 --img_size 480 +python test.py --model lavt --swin_type base --dataset refcoco+ --split val --resume ./checkpoints/refcoco+.pth --workers 4 --ddp_trained_weights --window12 --img_size 480 +``` +* *--split* is the subset to evaluate, and one can choose from `val`, `testA`, and `testB`. +* *--resume* is the path to the weights of a trained model. + +For G-Ref (UMD)/G-Ref (Google), run one of +```shell +python test.py --model lavt --swin_type base --dataset refcocog --splitBy umd --split val --resume ./checkpoints/gref_umd.pth --workers 4 --ddp_trained_weights --window12 --img_size 480 +python test.py --model lavt --swin_type base --dataset refcocog --splitBy google --split val --resume ./checkpoints/gref_google.pth --workers 4 --ddp_trained_weights --window12 --img_size 480 +``` +* *--splitBy* specifies the partition to evaluate. +One can choose from `umd` or `google`. +* *--split* is the subset (according to the specified partition) to evaluate, and one can choose from `val` and `test` for the UMD partition, and only `val` for the Google partition.. +* *--resume* is the path to the weights of a trained model. + +## Results +The complete test results of the released LAVT models are summarized as follows: + +| Dataset | P@0.5 | P@0.6 | P@0.7 | P@0.8 | P@0.9 | Overall IoU | Mean IoU | +|:---------------:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----------:|:--------:| +| RefCOCO val | 84.46 | 80.90 | 75.28 | 64.71 | 34.30 | 72.73 | 74.46 | +| RefCOCO test A | 88.07 | 85.17 | 79.90 | 68.52 | 35.69 | 75.82 | 76.89 | +| RefCOCO test B | 79.12 | 74.94 | 69.17 | 59.37 | 34.45 | 68.79 | 70.94 | +| RefCOCO+ val | 74.44 | 70.91 | 65.58 | 56.34 | 30.23 | 62.14 | 65.81 | +| RefCOCO+ test A | 80.68 | 77.96 | 72.90 | 62.21 | 32.36 | 68.38 | 70.97 | +| RefCOCO+ test B | 65.66 | 61.85 | 55.94 | 47.56 | 27.24 | 55.10 | 59.23 | +| G-Ref val (UMD) | 70.81 | 65.28 | 58.60 | 47.49 | 22.73 | 61.24 | 63.34 | +| G-Ref test (UMD)| 71.54 | 66.38 | 59.00 | 48.21 | 23.10 | 62.09 | 63.62 | +|G-Ref val (Goog.)| 71.16 | 67.21 | 61.76 | 51.98 | 27.30 | 60.50 | 63.66 | + +We have validated LAVT on RefCOCO with multiple runs. +The overall IoU on the val set generally lies in the range of 72.73±0.5%. + + +## Demo: Try LAVT on Your Own Image-text Pairs! +One can run inference on a custom image-text pair +and visualize the result by running the script `./demo_inference.py`. +Choose your photos and expessions and have fun. + + +## Citing LAVT +``` +@inproceedings{yang2022lavt, + title={LAVT: Language-Aware Vision Transformer for Referring Image Segmentation}, + author={Yang, Zhao and Wang, Jiaqi and Tang, Yansong and Chen, Kai and Zhao, Hengshuang and Torr, Philip HS}, + booktitle={CVPR}, + year={2022} +} +``` + + +## Contributing +We appreciate all contributions. +It helps the project if you could +- report issues you are facing, +- give a :+1: on issues reported by others that are relevant to you, +- answer issues reported by others for which you have found solutions, +- and implement helpful new features or improve the code otherwise with pull requests. + +## Acknowledgements +Code in this repository is built upon several public repositories. +Specifically, +* data pre-processing leverages the [refer](https://github.com/lichengunc/refer) repository, +* the backbone model is implemented based on code from [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation), +* the training and testing pipelines are adapted from [RefVOS](https://github.com/miriambellver/refvos), +* and implementation of the BERT model (files in the bert directory) is from [Hugging Face Transformers v3.0.2](https://github.com/huggingface/transformers/tree/v3.0.2) +(we migrated over the relevant code to fix a bug and simplify the installation process). + +Some of these repositories in turn adapt code from [OpenMMLab](https://github.com/open-mmlab) and [TorchVision](https://github.com/pytorch/vision). +We'd like to thank the authors/organizations of these repositories for open sourcing their projects. + + +## License +GNU GPLv3 diff --git a/elia/__pycache__/args.cpython-37.pyc b/elia/__pycache__/args.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f640c5b4ee64eda39a04bf8049ed80dc140ef6e Binary files /dev/null and b/elia/__pycache__/args.cpython-37.pyc differ diff --git a/elia/__pycache__/args.cpython-38.pyc b/elia/__pycache__/args.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e42ca8813cec4f1a54d78aebc2ea603adf3defe3 Binary files /dev/null and b/elia/__pycache__/args.cpython-38.pyc differ diff --git a/elia/__pycache__/transforms.cpython-37.pyc b/elia/__pycache__/transforms.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23f8d1523aee9c09d0276e0b628bd712c7bad041 Binary files /dev/null and b/elia/__pycache__/transforms.cpython-37.pyc differ diff --git a/elia/__pycache__/transforms.cpython-38.pyc b/elia/__pycache__/transforms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..815e3982a1d1d71927cc081f1e04de5d988183bc Binary files /dev/null and b/elia/__pycache__/transforms.cpython-38.pyc differ diff --git a/elia/__pycache__/utils.cpython-37.pyc b/elia/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a04ac1d1a6fdbcf6d8b1f00ee78c4698efc273b Binary files /dev/null and b/elia/__pycache__/utils.cpython-37.pyc differ diff --git a/elia/__pycache__/utils.cpython-38.pyc b/elia/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59c425f22032f4435c681238cd31a83bd8fb1061 Binary files /dev/null and b/elia/__pycache__/utils.cpython-38.pyc differ diff --git a/elia/app.py b/elia/app.py new file mode 100644 index 0000000000000000000000000000000000000000..03b11a42bb61ff9a4b2dc51fd6d8e4fe106de893 --- /dev/null +++ b/elia/app.py @@ -0,0 +1,310 @@ +import gradio as gr + +image_path = './image001.png' +sentence = 'spoon on the dish' +weights = './checkpoints/model_best_refcoco_0508.pth' +device = 'cpu' + +# pre-process the input image +from PIL import Image +import torchvision.transforms as T +import numpy as np +import datetime +import os +import time + +import torch +import torch.utils.data +from torch import nn + +from bert.multimodal_bert import MultiModalBert +import torchvision + +from lib import multimodal_segmentation_ppm +#import transforms as T +import utils + +import numpy as np +from PIL import Image +import torch.nn.functional as F + +from modeling.MaskFormerModel import MaskFormerHead +from addict import Dict +#from bert.modeling_bert import BertLMPredictionHead, BertEncoder +import cv2 +import textwrap + +class WrapperModel(nn.Module): + def __init__(self, image_model, language_model, classifier) : + super(WrapperModel, self).__init__() + self.image_model = image_model + self.language_model = language_model + self.classifier = classifier + + config = Dict({ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + #"max_position_embeddings": 16+20, + "model_type": "bert", + "num_attention_heads": 8, + "num_hidden_layers": 8, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "use_cache": True, + "vocab_size": 30522 + }) + + + + def _get_binary_mask(self, target): + # 返回每类的binary mask + y, x = target.size() + target_onehot = torch.zeros(self.num_classes + 1, y, x) + target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1) + return target_onehot[1:] + + def semantic_inference(self, mask_cls, mask_pred): + mask_cls = F.softmax(mask_cls, dim=1)[...,1:] + mask_pred = mask_pred.sigmoid() + semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) + return semseg + + def forward(self, image, sentences, attentions): + print(image.sum(), sentences.sum(), attentions.sum()) + input_shape = image.shape[-2:] + l_mask = attentions.unsqueeze(dim=-1) + + i0, Wh, Ww = self.image_model.forward_stem(image) + l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions) + + i1 = self.image_model.forward_stage1(i0, Wh, Ww) + l1 = self.language_model.forward_stage1(l0, extended_attention_mask) + i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask) + l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) + i1 = i1_temp + + i2 = self.image_model.forward_stage2(i1, Wh, Ww) + l2 = self.language_model.forward_stage2(l1, extended_attention_mask) + i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask) + l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) + i2 = i2_temp + + i3 = self.image_model.forward_stage3(i2, Wh, Ww) + l3 = self.language_model.forward_stage3(l2, extended_attention_mask) + i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask) + l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) + i3 = i3_temp + + i4 = self.image_model.forward_stage4(i3, Wh, Ww) + l4 = self.language_model.forward_stage4(l3, extended_attention_mask) + i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask) + l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) + i4 = i4_temp + + #i1_residual, i2_residual, i3_residual, i4_residual = features + #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual) + #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) + outputs = {} + outputs['s1'] = i1_residual + outputs['s2'] = i2_residual + outputs['s3'] = i3_residual + outputs['s4'] = i4_residual + + predictions = self.classifier(outputs) + return predictions + +#img = Image.open(image_path).convert("RGB") + +# pre-process the raw sentence +from bert.tokenization_bert import BertTokenizer +import torch + +# initialize model and load weights +#from bert.modeling_bert import BertModel +#from lib import segmentation + +# construct a mini args class; like from a config file + + +class args: + swin_type = 'base' + window12 = True + mha = '' + fusion_drop = 0.0 + + +#single_model = segmentation.__dict__['lavt'](pretrained='', args=args) +single_model = multimodal_segmentation_ppm.__dict__['lavt'](pretrained='',args=args) +single_model.to(device) +model_class = MultiModalBert +single_bert_model = model_class.from_pretrained('bert-base-uncased', embed_dim=single_model.backbone.embed_dim) +single_bert_model.pooler = None + +input_shape = dict() +input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) +input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) +input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) +input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) + + + +cfg = Dict() +cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 +cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 +cfg.MODEL.MASK_FORMER.NHEADS = 8 +cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4 +cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 +cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 +cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] + +cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 +cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 +cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1 +cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 +cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10 +cfg.MODEL.MASK_FORMER.PRE_NORM = False + + +maskformer_head = MaskFormerHead(cfg, input_shape) + + +model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head) + + + +checkpoint = torch.load(weights, map_location='cpu') + +model.load_state_dict(checkpoint['model'], strict=False) +model.to(device) +model.eval() +#single_bert_model.load_state_dict(checkpoint['bert_model']) +#single_model.load_state_dict(checkpoint['model']) +#model = single_model.to(device) +#bert_model = single_bert_model.to(device) + + +# inference +#import torch.nn.functional as F +#last_hidden_states = bert_model(padded_sent_toks, attention_mask=attention_mask)[0] +#embedding = last_hidden_states.permute(0, 2, 1) +#output = model(img, embedding, l_mask=attention_mask.unsqueeze(-1)) +#output = output.argmax(1, keepdim=True) # (1, 1, 480, 480) +#output = F.interpolate(output.float(), (original_h, original_w)) # 'nearest'; resize to the original image size +#output = output.squeeze() # (orig_h, orig_w) +#output = output.cpu().data.numpy() # (orig_h, orig_w) + +#output = pred_masks[0] + +#output = output.cpu() + + + +#print(output.shape) +#output_mask = output.argmax(1).data.numpy() +#output = (output > 0.5).data.cpu().numpy() + + +# show/save results +def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 0, 0]], cscale=1, alpha=0.4): + from scipy.ndimage.morphology import binary_dilation + + colors = np.reshape(colors, (-1, 3)) + colors = np.atleast_2d(colors) * cscale + + im_overlay = image.copy() + object_ids = np.unique(mask) + + for object_id in object_ids[1:]: + # Overlay color on binary mask + foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) + binary_mask = mask == object_id + + # Compose image + im_overlay[binary_mask] = foreground[binary_mask] + + # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask + countours = binary_dilation(binary_mask) ^ binary_mask + # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask + im_overlay[countours, :] = 0 + + return im_overlay.astype(image.dtype) + + +def run_model(img, sentence): + +#img = Image.open(image_path).convert("RGB") + img = Image.fromarray(img) + img = img.convert("RGB") + #print(img.shape) + img_ndarray = np.array(img) # (orig_h, orig_w, 3); for visualization + original_w, original_h = img.size # PIL .size returns width first and height second + + image_transforms = T.Compose( + [ + T.Resize((480, 480)), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ] + ) + + img = image_transforms(img).unsqueeze(0) # (1, 3, 480, 480) + img = img.to(device) # for inference (input) + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + sentence_tokenized = tokenizer.encode(text=sentence, add_special_tokens=True) + sentence_tokenized = sentence_tokenized[:20] # if the sentence is longer than 20, then this truncates it to 20 words + # pad the tokenized sentence + padded_sent_toks = [0] * 20 + padded_sent_toks[:len(sentence_tokenized)] = sentence_tokenized + # create a sentence token mask: 1 for real words; 0 for padded tokens + attention_mask = [0] * 20 + attention_mask[:len(sentence_tokenized)] = [1]*len(sentence_tokenized) + # convert lists to tensors + padded_sent_toks = torch.tensor(padded_sent_toks).unsqueeze(0) # (1, 20) + attention_mask = torch.tensor(attention_mask).unsqueeze(0) # (1, 20) + padded_sent_toks = padded_sent_toks.to(device) # for inference (input) + attention_mask = attention_mask.to(device) # for inference (input) + + output = model(img, padded_sent_toks, attention_mask)[0] + #print(output[0].keys()) + #print(output[1].shape) + mask_cls_results = output["pred_logits"] + mask_pred_results = output["pred_masks"] + + target_shape = img_ndarray.shape[:2] + #print(target_shape, mask_pred_results.shape) + mask_pred_results = F.interpolate(mask_pred_results, size=(480,480), mode='bilinear', align_corners=True) + + pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results) + + output = torch.nn.functional.interpolate(pred_masks, target_shape) + output = (output > 0.5).data.cpu().numpy() + + output = output.astype(np.uint8) # (orig_h, orig_w), np.uint8 + # Overlay the mask on the image + print(img_ndarray.shape, output.shape) + visualization = overlay_davis(img_ndarray, output[0][0]) # red + visualization = Image.fromarray(visualization) + # show the visualization + #visualization.show() + # Save the visualization + #visualization.save('./demo/spoon_on_the_dish.jpg') + return visualization + + + + +demo = gr.Interface(run_model, inputs=[gr.Image(), "text"], outputs=["image"]) +demo.launch() + diff --git a/elia/args.py b/elia/args.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f8480b92afc4c61e3c45a79b3d332d9c53226a --- /dev/null +++ b/elia/args.py @@ -0,0 +1,74 @@ +import argparse + + +def get_parser(): + parser = argparse.ArgumentParser(description='LAVT training and testing') + parser.add_argument('--amsgrad', action='store_true', + help='if true, set amsgrad to True in an Adam or AdamW optimizer.') + parser.add_argument('-b', '--batch-size', default=8, type=int) + parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer') + parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights') + parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog') + parser.add_argument('--ddp_trained_weights', action='store_true', + help='Only needs specified when testing,' + 'whether the weights to be loaded are from a DDP-trained model') + parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine + parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run') + parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs') + parser.add_argument('--img_size', default=480, type=int, help='input image size') + parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel') + parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate') + parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,' + 'where a, b, c, and d refer to the numbers of heads in stage-1,' + 'stage-2, stage-3, and stage-4 PWAMs') + parser.add_argument('--model', default='lavt', help='model: lavt, lavt_one') + parser.add_argument('--model_id', default='lavt', help='name to identify the model') + parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights') + parser.add_argument('--pin_mem', action='store_true', + help='If true, pin memory when using the data loader.') + parser.add_argument('--pretrained_swin_weights', default='', + help='path to pre-trained Swin backbone weights') + parser.add_argument('--print-freq', default=10, type=int, help='print frequency') + parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory') + parser.add_argument('--resume', default='auto', help='resume from checkpoint') + parser.add_argument('--split', default='test', help='only used when testing') + parser.add_argument('--splitBy', default='unc', help='change to umd or google when the dataset is G-Ref (RefCOCOg)') + parser.add_argument('--swin_type', default='base', + help='tiny, small, base, or large variants of the Swin Transformer') + parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay', + dest='weight_decay') + parser.add_argument('--window12', action='store_true', + help='only needs specified when testing,' + 'when training, window size is inferred from pre-trained weights file name' + '(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.') + parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--max_ckpt', default=2, type=int) + parser.add_argument('--num_object_queries', default=1, type=int) + parser.add_argument('--no_object_weight', default=0.0, type=float) + parser.add_argument('--class_weight', default=2.0, type=float) + parser.add_argument('--dice_weight', default=2.0, type=float) + parser.add_argument('--mask_weight', default=2.0, type=float) + parser.add_argument('--train_num_points', default=12544, type=int) + parser.add_argument('--dim_feedforward', default=2048, type=int) + parser.add_argument('--dec_layers', default=10, type=int) + parser.add_argument('--transformer_enc_layers', default=4, type=int) + + parser.add_argument('--plic_pos_weight', default=0.5, type=float) + parser.add_argument('--plic_neg_weight', default=0.5, type=float) + parser.add_argument('--plic_lang_weight', default=0.5, type=float) + parser.add_argument('--plic_pos_alpha', default=0.0, type=float) + parser.add_argument('--plic_neg_alpha', default=0.0, type=float) + parser.add_argument('--plic_lang_alpha', default=0.0, type=float) + parser.add_argument('--plic_pos_temp', default=0.2, type=float) + parser.add_argument('--plic_neg_temp', default=0.2, type=float) + parser.add_argument('--plic_lang_temp', default=0.2, type=float) + parser.add_argument('--smlm_weight', default=1.0, type=float) + parser.add_argument('--vis_dir', default='./vis_dir') + + return parser + + +if __name__ == "__main__": + parser = get_parser() + args_dict = parser.parse_args() diff --git a/elia/bert/__pycache__/activations.cpython-37.pyc b/elia/bert/__pycache__/activations.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf79bfc969d2dfa652d29321c45aec02bc8ed301 Binary files /dev/null and b/elia/bert/__pycache__/activations.cpython-37.pyc differ diff --git a/elia/bert/__pycache__/activations.cpython-38.pyc b/elia/bert/__pycache__/activations.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d3804e94dab5c36f8cdda958065b4e32dc4aabd Binary files /dev/null and b/elia/bert/__pycache__/activations.cpython-38.pyc differ diff --git a/elia/bert/__pycache__/configuration_bert.cpython-37.pyc b/elia/bert/__pycache__/configuration_bert.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8084771ac9831c916fa953b7b10017435ea14d0 Binary files /dev/null and b/elia/bert/__pycache__/configuration_bert.cpython-37.pyc differ diff --git a/elia/bert/__pycache__/configuration_bert.cpython-38.pyc b/elia/bert/__pycache__/configuration_bert.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9eb473e317fe8fb5c667a3dbe69bace5c29e8986 Binary files /dev/null and b/elia/bert/__pycache__/configuration_bert.cpython-38.pyc differ diff --git a/elia/bert/__pycache__/configuration_utils.cpython-37.pyc b/elia/bert/__pycache__/configuration_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea208dae8700e64e59cc156989eabbf2e528dc80 Binary files /dev/null and b/elia/bert/__pycache__/configuration_utils.cpython-37.pyc differ diff --git a/elia/bert/__pycache__/configuration_utils.cpython-38.pyc b/elia/bert/__pycache__/configuration_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eefb7f8062db85c760031b0999814cf84c2f992b Binary files /dev/null and b/elia/bert/__pycache__/configuration_utils.cpython-38.pyc differ diff --git a/elia/bert/__pycache__/file_utils.cpython-37.pyc b/elia/bert/__pycache__/file_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48ad2774a20d2634cf7e3541436cc4a40aca9ca5 Binary files /dev/null and b/elia/bert/__pycache__/file_utils.cpython-37.pyc differ diff --git a/elia/bert/__pycache__/file_utils.cpython-38.pyc b/elia/bert/__pycache__/file_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29ede2127f22edebb5409a5bfcbec143227b7d01 Binary files /dev/null and b/elia/bert/__pycache__/file_utils.cpython-38.pyc differ diff --git a/elia/bert/__pycache__/generation_utils.cpython-37.pyc b/elia/bert/__pycache__/generation_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d79d8c2dd616ce91f49c643e29b40e084012bc28 Binary files /dev/null and b/elia/bert/__pycache__/generation_utils.cpython-37.pyc differ diff --git a/elia/bert/__pycache__/generation_utils.cpython-38.pyc b/elia/bert/__pycache__/generation_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c36afeb132be1ff8b5cca362f0911bf211e1b1e3 Binary files /dev/null and b/elia/bert/__pycache__/generation_utils.cpython-38.pyc differ diff --git a/elia/bert/__pycache__/modeling_bert.cpython-37.pyc b/elia/bert/__pycache__/modeling_bert.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4114e8086ba2f24b8786db255b4d7c0391b372c Binary files /dev/null and b/elia/bert/__pycache__/modeling_bert.cpython-37.pyc differ diff --git a/elia/bert/__pycache__/modeling_bert.cpython-38.pyc b/elia/bert/__pycache__/modeling_bert.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e4f837799032e0940315dd0506de477acb1acf1 Binary files /dev/null and b/elia/bert/__pycache__/modeling_bert.cpython-38.pyc differ diff --git a/elia/bert/__pycache__/modeling_utils.cpython-37.pyc b/elia/bert/__pycache__/modeling_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c77d5fa351f765bffe538cc8fe462ec1f119a704 Binary files /dev/null and b/elia/bert/__pycache__/modeling_utils.cpython-37.pyc differ diff --git a/elia/bert/__pycache__/modeling_utils.cpython-38.pyc b/elia/bert/__pycache__/modeling_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52806021ea6d36cc4fad71c04b02165b2e14f4e7 Binary files /dev/null and b/elia/bert/__pycache__/modeling_utils.cpython-38.pyc differ diff --git a/elia/bert/__pycache__/multimodal_bert.cpython-37.pyc b/elia/bert/__pycache__/multimodal_bert.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d286f1ac28e65d2c220c49e2cc53f8283ef8e87 Binary files /dev/null and b/elia/bert/__pycache__/multimodal_bert.cpython-37.pyc differ diff --git a/elia/bert/__pycache__/multimodal_bert.cpython-38.pyc b/elia/bert/__pycache__/multimodal_bert.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0919175640605c3d91aa56da62a3a916bca08648 Binary files /dev/null and b/elia/bert/__pycache__/multimodal_bert.cpython-38.pyc differ diff --git a/elia/bert/__pycache__/tokenization_bert.cpython-37.pyc b/elia/bert/__pycache__/tokenization_bert.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..047e5c66121aba964c5d9ecd7c5d3079213f1055 Binary files /dev/null and b/elia/bert/__pycache__/tokenization_bert.cpython-37.pyc differ diff --git a/elia/bert/__pycache__/tokenization_bert.cpython-38.pyc b/elia/bert/__pycache__/tokenization_bert.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46e86bdfb72fd99e3a12e8304d1a87da8874af31 Binary files /dev/null and b/elia/bert/__pycache__/tokenization_bert.cpython-38.pyc differ diff --git a/elia/bert/__pycache__/tokenization_utils.cpython-37.pyc b/elia/bert/__pycache__/tokenization_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f93a01807dbfcf5f158ea2d18ec7b1a3cf156ff Binary files /dev/null and b/elia/bert/__pycache__/tokenization_utils.cpython-37.pyc differ diff --git a/elia/bert/__pycache__/tokenization_utils.cpython-38.pyc b/elia/bert/__pycache__/tokenization_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0c7b8e10c8c808b0d8466184f2107f6efa0261c Binary files /dev/null and b/elia/bert/__pycache__/tokenization_utils.cpython-38.pyc differ diff --git a/elia/bert/__pycache__/tokenization_utils_base.cpython-37.pyc b/elia/bert/__pycache__/tokenization_utils_base.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e583920001390bd94d4c8cbd447cc601e15d9846 Binary files /dev/null and b/elia/bert/__pycache__/tokenization_utils_base.cpython-37.pyc differ diff --git a/elia/bert/__pycache__/tokenization_utils_base.cpython-38.pyc b/elia/bert/__pycache__/tokenization_utils_base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9d46a9416d56f3d4ed2ae4473434002b42eb68c Binary files /dev/null and b/elia/bert/__pycache__/tokenization_utils_base.cpython-38.pyc differ diff --git a/elia/bert/activations.py b/elia/bert/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..8a1206ee285ce3f0484d129711a2d684700a20a1 --- /dev/null +++ b/elia/bert/activations.py @@ -0,0 +1,56 @@ +import logging +import math + +import torch +import torch.nn.functional as F + + +logger = logging.getLogger(__name__) + + +def swish(x): + return x * torch.sigmoid(x) + + +def _gelu_python(x): + """ Original Implementation of the gelu activation function in Google Bert repo when initially created. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + This is now written in C in torch.nn.functional + Also see https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def gelu_new(x): + """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). + Also see https://arxiv.org/abs/1606.08415 + """ + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +if torch.__version__ < "1.4.0": + gelu = _gelu_python +else: + gelu = F.gelu + + +def gelu_fast(x): + return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) + + +ACT2FN = { + "relu": F.relu, + "swish": swish, + "gelu": gelu, + "tanh": torch.tanh, + "gelu_new": gelu_new, + "gelu_fast": gelu_fast, +} + + +def get_activation(activation_string): + if activation_string in ACT2FN: + return ACT2FN[activation_string] + else: + raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys()))) diff --git a/elia/bert/configuration_bert.py b/elia/bert/configuration_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..8e815837bc4dbc5fc8eec7ee37547b5d41519af5 --- /dev/null +++ b/elia/bert/configuration_bert.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BERT model configuration """ + + +import logging + +from .configuration_utils import PretrainedConfig + + +logger = logging.getLogger(__name__) + +BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", + "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", + "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", + "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", + "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", + "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", + "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", + "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", + "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", + "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", + "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", + "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", + "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", + "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", + "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", + "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json", + "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json", + "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json", + "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json", + "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", + "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", + "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json", + # See all BERT models at https://huggingface.co/models?filter=bert +} + + +class BertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.BertModel`. + It is used to instantiate an BERT model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the BERT `bert-base-uncased `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used + to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` + for more information. + + + Args: + vocab_size (:obj:`int`, optional, defaults to 30522): + Vocabulary size of the BERT model. Defines the different tokens that + can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`. + hidden_size (:obj:`int`, optional, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, optional, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, optional, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, optional, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): + The non-linear activation function (function or string) in the encoder and pooler. + If string, "gelu", "relu", "swish" and "gelu_new" are supported. + hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, optional, defaults to 512): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, optional, defaults to 2): + The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`. + initializer_range (:obj:`float`, optional, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): + The epsilon used by the layer normalization layers. + gradient_checkpointing (:obj:`bool`, optional, defaults to False): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + + Example:: + + >>> from transformers import BertModel, BertConfig + + >>> # Initializing a BERT bert-base-uncased style configuration + >>> configuration = BertConfig() + + >>> # Initializing a model from the bert-base-uncased style configuration + >>> model = BertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "bert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + gradient_checkpointing=False, + **kwargs + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.gradient_checkpointing = gradient_checkpointing diff --git a/elia/bert/configuration_utils.py b/elia/bert/configuration_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9929ee4c19dab9a88bb51e0281d220a8456c2fce --- /dev/null +++ b/elia/bert/configuration_utils.py @@ -0,0 +1,408 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Configuration base class and utilities.""" + + +import copy +import json +import logging +import os +from typing import Dict, Tuple + +from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url + + +logger = logging.getLogger(__name__) + + +class PretrainedConfig(object): + r""" Base class for all configuration classes. + Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. + + Note: + A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights. + It only affects the model's configuration. + + Class attributes (overridden by derived classes): + - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`. + + Args: + finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`): + Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. + num_labels (:obj:`int`, `optional`, defaults to `2`): + Number of classes to use when the model is a classification model (sequences/tokens) + output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`): + Should the model returns all hidden-states. + output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`): + Should the model returns all attentions. + torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`): + Is the model used with Torchscript (for PyTorch models). + """ + model_type: str = "" + + def __init__(self, **kwargs): + # Attributes with defaults + self.output_hidden_states = kwargs.pop("output_hidden_states", False) + self.output_attentions = kwargs.pop("output_attentions", False) + self.use_cache = kwargs.pop("use_cache", True) # Not used by all models + self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models + self.use_bfloat16 = kwargs.pop("use_bfloat16", False) + self.pruned_heads = kwargs.pop("pruned_heads", {}) + + # Is decoder is used in encoder-decoder models to differentiate encoder from decoder + self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) + self.is_decoder = kwargs.pop("is_decoder", False) + + # Parameters for sequence generation + self.max_length = kwargs.pop("max_length", 20) + self.min_length = kwargs.pop("min_length", 0) + self.do_sample = kwargs.pop("do_sample", False) + self.early_stopping = kwargs.pop("early_stopping", False) + self.num_beams = kwargs.pop("num_beams", 1) + self.temperature = kwargs.pop("temperature", 1.0) + self.top_k = kwargs.pop("top_k", 50) + self.top_p = kwargs.pop("top_p", 1.0) + self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) + self.length_penalty = kwargs.pop("length_penalty", 1.0) + self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) + self.bad_words_ids = kwargs.pop("bad_words_ids", None) + self.num_return_sequences = kwargs.pop("num_return_sequences", 1) + + # Fine-tuning task arguments + self.architectures = kwargs.pop("architectures", None) + self.finetuning_task = kwargs.pop("finetuning_task", None) + self.id2label = kwargs.pop("id2label", None) + self.label2id = kwargs.pop("label2id", None) + if self.id2label is not None: + kwargs.pop("num_labels", None) + self.id2label = dict((int(key), value) for key, value in self.id2label.items()) + # Keys are always strings in JSON so convert ids to int here. + else: + self.num_labels = kwargs.pop("num_labels", 2) + + # Tokenizer arguments TODO: eventually tokenizer and models should share the same config + self.prefix = kwargs.pop("prefix", None) + self.bos_token_id = kwargs.pop("bos_token_id", None) + self.pad_token_id = kwargs.pop("pad_token_id", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) + self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) + + # task specific arguments + self.task_specific_params = kwargs.pop("task_specific_params", None) + + # TPU arguments + self.xla_device = kwargs.pop("xla_device", None) + + # Additional attributes without default values + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error("Can't set {} with value {} for {}".format(key, value, self)) + raise err + + @property + def num_labels(self): + return len(self.id2label) + + @num_labels.setter + def num_labels(self, num_labels): + self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)} + self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) + + def save_pretrained(self, save_directory): + """ + Save a configuration object to the directory `save_directory`, so that it + can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. + + Args: + save_directory (:obj:`string`): + Directory where the configuration JSON file will be saved. + """ + if os.path.isfile(save_directory): + raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory)) + os.makedirs(save_directory, exist_ok=True) + # If we save using the predefined names, we can load using `from_pretrained` + output_config_file = os.path.join(save_directory, CONFIG_NAME) + + self.to_json_file(output_config_file, use_diff=True) + logger.info("Configuration saved in {}".format(output_config_file)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig": + r""" + + Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. + + Args: + pretrained_model_name_or_path (:obj:`string`): + either: + - a string with the `shortcut name` of a pre-trained model configuration to load from cache or + download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to + our S3, e.g.: ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing a configuration file saved using the + :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. + - a path or url to a saved configuration JSON `file`, e.g.: + ``./my_model_directory/configuration.json``. + cache_dir (:obj:`string`, `optional`): + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + kwargs (:obj:`Dict[str, any]`, `optional`): + The values in kwargs of any keys which are configuration attributes will be used to override the loaded + values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is + controlled by the `return_unused_kwargs` keyword parameter. + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Force to (re-)download the model weights and configuration files and override the cached versions if they exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. + proxies (:obj:`Dict`, `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g.: + :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` + The proxies are used on each request. + return_unused_kwargs: (`optional`) bool: + If False, then this function returns just the final configuration object. + If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a + dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part + of kwargs which has not been used to update `config` and is otherwise ignored. + + Returns: + :class:`PretrainedConfig`: An instance of a configuration object + + Examples:: + + # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a + # derived class: BertConfig + config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. + config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` + config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') + config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) + assert config.output_attention == True + config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, + foo=False, return_unused_kwargs=True) + assert config.output_attention == True + assert unused_kwargs == {'foo': False} + + """ + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + return cls.from_dict(config_dict, **kwargs) + + @classmethod + def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used + for instantiating a Config using `from_dict`. + + Parameters: + pretrained_model_name_or_path (:obj:`string`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + + Returns: + :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object. + + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + + if os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + else: + config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False) + + try: + # Load from URL or cache if already cached + resolved_config_file = cached_path( + config_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + ) + # Load config dict + if resolved_config_file is None: + raise EnvironmentError + config_dict = cls._dict_from_json_file(resolved_config_file) + + except EnvironmentError: + msg = ( + f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n" + f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" + f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n" + ) + raise EnvironmentError(msg) + + except json.JSONDecodeError: + msg = ( + "Couldn't reach server at '{}' to download configuration file or " + "configuration file is not a valid JSON file. " + "Please check network or file content here: {}.".format(config_file, resolved_config_file) + ) + raise EnvironmentError(msg) + + if resolved_config_file == config_file: + logger.info("loading configuration file {}".format(config_file)) + else: + logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file)) + + return config_dict, kwargs + + @classmethod + def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig": + """ + Constructs a `Config` from a Python dictionary of parameters. + + Args: + config_dict (:obj:`Dict[str, any]`): + Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved + from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict` + method. + kwargs (:obj:`Dict[str, any]`): + Additional parameters from which to initialize the configuration object. + + Returns: + :class:`PretrainedConfig`: An instance of a configuration object + """ + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + + config = cls(**config_dict) + + if hasattr(config, "pruned_heads"): + config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) + + # Update config with kwargs if needed + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info("Model config %s", str(config)) + if return_unused_kwargs: + return config, kwargs + else: + return config + + @classmethod + def from_json_file(cls, json_file: str) -> "PretrainedConfig": + """ + Constructs a `Config` from the path to a json file of parameters. + + Args: + json_file (:obj:`string`): + Path to the JSON file containing the parameters. + + Returns: + :class:`PretrainedConfig`: An instance of a configuration object + + """ + config_dict = cls._dict_from_json_file(json_file) + return cls(**config_dict) + + @classmethod + def _dict_from_json_file(cls, json_file: str): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return "{} {}".format(self.__class__.__name__, self.to_json_string()) + + def to_diff_dict(self): + """ + Removes all attributes from config which correspond to the default + config attributes for better readability and serializes to a Python + dictionary. + + Returns: + :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = PretrainedConfig().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if key not in default_config_dict or value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. + + Returns: + :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + if hasattr(self.__class__, "model_type"): + output["model_type"] = self.__class__.model_type + return output + + def to_json_string(self, use_diff=True): + """ + Serializes this instance to a JSON string. + + Args: + use_diff (:obj:`bool`): + If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string. + + Returns: + :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format. + """ + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path, use_diff=True): + """ + Save this instance to a json file. + + Args: + json_file_path (:obj:`string`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (:obj:`bool`): + If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string(use_diff=use_diff)) + + def update(self, config_dict: Dict): + """ + Updates attributes of this class + with attributes from `config_dict`. + + Args: + :obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class. + """ + for key, value in config_dict.items(): + setattr(self, key, value) diff --git a/elia/bert/file_utils.py b/elia/bert/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..81b76b7fefd186d540fda1014dd69724049a4483 --- /dev/null +++ b/elia/bert/file_utils.py @@ -0,0 +1,808 @@ +""" +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" + +import fnmatch +import json +import logging +import os +import shutil +import sys +import tarfile +import tempfile +from contextlib import contextmanager +from functools import partial, wraps +from hashlib import sha256 +from pathlib import Path +from typing import Dict, Optional, Union +from urllib.parse import urlparse +from zipfile import ZipFile, is_zipfile + +import requests +from filelock import FileLock +from tqdm.auto import tqdm + +#from . import __version__ +__version__ = "3.0.2" + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +try: + USE_TF = os.environ.get("USE_TF", "AUTO").upper() + USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() + if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"): + import torch + + _torch_available = True # pylint: disable=invalid-name + logger.info("PyTorch version {} available.".format(torch.__version__)) + else: + logger.info("Disabling PyTorch because USE_TF is set") + _torch_available = False +except ImportError: + _torch_available = False # pylint: disable=invalid-name + +try: + USE_TF = os.environ.get("USE_TF", "AUTO").upper() + USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() + + if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"): + import tensorflow as tf + + assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 + _tf_available = True # pylint: disable=invalid-name + logger.info("TensorFlow version {} available.".format(tf.__version__)) + else: + logger.info("Disabling Tensorflow because USE_TORCH is set") + _tf_available = False +except (ImportError, AssertionError): + _tf_available = False # pylint: disable=invalid-name + + +try: + from torch.hub import _get_torch_home + + torch_cache_home = _get_torch_home() +except ImportError: + torch_cache_home = os.path.expanduser( + os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) + ) + + +try: + import torch_xla.core.xla_model as xm # noqa: F401 + + if _torch_available: + _torch_tpu_available = True # pylint: disable= + else: + _torch_tpu_available = False +except ImportError: + _torch_tpu_available = False + + +try: + import psutil # noqa: F401 + + _psutil_available = True + +except ImportError: + _psutil_available = False + + +try: + import py3nvml # noqa: F401 + + _py3nvml_available = True + +except ImportError: + _py3nvml_available = False + + +try: + from apex import amp # noqa: F401 + + _has_apex = True +except ImportError: + _has_apex = False + +default_cache_path = os.path.join(torch_cache_home, "transformers") + + +PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) +PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) +TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) + +WEIGHTS_NAME = "pytorch_model.bin" +TF2_WEIGHTS_NAME = "tf_model.h5" +TF_WEIGHTS_NAME = "model.ckpt" +CONFIG_NAME = "config.json" +MODEL_CARD_NAME = "modelcard.json" + + +MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]] +DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] +DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] + +S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" +CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co" + + +def is_torch_available(): + return _torch_available + + +def is_tf_available(): + return _tf_available + + +def is_torch_tpu_available(): + return _torch_tpu_available + + +def is_psutil_available(): + return _psutil_available + + +def is_py3nvml_available(): + return _py3nvml_available + + +def is_apex_available(): + return _has_apex + + +def add_start_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") + return fn + + return docstring_decorator + + +def add_start_docstrings_to_callable(*docstr): + def docstring_decorator(fn): + class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0]) + intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name) + note = r""" + + .. note:: + Although the recipe for forward pass needs to be defined within + this function, one should call the :class:`Module` instance afterwards + instead of this since the former takes care of running the + pre and post processing steps while the latter silently ignores them. + """ + fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") + return fn + + return docstring_decorator + + +def add_end_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = fn.__doc__ + "".join(docstr) + return fn + + return docstring_decorator + + +PT_TOKEN_CLASSIFICATION_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + >>> import torch + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1 + + >>> outputs = model(**inputs, labels=labels) + >>> loss, scores = outputs[:2] +""" + +PT_QUESTION_ANSWERING_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + >>> import torch + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> start_positions = torch.tensor([1]) + >>> end_positions = torch.tensor([3]) + + >>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions) + >>> loss, start_scores, end_scores = outputs[:3] +""" + +PT_SEQUENCE_CLASSIFICATION_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + >>> import torch + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 + >>> outputs = model(**inputs, labels=labels) + >>> loss, logits = outputs[:2] +""" + +PT_MASKED_LM_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + >>> import torch + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"] + + >>> outputs = model(input_ids, labels=input_ids) + >>> loss, prediction_scores = outputs[:2] +""" + +PT_BASE_MODEL_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + >>> import torch + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple +""" + +PT_MULTIPLE_CHOICE_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + >>> import torch + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> choice0 = "It is eaten with a fork and a knife." + >>> choice1 = "It is eaten while held in the hand." + >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1 + + >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True) + >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1 + + >>> # the linear classifier still needs to be trained + >>> loss, logits = outputs[:2] +""" + +PT_CAUSAL_LM_SAMPLE = r""" + Example:: + + >>> import torch + >>> from transformers import {tokenizer_class}, {model_class} + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs, labels=inputs["input_ids"]) + >>> loss, logits = outputs[:2] +""" + +TF_TOKEN_CLASSIFICATION_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") + >>> input_ids = inputs["input_ids"] + >>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1 + + >>> outputs = model(inputs) + >>> loss, scores = outputs[:2] +""" + +TF_QUESTION_ANSWERING_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + >>> input_dict = tokenizer(question, text, return_tensors='tf') + >>> start_scores, end_scores = model(input_dict) + + >>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0]) + >>> answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) +""" + +TF_SEQUENCE_CLASSIFICATION_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") + >>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1 + + >>> outputs = model(inputs) + >>> loss, logits = outputs[:2] +""" + +TF_MASKED_LM_SAMPLE = r""" + Example:: + >>> from transformers import {tokenizer_class}, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1 + + >>> outputs = model(input_ids) + >>> prediction_scores = outputs[0] +""" + +TF_BASE_MODEL_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") + >>> outputs = model(inputs) + + >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple +""" + +TF_MULTIPLE_CHOICE_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> choice0 = "It is eaten with a fork and a knife." + >>> choice1 = "It is eaten while held in the hand." + + >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True) + >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}} + >>> outputs = model(inputs) # batch size is 1 + + >>> # the linear classifier still needs to be trained + >>> logits = outputs[0] +""" + +TF_CAUSAL_LM_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") + >>> outputs = model(inputs) + >>> logits = outputs[0] +""" + + +def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None): + def docstring_decorator(fn): + model_class = fn.__qualname__.split(".")[0] + is_tf_class = model_class[:2] == "TF" + + if "SequenceClassification" in model_class: + code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE + elif "QuestionAnswering" in model_class: + code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE + elif "TokenClassification" in model_class: + code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE + elif "MultipleChoice" in model_class: + code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE + elif "MaskedLM" in model_class: + code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE + elif "LMHead" in model_class: + code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE + elif "Model" in model_class: + code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE + else: + raise ValueError(f"Docstring can't be built for model {model_class}") + + built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint) + fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc + return fn + + return docstring_decorator + + +def is_remote_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + + +def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str: + """ + Resolve a model identifier, and a file name, to a HF-hosted url + on either S3 or Cloudfront (a Content Delivery Network, or CDN). + + Cloudfront is replicated over the globe so downloads are way faster + for the end user (and it also lowers our bandwidth costs). However, it + is more aggressively cached by default, so may not always reflect the + latest changes to the underlying file (default TTL is 24 hours). + + In terms of client-side caching from this library, even though + Cloudfront relays the ETags from S3, using one or the other + (or switching from one to the other) will affect caching: cached files + are not shared between the two because the cached file's name contains + a hash of the url. + """ + endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX + legacy_format = "/" not in model_id + if legacy_format: + return f"{endpoint}/{model_id}-{filename}" + else: + return f"{endpoint}/{model_id}/{filename}" + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name + so that TF 2.0 can identify it as a HDF5 file + (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) + """ + url_bytes = url.encode("utf-8") + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode("utf-8") + etag_hash = sha256(etag_bytes) + filename += "." + etag_hash.hexdigest() + + if url.endswith(".h5"): + filename += ".h5" + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = TRANSFORMERS_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + ".json" + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata["url"] + etag = metadata["etag"] + + return url, etag + + +def cached_path( + url_or_filename, + cache_dir=None, + force_download=False, + proxies=None, + resume_download=False, + user_agent: Union[Dict, str, None] = None, + extract_compressed_file=False, + force_extract=False, + local_files_only=False, +) -> Optional[str]: + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + Args: + cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). + force_download: if True, re-dowload the file even if it's already cached in the cache dir. + resume_download: if True, resume the download if incompletly recieved file is found. + user_agent: Optional string or dict that will be appended to the user-agent on remote requests. + extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed + file in a folder along the archive. + force_extract: if True when extract_compressed_file is True and the archive was already extracted, + re-extract the archive and overide the folder where it was extracted. + + Return: + None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). + Local path (string) otherwise + """ + if cache_dir is None: + cache_dir = TRANSFORMERS_CACHE + if isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if is_remote_url(url_or_filename): + # URL, so get it from the cache (downloading if necessary) + output_path = get_from_cache( + url_or_filename, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + user_agent=user_agent, + local_files_only=local_files_only, + ) + elif os.path.exists(url_or_filename): + # File, and it exists. + output_path = url_or_filename + elif urlparse(url_or_filename).scheme == "": + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + + if extract_compressed_file: + if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): + return output_path + + # Path where we extract compressed archives + # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/" + output_dir, output_file = os.path.split(output_path) + output_extract_dir_name = output_file.replace(".", "-") + "-extracted" + output_path_extracted = os.path.join(output_dir, output_extract_dir_name) + + if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: + return output_path_extracted + + # Prevent parallel extractions + lock_path = output_path + ".lock" + with FileLock(lock_path): + shutil.rmtree(output_path_extracted, ignore_errors=True) + os.makedirs(output_path_extracted) + if is_zipfile(output_path): + with ZipFile(output_path, "r") as zip_file: + zip_file.extractall(output_path_extracted) + zip_file.close() + elif tarfile.is_tarfile(output_path): + tar_file = tarfile.open(output_path) + tar_file.extractall(output_path_extracted) + tar_file.close() + else: + raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) + + return output_path_extracted + + return output_path + + +def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None): + ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) + if is_torch_available(): + ua += "; torch/{}".format(torch.__version__) + if is_tf_available(): + ua += "; tensorflow/{}".format(tf.__version__) + if isinstance(user_agent, dict): + ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) + elif isinstance(user_agent, str): + ua += "; " + user_agent + headers = {"user-agent": ua} + if resume_size > 0: + headers["Range"] = "bytes=%d-" % (resume_size,) + response = requests.get(url, stream=True, proxies=proxies, headers=headers) + if response.status_code == 416: # Range not satisfiable + return + content_length = response.headers.get("Content-Length") + total = resume_size + int(content_length) if content_length is not None else None + progress = tqdm( + unit="B", + unit_scale=True, + total=total, + initial=resume_size, + desc="Downloading", + disable=bool(logger.getEffectiveLevel() == logging.NOTSET), + ) + for chunk in response.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache( + url, + cache_dir=None, + force_download=False, + proxies=None, + etag_timeout=10, + resume_download=False, + user_agent: Union[Dict, str, None] = None, + local_files_only=False, +) -> Optional[str]: + """ + Given a URL, look for the corresponding file in the local cache. + If it's not there, download it. Then return the path to the cached file. + + Return: + None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). + Local path (string) otherwise + """ + if cache_dir is None: + cache_dir = TRANSFORMERS_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + os.makedirs(cache_dir, exist_ok=True) + + etag = None + if not local_files_only: + try: + response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) + if response.status_code == 200: + etag = response.headers.get("ETag") + except (EnvironmentError, requests.exceptions.Timeout): + # etag is already None + pass + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible. + # try to get the last downloaded one + if etag is None: + if os.path.exists(cache_path): + return cache_path + else: + matching_files = [ + file + for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") + if not file.endswith(".json") and not file.endswith(".lock") + ] + if len(matching_files) > 0: + return os.path.join(cache_dir, matching_files[-1]) + else: + # If files cannot be found and local_files_only=True, + # the models might've been found if local_files_only=False + # Notify the user about that + if local_files_only: + raise ValueError( + "Cannot find the requested files in the cached path and outgoing traffic has been" + " disabled. To enable model look-ups and downloads online, set 'local_files_only'" + " to False." + ) + return None + + # From now on, etag is not None. + if os.path.exists(cache_path) and not force_download: + return cache_path + + # Prevent parallel downloads of the same file with a lock. + lock_path = cache_path + ".lock" + with FileLock(lock_path): + + # If the download just completed while the lock was activated. + if os.path.exists(cache_path) and not force_download: + # Even if returning early like here, the lock will be released. + return cache_path + + if resume_download: + incomplete_path = cache_path + ".incomplete" + + @contextmanager + def _resumable_file_manager(): + with open(incomplete_path, "a+b") as f: + yield f + + temp_file_manager = _resumable_file_manager + if os.path.exists(incomplete_path): + resume_size = os.stat(incomplete_path).st_size + else: + resume_size = 0 + else: + temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) + resume_size = 0 + + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with temp_file_manager() as temp_file: + logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) + + http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) + + logger.info("storing %s in cache at %s", url, cache_path) + os.replace(temp_file.name, cache_path) + + logger.info("creating metadata file for %s", cache_path) + meta = {"url": url, "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: + json.dump(meta, meta_file) + + return cache_path + + +class cached_property(property): + """ + Descriptor that mimics @property but caches output in member variable. + + From tensorflow_datasets + + Built-in in functools from Python 3.8. + """ + + def __get__(self, obj, objtype=None): + # See docs.python.org/3/howto/descriptor.html#properties + if obj is None: + return self + if self.fget is None: + raise AttributeError("unreadable attribute") + attr = "__cached_" + self.fget.__name__ + cached = getattr(obj, attr, None) + if cached is None: + cached = self.fget(obj) + setattr(obj, attr, cached) + return cached + + +def torch_required(func): + # Chose a different decorator name than in tests so it's clear they are not the same. + @wraps(func) + def wrapper(*args, **kwargs): + if is_torch_available(): + return func(*args, **kwargs) + else: + raise ImportError(f"Method `{func.__name__}` requires PyTorch.") + + return wrapper + + +def tf_required(func): + # Chose a different decorator name than in tests so it's clear they are not the same. + @wraps(func) + def wrapper(*args, **kwargs): + if is_tf_available(): + return func(*args, **kwargs) + else: + raise ImportError(f"Method `{func.__name__}` requires TF.") + + return wrapper diff --git a/elia/bert/generation_utils.py b/elia/bert/generation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3c49e15abf7bc822940112207f10e18e2e0230cc --- /dev/null +++ b/elia/bert/generation_utils.py @@ -0,0 +1,993 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Iterable, Optional, Tuple + +import torch +from torch import Tensor +from torch.nn import functional as F + + +logger = logging.getLogger(__name__) + + +class GenerationMixin: + """ + A class contraining all of the functions supporting generation, to be used as a mixin in PreTrainedModel. + """ + + def prepare_inputs_for_generation(self, input_ids, **kwargs): + return {"input_ids": input_ids} + + def adjust_logits_during_generation(self, logits, **kwargs): + return logits + + def _use_cache(self, outputs, use_cache): + """During generation, decide whether to pass the `past` variable to the next forward pass.""" + if len(outputs) <= 1 or use_cache is False: + return False + if hasattr(self.config, "mem_len") and self.config.mem_len == 0: + return False + return True + + def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): + """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """ + for i in range(batch_size * num_beams): + for previous_token in set(prev_output_tokens[i].tolist()): + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if lprobs[i, previous_token] < 0: + lprobs[i, previous_token] *= repetition_penalty + else: + lprobs[i, previous_token] /= repetition_penalty + + def postprocess_next_token_scores( + self, + scores, + input_ids, + no_repeat_ngram_size, + bad_words_ids, + cur_len, + min_length, + max_length, + eos_token_id, + repetition_penalty, + batch_size, + num_beams, + ): + # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + self.enforce_repetition_penalty_( + scores, batch_size, num_beams, input_ids, repetition_penalty, + ) + + # set eos token prob to zero if min_length is not reached + if eos_token_id is not None and cur_len < min_length: + scores[:, eos_token_id] = -float("inf") + + if no_repeat_ngram_size > 0: + # calculate a list of banned tokens to prevent repetitively generating the same ngrams + num_batch_hypotheses = batch_size * num_beams + # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 + banned_batch_tokens = calc_banned_ngram_tokens( + input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len + ) + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float("inf") + + if bad_words_ids is not None: + # calculate a list of banned tokens according to bad words + banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) + + for i, banned_tokens in enumerate(banned_tokens): + scores[i, banned_tokens] = -float("inf") + + return scores + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.LongTensor] = None, + max_length: Optional[int] = None, + min_length: Optional[int] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[bool] = None, + num_beams: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + bad_words_ids: Optional[Iterable[int]] = None, + bos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + length_penalty: Optional[float] = None, + no_repeat_ngram_size: Optional[int] = None, + num_return_sequences: Optional[int] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_start_token_id: Optional[int] = None, + use_cache: Optional[bool] = None, + **model_specific_kwargs + ) -> torch.LongTensor: + r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. + + Adapted in part from `Facebook's XLM beam search code`_. + + .. _`Facebook's XLM beam search code`: + https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529 + + + Parameters: + + input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)` + The sequence used as a prompt for the generation. If `None` the method initializes + it as an empty `torch.LongTensor` of shape `(1,)`. + + max_length: (`optional`) int + The max length of the sequence to be generated. Between `min_length` and infinity. Default to 20. + + min_length: (`optional`) int + The min length of the sequence to be generated. Between 0 and infinity. Default to 0. + + do_sample: (`optional`) bool + If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`. + + early_stopping: (`optional`) bool + if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`. + + num_beams: (`optional`) int + Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1. + + temperature: (`optional`) float + The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + + top_k: (`optional`) int + The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. + + top_p: (`optional`) float + The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. + + repetition_penalty: (`optional`) float + The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0. + + pad_token_id: (`optional`) int + Padding token. Default to specicic model pad_token_id or None if it does not exist. + + bos_token_id: (`optional`) int + BOS token. Defaults to `bos_token_id` as defined in the models config. + + eos_token_id: (`optional`) int + EOS token. Defaults to `eos_token_id` as defined in the models config. + + length_penalty: (`optional`) float + Exponential penalty to the length. Default to 1. + + no_repeat_ngram_size: (`optional`) int + If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once. + bad_words_ids: (`optional`) list of lists of int + `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. + + num_return_sequences: (`optional`) int + The number of independently computed returned sequences for each element in the batch. Default to 1. + + attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids` + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + Defaults to `None`. + + `What are attention masks? <../glossary.html#attention-mask>`__ + + decoder_start_token_id=None: (`optional`) int + If an encoder-decoder model starts decoding with a different token than BOS. + Defaults to `None` and is changed to `BOS` later. + + use_cache: (`optional`) bool + If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`. + + model_specific_kwargs: (`optional`) dict + Additional model specific kwargs will be forwarded to the `forward` function of the model. + + Return: + + output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)` + sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id` + + Examples:: + + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. + outputs = model.generate(max_length=40) # do greedy decoding + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. + input_context = 'The dog' + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' + for i in range(3): # 3 output sequences were generated + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. + input_context = 'The dog' + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling + for i in range(3): # 3 output sequences were generated + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. + input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache. + input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl + bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated + """ + + # We cannot generate if the model does not have a LM head + if self.get_output_embeddings() is None: + raise AttributeError( + "You tried to generate sequences with a model that does not have a LM Head." + "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )" + ) + + max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length + do_sample = do_sample if do_sample is not None else self.config.do_sample + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + use_cache = use_cache if use_cache is not None else self.config.use_cache + num_beams = num_beams if num_beams is not None else self.config.num_beams + temperature = temperature if temperature is not None else self.config.temperature + top_k = top_k if top_k is not None else self.config.top_k + top_p = top_p if top_p is not None else self.config.top_p + repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + no_repeat_ngram_size = ( + no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size + ) + bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids + num_return_sequences = ( + num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences + ) + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] # overriden by the input batch_size + else: + batch_size = 1 + + assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." + assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." + assert isinstance(do_sample, bool), "`do_sample` should be a boolean." + assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." + assert isinstance(use_cache, bool), "`use_cache` should be a boolean." + assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." + assert temperature > 0, "`temperature` should be strictly positive." + assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." + assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." + assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." + assert input_ids is not None or ( + isinstance(bos_token_id, int) and bos_token_id >= 0 + ), "If input_ids is not defined, `bos_token_id` should be a positive integer." + assert pad_token_id is None or ( + isinstance(pad_token_id, int) and (pad_token_id >= 0) + ), "`pad_token_id` should be a positive integer." + assert (eos_token_id is None) or ( + isinstance(eos_token_id, int) and (eos_token_id >= 0) + ), "`eos_token_id` should be a positive integer." + assert length_penalty > 0, "`length_penalty` should be strictly positive." + assert ( + isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 + ), "`no_repeat_ngram_size` should be a positive integer." + assert ( + isinstance(num_return_sequences, int) and num_return_sequences > 0 + ), "`num_return_sequences` should be a strictly positive integer." + assert ( + bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) + ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" + + if input_ids is None: + assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( + "you should either supply a context to complete as `input_ids` input " + "or a `bos_token_id` (integer >= 0) as a first token to start the generation." + ) + input_ids = torch.full( + (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device, + ) + else: + assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." + + # not allow to duplicate outputs when greedy decoding + if do_sample is False: + if num_beams == 1: + # no_beam_search greedy generation conditions + assert ( + num_return_sequences == 1 + ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1" + + else: + # beam_search greedy generation conditions + assert ( + num_beams >= num_return_sequences + ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" + + # create attention mask if necessary + # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 + if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): + attention_mask = input_ids.ne(pad_token_id).long() + elif attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + # set pad_token_id to eos_token_id if not set. Important that this is done after + # attention_mask is created + if pad_token_id is None and eos_token_id is not None: + logger.warning( + "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id) + ) + pad_token_id = eos_token_id + + # current position and vocab size + if hasattr(self.config, "vocab_size"): + vocab_size = self.config.vocab_size + elif ( + self.config.is_encoder_decoder + and hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "vocab_size") + ): + vocab_size = self.config.decoder.vocab_size + + # set effective batch size and effective batch multiplier according to do_sample + if do_sample: + effective_batch_size = batch_size * num_return_sequences + effective_batch_mult = num_return_sequences + else: + effective_batch_size = batch_size + effective_batch_mult = 1 + + if self.config.is_encoder_decoder: + if decoder_start_token_id is None: + decoder_start_token_id = bos_token_id + + assert ( + decoder_start_token_id is not None + ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" + assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) + assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) + + # get encoder and store encoder outputs + encoder = self.get_encoder() + + encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask) + + # Expand input ids if num_beams > 1 or num_return_sequences > 1 + if num_return_sequences > 1 or num_beams > 1: + input_ids_len = input_ids.shape[-1] + input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) + attention_mask = attention_mask.unsqueeze(1).expand( + batch_size, effective_batch_mult * num_beams, input_ids_len + ) + + input_ids = input_ids.contiguous().view( + effective_batch_size * num_beams, input_ids_len + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) + attention_mask = attention_mask.contiguous().view( + effective_batch_size * num_beams, input_ids_len + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) + + if self.config.is_encoder_decoder: + # create empty decoder_input_ids + input_ids = torch.full( + (effective_batch_size * num_beams, 1), + decoder_start_token_id, + dtype=torch.long, + device=next(self.parameters()).device, + ) + cur_len = 1 + + assert ( + batch_size == encoder_outputs[0].shape[0] + ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} " + + # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) + expanded_batch_idxs = ( + torch.arange(batch_size) + .view(-1, 1) + .repeat(1, num_beams * effective_batch_mult) + .view(-1) + .to(input_ids.device) + ) + # expand encoder_outputs + encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:]) + + else: + encoder_outputs = None + cur_len = input_ids.shape[-1] + + assert ( + cur_len < max_length + ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`" + + if num_beams > 1: + output = self._generate_beam_search( + input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + early_stopping=early_stopping, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + batch_size=effective_batch_size, + num_return_sequences=num_return_sequences, + length_penalty=length_penalty, + num_beams=num_beams, + vocab_size=vocab_size, + encoder_outputs=encoder_outputs, + attention_mask=attention_mask, + use_cache=use_cache, + model_specific_kwargs=model_specific_kwargs, + ) + else: + output = self._generate_no_beam_search( + input_ids, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + batch_size=effective_batch_size, + encoder_outputs=encoder_outputs, + attention_mask=attention_mask, + use_cache=use_cache, + model_specific_kwargs=model_specific_kwargs, + ) + + return output + + def _generate_no_beam_search( + self, + input_ids, + cur_len, + max_length, + min_length, + do_sample, + temperature, + top_k, + top_p, + repetition_penalty, + no_repeat_ngram_size, + bad_words_ids, + pad_token_id, + eos_token_id, + batch_size, + encoder_outputs, + attention_mask, + use_cache, + model_specific_kwargs, + ): + """ Generate sequences for each example without beam search (num_beams == 1). + All returned sequence are generated independantly. + """ + # length of generated sentences / unfinished sentences + unfinished_sents = input_ids.new(batch_size).fill_(1) + sent_lengths = input_ids.new(batch_size).fill_(max_length) + + past = (encoder_outputs, None) if encoder_outputs is not None else None + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation( + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs + ) + + outputs = self(**model_inputs) + next_token_logits = outputs[0][:, -1, :] + + scores = self.postprocess_next_token_scores( + scores=next_token_logits, + input_ids=input_ids, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + cur_len=cur_len, + min_length=min_length, + max_length=max_length, + eos_token_id=eos_token_id, + repetition_penalty=repetition_penalty, + batch_size=batch_size, + num_beams=1, + ) + + # if model has past, then set the past variable to speed up decoding + if self._use_cache(outputs, use_cache): + past = outputs[1] + + if do_sample: + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + scores = scores / temperature + # Top-p/top-k filtering + next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) + # Sample + probs = F.softmax(next_token_logscores, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + # Greedy decoding + next_token = torch.argmax(next_token_logits, dim=-1) + + # update generations and finished sentences + if eos_token_id is not None: + # pad finished sentences if eos_token_id exist + tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) + else: + tokens_to_add = next_token + + # add token and increase length by one + input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) + cur_len = cur_len + 1 + + if eos_token_id is not None: + eos_in_sents = tokens_to_add == eos_token_id + # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length + is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() + sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len) + # unfinished_sents is set to zero if eos in sentence + unfinished_sents.mul_((~eos_in_sents).long()) + + # stop when there is a in each sentence, or if we exceed the maximul length + if unfinished_sents.max() == 0: + break + + # extend attention_mask for new generated input if only decoder + if self.config.is_encoder_decoder is False: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + return input_ids + + def _generate_beam_search( + self, + input_ids, + cur_len, + max_length, + min_length, + do_sample, + early_stopping, + temperature, + top_k, + top_p, + repetition_penalty, + no_repeat_ngram_size, + bad_words_ids, + pad_token_id, + eos_token_id, + batch_size, + num_return_sequences, + length_penalty, + num_beams, + vocab_size, + encoder_outputs, + attention_mask, + use_cache, + model_specific_kwargs, + ): + """ Generate sequences for each example with beam search. + """ + + # generated hypotheses + generated_hyps = [ + BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) + for _ in range(batch_size) + ] + + # scores for each sentence in the beam + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + + # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times + if do_sample is False: + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) + + # cache compute states + past = (encoder_outputs, None) if encoder_outputs is not None else None + + # done sentences + done = [False for _ in range(batch_size)] + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation( + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs + ) + outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) + next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) + + # if model has past, then set the past variable to speed up decoding + if self._use_cache(outputs, use_cache): + past = outputs[1] + if self.config.is_encoder_decoder and do_sample is False: + # TODO (PVP) still a bit hacky here - there might be a better solution + next_token_logits = self.adjust_logits_during_generation( + next_token_logits, cur_len=cur_len, max_length=max_length + ) + + scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) + + scores = self.postprocess_next_token_scores( + scores=scores, + input_ids=input_ids, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + cur_len=cur_len, + min_length=min_length, + max_length=max_length, + eos_token_id=eos_token_id, + repetition_penalty=repetition_penalty, + batch_size=batch_size, + num_beams=num_beams, + ) + + assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( + scores.shape, (batch_size * num_beams, vocab_size) + ) + + if do_sample: + _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) + # Temperature + if temperature != 1.0: + _scores = _scores / temperature + # Top-p/top-k filtering + _scores = top_k_top_p_filtering( + _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 + ) # (batch_size * num_beams, vocab_size) + # re-organize to group the beam together to sample from all beam_idxs + _scores = _scores.contiguous().view( + batch_size, num_beams * vocab_size + ) # (batch_size, num_beams * vocab_size) + + # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) + probs = F.softmax(_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2) + # Compute next scores + next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) + # sort the sampled vector to make sure that the first num_beams samples are the best + next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2) + + else: + next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) + + # re-organize to group the beam together (we are keeping top hypothesis accross beams) + next_scores = next_scores.view( + batch_size, num_beams * vocab_size + ) # (batch_size, num_beams * vocab_size) + + next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True) + + assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams) + + # next batch beam content + next_batch_beam = [] + + # for each sentence + for batch_idx in range(batch_size): + + # if we are done with this sentence, add a pad token + if done[batch_idx]: + assert ( + len(generated_hyps[batch_idx]) >= num_beams + ), "Batch can only be done if at least {} beams have been generated".format(num_beams) + assert ( + eos_token_id is not None and pad_token_id is not None + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" + next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch + continue + + # next sentence beam content, this will get added to next_batch_beam + next_sent_beam = [] + + # next tokens for this sentence + for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx]) + ): + # get beam and token IDs + beam_id = beam_token_id // vocab_size + token_id = beam_token_id % vocab_size + + effective_beam_id = batch_idx * num_beams + beam_id + # add to generated hypotheses if end of sentence + if (eos_token_id is not None) and (token_id.item() == eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams + if is_beam_token_worse_than_top_num_beams: + continue + generated_hyps[batch_idx].add( + input_ids[effective_beam_id].clone(), beam_token_score.item(), + ) + else: + # add next predicted token since it is not eos_token + next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) + + # once the beam for next step is full, don't add more tokens to it. + if len(next_sent_beam) == num_beams: + break + + # Check if we are done so that we can save a pad step if all(done) + done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( + next_scores[batch_idx].max().item(), cur_len + ) + + # update next beam content + assert len(next_sent_beam) == num_beams, "Beam should always be full" + next_batch_beam.extend(next_sent_beam) + assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step" + + # stop when we are done with each sentence + if all(done): + break + + # sanity check / prepare next batch + assert len(next_batch_beam) == batch_size * num_beams + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_tokens = input_ids.new([x[1] for x in next_batch_beam]) + beam_idx = input_ids.new([x[2] for x in next_batch_beam]) + + # re-order batch and update current length + input_ids = input_ids[beam_idx, :] + input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) + cur_len = cur_len + 1 + + # re-order internal states + if past is not None: + past = self._reorder_cache(past, beam_idx) + + # extend attention_mask for new generated input if only decoder + if self.config.is_encoder_decoder is False: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + # finalize all open beam hypotheses and add to generated hypotheses + for batch_idx in range(batch_size): + if done[batch_idx]: + continue + + # test that beam scores match previously calculated scores if not eos and batch_idx not done + if eos_token_id is not None and all( + (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx] + ): + assert torch.all( + next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx] + ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format( + next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx], + ) + + # need to add best num_beams hypotheses to generated hyps + for beam_id in range(num_beams): + effective_beam_id = batch_idx * num_beams + beam_id + final_score = beam_scores[effective_beam_id].item() + final_tokens = input_ids[effective_beam_id] + generated_hyps[batch_idx].add(final_tokens, final_score) + + # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch + output_batch_size = batch_size if do_sample else batch_size * num_return_sequences + output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences + + # select the best hypotheses + sent_lengths = input_ids.new(output_batch_size) + best = [] + + # retrieve best hypotheses + for i, hypotheses in enumerate(generated_hyps): + sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0]) + for j in range(output_num_return_sequences_per_batch): + effective_batch_idx = output_num_return_sequences_per_batch * i + j + best_hyp = sorted_hyps.pop()[1] + sent_lengths[effective_batch_idx] = len(best_hyp) + best.append(best_hyp) + + # shorter batches are padded + if sent_lengths.min().item() != sent_lengths.max().item(): + assert pad_token_id is not None, "`Pad_token_id` has to be defined" + sent_max_len = min(sent_lengths.max().item() + 1, max_length) + decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id) + + # fill with hypothesis and eos_token_id if necessary + for i, hypo in enumerate(best): + decoded[i, : sent_lengths[i]] = hypo + if sent_lengths[i] < max_length: + decoded[i, sent_lengths[i]] = eos_token_id + else: + # none of the hypotheses have an eos_token + assert (len(hypo) == max_length for hypo in best) + decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) + + return decoded + + @staticmethod + def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]: + return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) + + +def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None: + """Copied from fairseq for no_repeat_ngram in beam_search""" + if cur_len + 1 < no_repeat_ngram_size: + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + return [[] for _ in range(num_hypos)] + generated_ngrams = [{} for _ in range(num_hypos)] + for idx in range(num_hypos): + gen_tokens = prev_input_ids[idx].tolist() + generated_ngram = generated_ngrams[idx] + for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] + + def _get_generated_ngrams(hypo_idx): + # Before decoding the next token, prevent decoding of ngrams that have already appeared + start_idx = cur_len + 1 - no_repeat_ngram_size + ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) + return generated_ngrams[hypo_idx].get(ngram_idx, []) + + banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] + return banned_tokens + + +def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]: + banned_tokens = [] + + def _tokens_match(prev_tokens, tokens): + if len(tokens) == 0: + # if bad word tokens is just one token always ban it + return True + if len(tokens) > len(prev_input_ids): + # if bad word tokens are longer then prev input_ids they can't be equal + return False + + if prev_tokens[-len(tokens) :] == tokens: + # if tokens match + return True + else: + return False + + for prev_input_ids_slice in prev_input_ids: + banned_tokens_slice = [] + + for banned_token_seq in bad_words_ids: + assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format( + bad_words_ids + ) + + if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False: + # if tokens do not match continue + continue + + banned_tokens_slice.append(banned_token_seq[-1]) + + banned_tokens.append(banned_tokens_slice) + + return banned_tokens + + +def top_k_top_p_filtering( + logits: Tensor, + top_k: int = 0, + top_p: float = 1.0, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, +) -> Tensor: + """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits + + +class BeamHypotheses(object): + def __init__(self, num_beams, max_length, length_penalty, early_stopping): + """ + Initialize n-best list of hypotheses. + """ + self.max_length = max_length - 1 # ignoring bos_token + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.num_beams = num_beams + self.beams = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.beams) + + def add(self, hyp, sum_logprobs): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / len(hyp) ** self.length_penalty + if len(self) < self.num_beams or score > self.worst_score: + self.beams.append((score, hyp)) + if len(self) > self.num_beams: + sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) + del self.beams[sorted_scores[0][1]] + self.worst_score = sorted_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs, cur_len): + """ + If there are enough hypotheses and that none of the hypotheses being generated + can become better than the worst one in the heap, then we are done with this sentence. + """ + + if len(self) < self.num_beams: + return False + elif self.early_stopping: + return True + else: + cur_score = best_sum_logprobs / cur_len ** self.length_penalty + ret = self.worst_score >= cur_score + return ret diff --git a/elia/bert/modeling_bert.py b/elia/bert/modeling_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..e796878aa2e6e39d6e65b0941396bcedca396a46 --- /dev/null +++ b/elia/bert/modeling_bert.py @@ -0,0 +1,1569 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model. """ + + +import logging +import math +import os +import warnings + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from .activations import gelu, gelu_new, swish +from .configuration_bert import BertConfig +from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable +from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer + + +logger = logging.getLogger(__name__) + +_TOKENIZER_FOR_DOC = "BertTokenizer" + +BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "bert-base-uncased", + "bert-large-uncased", + "bert-base-cased", + "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", + "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-large-uncased-whole-word-masking-finetuned-squad", + "bert-large-cased-whole-word-masking-finetuned-squad", + "bert-base-cased-finetuned-mrpc", + "bert-base-german-dbmdz-cased", + "bert-base-german-dbmdz-uncased", + "cl-tohoku/bert-base-japanese", + "cl-tohoku/bert-base-japanese-whole-word-masking", + "cl-tohoku/bert-base-japanese-char", + "cl-tohoku/bert-base-japanese-char-whole-word-masking", + "TurkuNLP/bert-base-finnish-cased-v1", + "TurkuNLP/bert-base-finnish-uncased-v1", + "wietsedv/bert-base-dutch-cased", + # See all BERT models at https://huggingface.co/models?filter=bert +] + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """ Load tf checkpoints in a pytorch model. + """ + try: + import re + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info("Skipping {}".format("/".join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +def mish(x): + return x * torch.tanh(nn.functional.softplus(x)) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish} + + +BertLayerNorm = torch.nn.LayerNorm + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + if encoder_hidden_states is not None: + mixed_key_layer = self.key(encoder_hidden_states) + mixed_value_layer = self.value(encoder_hidden_states) + attention_mask = encoder_attention_mask + else: + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = BertAttention(config) + self.is_decoder = config.is_decoder + if self.is_decoder: + self.crossattention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=False, + ): + self_attention_outputs = self.attention( + hidden_states, attention_mask, head_mask, output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output,) + outputs + return outputs + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=False, + output_hidden_states=False, + ): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + output_attentions, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +BERT_START_DOCSTRING = r""" + This model is a PyTorch `torch.nn.Module `_ sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`transformers.BertTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.__call__` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + corresponds to a `sentence B` token + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask + is used in the cross-attention if the model is configured as a decoder. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): + If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. +""" + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, +) +class BertModel(BertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + :obj:`is_decoder` argument of the configuration set to :obj:`True`; an + :obj:`encoder_hidden_states` is expected as an input to the forward pass. + + .. _`Attention is all you need`: + https://arxiv.org/abs/1706.03762 + + """ + + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased") + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=None, + output_hidden_states=None, + ): + r""" + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + objective during pre-training. + + This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + outputs = (sequence_output, pooled_output,) + encoder_outputs[ + 1: + ] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + + +@add_start_docstrings( + """Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and + a `next sentence prediction (classification)` head. """, + BERT_START_DOCSTRING, +) +class BertForPreTraining(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + next_sentence_label=None, + output_attentions=None, + output_hidden_states=None, + **kwargs + ): + r""" + labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`): + Labels for computing the masked language modeling loss. + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring) + Indices should be in ``[0, 1]``. + ``0`` indicates sequence B is a continuation of sequence A, + ``1`` indicates sequence B is a random sequence. + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss. + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False + continuation before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + + Examples:: + + >>> from transformers import BertTokenizer, BertForPreTraining + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> model = BertForPreTraining.from_pretrained('bert-base-uncased') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_scores, seq_relationship_scores = outputs[:2] + + """ + if "masked_lm_labels" in kwargs: + warnings.warn( + "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", + DeprecationWarning, + ) + labels = kwargs.pop("masked_lm_labels") + assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + outputs = (prediction_scores, seq_relationship_score,) + outputs[ + 2: + ] # add hidden states and attention if they are here + + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + outputs = (total_loss,) + outputs + + return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) + + +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING +) +class BertLMHeadModel(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`." + + self.bert = BertModel(config) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=None, + output_hidden_states=None, + **kwargs + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the left-to-right language modeling loss (next word prediction). + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Next token prediction loss. + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Example:: + + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> config.is_decoder = True + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + """ + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here + + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + outputs = (ltr_lm_loss,) + outputs + + return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) +class BertForMaskedLM(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + assert ( + not config.is_decoder + ), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention." + + self.bert = BertModel(config) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased") + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=None, + output_hidden_states=None, + **kwargs + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the masked language modeling loss. + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Masked language modeling loss. + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + if "masked_lm_labels" in kwargs: + warnings.warn( + "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", + DeprecationWarning, + ) + labels = kwargs.pop("masked_lm_labels") + assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task." + assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here + + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + outputs = (masked_lm_loss,) + outputs + + return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + assert self.config.pad_token_id is not None, "The PAD token should be defined for generation" + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING, +) +class BertForNextSentencePrediction(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + + self.init_weights() + + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + next_sentence_label=None, + output_attentions=None, + output_hidden_states=None, + ): + r""" + next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) + Indices should be in ``[0, 1]``. + ``0`` indicates sequence B is a continuation of sequence A, + ``1`` indicates sequence B is a random sequence. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided): + Next sequence prediction (classification) loss. + seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + >>> from transformers import BertTokenizer, BertForNextSentencePrediction + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased') + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt') + + >>> loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1])) + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + """ + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + pooled_output = outputs[1] + + seq_relationship_score = self.cls(pooled_output) + + outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here + if next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + outputs = (next_sentence_loss,) + outputs + + return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) + + +@add_start_docstrings( + """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks. """, + BERT_START_DOCSTRING, +) +class BertForSequenceClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased") + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. + If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + + return outputs # (loss), logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """Bert Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, + BERT_START_DOCSTRING, +) +class BertForMultipleChoice(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + self.init_weights() + + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) + @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased") + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the multiple choice classification loss. + Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above) + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss. + classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): + `num_choices` is the second dimension of the input tensors. (see `input_ids` above). + + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + outputs = (loss,) + outputs + + return outputs # (loss), reshaped_logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """Bert Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, + BERT_START_DOCSTRING, +) +class BertForTokenClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased") + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the token classification loss. + Indices should be in ``[0, ..., config.num_labels - 1]``. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) : + Classification loss. + scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`) + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + + return outputs # (loss), scores, (hidden_states), (attentions) + + +@add_start_docstrings( + """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, + BERT_START_DOCSTRING, +) +class BertForQuestionAnswering(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased") + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-start scores (before SoftMax). + end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-end scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + outputs = (start_logits, end_logits,) + outputs[2:] + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + outputs = (total_loss,) + outputs + + return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) diff --git a/elia/bert/modeling_utils.py b/elia/bert/modeling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c5855e204247404cbfa6fed64fc930ca61d13780 --- /dev/null +++ b/elia/bert/modeling_utils.py @@ -0,0 +1,1268 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +import os +from typing import Callable, Dict, List, Optional, Tuple + +import torch +from torch import Tensor, device, dtype, nn +from torch.nn import CrossEntropyLoss +from torch.nn import functional as F + +from .activations import get_activation +from .configuration_utils import PretrainedConfig +from .file_utils import ( + DUMMY_INPUTS, + TF2_WEIGHTS_NAME, + TF_WEIGHTS_NAME, + WEIGHTS_NAME, + cached_path, + hf_bucket_url, + is_remote_url, +) +from .generation_utils import GenerationMixin + + +logger = logging.getLogger(__name__) + + +try: + from torch.nn import Identity +except ImportError: + # Older PyTorch compatibility + class Identity(nn.Module): + r"""A placeholder identity operator that is argument-insensitive. + """ + + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, input): + return input + + +def find_pruneable_heads_and_indices( + heads: List, n_heads: int, head_size: int, already_pruned_heads: set +) -> Tuple[set, "torch.LongTensor"]: + mask = torch.ones(n_heads, head_size) + heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads + for head in heads: + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in already_pruned_heads) + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index: torch.LongTensor = torch.arange(len(mask))[mask].long() + return heads, index + + +class ModuleUtilsMixin: + """ + A few utilities for torch.nn.Modules, to be used as a mixin. + """ + + def num_parameters(self, only_trainable: bool = False) -> int: + """ + Get number of (optionally, trainable) parameters in the module. + """ + params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters() + return sum(p.numel() for p in params) + + @staticmethod + def _hook_rss_memory_pre_forward(module, *args, **kwargs): + try: + import psutil + except (ImportError): + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") + + process = psutil.Process(os.getpid()) + mem = process.memory_info() + module.mem_rss_pre_forward = mem.rss + return None + + @staticmethod + def _hook_rss_memory_post_forward(module, *args, **kwargs): + try: + import psutil + except (ImportError): + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") + + process = psutil.Process(os.getpid()) + mem = process.memory_info() + module.mem_rss_post_forward = mem.rss + mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward + module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0) + return None + + def add_memory_hooks(self): + """ Add a memory hook before and after each sub-module forward pass to record increase in memory consumption. + Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()` + """ + for module in self.modules(): + module.register_forward_pre_hook(self._hook_rss_memory_pre_forward) + module.register_forward_hook(self._hook_rss_memory_post_forward) + self.reset_memory_hooks_state() + + def reset_memory_hooks_state(self): + for module in self.modules(): + module.mem_rss_diff = 0 + module.mem_rss_post_forward = 0 + module.mem_rss_pre_forward = 0 + + @property + def device(self) -> device: + """ + Get torch.device from module, assuming that the whole module has one device. + """ + try: + return next(self.parameters()).device + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = self._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + @property + def dtype(self) -> dtype: + """ + Get torch.dtype from module, assuming that the whole module has one dtype. + """ + try: + return next(self.parameters()).dtype + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = self._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: + """type: torch.Tensor -> torch.Tensor""" + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + + if self.dtype == torch.float16: + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4 + elif self.dtype == torch.float32: + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 + else: + raise ValueError( + "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format( + self.dtype + ) + ) + + return encoder_extended_attention_mask + + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor: + """Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored. + + Arguments: + attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to + input_shape: tuple, shape of input_ids + device: torch.Device, usually self.device + + Returns: + torch.Tensor with dtype of attention_mask.dtype + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor: + """ + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + attention_probs has shape bsz x n_heads x N x N + Arguments: + head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads] + num_hidden_layers: int + Returns: + Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + or list with [None] for each layer + """ + if head_mask is not None: + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + if is_attention_chunked is True: + head_mask = head_mask.unsqueeze(-1) + else: + head_mask = [None] * num_hidden_layers + + return head_mask + + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer + assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" + head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility + return head_mask + + +class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): + r""" Base class for all models. + + :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models + as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads. + + Class attributes (overridden by derived classes): + - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. + - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments: + + - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`, + - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`, + - ``path``: a path (string) to the TensorFlow checkpoint. + + - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. + """ + config_class = None + base_model_prefix = "" + + @property + def dummy_inputs(self): + """ Dummy inputs to do a forward pass in the network. + + Returns: + torch.Tensor with dummy inputs + """ + return {"input_ids": torch.tensor(DUMMY_INPUTS)} + + def __init__(self, config, *inputs, **kwargs): + super().__init__() + if not isinstance(config, PretrainedConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " + "To create a model from a pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + ) + ) + # Save config in model + self.config = config + + @property + def base_model(self): + return getattr(self, self.base_model_prefix, self) + + def get_input_embeddings(self): + """ + Returns the model's input embeddings. + + Returns: + :obj:`nn.Module`: + A torch module mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.get_input_embeddings() + else: + raise NotImplementedError + + def set_input_embeddings(self, value: nn.Module): + """ + Set model's input embeddings + + Args: + value (:obj:`nn.Module`): + A module mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + base_model.set_input_embeddings(value) + else: + raise NotImplementedError + + def get_output_embeddings(self): + """ + Returns the model's output embeddings. + + Returns: + :obj:`nn.Module`: + A torch module mapping hidden states to vocabulary. + """ + return None # Overwrite for models with output embeddings + + def tie_weights(self): + """ + Tie the weights between the input embeddings and the output embeddings. + If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning + the weights instead. + """ + output_embeddings = self.get_output_embeddings() + if output_embeddings is not None: + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + + def _tie_or_clone_weights(self, output_embeddings, input_embeddings): + """ Tie or clone module weights depending of whether we are using TorchScript or not + """ + if self.config.torchscript: + output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone()) + else: + output_embeddings.weight = input_embeddings.weight + + if getattr(output_embeddings, "bias", None) is not None: + output_embeddings.bias.data = torch.nn.functional.pad( + output_embeddings.bias.data, + (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],), + "constant", + 0, + ) + if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): + output_embeddings.out_features = input_embeddings.num_embeddings + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None): + """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. + Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + + Arguments: + + new_num_tokens: (`optional`) int: + New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. + If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model. + + Return: ``torch.nn.Embeddings`` + Pointer to the input tokens Embeddings Module of the model + """ + base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed + model_embeds = base_model._resize_token_embeddings(new_num_tokens) + if new_num_tokens is None: + return model_embeds + + # Update base model and current model config + self.config.vocab_size = new_num_tokens + base_model.vocab_size = new_num_tokens + + # Tie weights again if needed + self.tie_weights() + + return model_embeds + + def _resize_token_embeddings(self, new_num_tokens): + old_embeddings = self.get_input_embeddings() + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) + self.set_input_embeddings(new_embeddings) + return self.get_input_embeddings() + + def _get_resized_embeddings( + self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None + ) -> torch.nn.Embedding: + """ Build a resized Embedding Module from a provided token Embedding Module. + Increasing the size will add newly initialized vectors at the end + Reducing the size will remove vectors from the end + + Args: + old_embeddings: ``torch.nn.Embedding`` + Old embeddings to be resized. + new_num_tokens: (`optional`) int + New number of tokens in the embedding matrix. + Increasing the size will add newly initialized vectors at the end + Reducing the size will remove vectors from the end + If not provided or None: return the provided token Embedding Module. + Return: ``torch.nn.Embedding`` + Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None + """ + if new_num_tokens is None: + return old_embeddings + + old_num_tokens, old_embedding_dim = old_embeddings.weight.size() + if old_num_tokens == new_num_tokens: + return old_embeddings + + # Build new embeddings + new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) + new_embeddings.to(old_embeddings.weight.device) + + # initialize all new embeddings (in particular added tokens) + self._init_weights(new_embeddings) + + # Copy token embeddings from the previous weights + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] + + return new_embeddings + + def init_weights(self): + """ Initialize and prunes weights if needed. """ + # Initialize weights + self.apply(self._init_weights) + + # Prune heads if needed + if self.config.pruned_heads: + self.prune_heads(self.config.pruned_heads) + + # Tie weights if needed + self.tie_weights() + + def prune_heads(self, heads_to_prune: Dict): + """ Prunes heads of the base model. + + Arguments: + + heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). + E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. + """ + # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads + for layer, heads in heads_to_prune.items(): + union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) + self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON + + self.base_model._prune_heads(heads_to_prune) + + def save_pretrained(self, save_directory): + """ Save a model and its configuration file to a directory, so that it + can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method. + + Arguments: + save_directory: directory to which to save. + """ + if os.path.isfile(save_directory): + logger.error("Provided path ({}) should be a directory, not a file".format(save_directory)) + return + os.makedirs(save_directory, exist_ok=True) + + # Only save the model itself if we are using distributed training + model_to_save = self.module if hasattr(self, "module") else self + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = os.path.join(save_directory, WEIGHTS_NAME) + + if getattr(self.config, "xla_device", False): + import torch_xla.core.xla_model as xm + + if xm.is_master_ordinal(): + # Save configuration file + model_to_save.config.save_pretrained(save_directory) + # xm.save takes care of saving only from master + xm.save(model_to_save.state_dict(), output_model_file) + else: + model_to_save.config.save_pretrained(save_directory) + torch.save(model_to_save.state_dict(), output_model_file) + + logger.info("Model weights saved in {}".format(output_model_file)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated) + To train the model, you should first set it back in training mode with ``model.train()`` + + The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. + It is up to you to train those weights with a downstream fine-tuning task. + + The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. + + Parameters: + pretrained_model_name_or_path: either: + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``) + + model_args: (`optional`) Sequence of positional arguments: + All remaning positional arguments will be passed to the underlying model's ``__init__`` method + + config: (`optional`) one of: + - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or + - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()` + + Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: + - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or + - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. + - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. + + state_dict: (`optional`) dict: + an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. + This option can be used if you want to create a model from a pretrained configuration but load your own weights. + In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. + + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + + resume_download: (`optional`) boolean, default False: + Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. + + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + + output_loading_info: (`optional`) boolean: + Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. + + kwargs: (`optional`) Remaining dictionary of keyword arguments: + Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + # For example purposes. Not runnable. + model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. + model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading + assert model.config.output_attention == True + # Loading from a TF checkpoint file instead of a PyTorch model (slower) + config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') + model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) + + """ + config = kwargs.pop("config", None) + state_dict = kwargs.pop("state_dict", None) + cache_dir = kwargs.pop("cache_dir", None) + from_tf = kwargs.pop("from_tf", False) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", False) + use_cdn = kwargs.pop("use_cdn", True) + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + *model_args, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + **kwargs, + ) + else: + model_kwargs = kwargs + + # Load model + if pretrained_model_name_or_path is not None: + if os.path.isdir(pretrained_model_name_or_path): + if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): + # Load from a TF 1.0 checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") + elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): + # Load from a TF 2.0 checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + # Load from a PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + else: + raise EnvironmentError( + "Error no file named {} found in directory {} or `from_tf` set to False".format( + [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], + pretrained_model_name_or_path, + ) + ) + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + archive_file = pretrained_model_name_or_path + elif os.path.isfile(pretrained_model_name_or_path + ".index"): + assert ( + from_tf + ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format( + pretrained_model_name_or_path + ".index" + ) + archive_file = pretrained_model_name_or_path + ".index" + else: + archive_file = hf_bucket_url( + pretrained_model_name_or_path, + filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME), + use_cdn=use_cdn, + ) + + try: + # Load from URL or cache if already cached + resolved_archive_file = cached_path( + archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + ) + if resolved_archive_file is None: + raise EnvironmentError + except EnvironmentError: + msg = ( + f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" + f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" + f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n" + ) + raise EnvironmentError(msg) + + if resolved_archive_file == archive_file: + logger.info("loading weights file {}".format(archive_file)) + else: + logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file)) + else: + resolved_archive_file = None + + # Instantiate model. + model = cls(config, *model_args, **model_kwargs) + + if state_dict is None and not from_tf: + try: + state_dict = torch.load(resolved_archive_file, map_location="cpu") + except Exception: + raise OSError( + "Unable to load weights from pytorch checkpoint file. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " + ) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + + if from_tf: + if resolved_archive_file.endswith(".index"): + # Load from a TensorFlow 1.X checkpoint - provided by original authors + model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' + else: + # Load from our TensorFlow 2.0 checkpoints + try: + from transformers import load_tf2_checkpoint_in_pytorch_model + + model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True) + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " + "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." + ) + raise + else: + # Convert old format to new format if needed from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + ############################################################################################## + # Print out state_dict's contents: keys + ''' + for key, _ in state_dict.items(): + print(key) + ''' + ############################################################################################## + + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: nn.Module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = "" + model_to_load = model + has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()) + if not hasattr(model, cls.base_model_prefix) and has_prefix_module: + start_prefix = cls.base_model_prefix + "." + if hasattr(model, cls.base_model_prefix) and not has_prefix_module: + model_to_load = getattr(model, cls.base_model_prefix) + + load(model_to_load, prefix=start_prefix) + + if model.__class__.__name__ != model_to_load.__class__.__name__: + base_model_state_dict = model_to_load.state_dict().keys() + head_model_state_dict_without_base_prefix = [ + key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() + ] + + missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " + f"initializing {model.__class__.__name__}: {unexpected_keys}\n" + f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" + f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " + f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " + f"and are newly initialized: {missing_keys}\n" + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + else: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" + f"If your task is similar to the task the model of the ckeckpoint was trained on, " + f"you can already use {model.__class__.__name__} for predictions without further training." + ) + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + model.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + model.tie_weights() # make sure token embedding weights are still tied if needed + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + if output_loading_info: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "error_msgs": error_msgs, + } + return model, loading_info + + if hasattr(config, "xla_device") and config.xla_device: + import torch_xla.core.xla_model as xm + + model = xm.send_cpu_data_to_device(model, xm.xla_device()) + model.to(xm.xla_device()) + + return model + + +class Conv1D(nn.Module): + def __init__(self, nf, nx): + """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2) + Basically works like a Linear layer but the weights are transposed + """ + super().__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.zeros(nf)) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(*size_out) + return x + + +class PoolerStartLogits(nn.Module): + """ Compute SQuAD start_logits from sequence hidden states. """ + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, 1) + + def forward(self, hidden_states, p_mask=None): + """ Args: + **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)` + invalid position mask such as query and special symbols (PAD, SEP, CLS) + 1.0 means token should be masked. + """ + x = self.dense(hidden_states).squeeze(-1) + + if p_mask is not None: + if next(self.parameters()).dtype == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class PoolerEndLogits(nn.Module): + """ Compute SQuAD end_logits from sequence hidden states and start token hidden state. + """ + + def __init__(self, config): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense_1 = nn.Linear(config.hidden_size, 1) + + def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None): + """ Args: + One of ``start_states``, ``start_positions`` should be not None. + If both are set, ``start_positions`` overrides ``start_states``. + + **start_states**: ``torch.LongTensor`` of shape identical to hidden_states + hidden states of the first tokens for the labeled span. + **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` + position of the first token for the labeled span: + **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` + Mask of invalid position such as query and special symbols (PAD, SEP, CLS) + 1.0 means token should be masked. + """ + assert ( + start_states is not None or start_positions is not None + ), "One of start_states, start_positions should be not None" + if start_positions is not None: + slen, hsz = hidden_states.shape[-2:] + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) + start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) + + x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) + x = self.activation(x) + x = self.LayerNorm(x) + x = self.dense_1(x).squeeze(-1) + + if p_mask is not None: + if next(self.parameters()).dtype == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class PoolerAnswerClass(nn.Module): + """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """ + + def __init__(self, config): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) + + def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None): + """ + Args: + One of ``start_states``, ``start_positions`` should be not None. + If both are set, ``start_positions`` overrides ``start_states``. + + **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``. + hidden states of the first tokens for the labeled span. + **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` + position of the first token for the labeled span. + **cls_index**: torch.LongTensor of shape ``(batch_size,)`` + position of the CLS token. If None, take the last token. + + note(Original repo): + no dependency on end_feature so that we can obtain one single `cls_logits` + for each sample + """ + hsz = hidden_states.shape[-1] + assert ( + start_states is not None or start_positions is not None + ), "One of start_states, start_positions should be not None" + if start_positions is not None: + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) + + if cls_index is not None: + cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) + else: + cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) + + x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) + x = self.activation(x) + x = self.dense_1(x).squeeze(-1) + + return x + + +class SQuADHead(nn.Module): + r""" A SQuAD head inspired by XLNet. + + Parameters: + config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model. + + Inputs: + **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)`` + hidden states of sequence tokens + **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` + position of the first token for the labeled span. + **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` + position of the last token for the labeled span. + **cls_index**: torch.LongTensor of shape ``(batch_size,)`` + position of the CLS token. If None, take the last token. + **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)`` + Whether the question has a possible answer in the paragraph or not. + **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` + Mask of invalid position such as query and special symbols (PAD, SEP, CLS) + 1.0 means token should be masked. + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. + **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)`` + Log probabilities for the top config.start_n_top start token possibilities (beam-search). + **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)`` + Indices for the top config.start_n_top start token possibilities (beam-search). + **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` + Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). + **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` + Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). + **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) + ``torch.FloatTensor`` of shape ``(batch_size,)`` + Log probabilities for the ``is_impossible`` label of the answers. + """ + + def __init__(self, config): + super().__init__() + self.start_n_top = config.start_n_top + self.end_n_top = config.end_n_top + + self.start_logits = PoolerStartLogits(config) + self.end_logits = PoolerEndLogits(config) + self.answer_class = PoolerAnswerClass(config) + + def forward( + self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None, + ): + outputs = () + + start_logits = self.start_logits(hidden_states, p_mask=p_mask) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, let's remove the dimension added by batch splitting + for x in (start_positions, end_positions, cls_index, is_impossible): + if x is not None and x.dim() > 1: + x.squeeze_(-1) + + # during training, compute the end logits based on the ground truth of the start position + end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) + + loss_fct = CrossEntropyLoss() + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if cls_index is not None and is_impossible is not None: + # Predict answerability from the representation of CLS and START + cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) + loss_fct_cls = nn.BCEWithLogitsLoss() + cls_loss = loss_fct_cls(cls_logits, is_impossible) + + # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss + total_loss += cls_loss * 0.5 + + outputs = (total_loss,) + outputs + + else: + # during inference, compute the end logits based on beam search + bsz, slen, hsz = hidden_states.size() + start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) + + start_top_log_probs, start_top_index = torch.topk( + start_log_probs, self.start_n_top, dim=-1 + ) # shape (bsz, start_n_top) + start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) + start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) + start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) + + hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( + start_states + ) # shape (bsz, slen, start_n_top, hsz) + p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None + end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) + end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) + + end_top_log_probs, end_top_index = torch.topk( + end_log_probs, self.end_n_top, dim=1 + ) # shape (bsz, end_n_top, start_n_top) + end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) + end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) + + start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) + cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) + + outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs + + # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits + # or (if labels are provided) (total_loss,) + return outputs + + +class SequenceSummary(nn.Module): + r""" Compute a single vector summary of a sequence hidden states according to various possibilities: + Args of the config class: + summary_type: + - 'last' => [default] take the last token hidden state (like XLNet) + - 'first' => take the first token hidden state (like Bert) + - 'mean' => take the mean of all tokens hidden states + - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2) + - 'attn' => Not implemented now, use multi-head attention + summary_use_proj: Add a projection after the vector extraction + summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. + summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default + summary_first_dropout: Add a dropout before the projection and activation + summary_last_dropout: Add a dropout after the projection and activation + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = (get_activation(activation_string) if activation_string else Identity()) + + self.first_dropout = Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward(self, hidden_states, cls_index=None): + """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer. + cls_index: [optional] position of the classification token if summary_type == 'cls_index', + shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. + if summary_type == 'cls_index' and cls_index is None: + we take the last token of the sequence as classification token + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + +def prune_linear_layer(layer, index, dim=0): + """ Prune a linear layer (a model parameters) to keep only entries in index. + Return the pruned layer as a new layer with requires_grad=True. + Used to remove heads. + """ + index = index.to(layer.weight.device) + W = layer.weight.index_select(dim, index).clone().detach() + if layer.bias is not None: + if dim == 1: + b = layer.bias.clone().detach() + else: + b = layer.bias[index].clone().detach() + new_size = list(layer.weight.size()) + new_size[dim] = len(index) + new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) + new_layer.weight.requires_grad = False + new_layer.weight.copy_(W.contiguous()) + new_layer.weight.requires_grad = True + if layer.bias is not None: + new_layer.bias.requires_grad = False + new_layer.bias.copy_(b.contiguous()) + new_layer.bias.requires_grad = True + return new_layer + + +def prune_conv1d_layer(layer, index, dim=1): + """ Prune a Conv1D layer (a model parameters) to keep only entries in index. + A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed. + Return the pruned layer as a new layer with requires_grad=True. + Used to remove heads. + """ + index = index.to(layer.weight.device) + W = layer.weight.index_select(dim, index).clone().detach() + if dim == 0: + b = layer.bias.clone().detach() + else: + b = layer.bias[index].clone().detach() + new_size = list(layer.weight.size()) + new_size[dim] = len(index) + new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) + new_layer.weight.requires_grad = False + new_layer.weight.copy_(W.contiguous()) + new_layer.weight.requires_grad = True + new_layer.bias.requires_grad = False + new_layer.bias.copy_(b.contiguous()) + new_layer.bias.requires_grad = True + return new_layer + + +def prune_layer(layer, index, dim=None): + """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index. + Return the pruned layer as a new layer with requires_grad=True. + Used to remove heads. + """ + if isinstance(layer, nn.Linear): + return prune_linear_layer(layer, index, dim=0 if dim is None else dim) + elif isinstance(layer, Conv1D): + return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) + else: + raise ValueError("Can't prune layer of class {}".format(layer.__class__)) + + +def apply_chunking_to_forward( + chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors +) -> torch.Tensor: + """ + This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`. + It then applies a layer `forward_fn` to each chunk independently to save memory. + If the `forward_fn` is independent across the `chunk_dim` this function will yield the + same result as not applying it. + + Args: + chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size` + chunk_dim: int - the dimension over which the input_tensors should be chunked + forward_fn: fn - the forward fn of the model + input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked + Returns: + a Tensor with the same shape the foward_fn would have given if applied + + + Examples:: + + # rename the usual forward() fn to forward_chunk() + def forward_chunk(self, hidden_states): + hidden_states = self.decoder(hidden_states) + return hidden_states + + # implement a chunked forward function + def forward(self, hidden_states): + return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states) + """ + + assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) + tensor_shape = input_tensors[0].shape + assert all( + input_tensor.shape == tensor_shape for input_tensor in input_tensors + ), "All input tenors have to be of the same shape" + + # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability + num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) + assert num_args_in_forward_chunk_fn == len( + input_tensors + ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format( + num_args_in_forward_chunk_fn, len(input_tensors) + ) + + if chunk_size > 0: + assert ( + input_tensors[0].shape[chunk_dim] % chunk_size == 0 + ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format( + input_tensors[0].shape[chunk_dim], chunk_size + ) + + num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size + + # chunk input tensor into tuples + input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) + # apply forward fn to every tuple + output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) + # concatenate output at same dimension + return torch.cat(output_chunks, dim=chunk_dim) + + return forward_fn(*input_tensors) diff --git a/elia/bert/multimodal_bert.py b/elia/bert/multimodal_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..b01cdbefe1197b45e25bdd754f847fd7efdd7c01 --- /dev/null +++ b/elia/bert/multimodal_bert.py @@ -0,0 +1,277 @@ + +from .modeling_bert import BertModel +import torch +import torch.nn as nn +import torch.nn.functional as F + + + +class MultiModalBert(BertModel): + def __init__(self, config, embed_dim, pwam_idx=[3,6,9,12], num_heads_fusion=[1,1,1,1], fusion_drop=0.0): + super().__init__(config) + self.pwam_idx = pwam_idx + self.num_heads_fusion = num_heads_fusion + self.fusion_drop = fusion_drop + + pwam_dims=[embed_dim * 2** i for i in range(len(pwam_idx))] + #print(pwam_dims) + self.pwams = nn.ModuleList() + self.res_gates = nn.ModuleList() + self.norms = nn.ModuleList() + for i in range(0, len(pwam_idx)): + dim = pwam_dims[i] + fusion = PWAM(768, # both the visual input and for combining, num of channels + dim, # v_in + 768, # l_in + 768, # key + 768, # value + num_heads=num_heads_fusion[i], + dropout=fusion_drop) + self.pwams.append(fusion) + + res_gate = nn.Sequential( + nn.Linear(768, 768, bias=False), + nn.ReLU(), + nn.Linear(768, 768, bias=False), + nn.Tanh() + ) + nn.init.zeros_(res_gate[0].weight) + nn.init.zeros_(res_gate[2].weight) + self.res_gates.append(res_gate) + + self.norms.append(nn.LayerNorm(768)) + + def forward_stem(self, input_ids, attention_mask): + input_shape = input_ids.size() + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device) + + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, input_ids.device) + + embedding_output = self.embeddings( + input_ids=input_ids, token_type_ids=token_type_ids + ) + #print(embedding_output.shape, extended_attention_mask.shape, "?>>>") + return embedding_output, extended_attention_mask + + def forward_stage1(self, hidden_states, attention_mask): + for i in range(0, self.pwam_idx[0]): + layer_module = self.encoder.layer[i] + layer_outputs = layer_module( + hidden_states, + attention_mask, + ) + hidden_states = layer_outputs[0] + + return layer_outputs[0] + + def forward_stage2(self, hidden_states, attention_mask): + for i in range(self.pwam_idx[0], self.pwam_idx[1]): + layer_module = self.encoder.layer[i] + layer_outputs = layer_module( + hidden_states, + attention_mask, + ) + hidden_states = layer_outputs[0] + + return layer_outputs[0] + + def forward_stage3(self, hidden_states, attention_mask): + for i in range(self.pwam_idx[1], self.pwam_idx[2]): + layer_module = self.encoder.layer[i] + layer_outputs = layer_module( + hidden_states, + attention_mask, + ) + hidden_states = layer_outputs[0] + + return layer_outputs[0] + + def forward_stage4(self, hidden_states, attention_mask): + for i in range(self.pwam_idx[2], self.pwam_idx[3]): + layer_module = self.encoder.layer[i] + layer_outputs = layer_module( + hidden_states, + attention_mask, + ) + hidden_states = layer_outputs[0] + + return layer_outputs[0] + + def forward_pwam1(self, x, l, l_mask): + l_residual = self.pwams[0](x, l, l_mask) + l = l + (self.res_gates[0](l_residual) * l_residual) + return self.norms[0](l_residual), l + + def forward_pwam2(self, x, l, l_mask): + l_residual = self.pwams[1](x, l, l_mask) + l = l + (self.res_gates[1](l_residual) * l_residual) + return self.norms[1](l_residual), l + + def forward_pwam3(self, x, l, l_mask): + l_residual = self.pwams[2](x, l, l_mask) + l = l + (self.res_gates[2](l_residual) * l_residual) + return self.norms[2](l_residual), l + + def forward_pwam4(self, x, l, l_mask): + l_residual = self.pwams[3](x, l, l_mask) + l = l + (self.res_gates[3](l_residual) * l_residual) + return self.norms[3](l_residual), l + +class PWAM(nn.Module): + def __init__(self, dim, v_in_channels, l_in_channels, key_channels, value_channels, num_heads=0, dropout=0.0): + super(PWAM, self).__init__() + # input x shape: (B, H*W, dim) + #self.vis_project = nn.Sequential(nn.Conv1d(dim, dim, 1, 1), # the init function sets bias to 0 if bias is True + # nn.GELU(), + # nn.Dropout(dropout) + # ) + #self.vis_project = nn.Sequential(nn.Conv1d(dim, dim, 1, 1), # the init function sets bias to 0 if bias is True + self.vis_project = nn.Sequential(nn.Linear(dim, dim), # the init function sets bias to 0 if bias is True + nn.GELU(), + nn.Dropout(dropout) + ) + + self.image_lang_att = SpatialImageLanguageAttention(v_in_channels, # v_in + l_in_channels, # l_in + key_channels, # key + value_channels, # value + out_channels=value_channels, # out + num_heads=num_heads) + + self.project_mm = nn.Sequential(nn.Conv1d(value_channels, value_channels, 1, 1), + nn.GELU(), + nn.Dropout(dropout) + ) + + def forward(self, x, l, l_mask): + # input x shape: (B, H*W, dim) + #print("???", x.shape, l.shape, l_mask.shape) + #print(self.vis_project) + #vis = self.vis_project(x.permute(0, 2, 1)) # (B, dim, H*W) + vis = self.vis_project(l) # (B, dim, H*W) + + lang = self.image_lang_att(x, l, l_mask) # (B, H*W, dim) + + lang = lang.permute(0, 2, 1) # (B, dim, H*W) + + #print("vis", vis.shape, "lang", lang.shape) + mm = torch.mul(vis.permute(0,2,1), lang) + #print(mm.shape) + mm = self.project_mm(mm) # (B, dim, H*W) + + mm = mm.permute(0, 2, 1) # (B, H*W, dim) + + return mm + + #self.fusion = PWAM(dim, # both the visual input and for combining, num of channels + # dim, # v_in + # 768, # l_in + # dim, # key + # dim, # value + # num_heads=num_heads_fusion, + # dropout=fusion_drop) + +class SpatialImageLanguageAttention(nn.Module): + def __init__(self, v_in_channels, l_in_channels, key_channels, value_channels, out_channels=None, num_heads=1): + super(SpatialImageLanguageAttention, self).__init__() + # x shape: (B, H*W, v_in_channels) + # l input shape: (B, l_in_channels, N_l) + # l_mask shape: (B, N_l, 1) + self.v_in_channels = v_in_channels + self.l_in_channels = l_in_channels + self.out_channels = out_channels + self.key_channels = key_channels + self.value_channels = value_channels + self.num_heads = num_heads + if out_channels is None: + self.out_channels = self.value_channels + + # Keys: language features: (B, l_in_channels, #words) + # avoid any form of spatial normalization because a sentence contains many padding 0s + self.f_query = nn.Sequential( + nn.Conv1d(self.l_in_channels, self.key_channels, kernel_size=1, stride=1), + ) + + # Queries: visual features: (B, H*W, v_in_channels) + self.f_key = nn.Sequential( + nn.Conv1d(self.v_in_channels, self.key_channels, kernel_size=1, stride=1), + nn.InstanceNorm1d(self.key_channels), + ) + + # Values: language features: (B, l_in_channels, #words) + #self.f_value = nn.Sequential( + # nn.Conv1d(self.l_in_channels, self.value_channels, kernel_size=1, stride=1), + #) + self.f_value = nn.Sequential( + nn.Conv1d(self.v_in_channels, self.key_channels, kernel_size=1, stride=1), + nn.InstanceNorm1d(self.key_channels), + ) + + # Out projection + self.W = nn.Sequential( + nn.Conv1d(self.value_channels, self.out_channels, kernel_size=1, stride=1), + nn.InstanceNorm1d(self.out_channels), + ) + + def forward(self, x, l, l_mask): + #print('input shape', x.shape, l.shape, l_mask.shape) + l_mask = l_mask.squeeze(1) + # x shape: (B, H*W, v_in_channels) + # l input shape: (B, l_in_channels, N_l) + # l_mask shape: (B, N_l, 1) + B, HW = x.size(0), x.size(1) + x = x.permute(0, 2, 1) # (B, key_channels, H*W) + l = l.permute(0,2,1) + #l_mask = l_mask.permute(0, 2, 1) # (B, N_l, 1) -> (B, 1, N_l) + l_mask = l_mask # (B, N_l, 1) -> (B, 1, N_l) + + #query = self.f_query(x) # (B, key_channels, H*W) if Conv1D + #query = query.permute(0, 2, 1) # (B, H*W, key_channels) + #key = self.f_key(l) # (B, key_channels, N_l) + #value = self.f_value(l) # (B, self.value_channels, N_l) + #key = key * l_mask # (B, key_channels, N_l) + #value = value * l_mask # (B, self.value_channels, N_l) + + #print(l.shape, self.f_query) + query = self.f_query(l) # (B, key_channels, H*W) if Conv1D + query = query * l_mask # (B, key_channels, N_l) + query = query.permute(0, 2, 1) # (B, N_l, key_channels) + + key = self.f_key(x) # (B, key_channels, H*W) if Conv1D + value = self.f_value(x) # (B, key_channels, H*W) if Conv1D + + n_l = query.size(1) + #print(query.shape, key.shape, value.shape) + + #query = query.reshape(B, HW, self.num_heads, self.key_channels//self.num_heads).permute(0, 2, 1, 3) + # (b, num_heads, H*W, self.key_channels//self.num_heads) + #key = key.reshape(B, self.num_heads, self.key_channels//self.num_heads, n_l) + # (b, num_heads, self.key_channels//self.num_heads, n_l) + #value = value.reshape(B, self.num_heads, self.value_channels//self.num_heads, n_l) + # # (b, num_heads, self.value_channels//self.num_heads, n_l) + key = key.reshape(B, self.num_heads, self.key_channels//self.num_heads, HW) + value = value.reshape(B, self.num_heads, self.key_channels//self.num_heads, HW) + # (b, num_heads, H*W, self.key_channels//self.num_heads) + #query = query.reshape(B, self.num_heads, self.key_channels//self.num_heads, n_l) + query = query.reshape(B, n_l, self.num_heads, self.key_channels//self.num_heads).permute(0, 2, 1, 3) + # (b, num_heads, self.key_channels//self.num_heads, n_l) + #value = value.reshape(B, self.num_heads, self.value_channels//self.num_heads, n_l) + #print('after reshape', query.shape, key.shape, value.shape) + + l_mask = l_mask.unsqueeze(-1) # (b, 1, 1, n_l) + + #sim_map = torch.matmul(query, key) # (B, self.num_heads, H*W, N_l) + sim_map = torch.matmul(query, key) # (B, self.num_heads, N_l, H*W) + sim_map = (self.key_channels ** -.5) * sim_map # scaled dot product + + sim_map = sim_map + (1e4*l_mask - 1e4) # assign a very small number to padding positions + sim_map = F.softmax(sim_map, dim=-1) # (B, num_heads, h*w, N_l) + out = torch.matmul(sim_map, value.permute(0, 1, 3, 2)) # (B, num_heads, H*W, self.value_channels//num_heads) + #print('out', out.shape) + #out = out.permute(0, 2, 1, 3).contiguous().reshape(B, HW, self.value_channels) # (B, H*W, value_channels) + out = out.permute(0, 2, 1, 3).contiguous().reshape(B, n_l, self.value_channels) # (B, H*W, value_channels) + out = out.permute(0, 2, 1) # (B, value_channels, HW) + out = self.W(out) # (B, value_channels, HW) + out = out.permute(0, 2, 1) # (B, HW, value_channels) + + return out diff --git a/elia/bert/tokenization_bert.py b/elia/bert/tokenization_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..972e1733163522359750dddedf6dea885085b2ca --- /dev/null +++ b/elia/bert/tokenization_bert.py @@ -0,0 +1,545 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + + +import collections +import logging +import os +import unicodedata +from typing import List, Optional + +from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace + + +logger = logging.getLogger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", + "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", + "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", + "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", + "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", + "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", + "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", + "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", + "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", + "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", + "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", + "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", + "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", + "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt", + "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt", + "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt", + "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt", + "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "bert-base-uncased": 512, + "bert-large-uncased": 512, + "bert-base-cased": 512, + "bert-large-cased": 512, + "bert-base-multilingual-uncased": 512, + "bert-base-multilingual-cased": 512, + "bert-base-chinese": 512, + "bert-base-german-cased": 512, + "bert-large-uncased-whole-word-masking": 512, + "bert-large-cased-whole-word-masking": 512, + "bert-large-uncased-whole-word-masking-finetuned-squad": 512, + "bert-large-cased-whole-word-masking-finetuned-squad": 512, + "bert-base-cased-finetuned-mrpc": 512, + "bert-base-german-dbmdz-cased": 512, + "bert-base-german-dbmdz-uncased": 512, + "TurkuNLP/bert-base-finnish-cased-v1": 512, + "TurkuNLP/bert-base-finnish-uncased-v1": 512, + "wietsedv/bert-base-dutch-cased": 512, +} + +PRETRAINED_INIT_CONFIGURATION = { + "bert-base-uncased": {"do_lower_case": True}, + "bert-large-uncased": {"do_lower_case": True}, + "bert-base-cased": {"do_lower_case": False}, + "bert-large-cased": {"do_lower_case": False}, + "bert-base-multilingual-uncased": {"do_lower_case": True}, + "bert-base-multilingual-cased": {"do_lower_case": False}, + "bert-base-chinese": {"do_lower_case": False}, + "bert-base-german-cased": {"do_lower_case": False}, + "bert-large-uncased-whole-word-masking": {"do_lower_case": True}, + "bert-large-cased-whole-word-masking": {"do_lower_case": False}, + "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True}, + "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False}, + "bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, + "bert-base-german-dbmdz-cased": {"do_lower_case": False}, + "bert-base-german-dbmdz-uncased": {"do_lower_case": True}, + "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False}, + "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True}, + "wietsedv/bert-base-dutch-cased": {"do_lower_case": False}, +} + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BertTokenizer(PreTrainedTokenizer): + r""" + Constructs a BERT tokenizer. Based on WordPiece. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + + Args: + vocab_file (:obj:`string`): + File containing the vocabulary. + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to lowercase the input when tokenizing. + do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to do basic tokenization before WordPiece. + never_split (:obj:`Iterable`, `optional`, defaults to :obj:`None`): + Collection of tokens which will never be split during tokenization. Only has an effect when + :obj:`do_basic_tokenize=True` + unk_token (:obj:`string`, `optional`, defaults to "[UNK]"): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (:obj:`string`, `optional`, defaults to "[SEP]"): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences + for sequence classification or for a text and a question for question answering. + It is also used as the last token of a sequence built with special tokens. + pad_token (:obj:`string`, `optional`, defaults to "[PAD]"): + The token used for padding, for example when batching sequences of different lengths. + cls_token (:obj:`string`, `optional`, defaults to "[CLS]"): + The classifier token which is used when doing sequence classification (classification of the whole + sequence instead of per-token classification). It is the first token of the sequence when built with + special tokens. + mask_token (:obj:`string`, `optional`, defaults to "[MASK]"): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to tokenize Chinese characters. + This should likely be deactivated for Japanese: + see: https://github.com/huggingface/transformers/issues/328 + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + **kwargs + ): + super().__init__( + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs, + ) + + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars + ) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A BERT sequence has the following format: + + - single sequence: ``[CLS] X [SEP]`` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True if the token list is already formatted with special tokens for the model + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. + A BERT sequence pair mask has the following format: + + :: + + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + + if token_ids_1 is None, only returns the first portion of the mask (0's). + + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given + sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, vocab_path): + """ + Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory. + + Args: + vocab_path (:obj:`str`): + The directory in which to save the vocabulary. + + Returns: + :obj:`Tuple(str)`: Paths to the files saved. + """ + index = 0 + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"]) + else: + vocab_file = vocab_path + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + "Saving vocabulary to {}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!".format(vocab_file) + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True): + """ Constructs a BasicTokenizer. + + Args: + **do_lower_case**: Whether to lower case the input. + **never_split**: (`optional`) list of str + Kept for backward compatibility purposes. + Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) + List of token not to split. + **tokenize_chinese_chars**: (`optional`) boolean (default True) + Whether to tokenize Chinese characters. + This should likely be deactivated for Japanese: + see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 + """ + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + + def tokenize(self, text, never_split=None): + """ Basic Tokenization of a piece of text. + Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. + + Args: + **never_split**: (`optional`) list of str + Kept for backward compatibility purposes. + Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) + List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if never_split is not None and text in never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + diff --git a/elia/bert/tokenization_utils.py b/elia/bert/tokenization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d878210f407d8fb10226d9e4435c761a1f7483fc --- /dev/null +++ b/elia/bert/tokenization_utils.py @@ -0,0 +1,723 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization classes for python tokenizers. + For fast tokenizers (provided by HuggingFace's tokenizers library) see tokenization_utils_fast.py +""" + +import itertools +import logging +import re +import unicodedata +from typing import Dict, List, Optional, Tuple, Union + +from .file_utils import add_end_docstrings +from .tokenization_utils_base import ( + ENCODE_KWARGS_DOCSTRING, + ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING, + AddedToken, + BatchEncoding, + EncodedInput, + EncodedInputPair, + PaddingStrategy, + PreTokenizedInput, + PreTokenizedInputPair, + PreTrainedTokenizerBase, + TensorType, + TextInput, + TextInputPair, + TruncationStrategy, +) + + +logger = logging.getLogger(__name__) + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +def _is_end_of_word(text): + """Checks whether the last character in text is one of a punctuation, control or whitespace character.""" + last_char = text[-1] + return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char)) + + +def _is_start_of_word(text): + """Checks whether the first character in text is one of a punctuation, control or whitespace character.""" + first_char = text[0] + return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char)) + + +class PreTrainedTokenizer(PreTrainedTokenizerBase): + """ Base class for all slow tokenizers. + + Handle all the shared methods for tokenization and special tokens as well as methods + downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary. + + This class also contain the added tokens in a unified way on top of all tokenizers so we don't + have to handle the specific vocabulary augmentation methods of the various underlying + dictionary structures (BPE, sentencepiece...). + + Class attributes (overridden by derived classes): + + - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file + required by the model, and as associated values, the filename for saving the associated file (string). + - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys + being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the + `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the + associated pretrained vocabulary file. + - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained + models, and as associated values, the maximum length of the sequence inputs of this model, or None if the + model has no maximum input size. + - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the + pretrained models, and as associated values, a dictionnary of specific arguments to pass to the + ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the + ``from_pretrained()`` method. + + Args: + - ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model. + When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated + model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`). + no associated max_length can be found in ``max_model_input_sizes``. + - ``padding_side``: (`Optional`) string: the side on which the model should have padding applied. + Should be selected between ['right', 'left'] + - ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the + model ("token_type_ids", "attention_mask"...). + - ``bos_token``: (`Optional`) string: a beginning of sentence token. + Will be associated to ``self.bos_token`` and ``self.bos_token_id`` + - ``eos_token``: (`Optional`) string: an end of sentence token. + Will be associated to ``self.eos_token`` and ``self.eos_token_id`` + - ``unk_token``: (`Optional`) string: an unknown token. + Will be associated to ``self.unk_token`` and ``self.unk_token_id`` + - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). + Will be associated to ``self.sep_token`` and ``self.sep_token_id`` + - ``pad_token``: (`Optional`) string: a padding token. + Will be associated to ``self.pad_token`` and ``self.pad_token_id`` + - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence + leveraging self-attention along the full depth of the model). + Will be associated to ``self.cls_token`` and ``self.cls_token_id`` + - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language + modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id`` + - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. + Adding all special tokens here ensure they won't be split by the tokenization process. + Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids`` + + + .. automethod:: __call__ + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Added tokens - We store this for both slow and fast tokenizers + # until the serialization of Fast tokenizers is updated + self.added_tokens_encoder: Dict[str, int] = {} + self.added_tokens_decoder: Dict[int, str] = {} + self.unique_no_split_tokens: List[str] = [] + + @property + def is_fast(self) -> bool: + return False + + @property + def vocab_size(self) -> int: + """ Size of the base vocabulary (without the added tokens) """ + raise NotImplementedError + + def get_vocab(self): + """ Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """ + raise NotImplementedError() + + def get_added_vocab(self) -> Dict[str, int]: + return self.added_tokens_encoder + + def __len__(self): + """ Size of the full vocabulary with the added tokens """ + return self.vocab_size + len(self.added_tokens_encoder) + + def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens=False) -> int: + """ + Add a list of new tokens to the tokenizer class. If the new tokens are not in the + vocabulary, they are added to it with indices starting from length of the current vocabulary. + + Args: + new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not + already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). + + Returns: + Number of tokens added to the vocabulary. + + Examples:: + + # Let's see how to increase the vocabulary of Bert model and tokenizer + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertModel.from_pretrained('bert-base-uncased') + + num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) + print('We have added', num_added_toks, 'tokens') + model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. + """ + new_tokens = [str(tok) for tok in new_tokens] + + tokens_to_add = [] + for token in new_tokens: + assert isinstance(token, str) + if not special_tokens and self.init_kwargs.get("do_lower_case", False): + token = token.lower() + if ( + token != self.unk_token + and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) + and token not in tokens_to_add + ): + tokens_to_add.append(token) + if self.verbose: + logger.info("Adding %s to the vocabulary", token) + + added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add)) + added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} + self.added_tokens_encoder.update(added_tok_encoder) + self.added_tokens_decoder.update(added_tok_decoder) + + # Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert) + if special_tokens: + self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(new_tokens))) + else: + # Or on the newly added tokens + self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(tokens_to_add))) + + return len(tokens_to_add) + + def num_special_tokens_to_add(self, pair=False): + """ + Returns the number of added tokens when encoding a sequence with special tokens. + + Note: + This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this + inside your training loop. + + Args: + pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the + number of added tokens in the case of a single sequence if set to False. + + Returns: + Number of tokens added to sequences + """ + token_ids_0 = [] + token_ids_1 = [] + return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None)) + + def tokenize(self, text: TextInput, **kwargs): + """ Converts a string in a sequence of tokens (string), using the tokenizer. + Split in words for word-based vocabulary or sub-words for sub-word-based + vocabularies (BPE/SentencePieces/WordPieces). + + Take care of added tokens. + + Args: + text (:obj:`string`): The sequence to be encoded. + **kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method. + """ + # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors + all_special_tokens_extended = dict( + (str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken) + ) + + text, kwargs = self.prepare_for_tokenization(text, **kwargs) + + if kwargs: + logger.warning(f"Keyword arguments {kwargs} not recognized.") + + # TODO: should this be in the base class? + if self.init_kwargs.get("do_lower_case", False): + # convert non-special tokens to lowercase + escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens] + pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" + text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) + + def split_on_token(tok, text): + result = [] + tok_extended = all_special_tokens_extended.get(tok, None) + split_text = text.split(tok) + full_word = "" + for i, sub_text in enumerate(split_text): + # AddedToken can control whitespace stripping around them. + # We use them for GPT2 and Roberta to have different behavior depending on the special token + # Cf. https://github.com/huggingface/transformers/pull/2778 + # and https://github.com/huggingface/transformers/issues/3788 + if isinstance(tok_extended, AddedToken): + if tok_extended.single_word: + # Try to avoid splitting on token + if ( + i < len(split_text) - 1 + and not _is_end_of_word(sub_text) + and not _is_start_of_word(split_text[i + 1]) + ): + # Don't extract the special token + full_word += sub_text + tok + elif full_word: + full_word += sub_text + result += [full_word] + full_word = "" + continue + # Strip white spaces on the right + if tok_extended.rstrip and i > 0: + # A bit counter-intuitive but we strip the left of the string + # since tok_extended.rstrip means the special token is eating all white spaces on its right + sub_text = sub_text.lstrip() + # Strip white spaces on the left + if tok_extended.lstrip and i < len(split_text) - 1: + sub_text = sub_text.rstrip() # Opposite here + else: + # We strip left and right by default + if i < len(split_text) - 1: + sub_text = sub_text.rstrip() + if i > 0: + sub_text = sub_text.lstrip() + + if i == 0 and not sub_text: + result += [tok] + elif i == len(split_text) - 1: + if sub_text: + result += [sub_text] + else: + pass + else: + if sub_text: + result += [sub_text] + result += [tok] + return result + + def split_on_tokens(tok_list, text): + if not text.strip(): + return [] + if not tok_list: + return self._tokenize(text) + + tokenized_text = [] + text_list = [text] + for tok in tok_list: + tokenized_text = [] + for sub_text in text_list: + if sub_text not in self.unique_no_split_tokens: + tokenized_text += split_on_token(tok, sub_text) + else: + tokenized_text += [sub_text] + text_list = tokenized_text + + return list( + itertools.chain.from_iterable( + ( + self._tokenize(token) if token not in self.unique_no_split_tokens else [token] + for token in tokenized_text + ) + ) + ) + + no_split_token = self.unique_no_split_tokens + tokenized_text = split_on_tokens(no_split_token, text) + return tokenized_text + + def _tokenize(self, text, **kwargs): + """ Converts a string in a sequence of tokens (string), using the tokenizer. + Split in words for word-based vocabulary or sub-words for sub-word-based + vocabularies (BPE/SentencePieces/WordPieces). + + Do NOT take care of added tokens. + """ + raise NotImplementedError + + def convert_tokens_to_ids(self, tokens): + """ Converts a token string (or a sequence of tokens) in a single integer id + (or a sequence of ids), using the vocabulary. + """ + if tokens is None: + return None + + if isinstance(tokens, str): + return self._convert_token_to_id_with_added_voc(tokens) + + ids = [] + for token in tokens: + ids.append(self._convert_token_to_id_with_added_voc(token)) + return ids + + def _convert_token_to_id_with_added_voc(self, token): + if token is None: + return None + + if token in self.added_tokens_encoder: + return self.added_tokens_encoder[token] + return self._convert_token_to_id(token) + + def _convert_token_to_id(self, token): + raise NotImplementedError + + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_pretokenized: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + def get_input_ids(text): + if isinstance(text, str): + tokens = self.tokenize(text, **kwargs) + return self.convert_tokens_to_ids(tokens) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): + if is_pretokenized: + tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text))) + return self.convert_tokens_to_ids(tokens) + else: + return self.convert_tokens_to_ids(text) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): + return text + else: + if is_pretokenized: + raise ValueError( + f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_pretokenized=True`." + ) + else: + raise ValueError( + f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." + ) + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers." + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + first_ids = get_input_ids(text) + second_ids = get_input_ids(text_pair) if text_pair is not None else None + + return self.prepare_for_model( + first_ids, + pair_ids=second_ids, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + List[PreTokenizedInputPair], + List[EncodedInput], + List[EncodedInputPair], + ], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_pretokenized: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + def get_input_ids(text): + if isinstance(text, str): + tokens = self.tokenize(text, **kwargs) + return self.convert_tokens_to_ids(tokens) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): + if is_pretokenized: + tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text))) + return self.convert_tokens_to_ids(tokens) + else: + return self.convert_tokens_to_ids(text) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): + return text + else: + raise ValueError( + "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." + ) + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers." + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + input_ids = [] + for ids_or_pair_ids in batch_text_or_text_pairs: + if not isinstance(ids_or_pair_ids, (list, tuple)): + ids, pair_ids = ids_or_pair_ids, None + elif is_pretokenized and not isinstance(ids_or_pair_ids[0], (list, tuple)): + ids, pair_ids = ids_or_pair_ids, None + else: + ids, pair_ids = ids_or_pair_ids + + first_ids = get_input_ids(ids) + second_ids = get_input_ids(pair_ids) if pair_ids is not None else None + input_ids.append((first_ids, second_ids)) + + batch_outputs = self._batch_prepare_for_model( + input_ids, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def _batch_prepare_for_model( + self, + batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. + It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens + + Args: + batch_ids_pairs: list of tokenized input ids or input ids pairs + """ + + batch_outputs = {} + for first_ids, second_ids in batch_ids_pairs: + outputs = self.prepare_for_model( + first_ids, + second_ids, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict): + """ Performs any necessary transformations before tokenization. + + This method should pop the arguments from kwargs and return kwargs as well. + We test kwargs at the end of the encoding process to be sure all the arguments have been used. + """ + return (text, kwargs) + + def get_special_tokens_mask( + self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + + Args: + token_ids_0: list of ids (must not contain special tokens) + token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids + for sequence pairs + already_has_special_tokens: (default False) Set to True if the token list is already formated with + special tokens for the model + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) + + def convert_ids_to_tokens( + self, ids: Union[int, List[int]], skip_special_tokens: bool = False + ) -> Union[str, List[str]]: + """ Converts a single index or a sequence of indices (integers) in a token " + (resp.) a sequence of tokens (str), using the vocabulary and added tokens. + + Args: + skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False + """ + if isinstance(ids, int): + if ids in self.added_tokens_decoder: + return self.added_tokens_decoder[ids] + else: + return self._convert_id_to_token(ids) + tokens = [] + for index in ids: + index = int(index) + if skip_special_tokens and index in self.all_special_ids: + continue + if index in self.added_tokens_decoder: + tokens.append(self.added_tokens_decoder[index]) + else: + tokens.append(self._convert_id_to_token(index)) + return tokens + + def _convert_id_to_token(self, index: int) -> str: + raise NotImplementedError + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """ Converts a sequence of tokens (string) in a single string. + The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids)) + but we often want to remove sub-word tokenization artifacts at the same time. + """ + return " ".join(self.convert_ids_to_tokens(tokens)) + + def decode( + self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True + ) -> str: + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + + # To avoid mixing byte-level and unicode for byte-level BPT + # we need to build string separatly for added tokens and byte-level tokens + # cf. https://github.com/huggingface/transformers/issues/1133 + sub_texts = [] + current_sub_text = [] + for token in filtered_tokens: + if skip_special_tokens and token in self.all_special_ids: + continue + if token in self.added_tokens_encoder: + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + text = " ".join(sub_texts) + + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text + + def save_vocabulary(self, save_directory) -> Tuple[str]: + """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens + and special token mappings. + + Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full + Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` + class method. + """ + raise NotImplementedError diff --git a/elia/bert/tokenization_utils_base.py b/elia/bert/tokenization_utils_base.py new file mode 100644 index 0000000000000000000000000000000000000000..6a1219a4d3473ae510f0da905ea09a76019a7996 --- /dev/null +++ b/elia/bert/tokenization_utils_base.py @@ -0,0 +1,2317 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Base classes common to both the slow and the fast tokenization classes: + PreTrainedTokenizerBase (host all the user fronting encoding methodes) + Special token mixing (host the special tokens logic) and + BatchEncoding (wrap the dictionnary of output with special method for the Fast tokenizers) +""" + +import copy +import json +import logging +import os +import warnings +from collections import UserDict +from enum import Enum +from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union + +import numpy as np +from tokenizers import AddedToken +from tokenizers import Encoding as EncodingFast + +from .file_utils import ( + add_end_docstrings, + cached_path, + hf_bucket_url, + is_remote_url, + is_tf_available, + is_torch_available, + torch_required, +) + + +if is_tf_available(): + import tensorflow as tf +if is_torch_available(): + import torch + + +logger = logging.getLogger(__name__) + +VERY_LARGE_INTEGER = int(1e30) # This is used to set the max input length for a model with infinite size input +LARGE_INTEGER = int(1e20) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER + +# Define type aliases and NamedTuples +TextInput = str +PreTokenizedInput = List[str] +EncodedInput = List[int] +TextInputPair = Tuple[str, str] +PreTokenizedInputPair = Tuple[List[str], List[str]] +EncodedInputPair = Tuple[List[int], List[int]] + + +# Slow tokenizers used to be saved in three separated files +SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" +ADDED_TOKENS_FILE = "added_tokens.json" +TOKENIZER_CONFIG_FILE = "tokenizer_config.json" + +# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file +FULL_TOKENIZER_FILE = "tokenizer.json" + + +class ExplicitEnum(Enum): + """ Enum with more explicit error message for missing values. + """ + + @classmethod + def _missing_(cls, value): + raise ValueError( + "%r is not a valid %s, please select one of %s" + % (value, cls.__name__, str(list(cls._value2member_map_.keys()))) + ) + + +class TruncationStrategy(ExplicitEnum): + ONLY_FIRST = "only_first" + ONLY_SECOND = "only_second" + LONGEST_FIRST = "longest_first" + DO_NOT_TRUNCATE = "do_not_truncate" + + +class PaddingStrategy(ExplicitEnum): + LONGEST = "longest" + MAX_LENGTH = "max_length" + DO_NOT_PAD = "do_not_pad" + + +class TensorType(ExplicitEnum): + PYTORCH = "pt" + TENSORFLOW = "tf" + NUMPY = "np" + + +class CharSpan(NamedTuple): + """ Character span in the original string + + Args: + start: index of the first character in the original string + end: index of the character following the last character in the original string + """ + + start: int + end: int + + +class TokenSpan(NamedTuple): + """ Token span in an encoded string (list of tokens) + + Args: + start: index of the first token in the span + end: index of the token following the last token in the span + """ + + start: int + end: int + + +class BatchEncoding(UserDict): + """ BatchEncoding hold the output of the encode and batch_encode methods (tokens, attention_masks, etc). + This class is derived from a python Dictionary and can be used as a dictionnary. + In addition, this class expose utility methods to map from word/char space to token space. + + Args: + data (:obj:`dict`): Dictionary of lists/arrays returned by the encode/batch_encode methods ('input_ids', 'attention_mask'...) + encoding (:obj:`EncodingFast`, :obj:`list(EncodingFast)`, `optional`, defaults to :obj:`None`): + If the tokenizer is a fast tokenizer which outputs additional informations like mapping from word/char space to token space + the `EncodingFast` instance or list of instance (for batches) hold these informations. + tensor_type (:obj:`Union[None, str, TensorType]`, `optional`, defaults to :obj:`None`): + You can give a tensor_type here to convert the lists of integers in PyTorch/TF/Numpy Tensors at initialization + prepend_batch_axis (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True to add a batch axis when converting in Tensors (see :obj:`tensor_type` above) + """ + + def __init__( + self, + data: Optional[Dict[str, Any]] = None, + encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None, + tensor_type: Union[None, str, TensorType] = None, + prepend_batch_axis: bool = False, + ): + super().__init__(data) + + if isinstance(encoding, EncodingFast): + encoding = [encoding] + + self._encodings = encoding + + self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis) + + @property + def is_fast(self): + """ + Indicate if this BatchEncoding was generated from the result of a PreTrainedTokenizerFast + Returns: True if generated from subclasses of PreTrainedTokenizerFast, else otherwise + """ + return self._encodings is not None + + def __getitem__(self, item: Union[int, str]) -> EncodingFast: + """ If the key is a string, get the value of the dict associated to `key` ('input_ids', 'attention_mask'...) + If the key is an integer, get the EncodingFast for batch item with index `key` + """ + if isinstance(item, str): + return self.data[item] + elif self._encodings is not None: + return self._encodings[item] + else: + raise KeyError( + "Indexing with integers (to access backend Encoding for a given batch index) " + "is not available when using Python based tokenizers" + ) + + def __getattr__(self, item: str): + try: + return self.data[item] + except KeyError: + raise AttributeError + + def __getstate__(self): + return {"data": self.data, "encodings": self._encodings} + + def __setstate__(self, state): + if "data" in state: + self.data = state["data"] + + if "encodings" in state: + self._encodings = state["encodings"] + + def keys(self): + return self.data.keys() + + def values(self): + return self.data.values() + + def items(self): + return self.data.items() + + # After this point: + # Extended properties and methods only available for fast (Rust-based) tokenizers + # provided by HuggingFace tokenizers library. + + @property + def encodings(self) -> Optional[List[EncodingFast]]: + """ + Return the list all encoding from the tokenization process + + Returns: List[EncodingFast] or None if input was tokenized through Python (i.e. not fast) tokenizer + """ + return self._encodings + + def tokens(self, batch_index: int = 0) -> List[str]: + if not self._encodings: + raise ValueError("tokens() is not available when using Python based tokenizers") + return self._encodings[batch_index].tokens + + def words(self, batch_index: int = 0) -> List[Optional[int]]: + if not self._encodings: + raise ValueError("words() is not available when using Python based tokenizers") + return self._encodings[batch_index].words + + def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int: + """ + Get the index of the word corresponding (i.e. comprising) to an encoded token + in a sequence of the batch. + + Can be called as: + + - ``self.token_to_word(token_index)`` if batch size is 1 + - ``self.token_to_word(batch_index, token_index)`` if batch size is greater than 1 + + This method is particularly suited when the input sequences are provided as + pre-tokenized sequences (i.e. words are defined by the user). In this case it allows + to easily associate encoded tokens with provided tokenized words. + + Args: + batch_or_token_index (:obj:`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, + this can be the index of the token in the sequence + token_index (:obj:`int`, `optional`): + If a batch index is provided in `batch_or_token_index`, this can be the index + of the token in the sequence. + + Returns: + :obj:`int`: + index of the word in the input sequence. + + """ + + if not self._encodings: + raise ValueError("token_to_word() is not available when using Python based tokenizers") + if token_index is not None: + batch_index = batch_or_token_index + else: + batch_index = 0 + token_index = batch_or_token_index + if batch_index < 0: + batch_index = self._batch_size + batch_index + if token_index < 0: + token_index = self._seq_len + token_index + return self._encodings[batch_index].token_to_word(token_index) + + def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> TokenSpan: + """ + Get the encoded token span corresponding to a word in the sequence of the batch. + + Token spans are returned as a TokenSpan NamedTuple with: + + - start: index of the first token + - end: index of the token following the last token + + Can be called as: + + - ``self.word_to_tokens(word_index)`` if batch size is 1 + - ``self.word_to_tokens(batch_index, word_index)`` if batch size is greater or equal to 1 + + This method is particularly suited when the input sequences are provided as + pre-tokenized sequences (i.e. words are defined by the user). In this case it allows + to easily associate encoded tokens with provided tokenized words. + + Args: + batch_or_word_index (:obj:`int`): + Index of the sequence in the batch. If the batch only comprises one sequence, + this can be the index of the word in the sequence + word_index (:obj:`int`, `optional`): + If a batch index is provided in `batch_or_token_index`, this can be the index + of the word in the sequence. + + Returns: + :obj:`TokenSpan`: + Span of tokens in the encoded sequence. + + :obj:`TokenSpan` are NamedTuple with: + + - start: index of the first token + - end: index of the token following the last token + """ + + if not self._encodings: + raise ValueError("word_to_tokens() is not available when using Python based tokenizers") + if word_index is not None: + batch_index = batch_or_word_index + else: + batch_index = 0 + word_index = batch_or_word_index + if batch_index < 0: + batch_index = self._batch_size + batch_index + if word_index < 0: + word_index = self._seq_len + word_index + return TokenSpan(*(self._encodings[batch_index].word_to_tokens(word_index))) + + def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan: + """ + Get the character span corresponding to an encoded token in a sequence of the batch. + + Character spans are returned as a CharSpan NamedTuple with: + + - start: index of the first character in the original string associated to the token + - end: index of the character following the last character in the original string associated to the token + + Can be called as: + + - ``self.token_to_chars(token_index)`` if batch size is 1 + - ``self.token_to_chars(batch_index, token_index)`` if batch size is greater or equal to 1 + + Args: + batch_or_token_index (:obj:`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, + this can be the index of the token in the sequence + token_index (:obj:`int`, `optional`): + If a batch index is provided in `batch_or_token_index`, this can be the index + of the token or tokens in the sequence. + + Returns: + :obj:`CharSpan`: + Span of characters in the original string. + + :obj:`CharSpan` are NamedTuple with: + + - start: index of the first character in the original string + - end: index of the character following the last character in the original string + """ + + if not self._encodings: + raise ValueError("token_to_chars() is not available when using Python based tokenizers") + if token_index is not None: + batch_index = batch_or_token_index + else: + batch_index = 0 + token_index = batch_or_token_index + return CharSpan(*(self._encodings[batch_index].token_to_chars(token_index))) + + def char_to_token(self, batch_or_char_index: int, char_index: Optional[int] = None) -> int: + """ + Get the index of the token in the encoded output comprising a character + in the original string for a sequence of the batch. + + Can be called as: + + - ``self.char_to_token(char_index)`` if batch size is 1 + - ``self.char_to_token(batch_index, char_index)`` if batch size is greater or equal to 1 + + This method is particularly suited when the input sequences are provided as + pre-tokenized sequences (i.e. words are defined by the user). In this case it allows + to easily associate encoded tokens with provided tokenized words. + + Args: + batch_or_char_index (:obj:`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, + this can be the index of the word in the sequence + char_index (:obj:`int`, `optional`): + If a batch index is provided in `batch_or_token_index`, this can be the index + of the word in the sequence. + + + Returns: + :obj:`int`: Index of the token. + """ + + if not self._encodings: + raise ValueError("char_to_token() is not available when using Python based tokenizers") + if char_index is not None: + batch_index = batch_or_char_index + else: + batch_index = 0 + char_index = batch_or_char_index + return self._encodings[batch_index].char_to_token(char_index) + + def word_to_chars(self, batch_or_word_index: int, word_index: Optional[int] = None) -> CharSpan: + """ + Get the character span in the original string corresponding to given word in a sequence + of the batch. + + Character spans are returned as a CharSpan NamedTuple with: + + - start: index of the first character in the original string + - end: index of the character following the last character in the original string + + Can be called as: + + - ``self.word_to_chars(word_index)`` if batch size is 1 + - ``self.word_to_chars(batch_index, word_index)`` if batch size is greater or equal to 1 + + Args: + batch_or_word_index (:obj:`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, + this can be the index of the word in the sequence + word_index (:obj:`int`, `optional`): + If a batch index is provided in `batch_or_token_index`, this can be the index + of the word in the sequence. + + Returns: + :obj:`CharSpan` or :obj:`List[CharSpan]`: + Span(s) of the associated character or characters in the string. + CharSpan are NamedTuple with: + + - start: index of the first character associated to the token in the original string + - end: index of the character following the last character associated to the token in the original string + """ + + if not self._encodings: + raise ValueError("word_to_chars() is not available when using Python based tokenizers") + if word_index is not None: + batch_index = batch_or_word_index + else: + batch_index = 0 + word_index = batch_or_word_index + return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index))) + + def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None) -> int: + """ + Get the word in the original string corresponding to a character in the original string of + a sequence of the batch. + + Can be called as: + + - ``self.char_to_word(char_index)`` if batch size is 1 + - ``self.char_to_word(batch_index, char_index)`` if batch size is greater than 1 + + This method is particularly suited when the input sequences are provided as + pre-tokenized sequences (i.e. words are defined by the user). In this case it allows + to easily associate encoded tokens with provided tokenized words. + + Args: + batch_or_char_index (:obj:`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, + this can be the index of the character in the orginal string. + char_index (:obj:`int`, `optional`): + If a batch index is provided in `batch_or_token_index`, this can be the index + of the character in the orginal string. + + + Returns: + :obj:`int` or :obj:`List[int]`: + Index or indices of the associated encoded token(s). + """ + + if not self._encodings: + raise ValueError("char_to_word() is not available when using Python based tokenizers") + if char_index is not None: + batch_index = batch_or_char_index + else: + batch_index = 0 + char_index = batch_or_char_index + return self._encodings[batch_index].char_to_word(char_index) + + def convert_to_tensors(self, tensor_type: Union[None, str, TensorType], prepend_batch_axis: bool = False): + if tensor_type is None: + return self + + # Convert to TensorType + if not isinstance(tensor_type, TensorType): + tensor_type = TensorType(tensor_type) + + # Get a function reference for the correct framework + if tensor_type == TensorType.TENSORFLOW and is_tf_available(): + as_tensor = tf.constant + elif tensor_type == TensorType.PYTORCH and is_torch_available(): + as_tensor = torch.tensor + elif tensor_type == TensorType.NUMPY: + as_tensor = np.asarray + else: + raise ImportError( + "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( + tensor_type + ) + ) + + # Do the tensor conversion in batch + for key, value in self.items(): + try: + if prepend_batch_axis: + value = [value] + + tensor = as_tensor(value) + + # at-least2d + if tensor.ndim > 2: + tensor = tensor.squeeze(0) + elif tensor.ndim < 2: + tensor = tensor[None, :] + + self[key] = tensor + except: # noqa E722 + raise ValueError( + "Unable to create tensor, you should probably activate truncation and/or padding " + "with 'padding=True' 'truncation=True' to have batched tensors with the same length." + ) + + return self + + @torch_required + def to(self, device: str): + """Send all values to device by calling v.to(device)""" + self.data = {k: v.to(device) for k, v in self.data.items()} + return self + + +# class AddedToken(UserString): +# """ AddedToken represents a token to be added to a Tokenizer + +# An AddedToken can have special options defining the way it should behave. + +# Args: +# content: str: +# The content of the token + +# single_word: bool +# Whether this token should only match against single word. If True, +# this token will never match inside of a word. + +# lstrip: bool +# Whether this token should strip all potential whitespaces on the left side. +# If True, this token will greedily match any whitespace on the left and then strip +# them out. + +# rstrip: bool +# Whether this token should strip all potential whitespaces on the right side. +# If True, this token will greedily match any whitespace on the right and then strip +# them out. +# """ + +# def __init__( +# self, data: str, single_word: bool = False, lstrip: bool = False, rstrip: bool = False, +# ): +# super().__init__(data) + +# self._single_word = single_word +# self._lstrip = lstrip +# self._rstrip = rstrip + +# def lower(self): +# return AddedToken(self.data.lower(), self._single_word, self._lstrip, self._rstrip) + + +class SpecialTokensMixin: + """ SpecialTokensMixin is derived by ``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` and + handles specific behaviors related to special tokens. In particular, this class hold the + attributes which can be used to directly access to these special tokens in a + model-independant manner and allow to set and update the special tokens. + """ + + SPECIAL_TOKENS_ATTRIBUTES = [ + "bos_token", + "eos_token", + "unk_token", + "sep_token", + "pad_token", + "cls_token", + "mask_token", + "additional_special_tokens", + ] + + def __init__(self, verbose=True, **kwargs): + self._bos_token = None + self._eos_token = None + self._unk_token = None + self._sep_token = None + self._pad_token = None + self._cls_token = None + self._mask_token = None + self._pad_token_type_id = 0 + self._additional_special_tokens = [] + self.verbose = verbose + + # We directly set the hidden value to allow initialization with special tokens + # which are not yet in the vocabulary. Necesssary for serialization/de-serialization + # TODO clean this up at some point (probably by sitching to fast tokenizers) + for key, value in kwargs.items(): + if key in self.SPECIAL_TOKENS_ATTRIBUTES: + if key == "additional_special_tokens": + assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value) + setattr(self, key, value) + elif isinstance(value, (str, AddedToken)): + setattr(self, key, value) + else: + raise TypeError( + "special token {} has to be either str or AddedToken but got: {}".format(key, type(value)) + ) + + def sanitize_special_tokens(self) -> int: + """ Make sure that all the special tokens attributes of the tokenizer (tokenizer.mask_token, tokenizer.cls_token, ...) + are in the vocabulary. Add the missing ones to the vocabulary if needed. + + Return: + Number of tokens added in the vocaulary during the operation. + """ + return self.add_tokens(self.all_special_tokens_extended, special_tokens=True) + + def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int: + """ + Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them + to class attributes. If special tokens are NOT in the vocabulary, they are added + to it (indexed starting from the last index of the current vocabulary). + + Using `add_special_tokens` will ensure your special tokens can be used in several ways: + + - special tokens are carefully handled by the tokenizer (they are never split) + - you can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This makes it easy to develop model-agnostic training and fine-tuning scripts. + + When possible, special tokens are already registered for provided pretrained models (ex: BertTokenizer cls_token is already registered to be '[CLS]' and XLM's one is also registered to be '') + + Args: + special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: + [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, + ``additional_special_tokens``]. + + Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). + + Returns: + Number of tokens added to the vocabulary. + + Examples:: + + # Let's see how to add a new classification token to GPT-2 + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + model = GPT2Model.from_pretrained('gpt2') + + special_tokens_dict = {'cls_token': ''} + + num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) + print('We have added', num_added_toks, 'tokens') + model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. + + assert tokenizer.cls_token == '' + """ + if not special_tokens_dict: + return 0 + + added_tokens = 0 + for key, value in special_tokens_dict.items(): + assert key in self.SPECIAL_TOKENS_ATTRIBUTES + + if self.verbose: + logger.info("Assigning %s to the %s key of the tokenizer", value, key) + setattr(self, key, value) + + if key == "additional_special_tokens": + assert isinstance(value, (list, tuple)) and all( + isinstance(t, (str, AddedToken)) for t in value + ), f"Tokens {value} for key {key} should all be str or AddedToken instances" + added_tokens += self.add_tokens(value, special_tokens=True) + else: + assert isinstance( + value, (str, AddedToken) + ), f"Token {value} for key {key} should be a str or an AddedToken instance" + added_tokens += self.add_tokens([value], special_tokens=True) + + return added_tokens + + def add_tokens(self, new_tokens: Union[str, AddedToken, List[str], List[AddedToken]], special_tokens=False) -> int: + """ + Add a list of new tokens to the tokenizer class. If the new tokens are not in the + vocabulary, they are added to it with indices starting from length of the current vocabulary. + + Args: + new_tokens: string or list of string or :class:`~transformers.AddedToken`. Each string is a token to add. + Tokens are only added if they are not already in the vocabulary. AddedToken wrap a string token to + let you personnalize it's behavior (Whether this token should only match against single word, whether + this token should strip all potential whitespaces on the left side, Whether this token should strip + all potential whitespaces on the right side...). + special_token: can be used to specify if the token is a special token. This mostly change the normalization + behavior (special tokens like CLS or [MASK] are usually not lower-cased for instance) + + See details for :class:`~transformers.AddedToken` in HuggingFace tokenizers library. + + Returns: + Number of tokens added to the vocabulary. + + Examples:: + + # Let's see how to increase the vocabulary of Bert model and tokenizer + tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') + model = BertModel.from_pretrained('bert-base-uncased') + + num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) + print('We have added', num_added_toks, 'tokens') + model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. + """ + if not new_tokens: + return 0 + + if not isinstance(new_tokens, (list, tuple)): + new_tokens = [new_tokens] + + return self._add_tokens(new_tokens, special_tokens=special_tokens) + + @property + def bos_token(self): + """ Beginning of sentence token (string). Log an error if used while not having been set. """ + if self._bos_token is None and self.verbose: + logger.error("Using bos_token, but it is not set yet.") + return None + return str(self._bos_token) + + @property + def eos_token(self): + """ End of sentence token (string). Log an error if used while not having been set. """ + if self._eos_token is None and self.verbose: + logger.error("Using eos_token, but it is not set yet.") + return None + return str(self._eos_token) + + @property + def unk_token(self): + """ Unknown token (string). Log an error if used while not having been set. """ + if self._unk_token is None and self.verbose: + logger.error("Using unk_token, but it is not set yet.") + return None + return str(self._unk_token) + + @property + def sep_token(self): + """ Separation token (string). E.g. separate context and query in an input sequence. Log an error if used while not having been set. """ + if self._sep_token is None and self.verbose: + logger.error("Using sep_token, but it is not set yet.") + return None + return str(self._sep_token) + + @property + def pad_token(self): + """ Padding token (string). Log an error if used while not having been set. """ + if self._pad_token is None and self.verbose: + logger.error("Using pad_token, but it is not set yet.") + return None + return str(self._pad_token) + + @property + def cls_token(self): + """ Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ + if self._cls_token is None and self.verbose: + logger.error("Using cls_token, but it is not set yet.") + return None + return str(self._cls_token) + + @property + def mask_token(self): + """ Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """ + if self._mask_token is None and self.verbose: + logger.error("Using mask_token, but it is not set yet.") + return None + return str(self._mask_token) + + @property + def additional_special_tokens(self): + """ All the additional special tokens you may want to use (list of strings). Log an error if used while not having been set. """ + if self._additional_special_tokens is None and self.verbose: + logger.error("Using additional_special_tokens, but it is not set yet.") + return None + return [str(tok) for tok in self._additional_special_tokens] + + @bos_token.setter + def bos_token(self, value): + self._bos_token = value + + @eos_token.setter + def eos_token(self, value): + self._eos_token = value + + @unk_token.setter + def unk_token(self, value): + self._unk_token = value + + @sep_token.setter + def sep_token(self, value): + self._sep_token = value + + @pad_token.setter + def pad_token(self, value): + self._pad_token = value + + @cls_token.setter + def cls_token(self, value): + self._cls_token = value + + @mask_token.setter + def mask_token(self, value): + self._mask_token = value + + @additional_special_tokens.setter + def additional_special_tokens(self, value): + self._additional_special_tokens = value + + @property + def bos_token_id(self): + """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """ + if self._bos_token is None: + return None + return self.convert_tokens_to_ids(self.bos_token) + + @property + def eos_token_id(self): + """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """ + if self._eos_token is None: + return None + return self.convert_tokens_to_ids(self.eos_token) + + @property + def unk_token_id(self): + """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """ + if self._unk_token is None: + return None + return self.convert_tokens_to_ids(self.unk_token) + + @property + def sep_token_id(self): + """ Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """ + if self._sep_token is None: + return None + return self.convert_tokens_to_ids(self.sep_token) + + @property + def pad_token_id(self): + """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """ + if self._pad_token is None: + return None + return self.convert_tokens_to_ids(self.pad_token) + + @property + def pad_token_type_id(self): + """ Id of the padding token type in the vocabulary.""" + return self._pad_token_type_id + + @property + def cls_token_id(self): + """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ + if self._cls_token is None: + return None + return self.convert_tokens_to_ids(self.cls_token) + + @property + def mask_token_id(self): + """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """ + if self._mask_token is None: + return None + return self.convert_tokens_to_ids(self.mask_token) + + @property + def additional_special_tokens_ids(self): + """ Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """ + return self.convert_tokens_to_ids(self.additional_special_tokens) + + @property + def special_tokens_map(self): + """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their + values ('', ''...) + Convert tokens of AddedToken type in string. + All returned tokens are strings + """ + set_attr = {} + for attr in self.SPECIAL_TOKENS_ATTRIBUTES: + attr_value = getattr(self, "_" + attr) + if attr_value: + set_attr[attr] = str(attr_value) + return set_attr + + @property + def special_tokens_map_extended(self): + """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their + values ('', ''...) + Keep the tokens as AddedToken if they are of this type. + + AddedToken can be used to control more finely how special tokens are tokenized. + """ + set_attr = {} + for attr in self.SPECIAL_TOKENS_ATTRIBUTES: + attr_value = getattr(self, "_" + attr) + if attr_value: + set_attr[attr] = attr_value + return set_attr + + @property + def all_special_tokens(self): + """ List all the special tokens ('', ''...) mapped to class attributes + Convert tokens of AddedToken type in string. + All returned tokens are strings + (cls_token, unk_token...). + """ + all_toks = [str(s) for s in self.all_special_tokens_extended] + return all_toks + + @property + def all_special_tokens_extended(self): + """ List all the special tokens ('', ''...) mapped to class attributes + Keep the tokens as AddedToken if they are of this type. + + AddedToken can be used to control more finely how special tokens are tokenized. + """ + all_toks = [] + set_attr = self.special_tokens_map_extended + for attr_value in set_attr.values(): + all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value]) + all_toks = list(set(all_toks)) + return all_toks + + @property + def all_special_ids(self): + """ List the vocabulary indices of the special tokens ('', ''...) mapped to + class attributes (cls_token, unk_token...). + """ + all_toks = self.all_special_tokens + all_ids = self.convert_tokens_to_ids(all_toks) + return all_ids + + +ENCODE_KWARGS_DOCSTRING = r""" + add_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`True`): + If set to ``True``, the sequences will be encoded with the special tokens relative + to their model. + `padding` (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`): + Activate and control padding. Accepts the following values: + + * `True` or `'longest'`: pad to the longest sequence in the batch (or no padding if only a single sequence if provided), + * `'max_length'`: pad to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`) + * `False` or `'do_not_pad'` (default): No padding (i.e. can output batch with sequences of uneven lengths) + `truncation` (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`): + Activate and control truncation. Accepts the following values: + + * `True` or `'longest_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided, + * `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided, + * `'only_second'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided, + * `False` or `'do_not_truncate'` (default): No truncation (i.e. can output batch with sequences length greater than the model max admissible input size) + `max_length` (:obj:`Union[int, None]`, `optional`, defaults to :obj:`None`): + Control the length for padding/truncation. Accepts the following values + + * `None` (default): This will use the predefined model max length if required by one of the truncation/padding parameters. If the model has no specific max input length (e.g. XLNet) truncation/padding to max length is deactivated. + * `any integer value` (e.g. `42`): Use this specific maximum length value if required by one of the truncation/padding parameters. + stride (:obj:`int`, `optional`, defaults to ``0``): + If set to a number along with max_length, the overflowing tokens returned when `return_overflowing_tokens=True` + will contain some tokens from the end of the truncated sequence returned to provide some overlap between truncated and overflow ing sequences. + The value of this argument defines the number of overlapping tokens. + is_pretokenized (:obj:`bool`, defaults to :obj:`False`): + Set to True to indicate the input is already tokenized + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + >= 7.5 (Volta). + return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`): + Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`, + PyTorch :obj:`torch.Tensor` or Numpy :oj: `np.ndarray` instead of a list of python integers. +""" + +ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + return_token_type_ids (:obj:`bool`, `optional`, defaults to :obj:`None`): + Whether to return token type IDs. If left to the default, will return the token type IDs according + to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + return_attention_mask (:obj:`bool`, `optional`, defaults to :obj:`none`): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. + + `What are attention masks? <../glossary.html#attention-mask>`__ + return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True to return overflowing token sequences (default False). + return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True to return special tokens mask information (default False). + return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True to return (char_start, char_end) for each token (default False). + If using Python's tokenizer, this method will raise NotImplementedError. + This one is only available on fast tokenizers inheriting from PreTrainedTokenizerFast. + **kwargs: passed to the `self.tokenize()` method + + Return: + A Dictionary of shape:: + + { + input_ids: list[int], + token_type_ids: list[int] if return_token_type_ids is True (default) + attention_mask: list[int] if return_attention_mask is True (default) + overflowing_tokens: list[int] if the tokenizer is a slow tokenize, else a List[List[int]] if a ``max_length`` is specified and ``return_overflowing_tokens=True`` + special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` + and return_special_tokens_mask is True + } + + With the fields: + + - ``input_ids``: list of token ids to be fed to a model + - ``token_type_ids``: list of token type ids to be fed to a model + - ``attention_mask``: list of indices specifying which tokens should be attended to by the model + - ``overflowing_tokens``: list of overflowing tokens sequences if a max length is specified and ``return_overflowing_tokens=True``. + - ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added + tokens and 1 specifying sequence tokens. +""" + + +class PreTrainedTokenizerBase(SpecialTokensMixin): + """ Base class for slow and fast tokenizers. + + Handle shared (mostly boiler plate) methods for slow and fast tokenizers. + """ + + vocab_files_names: Dict[str, str] = {} + pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {} + pretrained_init_configuration: Dict[str, Dict[str, Any]] = {} + max_model_input_sizes: Dict[str, int] = {} + model_input_names: List[str] = ["token_type_ids", "attention_mask"] + + padding_side: str = "right" + + def __init__(self, **kwargs): + # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) + self.init_inputs = () + self.init_kwargs = kwargs + + # For backward compatibility we fallback to set model_max_length from max_len if provided + model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None)) + self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER + + # Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed. + self.padding_side = kwargs.pop("padding_side", self.padding_side) + assert self.padding_side in [ + "right", + "left", + ], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" + self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) + + super().__init__(**kwargs) + + @property + def max_len(self) -> int: + """ Kept here for backward compatibility. + Now renamed to `model_max_length` to avoid ambiguity. + """ + return self.model_max_length + + @property + def max_len_single_sentence(self) -> int: + return self.model_max_length - self.num_special_tokens_to_add(pair=False) + + @property + def max_len_sentences_pair(self) -> int: + return self.model_max_length - self.num_special_tokens_to_add(pair=True) + + @max_len_single_sentence.setter + def max_len_single_sentence(self, value) -> int: + """ For backward compatibility, allow to try to setup 'max_len_single_sentence' """ + if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose: + logger.warning( + "Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up." + ) + else: + raise ValueError( + "Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up." + ) + + @max_len_sentences_pair.setter + def max_len_sentences_pair(self, value) -> int: + """ For backward compatibility, allow to try to setup 'max_len_sentences_pair' """ + if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose: + logger.warning( + "Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up." + ) + else: + raise ValueError( + "Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up." + ) + + @classmethod + def from_pretrained(cls, *inputs, **kwargs): + r""" + Instantiate a :class:`~transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer. + + Args: + pretrained_model_name_or_path: either: + + - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. + - (not applicable to all derived classes, deprecated) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. + + cache_dir: (`optional`) string: + Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. + + force_download: (`optional`) boolean, default False: + Force to (re-)download the vocabulary files and override the cached versions if they exists. + + resume_download: (`optional`) boolean, default False: + Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. + + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + + inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. + + kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details. + + Examples:: + + # We can't instantiate directly the base class `PreTrainedTokenizer` so let's show our examples on a derived class: BertTokenizer + + # Download vocabulary from S3 and cache. + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + + # Download vocabulary from S3 (user-uploaded) and cache. + tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-german-cased') + + # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`) + tokenizer = BertTokenizer.from_pretrained('./test/saved_model/') + + # If the tokenizer uses a single vocabulary file, you can point directly to this file + tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt') + + # You can link tokens to special vocabulary when instantiating + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='') + # You should be sure '' is in the vocabulary when doing that. + # Otherwise use tokenizer.add_special_tokens({'unk_token': ''}) instead) + assert tokenizer.unk_token == '' + + """ + return cls._from_pretrained(*inputs, **kwargs) + + @classmethod + def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + + s3_models = list(cls.max_model_input_sizes.keys()) + vocab_files = {} + init_configuration = {} + if pretrained_model_name_or_path in s3_models: + # Get the vocabulary from AWS S3 bucket + for file_id, map_list in cls.pretrained_vocab_files_map.items(): + vocab_files[file_id] = map_list[pretrained_model_name_or_path] + if ( + cls.pretrained_init_configuration + and pretrained_model_name_or_path in cls.pretrained_init_configuration + ): + init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path].copy() + else: + # Get the vocabulary from local files + logger.info( + "Model name '{}' not found in model shortcut name list ({}). " + "Assuming '{}' is a path, a model identifier, or url to a directory containing tokenizer files.".format( + pretrained_model_name_or_path, ", ".join(s3_models), pretrained_model_name_or_path + ) + ) + + if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + if len(cls.vocab_files_names) > 1: + raise ValueError( + "Calling {}.from_pretrained() with the path to a single file or url is not supported." + "Use a model identifier or the path to a directory instead.".format(cls.__name__) + ) + logger.warning( + "Calling {}.from_pretrained() with the path to a single file or url is deprecated".format( + cls.__name__ + ) + ) + file_id = list(cls.vocab_files_names.keys())[0] + vocab_files[file_id] = pretrained_model_name_or_path + else: + # At this point pretrained_model_name_or_path is either a directory or a model identifier name + additional_files_names = { + "added_tokens_file": ADDED_TOKENS_FILE, + "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, + "tokenizer_config_file": TOKENIZER_CONFIG_FILE, + "full_tokenizer_file": FULL_TOKENIZER_FILE, + } + # Look for the tokenizer files + for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items(): + if os.path.isdir(pretrained_model_name_or_path): + full_file_name = os.path.join(pretrained_model_name_or_path, file_name) + if not os.path.exists(full_file_name): + logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) + full_file_name = None + else: + full_file_name = hf_bucket_url( + pretrained_model_name_or_path, filename=file_name, use_cdn=False + ) + + vocab_files[file_id] = full_file_name + + # Get files from url, cache, or disk depending on the case + try: + resolved_vocab_files = {} + for file_id, file_path in vocab_files.items(): + if file_path is None: + resolved_vocab_files[file_id] = None + else: + resolved_vocab_files[file_id] = cached_path( + file_path, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + ) + except EnvironmentError: + if pretrained_model_name_or_path in s3_models: + msg = "Couldn't reach server at '{}' to download vocabulary files." + else: + msg = ( + "Model name '{}' was not found in tokenizers model name list ({}). " + "We assumed '{}' was a path or url to a directory containing vocabulary files " + "named {}, but couldn't find such vocabulary files at this path or url.".format( + pretrained_model_name_or_path, + ", ".join(s3_models), + pretrained_model_name_or_path, + list(cls.vocab_files_names.values()), + ) + ) + + raise EnvironmentError(msg) + + if all(full_file_name is None for full_file_name in resolved_vocab_files.values()): + raise EnvironmentError( + "Model name '{}' was not found in tokenizers model name list ({}). " + "We assumed '{}' was a path, a model identifier, or url to a directory containing vocabulary files " + "named {} but couldn't find such vocabulary files at this path or url.".format( + pretrained_model_name_or_path, + ", ".join(s3_models), + pretrained_model_name_or_path, + list(cls.vocab_files_names.values()), + ) + ) + + for file_id, file_path in vocab_files.items(): + if file_path == resolved_vocab_files[file_id]: + logger.info("loading file {}".format(file_path)) + else: + logger.info("loading file {} from cache at {}".format(file_path, resolved_vocab_files[file_id])) + + # Prepare tokenizer initialization kwargs + # Did we saved some inputs and kwargs to reload ? + tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None) + if tokenizer_config_file is not None: + with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: + init_kwargs = json.load(tokenizer_config_handle) + saved_init_inputs = init_kwargs.pop("init_inputs", ()) + if not init_inputs: + init_inputs = saved_init_inputs + else: + init_kwargs = init_configuration + + # Update with newly provided kwargs + init_kwargs.update(kwargs) + + # Set max length if needed + if pretrained_model_name_or_path in cls.max_model_input_sizes: + # if we're using a pretrained model, ensure the tokenizer + # wont index sequences longer than the number of positional embeddings + model_max_length = cls.max_model_input_sizes[pretrained_model_name_or_path] + if model_max_length is not None and isinstance(model_max_length, (int, float)): + init_kwargs["model_max_length"] = min(init_kwargs.get("model_max_length", int(1e30)), model_max_length) + + # Merge resolved_vocab_files arguments in init_kwargs. + added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None) + for args_name, file_path in resolved_vocab_files.items(): + if args_name not in init_kwargs: + init_kwargs[args_name] = file_path + + # Instantiate tokenizer. + try: + tokenizer = cls(*init_inputs, **init_kwargs) + except OSError: + raise OSError( + "Unable to load vocabulary from file. " + "Please check that the provided vocabulary is accessible and not corrupted." + ) + + # Save inputs and kwargs for saving and re-loading with ``save_pretrained`` + tokenizer.init_inputs = init_inputs + tokenizer.init_kwargs = init_kwargs + + # If there is a complementary special token map, load it + special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None) + if special_tokens_map_file is not None: + with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle: + special_tokens_map = json.load(special_tokens_map_handle) + + for key, value in special_tokens_map.items(): + if isinstance(value, dict): + value = AddedToken(**value) + setattr(tokenizer, key, value) + + # Add supplementary tokens. + special_tokens = tokenizer.all_special_tokens + if added_tokens_file is not None: + with open(added_tokens_file, encoding="utf-8") as added_tokens_handle: + added_tok_encoder = json.load(added_tokens_handle) + + # Sort added tokens by index + added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1])) + + for token, index in added_tok_encoder_sorted: + assert index == len(tokenizer), ( + f"Non-consecutive added token '{token}' found. " + f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary." + ) + tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens)) + + # Check all our special tokens are registrered as "no split" token (we don't cut them) and are in the vocab + added_tokens = tokenizer.sanitize_special_tokens() + if added_tokens: + logger.warning( + "Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained." + ) + + return tokenizer + + def save_pretrained(self, save_directory) -> Tuple[str]: + """ Save the tokenizer vocabulary files together with: + - added tokens, + - special-tokens-to-class-attributes-mapping, + - tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert). + + Warning: This won't save modifications you may have applied to the tokenizer after the instantiation + (e.g. modifying tokenizer.do_lower_case after creation). + + This method make sure the full tokenizer can then be re-loaded using the + :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method. + """ + if os.path.isfile(save_directory): + logger.error("Provided path ({}) should be a directory, not a file".format(save_directory)) + return + os.makedirs(save_directory, exist_ok=True) + + special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE) + added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE) + tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE) + + tokenizer_config = copy.deepcopy(self.init_kwargs) + if len(self.init_inputs) > 0: + tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs) + for file_id in self.vocab_files_names.keys(): + tokenizer_config.pop(file_id, None) + + with open(tokenizer_config_file, "w", encoding="utf-8") as f: + f.write(json.dumps(tokenizer_config, ensure_ascii=False)) + + with open(special_tokens_map_file, "w", encoding="utf-8") as f: + write_dict = {} + for key, value in self.special_tokens_map_extended.items(): + if isinstance(value, AddedToken): + write_dict[key] = value.__getstate__() + else: + write_dict[key] = value + f.write(json.dumps(write_dict, ensure_ascii=False)) + + added_vocab = self.get_added_vocab() + if added_vocab: + with open(added_tokens_file, "w", encoding="utf-8") as f: + out_str = json.dumps(added_vocab, ensure_ascii=False) + f.write(out_str) + + vocab_files = self.save_vocabulary(save_directory) + + return vocab_files + (special_tokens_map_file, added_tokens_file) + + @add_end_docstrings( + ENCODE_KWARGS_DOCSTRING, + """ + **kwargs: passed to the `self.tokenize()` method. + """, + ) + def encode( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str] = False, + truncation: Union[bool, str] = False, + max_length: Optional[int] = None, + stride: int = 0, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ): + """ + Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. + + Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``. + + Args: + text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method) + text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized + string using the `tokenize` method) or a list of integers (tokenized string ids using the + `convert_tokens_to_ids` method) + """ + encoded_inputs = self.encode_plus( + text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + return_tensors=return_tensors, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + def num_special_tokens_to_add(self, pair: bool = False) -> int: + raise NotImplementedError + + def _get_padding_truncation_strategies( + self, padding=False, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs + ): + """ Find the correct padding/truncation strategy with backward compatibility + for old arguments (truncation_strategy and pad_to_max_length) and behaviors. + """ + old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate") + old_pad_to_max_length = kwargs.pop("pad_to_max_length", False) + + # Backward compatibility for previous behavior, maybe we should deprecate it: + # If you only set max_length, it activates truncation for max_length + if max_length is not None and padding is False and truncation is False: + if verbose: + logger.warning( + "Truncation was not explicitely activated but `max_length` is provided a specific value, " + "please use `truncation=True` to explicitely truncate examples to max length. " + "Defaulting to 'longest_first' truncation strategy. " + "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy " + "more precisely by providing a specific strategy to `truncation`." + ) + truncation = "longest_first" + + # Get padding strategy + if padding is False and old_pad_to_max_length: + if verbose: + warnings.warn( + "The `pad_to_max_length` argument is deprecated and will be removed in a future version, " + "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or " + "use `padding='max_length'` to pad to a max length. In this case, you can give a specific " + "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the " + "maximal input size of the model (e.g. 512 for Bert).", + DeprecationWarning, + ) + if max_length is None: + padding_strategy = PaddingStrategy.LONGEST + else: + padding_strategy = PaddingStrategy.MAX_LENGTH + elif padding is not False: + if padding is True: + padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch + elif not isinstance(padding, PaddingStrategy): + padding_strategy = PaddingStrategy(padding) + else: + padding_strategy = PaddingStrategy.DO_NOT_PAD + + # Get truncation strategy + if truncation is False and old_truncation_strategy != "do_not_truncate": + if verbose: + warnings.warn( + "The `truncation_strategy` argument is deprecated and will be removed in a future version, " + "use `truncation=True` to truncate examples to a max length. You can give a specific " + "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the " + "maximal input size of the model (e.g. 512 for Bert). " + " If you have pairs of inputs, you can give a specific truncation strategy selected among " + "`truncation='only_first'` (will only truncate the first sentence in the pairs) " + "`truncation='only_second'` (will only truncate the second sentence in the pairs) " + "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).", + DeprecationWarning, + ) + truncation_strategy = TruncationStrategy(old_truncation_strategy) + elif truncation is not False: + if truncation is True: + truncation_strategy = ( + TruncationStrategy.LONGEST_FIRST + ) # Default to truncate the longest sequences in pairs of inputs + elif not isinstance(truncation, TruncationStrategy): + truncation_strategy = TruncationStrategy(truncation) + else: + truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE + + # Set max length if needed + if max_length is None: + if padding_strategy == PaddingStrategy.MAX_LENGTH: + if self.model_max_length > LARGE_INTEGER: + if verbose: + logger.warning( + "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. " + "Default to no padding." + ) + padding_strategy = PaddingStrategy.DO_NOT_PAD + else: + max_length = self.model_max_length + + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE: + if self.model_max_length > LARGE_INTEGER: + if verbose: + logger.warning( + "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. " + "Default to no truncation." + ) + truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE + else: + max_length = self.model_max_length + + # Test if we have a padding token + if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0): + raise ValueError( + "Asking to pad but the tokenizer does not have a padding token. " + "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` " + "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`." + ) + + # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided + if ( + truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE + and padding_strategy != PaddingStrategy.DO_NOT_PAD + and pad_to_multiple_of is not None + and max_length is not None + and (max_length % pad_to_multiple_of != 0) + ): + raise ValueError( + f"Truncation and padding are both activated but " + f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})." + ) + + return padding_strategy, truncation_strategy, max_length, kwargs + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], + text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str] = False, + truncation: Union[bool, str] = False, + max_length: Optional[int] = None, + stride: int = 0, + is_pretokenized: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Returns a dictionary containing the encoded sequence or sequence pair and additional information: + the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. + + Args: + text (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]``): + The sequence or batch of sequences to be encoded. + Each sequence can be a string or a list of strings (pre-tokenized string). + If the sequences are provided as list of strings (pretokenized), you must set `is_pretokenized=True` + (to lift the ambiguity with a batch of sequences) + text_pair (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]``): + The sequence or batch of sequences to be encoded. + Each sequence can be a string or a list of strings (pre-tokenized string). + If the sequences are provided as list of strings (pretokenized), you must set `is_pretokenized=True` + (to lift the ambiguity with a batch of sequences) + """ + # Input type checking for clearer error + assert isinstance(text, str) or ( + isinstance(text, (list, tuple)) + and ( + len(text) == 0 + or ( + isinstance(text[0], str) + or (isinstance(text[0], (list, tuple)) and (len(text[0]) == 0 or isinstance(text[0][0], str))) + ) + ) + ), ( + "text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + + assert ( + text_pair is None + or isinstance(text_pair, str) + or ( + isinstance(text_pair, (list, tuple)) + and ( + len(text_pair) == 0 + or ( + isinstance(text_pair[0], str) + or ( + isinstance(text_pair[0], (list, tuple)) + and (len(text_pair[0]) == 0 or isinstance(text_pair[0][0], str)) + ) + ) + ) + ) + ), ( + "text_pair input must of type `str` (single example), `List[str]` (batch or single pretokenized example) " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + + is_batched = bool( + (not is_pretokenized and isinstance(text, (list, tuple))) + or (is_pretokenized and isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))) + ) + + if is_batched: + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + return self.batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + is_pretokenized=is_pretokenized, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + text=text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + is_pretokenized=is_pretokenized, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def encode_plus( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str] = False, + truncation: Union[bool, str] = False, + max_length: Optional[int] = None, + stride: int = 0, + is_pretokenized: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Returns a dictionary containing the encoded sequence or sequence pair and additional information: + the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. + + Args: + text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]` (the later only for not-fast tokenizers)): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method) + text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized + string using the `tokenize` method) or a list of integers (tokenized string ids using the + `convert_tokens_to_ids` method) + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus( + text=text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_pretokenized=is_pretokenized, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_pretokenized: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + raise NotImplementedError + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + List[PreTokenizedInputPair], + List[EncodedInput], + List[EncodedInputPair], + ], + add_special_tokens: bool = True, + padding: Union[bool, str] = False, + truncation: Union[bool, str] = False, + max_length: Optional[int] = None, + stride: int = 0, + is_pretokenized: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Returns a dictionary containing the encoded sequence or sequence pair and additional information: + the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. + + Args: + batch_text_or_text_pairs (:obj:`List[str]`, :obj:`List[Tuple[str, str]]`, + :obj:`List[List[str]]`, :obj:`List[Tuple[List[str], List[str]]]`, + and for not-fast tokenizers, also: + :obj:`List[List[int]]`, :obj:`List[Tuple[List[int], List[int]]]`): + Batch of sequences or pair of sequences to be encoded. + This can be a list of string/string-sequences/int-sequences or a list of pair of + string/string-sequences/int-sequence (see details in encode_plus) + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_pretokenized=is_pretokenized, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + List[PreTokenizedInputPair], + List[EncodedInput], + List[EncodedInputPair], + ], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_pretokenized: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + raise NotImplementedError + + def pad( + self, + encoded_inputs: Union[ + BatchEncoding, + List[BatchEncoding], + Dict[str, EncodedInput], + Dict[str, List[EncodedInput]], + List[Dict[str, EncodedInput]], + ], + padding: Union[bool, str] = True, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + verbose: bool = True, + ) -> BatchEncoding: + """ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length in the batch. + + Padding side (left/right) padding token ids are defined at the tokenizer level + (with ``self.padding_side``, ``self.pad_token_id`` and ``self.pad_token_type_id``) + + Args: + encoded_inputs: Dictionary of tokenized inputs (`Dict[str, List[int]]`) or batch of tokenized inputs. + Batch of tokenized inputs can be given as dicts of lists or lists of dicts, both work so you can + use ``tokenizer.pad()`` during pre-processing as well as in a PyTorch Dataloader collate function. + (`Dict[str, List[List[int]]]` or `List[Dict[str, List[int]]]`). + padding: Boolean or specific strategy to use for padding. + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: + - 'longest' (or `True`) Pad to the longest sequence in the batch + - 'max_length': Pad to the max length (default) + - 'do_not_pad' (or `False`): Do not pad + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + >= 7.5 (Volta). + return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) + return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`): + Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`, + PyTorch :obj:`torch.Tensor` or Numpy :oj: `np.ndarray` instead of a list of python integers. + verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): + Set to ``False`` to avoid printing infos and warnings. + """ + # If we have a list of dicts, let's convert it in a dict of lists + if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)): + encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} + + assert "input_ids" in encoded_inputs, ( + "You should supply an encoding or a list of encodings to this method. " + "An encoding is the output of one the encoding methods of the tokenizer, i.e. " + "__call__/encode_plus/batch_encode_plus. " + ) + + if not encoded_inputs["input_ids"]: + if return_attention_mask: + encoded_inputs["attention_mask"] = [] + return encoded_inputs + + # Convert padding_strategy in PaddingStrategy + padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( + padding=padding, max_length=max_length, verbose=verbose + ) + + if encoded_inputs["input_ids"] and not isinstance(encoded_inputs["input_ids"][0], (list, tuple)): + encoded_inputs = self._pad( + encoded_inputs, + max_length=max_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + return BatchEncoding(encoded_inputs, tensor_type=return_tensors) + + batch_size = len(encoded_inputs["input_ids"]) + assert all( + len(v) == batch_size for v in encoded_inputs.values() + ), "Some items in the output dictionnary have a different batch size than others." + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(len(inputs) for inputs in encoded_inputs["input_ids"]) + padding_strategy = PaddingStrategy.MAX_LENGTH + + batch_outputs = {} + for i in range(batch_size): + inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) + outputs = self._pad( + inputs, + max_length=max_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + return BatchEncoding(batch_outputs, tensor_type=return_tensors) + + def create_token_type_ids_from_sequences(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List[int]: + if token_ids_1 is None: + return len(token_ids_0) * [0] + return [0] * len(token_ids_0) + [1] * len(token_ids_1) + + def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. This implementation does not add special tokens. + """ + if token_ids_1 is None: + return token_ids_0 + return token_ids_0 + token_ids_1 + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def prepare_for_model( + self, + ids: List[int], + pair_ids: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str] = False, + truncation: Union[bool, str] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs + ) -> BatchEncoding: + """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. + It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens + + Args: + ids: list of tokenized input ids. Can be obtained from a string by chaining the + `tokenize` and `convert_tokens_to_ids` methods. + pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the + `tokenize` and `convert_tokens_to_ids` methods. + """ + + if "return_lengths" in kwargs: + if verbose: + warnings.warn( + "The PreTrainedTokenizerBase.prepare_for_model `return_lengths` parameter is deprecated. " + "Please use `return_length` instead.", + FutureWarning, + ) + return_length = kwargs["return_lengths"] + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + # Compute the total size of the returned encodings + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + ids, pair_ids, overflowing_tokens = self.truncate_sequences( + ids, + pair_ids=pair_ids, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else []) + + # Build output dictionnary + encoded_inputs["input_ids"] = sequence + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + # Check lengths + if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose: + logger.warning( + "Token indices sequence length is longer than the specified maximum sequence length " + "for this model ({} > {}). Running this sequence through the model will result in " + "indexing errors".format(len(ids), self.model_max_length) + ) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + def truncate_sequences( + self, + ids: List[int], + pair_ids: Optional[List[int]] = None, + num_tokens_to_remove: int = 0, + truncation_strategy: Union[str, TruncationStrategy] = "longest_first", + stride: int = 0, + ) -> Tuple[List[int], List[int], List[int]]: + """ Truncates a sequence pair in place to the maximum length. + + Args: + ids: list of tokenized input ids. Can be obtained from a string by chaining the + `tokenize` and `convert_tokens_to_ids` methods. + pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the + `tokenize` and `convert_tokens_to_ids` methods. + num_tokens_to_remove (:obj:`int`, `optional`, defaults to ``0``): + number of tokens to remove using the truncation strategy + truncation_strategy (:obj:`string`, `optional`, defaults to "longest_first"): + String selected in the following options: + + - 'longest_first' (default): Iteratively reduce the inputs sequence until the input is under max_length + starting from the longest one at each token (when there is a pair of input sequences). + Overflowing tokens only contains overflow from the first sequence. + - 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove. + - 'only_second': Only truncate the second sequence + - 'do_not_truncate' + stride (:obj:`int`, `optional`, defaults to ``0``): + If set to a number along with max_length, the overflowing tokens returned will contain some tokens + from the main sequence returned. The value of this argument defines the number of additional tokens. + """ + if num_tokens_to_remove <= 0: + return ids, pair_ids, [] + + if not isinstance(truncation_strategy, TruncationStrategy): + truncation_strategy = TruncationStrategy(truncation_strategy) + + overflowing_tokens = [] + if truncation_strategy == TruncationStrategy.LONGEST_FIRST: + for _ in range(num_tokens_to_remove): + if pair_ids is None or len(ids) > len(pair_ids): + if not overflowing_tokens: + window_len = min(len(ids), stride + 1) + else: + window_len = 1 + overflowing_tokens.extend(ids[-window_len:]) + ids = ids[:-1] + else: + if not overflowing_tokens: + window_len = min(len(pair_ids), stride + 1) + else: + window_len = 1 + overflowing_tokens.extend(pair_ids[-window_len:]) + pair_ids = pair_ids[:-1] + elif truncation_strategy == TruncationStrategy.ONLY_FIRST: + if len(ids) > num_tokens_to_remove: + window_len = min(len(ids), stride + num_tokens_to_remove) + overflowing_tokens = ids[-window_len:] + ids = ids[:-num_tokens_to_remove] + else: + logger.error( + f"We need to remove {num_tokens_to_remove} to truncate the input" + f"but the first sequence has a length {len(ids)}. " + f"Please select another truncation strategy than {truncation_strategy}, " + f"for instance 'longest_first' or 'only_second'." + ) + elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: + if len(pair_ids) > num_tokens_to_remove: + window_len = min(len(pair_ids), stride + num_tokens_to_remove) + overflowing_tokens = pair_ids[-window_len:] + pair_ids = pair_ids[:-num_tokens_to_remove] + else: + logger.error( + f"We need to remove {num_tokens_to_remove} to truncate the input" + f"but the second sequence has a length {len(pair_ids)}. " + f"Please select another truncation strategy than {truncation_strategy}, " + f"for instance 'longest_first' or 'only_first'." + ) + + return (ids, pair_ids, overflowing_tokens) + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ Pad encoded inputs (on left/right and up to predefined legnth or max length in the batch) + + Args: + encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + >= 7.5 (Volta). + return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(encoded_inputs["input_ids"]) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = ( + padding_strategy != PaddingStrategy.DO_NOT_PAD and len(encoded_inputs["input_ids"]) != max_length + ) + + if needs_to_be_padded: + difference = max_length - len(encoded_inputs["input_ids"]) + if self.padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference + elif self.padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"]) + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"] + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + else: + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + + return encoded_inputs + + def batch_decode(self, sequences: List[List[int]], **kwargs) -> List[str]: + return [self.decode(seq, **kwargs) for seq in sequences] + + def decode( + self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True + ) -> str: + """ + Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary + with options to remove special tokens and clean up tokenization spaces. + Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``. + + Args: + token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods. + skip_special_tokens: if set to True, will replace special tokens. + clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces. + """ + raise NotImplementedError + + def get_special_tokens_mask( + self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. + + Args: + token_ids_0: list of ids (must not contain special tokens) + token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids + for sequence pairs + already_has_special_tokens: (default False) Set to True if the token list is already formated with + special tokens for the model + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + assert already_has_special_tokens and token_ids_1 is None, ( + "You cannot use ``already_has_special_tokens=False`` with this tokenizer. " + "Please use a slow (full python) tokenizer to activate this argument." + "Or set `return_special_token_mask=True` when calling the encoding method " + "to get the special tokens mask in any tokenizer. " + ) + + all_special_ids = self.all_special_ids # cache the property + + special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0] + + return special_tokens_mask + + @staticmethod + def clean_up_tokenization(out_string: str) -> str: + """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms. + """ + out_string = ( + out_string.replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") + ) + return out_string diff --git a/elia/demo_inference.py b/elia/demo_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..630be118942e8ebc5a9bf6b13df650683f4e4ae8 --- /dev/null +++ b/elia/demo_inference.py @@ -0,0 +1,295 @@ +image_path = './image001.png' +sentence = 'spoon on the dish' +weights = '/cluster/nvme4/cyx/lavt/vis/model_best_refcoco_0508.pth' +device = 'cpu' + +# pre-process the input image +from PIL import Image +import torchvision.transforms as T +import numpy as np +import datetime +import os +import time + +import torch +import torch.utils.data +from torch import nn + +from bert.multimodal_bert import MultiModalBert +import torchvision + +from lib import multimodal_segmentation_ppm +#import transforms as T +import utils + +import numpy as np +from PIL import Image +import torch.nn.functional as F + +from modeling.MaskFormerModel import MaskFormerHead +from addict import Dict +#from bert.modeling_bert import BertLMPredictionHead, BertEncoder +import cv2 +import textwrap + +class WrapperModel(nn.Module): + def __init__(self, image_model, language_model, classifier) : + super(WrapperModel, self).__init__() + self.image_model = image_model + self.language_model = language_model + self.classifier = classifier + + config = Dict({ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + #"max_position_embeddings": 16+20, + "model_type": "bert", + "num_attention_heads": 8, + "num_hidden_layers": 8, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "use_cache": True, + "vocab_size": 30522 + }) + + + + def _get_binary_mask(self, target): + # 返回每类的binary mask + y, x = target.size() + target_onehot = torch.zeros(self.num_classes + 1, y, x) + target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1) + return target_onehot[1:] + + def semantic_inference(self, mask_cls, mask_pred): + mask_cls = F.softmax(mask_cls, dim=1)[...,1:] + mask_pred = mask_pred.sigmoid() + semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) + return semseg + + def forward(self, image, sentences, attentions): + print(image.sum(), sentences.sum(), attentions.sum()) + input_shape = image.shape[-2:] + l_mask = attentions.unsqueeze(dim=-1) + + i0, Wh, Ww = self.image_model.forward_stem(image) + l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions) + + i1 = self.image_model.forward_stage1(i0, Wh, Ww) + l1 = self.language_model.forward_stage1(l0, extended_attention_mask) + i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask) + l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) + i1 = i1_temp + + i2 = self.image_model.forward_stage2(i1, Wh, Ww) + l2 = self.language_model.forward_stage2(l1, extended_attention_mask) + i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask) + l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) + i2 = i2_temp + + i3 = self.image_model.forward_stage3(i2, Wh, Ww) + l3 = self.language_model.forward_stage3(l2, extended_attention_mask) + i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask) + l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) + i3 = i3_temp + + i4 = self.image_model.forward_stage4(i3, Wh, Ww) + l4 = self.language_model.forward_stage4(l3, extended_attention_mask) + i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask) + l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) + i4 = i4_temp + + #i1_residual, i2_residual, i3_residual, i4_residual = features + #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual) + #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) + outputs = {} + outputs['s1'] = i1_residual + outputs['s2'] = i2_residual + outputs['s3'] = i3_residual + outputs['s4'] = i4_residual + + predictions = self.classifier(outputs) + return predictions + +#img = Image.open(image_path).convert("RGB") +img = Image.open(image_path).convert("RGB") +img_ndarray = np.array(img) # (orig_h, orig_w, 3); for visualization +original_w, original_h = img.size # PIL .size returns width first and height second + +image_transforms = T.Compose( + [ + T.Resize((480, 480)), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ] +) + +img = image_transforms(img).unsqueeze(0) # (1, 3, 480, 480) +img = img.to(device) # for inference (input) + +# pre-process the raw sentence +from bert.tokenization_bert import BertTokenizer +import torch +tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') +sentence_tokenized = tokenizer.encode(text=sentence, add_special_tokens=True) +sentence_tokenized = sentence_tokenized[:20] # if the sentence is longer than 20, then this truncates it to 20 words +# pad the tokenized sentence +padded_sent_toks = [0] * 20 +padded_sent_toks[:len(sentence_tokenized)] = sentence_tokenized +# create a sentence token mask: 1 for real words; 0 for padded tokens +attention_mask = [0] * 20 +attention_mask[:len(sentence_tokenized)] = [1]*len(sentence_tokenized) +# convert lists to tensors +padded_sent_toks = torch.tensor(padded_sent_toks).unsqueeze(0) # (1, 20) +attention_mask = torch.tensor(attention_mask).unsqueeze(0) # (1, 20) +padded_sent_toks = padded_sent_toks.to(device) # for inference (input) +attention_mask = attention_mask.to(device) # for inference (input) + +# initialize model and load weights +#from bert.modeling_bert import BertModel +#from lib import segmentation + +# construct a mini args class; like from a config file + + +class args: + swin_type = 'base' + window12 = True + mha = '' + fusion_drop = 0.0 + + +#single_model = segmentation.__dict__['lavt'](pretrained='', args=args) +single_model = multimodal_segmentation_ppm.__dict__['lavt'](pretrained='',args=args) +single_model.to(device) +model_class = MultiModalBert +single_bert_model = model_class.from_pretrained('bert-base-uncased', embed_dim=single_model.backbone.embed_dim) +single_bert_model.pooler = None + +input_shape = dict() +input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) +input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) +input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) +input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) + + + +cfg = Dict() +cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 +cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 +cfg.MODEL.MASK_FORMER.NHEADS = 8 +cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4 +cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 +cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 +cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] + +cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 +cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 +cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1 +cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 +cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10 +cfg.MODEL.MASK_FORMER.PRE_NORM = False + + +maskformer_head = MaskFormerHead(cfg, input_shape) + + +model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head) + + + +checkpoint = torch.load(weights, map_location='cpu') + +model.load_state_dict(checkpoint['model'], strict=False) +model.to(device) +model.eval() +#single_bert_model.load_state_dict(checkpoint['bert_model']) +#single_model.load_state_dict(checkpoint['model']) +#model = single_model.to(device) +#bert_model = single_bert_model.to(device) + + +# inference +#import torch.nn.functional as F +#last_hidden_states = bert_model(padded_sent_toks, attention_mask=attention_mask)[0] +#embedding = last_hidden_states.permute(0, 2, 1) +#output = model(img, embedding, l_mask=attention_mask.unsqueeze(-1)) +#output = output.argmax(1, keepdim=True) # (1, 1, 480, 480) +#output = F.interpolate(output.float(), (original_h, original_w)) # 'nearest'; resize to the original image size +#output = output.squeeze() # (orig_h, orig_w) +#output = output.cpu().data.numpy() # (orig_h, orig_w) + +output = model(img, padded_sent_toks, attention_mask)[0] +#print(output[0].keys()) +#print(output[1].shape) +mask_cls_results = output["pred_logits"] +mask_pred_results = output["pred_masks"] + +target_shape = img_ndarray.shape[:2] +#print(target_shape, mask_pred_results.shape) +mask_pred_results = F.interpolate(mask_pred_results, size=(480,480), mode='bilinear', align_corners=True) + +pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results) +#output = pred_masks[0] + +#output = output.cpu() + + + +#print(output.shape) +#output_mask = output.argmax(1).data.numpy() +#output = (output > 0.5).data.cpu().numpy() +output = torch.nn.functional.interpolate(pred_masks, target_shape) +output = (output > 0.5).data.cpu().numpy() + + +# show/save results +def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 0, 0]], cscale=1, alpha=0.4): + from scipy.ndimage.morphology import binary_dilation + + colors = np.reshape(colors, (-1, 3)) + colors = np.atleast_2d(colors) * cscale + + im_overlay = image.copy() + object_ids = np.unique(mask) + + for object_id in object_ids[1:]: + # Overlay color on binary mask + foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) + binary_mask = mask == object_id + + # Compose image + im_overlay[binary_mask] = foreground[binary_mask] + + # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask + countours = binary_dilation(binary_mask) ^ binary_mask + # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask + im_overlay[countours, :] = 0 + + return im_overlay.astype(image.dtype) + + +output = output.astype(np.uint8) # (orig_h, orig_w), np.uint8 +# Overlay the mask on the image +print(img_ndarray.shape, output.shape) +visualization = overlay_davis(img_ndarray, output[0][0]) # red +visualization = Image.fromarray(visualization) +# show the visualization +#visualization.show() +# Save the visualization +visualization.save('./demo/spoon_on_the_dish.jpg') + + + + diff --git a/elia/requirements.txt b/elia/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9cb7e037ed9179ca75ea3a86497e68307401a645 --- /dev/null +++ b/elia/requirements.txt @@ -0,0 +1,14 @@ +requests +filelock +tqdm +timm +mmcv-full==1.3.12 +mmsegmentation==0.17.0 +ftfy +regex +scipy +scikit-image +pycocotools==2.0.2 +opencv-python==4.5.3.56 +tokenizers==0.8.1rc1 +h5py \ No newline at end of file diff --git a/elia/test_elia.py b/elia/test_elia.py new file mode 100644 index 0000000000000000000000000000000000000000..00f79bdc289aaa54520c5036751a94b5bcd01c1b --- /dev/null +++ b/elia/test_elia.py @@ -0,0 +1,312 @@ + +import datetime +import os +import time + +import torch +import torch.utils.data +from torch import nn + +from bert.multimodal_bert import MultiModalBert +import torchvision + +from lib import multimodal_segmentation_ppm +import transforms as T +import utils + +import numpy as np +from PIL import Image +import torch.nn.functional as F + +from modeling.MaskFormerModel import MaskFormerHead +from addict import Dict +from bert.modeling_bert import BertLMPredictionHead, BertEncoder + +def get_dataset(image_set, transform, args): + from data.dataset_refer_bert import ReferDataset + ds = ReferDataset(args, + split=image_set, + image_transforms=transform, + target_transforms=None, + eval_mode=True + ) + num_classes = 2 + return ds, num_classes + + +def evaluate(model, data_loader, device): + model.eval() + metric_logger = utils.MetricLogger(delimiter=" ") + + # evaluation variables + cum_I, cum_U = 0, 0 + eval_seg_iou_list = [.5, .6, .7, .8, .9] + seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) + seg_total = 0 + mean_IoU = [] + header = 'Test:' + + with torch.no_grad(): + for data in metric_logger.log_every(data_loader, 100, header): + image, target, sentences, attentions = data + image, target, sentences, attentions = image.to(device), target.to(device), \ + sentences.to(device), attentions.to(device) + sentences = sentences.squeeze(1) + attentions = attentions.squeeze(1) + target = target.cpu().data.numpy() + for j in range(sentences.size(-1)): + #if bert_model is not None: + # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] + # embedding = last_hidden_states.permute(0, 2, 1) + # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1)) + #else: + output = model(image, sentences[:, :, j], attentions[:, :, j]) + mask_cls_results = output["pred_logits"] + mask_pred_results = output["pred_masks"] + + target_shape = target.shape[-2:] + mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True) + + pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results) + output = pred_masks[0] + + output = output.cpu() + #print(output.shape) + #output_mask = output.argmax(1).data.numpy() + output_mask = (output > 0.5).data.numpy() + I, U = computeIoU(output_mask, target) + if U == 0: + this_iou = 0.0 + else: + this_iou = I*1.0/U + mean_IoU.append(this_iou) + cum_I += I + cum_U += U + for n_eval_iou in range(len(eval_seg_iou_list)): + eval_seg_iou = eval_seg_iou_list[n_eval_iou] + seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) + seg_total += 1 + + #del image, target, sentences, attentions, output, output_mask + #if bert_model is not None: + # del last_hidden_states, embedding + + mean_IoU = np.array(mean_IoU) + mIoU = np.mean(mean_IoU) + print('Final results:') + print('Mean IoU is %.2f\n' % (mIoU*100.)) + results_str = '' + for n_eval_iou in range(len(eval_seg_iou_list)): + results_str += ' precision@%s = %.2f\n' % \ + (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) + results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) + print(results_str) + + +def get_transform(args): + transforms = [T.Resize(args.img_size, args.img_size), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ] + + return T.Compose(transforms) + + +def computeIoU(pred_seg, gd_seg): + I = np.sum(np.logical_and(pred_seg, gd_seg)) + U = np.sum(np.logical_or(pred_seg, gd_seg)) + + return I, U + +class WrapperModel(nn.Module): + def __init__(self, image_model, language_model, classifier, args) : + super(WrapperModel, self).__init__() + self.image_model = image_model + self.language_model = language_model + self.classifier = classifier + self.lang_proj = nn.Linear(768,256) + + config = Dict({ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + #"max_position_embeddings": 16+20, + "model_type": "bert", + "num_attention_heads": 8, + "num_hidden_layers": 8, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "use_cache": True, + "vocab_size": 30522 + }) + self.mlm_transformer = BertEncoder(config) + + self.lang_proj = nn.Linear(768,256) + self.mlm_vis_proj = nn.Conv2d(1024,512,1) + self.mlm_lang_proj = nn.Linear(768,512) + #print(vis_proj) + self.mlm_head = BertLMPredictionHead(config) + + assert args.img_size % 4 == 0 + num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2 + print(num_img_tokens) + self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512) + self.mlm_modal_embeds = nn.Embedding(3, 512) + + self.mlm_mask_embed = nn.Embedding(1, 512) + self.mlm_pos_mlp = nn.Sequential( + nn.Linear(2, 512), + nn.LayerNorm(512), + nn.Linear(512,512), + nn.GELU() + ) + + def _get_binary_mask(self, target): + # 返回每类的binary mask + y, x = target.size() + target_onehot = torch.zeros(self.num_classes + 1, y, x) + target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1) + return target_onehot[1:] + + def semantic_inference(self, mask_cls, mask_pred): + mask_cls = F.softmax(mask_cls, dim=1)[...,1:] + mask_pred = mask_pred.sigmoid() + semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) + return semseg + + def forward(self, image, sentences, attentions): + input_shape = image.shape[-2:] + l_mask = attentions.unsqueeze(dim=-1) + + i0, Wh, Ww = self.image_model.forward_stem(image) + l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions) + + i1 = self.image_model.forward_stage1(i0, Wh, Ww) + l1 = self.language_model.forward_stage1(l0, extended_attention_mask) + i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask) + l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) + i1 = i1_temp + + i2 = self.image_model.forward_stage2(i1, Wh, Ww) + l2 = self.language_model.forward_stage2(l1, extended_attention_mask) + i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask) + l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) + i2 = i2_temp + + i3 = self.image_model.forward_stage3(i2, Wh, Ww) + l3 = self.language_model.forward_stage3(l2, extended_attention_mask) + i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask) + l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) + i3 = i3_temp + + i4 = self.image_model.forward_stage4(i3, Wh, Ww) + l4 = self.language_model.forward_stage4(l3, extended_attention_mask) + i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask) + l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) + i4 = i4_temp + + #i1_residual, i2_residual, i3_residual, i4_residual = features + #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual) + #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) + outputs = {} + outputs['s1'] = i1_residual + outputs['s2'] = i2_residual + outputs['s3'] = i3_residual + outputs['s4'] = i4_residual + + predictions, _ = self.classifier(outputs) + return predictions + +def main(args): +#def main(local_rank, args): + + #device = torch.device(args.device) + device = 'cuda' + dataset_test, _ = get_dataset(args.split, get_transform(args=args), args) + test_sampler = torch.utils.data.SequentialSampler(dataset_test) + data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, + sampler=test_sampler, num_workers=args.workers) + print(args.model) + single_model = multimodal_segmentation_ppm.__dict__[args.model](pretrained='',args=args) + #single_model = MultiModalFocal(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], focal_windows=[9,9,9,9], drop_path_rate=0.3) + #single_model.init_weights('./focalnet_base_lrf.pth') + checkpoint = torch.load(args.resume, map_location='cpu') + #single_model.load_state_dict(checkpoint['model']) + #model = single_model.to(device) + + if args.model != 'lavt_one': + model_class = MultiModalBert + #single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=128) + single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=single_model.backbone.embed_dim) + # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines + if args.ddp_trained_weights: + single_bert_model.pooler = None + #single_bert_model.load_state_dict(checkpoint['bert_model']) + #bert_model = single_bert_model.to(device) + else: + bert_model = None + + #model = WrapperModel(single_model.backbone, single_bert_model, single_model.classifier) + #model.load_state_dict(checkpoint['model']) + #model.to(device) + input_shape = dict() + input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) + input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) + input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) + input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) + + + + cfg = Dict() + cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 + cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 + cfg.MODEL.MASK_FORMER.NHEADS = 8 + cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4 + cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 + cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 + cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] + + cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 + cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 + cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1 + cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 + cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10 + cfg.MODEL.MASK_FORMER.PRE_NORM = False + + + maskformer_head = MaskFormerHead(cfg, input_shape) + #maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head) + #maskformer_head.cuda() + #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False) + #single_head = maskformer_head.module + #print(single_head) + + model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head, args) + model.load_state_dict(checkpoint['model']) + model.to(device) + #model.cuda() + #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) + #single_model = model.module + #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) + #single_model = model.module + evaluate(model, data_loader_test, device=device) + + +if __name__ == "__main__": + from args import get_parser + parser = get_parser() + args = parser.parse_args() + print('Image size: {}'.format(str(args.img_size))) + print(args) + main(args) + #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count()) diff --git a/elia/test_lavt.py b/elia/test_lavt.py new file mode 100644 index 0000000000000000000000000000000000000000..e85c8ede2fdf46d6ade4b5d5fbc70d12628fa6d4 --- /dev/null +++ b/elia/test_lavt.py @@ -0,0 +1,139 @@ +import datetime +import os +import time + +import torch +import torch.utils.data +from torch import nn + +from bert.modeling_bert import BertModel +import torchvision + +from lib import segmentation +import transforms as T +import utils + +import numpy as np +from PIL import Image +import torch.nn.functional as F + + +def get_dataset(image_set, transform, args): + from data.dataset_refer_bert import ReferDataset + ds = ReferDataset(args, + split=image_set, + image_transforms=transform, + target_transforms=None, + eval_mode=True + ) + num_classes = 2 + return ds, num_classes + + +def evaluate(model, data_loader, bert_model, device): + model.eval() + metric_logger = utils.MetricLogger(delimiter=" ") + + # evaluation variables + cum_I, cum_U = 0, 0 + eval_seg_iou_list = [.5, .6, .7, .8, .9] + seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) + seg_total = 0 + mean_IoU = [] + header = 'Test:' + + with torch.no_grad(): + for data in metric_logger.log_every(data_loader, 100, header): + image, target, sentences, attentions = data + image, target, sentences, attentions = image.to(device), target.to(device), \ + sentences.to(device), attentions.to(device) + sentences = sentences.squeeze(1) + attentions = attentions.squeeze(1) + target = target.cpu().data.numpy() + for j in range(sentences.size(-1)): + if bert_model is not None: + last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] + embedding = last_hidden_states.permute(0, 2, 1) + output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1)) + else: + output = model(image, sentences[:, :, j], l_mask=attentions[:, :, j]) + + output = output.cpu() + output_mask = output.argmax(1).data.numpy() + I, U = computeIoU(output_mask, target) + if U == 0: + this_iou = 0.0 + else: + this_iou = I*1.0/U + mean_IoU.append(this_iou) + cum_I += I + cum_U += U + for n_eval_iou in range(len(eval_seg_iou_list)): + eval_seg_iou = eval_seg_iou_list[n_eval_iou] + seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) + seg_total += 1 + + del image, target, sentences, attentions, output, output_mask + if bert_model is not None: + del last_hidden_states, embedding + + mean_IoU = np.array(mean_IoU) + mIoU = np.mean(mean_IoU) + print('Final results:') + print('Mean IoU is %.2f\n' % (mIoU*100.)) + results_str = '' + for n_eval_iou in range(len(eval_seg_iou_list)): + results_str += ' precision@%s = %.2f\n' % \ + (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) + results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) + print(results_str) + + +def get_transform(args): + transforms = [T.Resize(args.img_size, args.img_size), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ] + + return T.Compose(transforms) + + +def computeIoU(pred_seg, gd_seg): + I = np.sum(np.logical_and(pred_seg, gd_seg)) + U = np.sum(np.logical_or(pred_seg, gd_seg)) + + return I, U + + +def main(args): + device = torch.device(args.device) + dataset_test, _ = get_dataset(args.split, get_transform(args=args), args) + test_sampler = torch.utils.data.SequentialSampler(dataset_test) + data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, + sampler=test_sampler, num_workers=args.workers) + print(args.model) + single_model = segmentation.__dict__[args.model](pretrained='',args=args) + checkpoint = torch.load(args.resume, map_location='cpu') + single_model.load_state_dict(checkpoint['model']) + model = single_model.to(device) + + if args.model != 'lavt_one': + model_class = BertModel + single_bert_model = model_class.from_pretrained(args.ck_bert) + # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines + if args.ddp_trained_weights: + single_bert_model.pooler = None + single_bert_model.load_state_dict(checkpoint['bert_model']) + bert_model = single_bert_model.to(device) + else: + bert_model = None + + evaluate(model, data_loader_test, bert_model, device=device) + + +if __name__ == "__main__": + from args import get_parser + parser = get_parser() + args = parser.parse_args() + print('Image size: {}'.format(str(args.img_size))) + main(args) diff --git a/elia/train_elia.py b/elia/train_elia.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce163e6497f2acb88dd1bc3219cf31997ac5495 --- /dev/null +++ b/elia/train_elia.py @@ -0,0 +1,812 @@ +import datetime +import os +import time + +import torch +import torch.utils.data +from torch import nn + +from functools import reduce +import operator +from bert.multimodal_bert import MultiModalBert + +import torchvision +from lib import multimodal_segmentation_ppm + +import transforms as T +import utils +import numpy as np + +import torch.nn.functional as F + +import gc +from collections import OrderedDict + +import torch.backends.cudnn as cudnn + +#from ffrecord.torch import DataLoader,Dataset +from modeling.MaskFormerModel import MaskFormerHead +from addict import Dict + +from mask2former_utils.criterion import SetCriterion, Criterion +from mask2former_utils.matcher import HungarianMatcher +from bert.modeling_bert import BertLMPredictionHead, BertEncoder + + + + +class WrapperModel(nn.Module): + def __init__(self, image_model, language_model, classifier, args) : + super(WrapperModel, self).__init__() + self.image_model = image_model + self.language_model = language_model + self.classifier = classifier + + self.lang_proj = nn.Linear(768,256) + + config = Dict({ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + #"max_position_embeddings": 16+20, + "model_type": "bert", + "num_attention_heads": 8, + "num_hidden_layers": 8, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "use_cache": True, + "vocab_size": 30522 + }) + self.mlm_transformer = BertEncoder(config) + + self.lang_proj = nn.Linear(768,256) + self.mlm_vis_proj = nn.Conv2d(1024,512,1) + self.mlm_lang_proj = nn.Linear(768,512) + #print(vis_proj) + self.mlm_head = BertLMPredictionHead(config) + + assert args.img_size % 4 == 0 + num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2 + print(num_img_tokens) + self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512) + self.mlm_modal_embeds = nn.Embedding(3, 512) + + self.mlm_mask_embed = nn.Embedding(1, 512) + self.mlm_pos_mlp = nn.Sequential( + nn.Linear(2, 512), + nn.LayerNorm(512), + nn.Linear(512,512), + nn.GELU() + ) + + def _get_binary_mask(self, target): + # 返回每类的binary mask + y, x = target.size() + target_onehot = torch.zeros(self.num_classes + 1, y, x) + target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1) + return target_onehot[1:] + + def semantic_inference(self, mask_cls, mask_pred): + mask_cls = F.softmax(mask_cls, dim=1)[...,1:] + mask_pred = mask_pred.sigmoid() + semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) + return semseg + + def forward(self, image, sentences, attentions, mlm_targets, mlm_masks, position): + input_shape = image.shape[-2:] + l_mask = attentions.unsqueeze(dim=-1) + + i0, Wh, Ww = self.image_model.forward_stem(image) + l0, extended_attention_mask = self.language_model.forward_stem(mlm_targets.squeeze(1), attentions) + + i1 = self.image_model.forward_stage1(i0, Wh, Ww) + l1 = self.language_model.forward_stage1(l0, extended_attention_mask) + i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask) + l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) + i1 = i1_temp + + i2 = self.image_model.forward_stage2(i1, Wh, Ww) + l2 = self.language_model.forward_stage2(l1, extended_attention_mask) + i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask) + l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) + i2 = i2_temp + + i3 = self.image_model.forward_stage3(i2, Wh, Ww) + l3 = self.language_model.forward_stage3(l2, extended_attention_mask) + i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask) + l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) + i3 = i3_temp + + i4 = self.image_model.forward_stage4(i3, Wh, Ww) + l4 = self.language_model.forward_stage4(l3, extended_attention_mask) + i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask) + l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) + i4 = i4_temp + + #i1_residual, i2_residual, i3_residual, i4_residual = features + #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual) + #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) + outputs = {} + outputs['s1'] = i1_residual + outputs['s2'] = i2_residual + outputs['s3'] = i3_residual + outputs['s4'] = i4_residual + + predictions, mask_features = self.classifier(outputs) + + #print(target_reshape.shape) + #tmp = np.argwhere(target_reshape[:, 0].detach().cpu().numpy()).reshape(-1, target_reshape.shape[2]*target_reshape[3], 3) + #centroid = tmp.mean(1) + #print(centroid) + #centroid_x, centroid_y = int(centroid[1]), int(centroid[0]) + #last_hidden_states = brt_model(sentences, attention_mask=attentions)[0] # (6, 10, 768) + #embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy + + + l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions) + l1 = self.language_model.forward_stage1(l0, extended_attention_mask) + l2 = self.language_model.forward_stage2(l1, extended_attention_mask) + l3 = self.language_model.forward_stage3(l2, extended_attention_mask) + l4 = self.language_model.forward_stage4(l3, extended_attention_mask) + + + mlp_embed = self.mlm_pos_mlp(position) + #print(centroid_x, centroid_y) + + mlm_targets = torch.where( + mlm_masks > 0, + mlm_targets, + torch.ones_like(mlm_targets) * (-1) + ) + + #print(x_c4[target_reshape[:, [0]].bool()].shape) + vis_features = self.mlm_vis_proj(i4_residual).flatten(2).permute(0,2,1) + #print(l4.shape) + lang_features = self.mlm_lang_proj(l4) + + #print(lang_features.shape, vis_features.shape, mlp_embed.shape) + mm_features = torch.cat([lang_features, vis_features, mlp_embed.unsqueeze(1)], dim=1) + #print(mm_features.shape) + + #print(mlm_modal_embeds.weight.shape) + modal_embeds = torch.cat([self.mlm_modal_embeds.weight[0].unsqueeze(0).repeat(1, lang_features.shape[1], 1), self.mlm_modal_embeds.weight[1].unsqueeze(0).repeat(1, vis_features.shape[1], 1), self.mlm_modal_embeds.weight[2].unsqueeze(0).repeat(1,1,1)], dim=1) + #print(modal_embeds.shape) + + #print(mlm_transformer) + + + #print(attentions.shape) + mixed_attention_mask = torch.cat([attentions.unsqueeze(-1), torch.ones(attentions.shape[0], vis_features.shape[1]+1, 1).to(attentions.device)], dim=1) + mixed_attention_mask = mixed_attention_mask.permute(0,2,1).unsqueeze(1) + mixed_attention_mask = (1-mixed_attention_mask)* -10000.0 + head_mask = [None] * 8 + #extended_attention_mask = get_extended_attention_mask(mixed_attention_mask, mm_features.shape, mm_features.device) + #print(mm_features.shape, mixed_attention_mask.shape, head_mask) + #print(mm_features.shape, self.mlm_pos_embeds.weight.shape, self.mlm_modal_embeds.weight.shape) + head_features = self.mlm_transformer(mm_features + self.mlm_pos_embeds.weight.unsqueeze(0) + modal_embeds, mixed_attention_mask, head_mask)[0] + #print(head_features.shape, attentions.shape) + head_features = head_features[:, :20][attentions.bool()] + + #print(embedding.shape, mask_features.shape) + mlm_predictions = self.mlm_head(head_features) + mlm_predictions = mlm_predictions.reshape(-1, self.language_model.config.vocab_size) + mlm_targets = mlm_targets.squeeze(1)[attentions.bool()] + #mlm_loss = mlm_weight * nn.CrossEntropyLoss(ignore_index=-1)(mlm_predictions, mlm_targets) + #loss += mlm_loss + #mlm_loss_print=mlm_loss.item() + + return predictions, mask_features, self.lang_proj((l4_residual * l_mask).sum(1)/l_mask.sum(1)), mlm_predictions, mlm_targets +# IoU calculation for validation +def IoU(pred, gt): + #pred = pred.argmax(1) + pred = (pred > 0.5) + + intersection = torch.sum(torch.mul(pred, gt)) + union = torch.sum(torch.add(pred, gt)) - intersection + + if intersection == 0 or union == 0: + iou = 0 + else: + iou = float(intersection) / float(union) + + return iou, intersection, union + +def get_dataset(image_set, transform, args): + from data.dataset_refer_bert_mlm import ReferDataset + ds = ReferDataset(args, + split=image_set, + image_transforms=transform, + target_transforms=None + ) + num_classes = 2 + + return ds, num_classes + + + +def get_transform(args): + transforms = [T.Resize(args.img_size, args.img_size), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ] + + return T.Compose(transforms) + + +#def criterion(input, target): +# weight = torch.FloatTensor([0.9, 1.1]).cuda() +# return nn.functional.cross_entropy(input, target, weight=weight) + + +def evaluate(model, data_loader): + model.eval() + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Test:' + total_its = 0 + acc_ious = 0 + + # evaluation variables + cum_I, cum_U = 0, 0 + eval_seg_iou_list = [.5, .6, .7, .8, .9] + seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) + seg_total = 0 + mean_IoU = [] + + with torch.no_grad(): + for data in metric_logger.log_every(data_loader, 100, header): + total_its += 1 + #image, target, sentences, attentions = data + #image, target, sentences, attentions = image.cuda(non_blocking=True),\ + # target.cuda(non_blocking=True),\ + # sentences.cuda(non_blocking=True),\ + # attentions.cuda(non_blocking=True) + + image, target, sentences, attentions, mlm_targets, mlm_masks, position = data + image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\ + target.cuda(non_blocking=True),\ + sentences.cuda(non_blocking=True),\ + attentions.cuda(non_blocking=True), \ + mlm_targets.cuda(non_blocking=True), \ + mlm_masks.cuda(non_blocking=True), \ + position.cuda(non_blocking=True) + + sentences = sentences.squeeze(1) + attentions = attentions.squeeze(1) + #print("sentences", sentences.shape) + #print("attentions", attentions.shape) + + + output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position) + mask_cls_results = output["pred_logits"] + mask_pred_results = output["pred_masks"] + + target_shape = target.shape[-2:] + mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True) + + pred_masks = model.module.semantic_inference(mask_cls_results, mask_pred_results) + output = pred_masks[0] + + + iou, I, U = IoU(output, target) + acc_ious += iou + mean_IoU.append(iou) + cum_I += I + cum_U += U + for n_eval_iou in range(len(eval_seg_iou_list)): + eval_seg_iou = eval_seg_iou_list[n_eval_iou] + seg_correct[n_eval_iou] += (iou >= eval_seg_iou) + seg_total += 1 + iou = acc_ious / total_its + + mean_IoU = np.array(mean_IoU) + mIoU = np.mean(mean_IoU) + print('Final results:') + print('Mean IoU is %.2f\n' % (mIoU * 100.)) + results_str = '' + for n_eval_iou in range(len(eval_seg_iou_list)): + results_str += ' precision@%s = %.2f\n' % \ + (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) + results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) + print(results_str) + + return 100 * iou, 100 * cum_I / cum_U + + +def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq, + iterations, args): + model.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) + header = 'Epoch: [{}]'.format(epoch) + train_loss = 0 + total_its = 0 + + for data in metric_logger.log_every(data_loader, print_freq, header): + total_its += 1 + #image, target, sentences, attentions = data + #image, target, sentences, attentions = image.cuda(non_blocking=True),\ + # target.cuda(non_blocking=True),\ + # sentences.cuda(non_blocking=True),\ + # attentions.cuda(non_blocking=True) + image, target, sentences, attentions, mlm_targets, mlm_masks, position = data + image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\ + target.cuda(non_blocking=True),\ + sentences.cuda(non_blocking=True),\ + attentions.cuda(non_blocking=True), \ + mlm_targets.cuda(non_blocking=True), \ + mlm_masks.cuda(non_blocking=True), \ + position.cuda(non_blocking=True) + + sentences = sentences.squeeze(1) + attentions = attentions.squeeze(1) + #l_mask = attentions.unsqueeze(dim=-1) + + output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position) + #print(avg_lang_feature.shape) + avg_lang_feature = torch.nn.functional.normalize(avg_lang_feature, dim=1) + #print("----") + #print(output.shape) + #print(mask_features.shape) + #print(avg_lang_feature.shape) + #print( mlm_predictions.shape) + #print(mlm_targets.shape) + #print("----") + + target_shape = target.shape[-2:] + output['pred_masks'] = F.interpolate(output['pred_masks'], size=target_shape, mode='bilinear', align_corners=True) + + if "aux_outputs" in output: + for i, aux_outputs in enumerate(output["aux_outputs"]): + output['aux_outputs'][i]['pred_masks'] = F.interpolate(output['aux_outputs'][i]['pred_masks'], size=target_shape, mode='bilinear', align_corners=True) + + # pixel region + B, C, H, W = mask_features.shape + + target_reshape = F.interpolate(target.unsqueeze(1).float(), size=mask_features.shape[-2:], mode='nearest').long() + + target_reshape = target_reshape.repeat(1, mask_features.shape[1], 1, 1) + #print(avg_pos_feature.shape, avg_lang_feature.shape, avg_neg_feature.shape) + + #cl_loss = 0.0 + plic_lang_loss = 0.0 + plic_pos_loss = 0.0 + plic_neg_loss = 0.0 + for i in range(B): + if ((target_reshape[[i]] == 0).sum() != 0 and (target_reshape[[i]] == 1).sum() != 0): + + avg_pos_feature = (mask_features[[i]] * target_reshape[[i]]).sum(-1).sum(-1) / target_reshape[[i]].sum(-1).sum(-1) + avg_neg_feature = (mask_features[[i]] * (1.0-target_reshape[[i]])).sum(-1).sum(-1) / (1.0-target_reshape[[i]]).sum(-1).sum(-1) + avg_pos_feature = torch.nn.functional.normalize(avg_pos_feature, dim=1) + avg_neg_feature = torch.nn.functional.normalize(avg_neg_feature, dim=1) + + #avg lang feature no normalize??? + + + + pos_features = mask_features[[i]][target_reshape[[i]]==1].view(1, C, -1) + neg_features = mask_features[[i]][target_reshape[[i]]==0].view(1, C, -1) + #inter_neg_features = mask_features[[B-i-1]][target_reshape[[B-i-1]]==1].view(1, C, -1) + #neg_features = torch.cat([intra_neg_features, inter_neg_features], dim=2) + + pos_features = torch.nn.functional.normalize(pos_features, dim=1) + neg_features = torch.nn.functional.normalize(neg_features, dim=1) + + #print(avg_lang_feature.shape, avg_lang_feature[[i]].shape, pos_features.shape) + lang_pos_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], pos_features) + lang_neg_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], neg_features) + + lang_matrix = torch.cat([lang_pos_scores.unsqueeze(-1), lang_neg_scores.unsqueeze(1).repeat(1, lang_pos_scores.shape[1], 1)], dim=2) + lang_labels = torch.zeros(lang_matrix.shape[1], dtype=torch.long).cuda() + lang_labels = lang_labels.unsqueeze(0).repeat(lang_matrix.shape[0], 1) + + lang_score = torch.softmax(lang_matrix, -1) + lang_score = 1.0 - lang_score[:, :, 0] + + pos_pos_scores = torch.einsum("bq,bqn->bn", avg_pos_feature, pos_features) + pos_neg_scores = torch.einsum("bqn,bqm->bnm", pos_features, neg_features) + + pos_matrix = torch.cat([pos_pos_scores.unsqueeze(-1), pos_neg_scores], dim=2) + pos_labels = torch.zeros(pos_matrix.shape[1], dtype=torch.long).cuda() + pos_labels = pos_labels.unsqueeze(0).repeat(pos_matrix.shape[0], 1) + + pos_score = torch.softmax(pos_matrix, -1) + pos_score = 1.0 - pos_score[:, :, 0] + #pos_weight = pos_weight.view(-1, pos_weight.shape[-1]) + + #intra_neg_features = torch.nn.functional.normalize(intra_neg_features, dim=1) + neg_neg_scores = torch.einsum("bq,bqn->bn", avg_neg_feature, neg_features) + neg_pos_scores = torch.einsum("bqn,bqm->bnm", neg_features, pos_features) + + neg_matrix = torch.cat([neg_neg_scores.unsqueeze(-1), neg_pos_scores], dim=2) + neg_labels = torch.zeros(neg_matrix.shape[1], dtype=torch.long).cuda() + neg_labels = neg_labels.unsqueeze(0).repeat(neg_matrix.shape[0], 1) + + neg_score = torch.softmax(neg_matrix, -1) + neg_score = 1.0 - neg_score[:, :, 0] + #neg_weight = neg_weight.view(-1, neg_weight.shape[-1]) + + pos_loss = (torch.pow(pos_score, args.plic_pos_alpha) * torch.nn.functional.cross_entropy(pos_matrix.view(-1, pos_matrix.shape[-1])/args.plic_pos_temp, pos_labels.view(-1), reduction='none')).mean() + neg_loss = (torch.pow(neg_score, args.plic_neg_alpha) * torch.nn.functional.cross_entropy(neg_matrix.view(-1, neg_matrix.shape[-1])/args.plic_neg_temp, neg_labels.view(-1), reduction='none')).mean() + + lang_loss = (torch.pow(lang_score, args.plic_lang_alpha) * torch.nn.functional.cross_entropy(lang_matrix.view(-1, lang_matrix.shape[-1])/args.plic_lang_temp, lang_labels.view(-1), reduction='none')).mean() + + plic_pos_loss += pos_loss + plic_neg_loss += neg_loss + plic_lang_loss += lang_loss + #cl_loss += 0.5 * (torch.nn.functional.cross_entropy(pos_matrix.view(-1, pos_matrix.shape[-1])/cl_temp, pos_labels.view(-1))+torch.nn.functional.cross_entropy(neg_matrix.view(-1, neg_matrix.shape[-1])/cl_temp, neg_labels.view(-1))) + plic_pos_loss = (args.plic_pos_weight * plic_pos_loss) / B + plic_neg_loss = (args.plic_neg_weight * plic_neg_loss) / B + plic_lang_loss = (args.plic_lang_weight * plic_lang_loss) / B + plic_loss = plic_pos_loss + plic_neg_loss +plic_lang_loss + + + #print(output.device, target.device) + losses = criterion(output, target) + weight_dict = criterion.weight_dict + + loss_ce = 0.0 + loss_dice = 0.0 + loss_mask = 0.0 + for k in list(losses.keys()): + if k in weight_dict: + losses[k] *= criterion.weight_dict[k] + if '_ce' in k: + loss_ce += losses[k] + elif '_dice' in k: + loss_dice += losses[k] + else: + loss_mask += losses[k] + else: + # remove this loss if not specified in `weight_dict` + losses.pop(k) + #loss = 0.3 * loss_ce + 0.3 * loss_dice + 0.4 * loss_mask + smlm_loss = args.smlm_weight * nn.CrossEntropyLoss(ignore_index=-1)(mlm_predictions, mlm_targets) + loss = loss_ce + loss_dice + loss_mask + plic_loss + smlm_loss + + + #loss = criterion(output.squeeze(1), target.float()) + optimizer.zero_grad() # set_to_none=True is only available in pytorch 1.6+ + loss.backward() + optimizer.step() + lr_scheduler.step() + + torch.cuda.synchronize() + train_loss += loss.item() + iterations += 1 + #metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) + metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"], loss_ce=loss_ce.item(), loss_dice=loss_dice.item(), loss_mask=loss_mask.item(), plic_loss=plic_loss.item(), plic_lang_loss=plic_lang_loss.item(), plic_pos_loss=plic_pos_loss.item(), plic_neg_loss=plic_neg_loss.item(), smlm_loss=smlm_loss.item()) + #metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"], loss_ce=loss_ce.item(), loss_dice=loss_dice.item(), loss_mask=loss_mask.item(), cl_loss=cl_loss.item(), cl_lang_loss=cl_lang_loss_print, cl_pos_loss=cl_pos_loss_print, cl_neg_loss=cl_neg_loss_print) + + #del image, target, sentences, attentions, loss, output, data + #if bert_model is not None: + # del last_hidden_states, embedding + + #gc.collect() + #torch.cuda.empty_cache() + #del loss + #del cl_loss + #del cl_lang_loss + #del loss_ce + #del loss_dice + #del loss_mask + torch.cuda.synchronize() + + +def main(args): +#def main(local_rank, args): + #ip = os.environ['MASTER_IP'] + #port = os.environ['MASTER_PORT'] + #hosts = int(os.environ['WORLD_SIZE']) # 机器个数 1 + #rank = int(os.environ['RANK']) # 当前机器编号 + #gpus = torch.cuda.device_count() # 每台机器的GPU个数 + #print(local_rank, rank, gpus) #3 0 8 + #dist.init_process_group(backend='nccl', init_method=f'tcp://{ip}:{port}', world_size=hosts*gpus, rank=rank*gpus+local_rank) + #torch.cuda.set_device(local_rank) + #dist.barrier() + + ##utils.init_distributed_mode(args) + #args.distributed=True + #args.gpu = local_rank + #print(args) + ##misc.init_distributed_mode(args) + + #print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + #print("{}".format(args).replace(', ', ',\n')) + + #device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + print('seed', seed) + torch.manual_seed(seed) + np.random.seed(seed) + + #cudnn.benchmark = True + + dataset, num_classes = get_dataset("train", + get_transform(args=args), + args=args) + dataset_test, _ = get_dataset("val", + get_transform(args=args), + args=args) + + # batch sampler + print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.") + num_tasks = utils.get_world_size() + global_rank = utils.get_rank() + #num_tasks = hosts*gpus + #global_rank = rank*gpus+local_rank + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, + shuffle=True) + test_sampler = torch.utils.data.SequentialSampler(dataset_test) + + # data loader + data_loader = torch.utils.data.DataLoader( + dataset, batch_size=args.batch_size, + sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True) + + data_loader_test = torch.utils.data.DataLoader( + dataset_test, batch_size=1, sampler=test_sampler, pin_memory=True, num_workers=args.workers) + + # model initialization + print(args.model) + model = multimodal_segmentation_ppm.__dict__[args.model](pretrained=args.pretrained_swin_weights, + args=args) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + #model.cuda() + #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True) + #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=False) + #single_model = model.module + + if args.model != 'lavt_one': + model_class = MultiModalBert + bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=model.backbone.embed_dim) + bert_model.pooler = None # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel + #bert_model.cuda() + bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model) + #bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[local_rank]) + #single_bert_model = bert_model.module + else: + bert_model = None + single_bert_model = None + + input_shape = dict() + input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) + input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) + input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) + input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) + + + + cfg = Dict() + cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 + cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 + cfg.MODEL.MASK_FORMER.NHEADS = 8 + cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = args.transformer_enc_layers + cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 + cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 + cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] + + cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 + cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 + cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = args.num_object_queries + cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = args.dim_feedforward + cfg.MODEL.MASK_FORMER.DEC_LAYERS = args.dec_layers + cfg.MODEL.MASK_FORMER.PRE_NORM = False + + cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True + cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = args.no_object_weight + cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = args.class_weight + cfg.MODEL.MASK_FORMER.DICE_WEIGHT = args.dice_weight + cfg.MODEL.MASK_FORMER.MASK_WEIGHT = args.mask_weight + + cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = args.train_num_points + cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0 + cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75 + print(cfg) + + maskformer_head = MaskFormerHead(cfg, input_shape) + maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head) + #maskformer_head.cuda() + #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False) + #single_head = maskformer_head.module + #print(single_head) + + model = WrapperModel(model.backbone, bert_model, maskformer_head, args) + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) + single_model = model.module + + # mask2former loss + deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION + no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT + + # loss weights + class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT + dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT + mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT + # self.criterion = Criterion(self.num_classes) + + # building criterion + + matcher = HungarianMatcher( + cost_class=class_weight, + cost_mask=mask_weight, + cost_dice=dice_weight, + num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, + ) + + weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight} + if deep_supervision: + dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS + aux_weight_dict = {} + for i in range(dec_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ["labels", "masks"] + criterion = SetCriterion( + cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, + matcher=matcher, + weight_dict=weight_dict, + eos_coef=no_object_weight, + losses=losses, + num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, + oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO, + importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO, + device='cuda' + ) + + if args.resume == "auto": + last_ckpt = "" + for e in range(args.epochs): + ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth') + if os.path.exists(ckpt_path): + last_ckpt = ckpt_path + args.resume = last_ckpt + + # resume training + if args.resume: + checkpoint = torch.load(args.resume, map_location='cpu') + single_model.load_state_dict(checkpoint['model']) + #if args.model != 'lavt_one': + # single_bert_model.load_state_dict(checkpoint['bert_model']) + + # parameters to optimize + backbone_no_decay = list() + backbone_decay = list() + for name, m in single_model.image_model.named_parameters(): + if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name: + backbone_no_decay.append(m) + else: + backbone_decay.append(m) + + params_to_optimize = [ + {'params': backbone_no_decay, 'weight_decay': 0.0}, + {'params': backbone_decay}, + {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, + # the following are the parameters of bert + {"params": reduce(operator.concat, + [[p for p in single_model.language_model.encoder.layer[i].parameters() + if p.requires_grad] for i in range(10)])}, + {"params": single_model.language_model.pwams.parameters()}, + {"params": single_model.language_model.res_gates.parameters()}, + {"params": single_model.language_model.norms.parameters()}, + {"params": single_model.lang_proj.parameters()}, + #{"params": single_model.language_model.parameters()}, + {'params': single_model.mlm_head.parameters()}, + {'params': single_model.mlm_vis_proj.parameters()}, + {'params': single_model.mlm_lang_proj.parameters()}, + {'params': single_model.mlm_transformer.parameters()}, + {'params': single_model.mlm_pos_embeds.parameters()}, + {'params': single_model.mlm_modal_embeds.parameters()}, + {'params': single_model.mlm_mask_embed.parameters()}, + {'params': single_model.mlm_pos_mlp.parameters()}, + #{'params': mlm_head.parameters(), 'weight_decay': 0.0}, + #{'params': mlm_vis_proj.parameters(), 'weight_decay': 0.0}, + #{'params': mlm_lang_proj.parameters(), 'weight_decay': 0.0}, + #{'params': mlm_transformer.parameters(), 'weight_decay': 0.0}, + #{'params': mlm_pos_embeds.parameters(), 'weight_decay': 0.0}, + #{'params': mlm_modal_embeds.parameters(), 'weight_decay': 0.0}, + #{'params': mlm_mask_embed.parameters(), 'weight_decay': 0.0}, + #{'params': mlm_pos_mlp.parameters(), 'weight_decay': 0.0}, + ] + + + # optimizer + optimizer = torch.optim.AdamW(params_to_optimize, + lr=args.lr, + weight_decay=args.weight_decay, + amsgrad=args.amsgrad + ) + + # learning rate scheduler + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, + lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) + + # housekeeping + start_time = time.time() + iterations = 0 + best_oIoU = -0.1 + + # resume training (optimizer, lr scheduler, and the epoch) + if args.resume: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + resume_epoch = checkpoint['epoch'] + else: + resume_epoch = -999 + + # training loops + for epoch in range(max(0, resume_epoch+1), args.epochs): + data_loader.sampler.set_epoch(epoch) + train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq, + iterations, args) + iou, overallIoU = evaluate(model, data_loader_test) + + print('Average object IoU {}'.format(iou)) + print('Overall IoU {}'.format(overallIoU)) + + + dict_to_save = {'model': single_model.state_dict(), + 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, + 'lr_scheduler': lr_scheduler.state_dict()} + + checkpoint_path = os.path.join(args.output_dir, 'checkpoint-{}.pth'.format(epoch)) + utils.save_on_master(dict_to_save, str(checkpoint_path) + '_TEMP') + if utils.is_main_process(): + os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path)) + + if utils.is_main_process(): + ckpt_paths = [] + for e in range(args.epochs): + ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth') + print(ckpt_path) + if os.path.exists(ckpt_path): + ckpt_paths.append(ckpt_path) + print(ckpt_paths) + for ckpt_path in ckpt_paths[:-args.max_ckpt]: + os.remove(ckpt_path) + print("remove {:s}".format(ckpt_path)) + + + save_checkpoint = (best_oIoU < overallIoU) + if save_checkpoint: + print('Better epoch: {}\n'.format(epoch)) + dict_to_save = {'model': single_model.state_dict(), + 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, + 'lr_scheduler': lr_scheduler.state_dict()} + + checkpoint_path = os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id)) + utils.save_on_master(dict_to_save, checkpoint_path + '_TEMP') + if utils.is_main_process(): + os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path)) + best_oIoU = overallIoU + torch.cuda.empty_cache() + + # summarize + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == "__main__": + from args import get_parser + parser = get_parser() + args = parser.parse_args() + os.makedirs(args.output_dir, exist_ok=True) + # set up distributed learning + utils.init_distributed_mode(args) + print('Image size: {}'.format(str(args.img_size))) + main(args) + #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count()) diff --git a/elia/train_lavt.py b/elia/train_lavt.py new file mode 100644 index 0000000000000000000000000000000000000000..3a6d78314f04b84ce618dcc9e33b1044a615fc1c --- /dev/null +++ b/elia/train_lavt.py @@ -0,0 +1,444 @@ +import haienv +haienv.set_env('lavt2') +import torch.multiprocessing as mp +import torch.distributed as dist + +import datetime +import os +import time + +import torch +import torch.utils.data +from torch import nn + +from functools import reduce +import operator +from bert.modeling_bert import BertModel + +import torchvision +from lib import segmentation + +import transforms as T +import utils +import numpy as np + +import torch.nn.functional as F + +import gc +from collections import OrderedDict + +import torch.backends.cudnn as cudnn + +from ffrecord.torch import DataLoader,Dataset +def get_dataset(image_set, transform, args): + from data.dataset_refer_bert import ReferDataset + ds = ReferDataset(args, + split=image_set, + image_transforms=transform, + target_transforms=None + ) + num_classes = 2 + + return ds, num_classes + + +# IoU calculation for validation +def IoU(pred, gt): + pred = pred.argmax(1) + + intersection = torch.sum(torch.mul(pred, gt)) + union = torch.sum(torch.add(pred, gt)) - intersection + + if intersection == 0 or union == 0: + iou = 0 + else: + iou = float(intersection) / float(union) + + return iou, intersection, union + + +def get_transform(args): + transforms = [T.Resize(args.img_size, args.img_size), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ] + + return T.Compose(transforms) + + +def criterion(input, target): + weight = torch.FloatTensor([0.9, 1.1]).cuda() + return nn.functional.cross_entropy(input, target, weight=weight) + + +def evaluate(model, data_loader, bert_model): + model.eval() + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Test:' + total_its = 0 + acc_ious = 0 + + # evaluation variables + cum_I, cum_U = 0, 0 + eval_seg_iou_list = [.5, .6, .7, .8, .9] + seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) + seg_total = 0 + mean_IoU = [] + + with torch.no_grad(): + for data in metric_logger.log_every(data_loader, 100, header): + total_its += 1 + image, target, sentences, attentions = data + image, target, sentences, attentions = image.cuda(non_blocking=True),\ + target.cuda(non_blocking=True),\ + sentences.cuda(non_blocking=True),\ + attentions.cuda(non_blocking=True) + + sentences = sentences.squeeze(1) + attentions = attentions.squeeze(1) + #print("sentences", sentences.shape) + #print("attentions", attentions.shape) + + if bert_model is not None: + last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] + #print("last hidden states", last_hidden_states.shape) + embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy + attentions = attentions.unsqueeze(dim=-1) # (B, N_l, 1) + output = model(image, embedding, l_mask=attentions) + else: + output = model(image, sentences, l_mask=attentions) + + iou, I, U = IoU(output, target) + acc_ious += iou + mean_IoU.append(iou) + cum_I += I + cum_U += U + for n_eval_iou in range(len(eval_seg_iou_list)): + eval_seg_iou = eval_seg_iou_list[n_eval_iou] + seg_correct[n_eval_iou] += (iou >= eval_seg_iou) + seg_total += 1 + iou = acc_ious / total_its + + mean_IoU = np.array(mean_IoU) + mIoU = np.mean(mean_IoU) + print('Final results:') + print('Mean IoU is %.2f\n' % (mIoU * 100.)) + results_str = '' + for n_eval_iou in range(len(eval_seg_iou_list)): + results_str += ' precision@%s = %.2f\n' % \ + (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) + results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) + print(results_str) + + return 100 * iou, 100 * cum_I / cum_U + + +def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq, + iterations, bert_model): + model.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) + header = 'Epoch: [{}]'.format(epoch) + train_loss = 0 + total_its = 0 + + for data in metric_logger.log_every(data_loader, print_freq, header): + total_its += 1 + image, target, sentences, attentions = data + image, target, sentences, attentions = image.cuda(non_blocking=True),\ + target.cuda(non_blocking=True),\ + sentences.cuda(non_blocking=True),\ + attentions.cuda(non_blocking=True) + + sentences = sentences.squeeze(1) + attentions = attentions.squeeze(1) + #print(sentences.shape, attentions.shape, target.shape) + #print(sentences) + #print('a', sentences.shape) + #print('b', attentions.shape) + + if bert_model is not None: + last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] # (6, 10, 768) + #print('c', last_hidden_states.shape) + + embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy + #print('e', embedding.shape) + attentions = attentions.unsqueeze(dim=-1) # (batch, N_l, 1) + #print('f', attentions.shape) + output = model(image, embedding, l_mask=attentions) + else: + output = model(image, sentences, l_mask=attentions) + + loss = criterion(output, target) + optimizer.zero_grad() # set_to_none=True is only available in pytorch 1.6+ + loss.backward() + optimizer.step() + lr_scheduler.step() + + torch.cuda.synchronize() + train_loss += loss.item() + iterations += 1 + metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) + + del image, target, sentences, attentions, loss, output, data + if bert_model is not None: + del last_hidden_states, embedding + + #gc.collect() + #torch.cuda.empty_cache() + torch.cuda.synchronize() + + +#def main(args): +def main(local_rank, args): + ip = os.environ['MASTER_IP'] + port = os.environ['MASTER_PORT'] + hosts = int(os.environ['WORLD_SIZE']) # 机器个数 1 + rank = int(os.environ['RANK']) # 当前机器编号 + gpus = torch.cuda.device_count() # 每台机器的GPU个数 + print(local_rank, rank, gpus) #3 0 8 + dist.init_process_group(backend='nccl', init_method=f'tcp://{ip}:{port}', world_size=hosts*gpus, rank=rank*gpus+local_rank) + torch.cuda.set_device(local_rank) + dist.barrier() + + #utils.init_distributed_mode(args) + args.distributed=True + args.gpu = local_rank + print(args) + #misc.init_distributed_mode(args) + + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(', ', ',\n')) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + #cudnn.benchmark = True + + dataset, num_classes = get_dataset("train", + get_transform(args=args), + args=args) + dataset_test, _ = get_dataset("val", + get_transform(args=args), + args=args) + + # batch sampler + print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.") + #num_tasks = utils.get_world_size() + #global_rank = utils.get_rank() + num_tasks = hosts*gpus + global_rank = rank*gpus+local_rank + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, + shuffle=True) + test_sampler = torch.utils.data.SequentialSampler(dataset_test) + + # data loader + data_loader = DataLoader( + dataset, batch_size=args.batch_size, + sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True) + + data_loader_test = DataLoader( + dataset_test, batch_size=1, sampler=test_sampler, pin_memory=True, num_workers=args.workers) + + # model initialization + print(args.model) + model = segmentation.__dict__[args.model](pretrained=args.pretrained_swin_weights, + args=args) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) + #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=False) + single_model = model.module + + if args.model != 'lavt_one': + model_class = BertModel + bert_model = model_class.from_pretrained(args.ck_bert) + bert_model.pooler = None # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel + bert_model.cuda() + bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model) + bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank]) + single_bert_model = bert_model.module + else: + bert_model = None + single_bert_model = None + + input_shape = dict() + input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) + input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) + input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) + input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) + + + + cfg = Dict() + cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 + cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 + cfg.MODEL.MASK_FORMER.NHEADS = 8 + cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4 + cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 + cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 + cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] + + cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 + cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 + cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1 + cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 + cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10 + cfg.MODEL.MASK_FORMER.PRE_NORM = False + + + maskformer_head = MaskFormerHead(cfg, input_shape) + maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head) + maskformer_head.cuda() + maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False) + single_head = maskformer_head.module + print(single_head) + + + if args.resume == "auto": + last_ckpt = "" + for e in range(args.epochs): + ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth') + if os.path.exists(ckpt_path): + last_ckpt = ckpt_path + args.resume = last_ckpt + + # resume training + if args.resume: + checkpoint = torch.load(args.resume, map_location='cpu') + single_model.load_state_dict(checkpoint['model']) + single_head.load_state_dict(checkpoint['head_model']) + if args.model != 'lavt_one': + single_bert_model.load_state_dict(checkpoint['bert_model']) + + # parameters to optimize + backbone_no_decay = list() + backbone_decay = list() + for name, m in single_model.backbone.named_parameters(): + if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name: + backbone_no_decay.append(m) + else: + backbone_decay.append(m) + + if args.model != 'lavt_one': + params_to_optimize = [ + {'params': backbone_no_decay, 'weight_decay': 0.0}, + {'params': backbone_decay}, + {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, + # the following are the parameters of bert + {"params": reduce(operator.concat, + [[p for p in single_bert_model.encoder.layer[i].parameters() + if p.requires_grad] for i in range(10)])}, + {"params": single_head.parameters()} + ] + else: + params_to_optimize = [ + {'params': backbone_no_decay, 'weight_decay': 0.0}, + {'params': backbone_decay}, + {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, + # the following are the parameters of bert + {"params": reduce(operator.concat, + [[p for p in single_model.text_encoder.encoder.layer[i].parameters() + if p.requires_grad] for i in range(10)])}, + ] + + # optimizer + optimizer = torch.optim.AdamW(params_to_optimize, + lr=args.lr, + weight_decay=args.weight_decay, + amsgrad=args.amsgrad + ) + + # learning rate scheduler + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, + lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) + + # housekeeping + start_time = time.time() + iterations = 0 + best_oIoU = -0.1 + + # resume training (optimizer, lr scheduler, and the epoch) + if args.resume: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + resume_epoch = checkpoint['epoch'] + else: + resume_epoch = -999 + + # training loops + for epoch in range(max(0, resume_epoch+1), args.epochs): + data_loader.sampler.set_epoch(epoch) + train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq, + iterations, bert_model, single_head) + iou, overallIoU = evaluate(model, data_loader_test, bert_model, single_head) + + print('Average object IoU {}'.format(iou)) + print('Overall IoU {}'.format(overallIoU)) + + + if single_bert_model is not None: + dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(), + 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, + 'lr_scheduler': lr_scheduler.state_dict(), 'head_model': single_head.state_dict()} + else: + dict_to_save = {'model': single_model.state_dict(), + 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, + 'lr_scheduler': lr_scheduler.state_dict()} + + checkpoint_path = os.path.join(args.output_dir, 'checkpoint-{}.pth'.format(epoch)) + utils.save_on_master(dict_to_save, str(checkpoint_path) + '_TEMP') + if utils.is_main_process(): + os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path)) + + if utils.is_main_process(): + ckpt_paths = [] + for e in range(args.epochs): + ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth') + print(ckpt_path) + if os.path.exists(ckpt_path): + ckpt_paths.append(ckpt_path) + print(ckpt_paths) + for ckpt_path in ckpt_paths[:-args.max_ckpt]: + os.remove(ckpt_path) + print("remove {:s}".format(ckpt_path)) + + + save_checkpoint = (best_oIoU < overallIoU) + if save_checkpoint: + print('Better epoch: {}\n'.format(epoch)) + if single_bert_model is not None: + dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(), + 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, + 'lr_scheduler': lr_scheduler.state_dict()} + else: + dict_to_save = {'model': single_model.state_dict(), + 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, + 'lr_scheduler': lr_scheduler.state_dict()} + + checkpoint_path = os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id)) + utils.save_on_master(dict_to_save, checkpoint_path + '_TEMP') + if utils.is_main_process(): + os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path)) + best_oIoU = overallIoU + + # summarize + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == "__main__": + from args import get_parser + parser = get_parser() + args = parser.parse_args() + os.makedirs(args.output_dir, exist_ok=True) + # set up distributed learning + #utils.init_distributed_mode(args) + print('Image size: {}'.format(str(args.img_size))) + #main(args) + mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count()) diff --git a/elia/transforms.py b/elia/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..0d22889dbab930ebe2a41dd01b8067465343079f --- /dev/null +++ b/elia/transforms.py @@ -0,0 +1,124 @@ +import numpy as np +from PIL import Image +import random + +import torch +from torchvision import transforms as T +from torchvision.transforms import functional as F + + +def pad_if_smaller(img, size, fill=0): + min_size = min(img.size) + if min_size < size: + ow, oh = img.size + padh = size - oh if oh < size else 0 + padw = size - ow if ow < size else 0 + img = F.pad(img, (0, 0, padw, padh), fill=fill) + return img + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + +class Resize(object): + def __init__(self, h, w): + self.h = h + self.w = w + + def __call__(self, image, target): + image = F.resize(image, (self.h, self.w)) + # If size is a sequence like (h, w), the output size will be matched to this. + # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio + target = F.resize(target, (self.h, self.w), interpolation=Image.NEAREST) + return image, target + + +class RandomResize(object): + def __init__(self, min_size, max_size=None): + self.min_size = min_size + if max_size is None: + max_size = min_size + self.max_size = max_size + + def __call__(self, image, target): + size = random.randint(self.min_size, self.max_size) # Return a random integer N such that a <= N <= b. Alias for randrange(a, b+1) + image = F.resize(image, size) + # If size is a sequence like (h, w), the output size will be matched to this. + # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio + target = F.resize(target, size, interpolation=Image.NEAREST) + return image, target + + +class RandomHorizontalFlip(object): + def __init__(self, flip_prob): + self.flip_prob = flip_prob + + def __call__(self, image, target): + if random.random() < self.flip_prob: + image = F.hflip(image) + target = F.hflip(target) + return image, target + + +class RandomCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, image, target): + image = pad_if_smaller(image, self.size) + target = pad_if_smaller(target, self.size, fill=255) + crop_params = T.RandomCrop.get_params(image, (self.size, self.size)) + image = F.crop(image, *crop_params) + target = F.crop(target, *crop_params) + return image, target + + +class CenterCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, image, target): + image = F.center_crop(image, self.size) + target = F.center_crop(target, self.size) + return image, target + + +class ToTensor(object): + def __call__(self, image, target): + image = F.to_tensor(image) + target = torch.as_tensor(np.asarray(target).copy(), dtype=torch.int64) + return image, target + + +class RandomAffine(object): + def __init__(self, angle, translate, scale, shear, resample=0, fillcolor=None): + self.angle = angle + self.translate = translate + self.scale = scale + self.shear = shear + self.resample = resample + self.fillcolor = fillcolor + + def __call__(self, image, target): + affine_params = T.RandomAffine.get_params(self.angle, self.translate, self.scale, self.shear, image.size) + image = F.affine(image, *affine_params) + target = F.affine(target, *affine_params) + return image, target + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target): + image = F.normalize(image, mean=self.mean, std=self.std) + return image, target + diff --git a/elia/utils.py b/elia/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8db2c5a47227f7f3bcd7c4f219d97d7141724ccb --- /dev/null +++ b/elia/utils.py @@ -0,0 +1,222 @@ +from __future__ import print_function +from collections import defaultdict, deque +import datetime +import math +import time +import torch +import torch.distributed as dist +import torch.backends.cudnn as cudnn + +import errno +import os + +import sys + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + sys.stdout.flush() + + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {}'.format(header, total_time_str)) + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ['WORLD_SIZE']) + print(f"RANK and WORLD_SIZE in environment: {rank}/{world_size}") + else: + rank = -1 + world_size = -1 + + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) + torch.distributed.barrier() + setup_for_distributed(is_main_process()) + + if args.output_dir: + mkdir(args.output_dir) + if args.model_id: + mkdir(os.path.join('./models/', args.model_id)) diff --git a/elia/visualize.py b/elia/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..8e4ae7b657f1a12b6bf7102b17bf414b64d60fa8 --- /dev/null +++ b/elia/visualize.py @@ -0,0 +1,505 @@ + +import datetime +import os +import time + +import torch +import torch.utils.data +from torch import nn + +from bert.multimodal_bert import MultiModalBert +import torchvision + +from lib import multimodal_segmentation_ppm +import transforms as T +import utils + +import numpy as np +from PIL import Image +import torch.nn.functional as F + +from modeling.MaskFormerModel import MaskFormerHead +from addict import Dict +from bert.modeling_bert import BertLMPredictionHead, BertEncoder +import cv2 +import textwrap + +def get_dataset(image_set, transform, args): + from data.dataset_refer_bert_vis import ReferDataset + ds = ReferDataset(args, + split=image_set, + image_transforms=transform, + target_transforms=None, + eval_mode=True + ) + num_classes = 2 + return ds, num_classes + + +def overlay_davis(image, mask, colors=[[0, 0, 0], [0, 255, 0]], cscale=1, alpha=0.4): + from scipy.ndimage.morphology import binary_dilation + + colors = np.reshape(colors, (-1, 3)) + colors = np.atleast_2d(colors) * cscale + + im_overlay = image.copy() + object_ids = np.unique(mask) + + for object_id in object_ids[1:]: + # Overlay color on binary mask + foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) + binary_mask = mask == object_id + + # Compose image + im_overlay[binary_mask] = foreground[binary_mask] + + # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask + countours = binary_dilation(binary_mask) ^ binary_mask + # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask + im_overlay[countours, :] = 0 + + return im_overlay.astype(image.dtype) + +def evaluate(model, data_loader, device, args): + model.eval() + metric_logger = utils.MetricLogger(delimiter=" ") + + # evaluation variables + cum_I, cum_U = 0, 0 + eval_seg_iou_list = [.5, .6, .7, .8, .9] + seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) + seg_total = 0 + mean_IoU = [] + header = 'Test:' + + with torch.no_grad(): + idx = 0 + for data in metric_logger.log_every(data_loader, 100, header): + idx += 1 + image, target, sentences, attentions, raw_sentences, this_img, orig_img = data + image, target, sentences, attentions = image.to(device), target.to(device), \ + sentences.to(device), attentions.to(device) + + sentences = sentences.squeeze(1) + attentions = attentions.squeeze(1) + #target = target.cpu().data.numpy() + + b, h, w, c = orig_img.shape + #orig_img = orig_img.numpy()[:, :, :, ::-1] + + orig_img =orig_img.data.cpu().numpy()[0, :, :, :].astype(np.uint8) + vis = np.zeros((h, w*2,3)).astype(np.uint8) + + #image_mean_iou = [] + target_numpy = target.cpu().numpy() + + for j in range(sentences.size(-1)): + #if bert_model is not None: + # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] + # embedding = last_hidden_states.permute(0, 2, 1) + # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1)) + #else: + output = model(image, sentences[:, :, j], attentions[:, :, j]) + mask_cls_results = output["pred_logits"] + mask_pred_results = output["pred_masks"] + + target_shape = target.shape[-2:] + mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True) + + pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results) + #output = pred_masks[0] + + #output = output.cpu() + + I, U = computeIoU(pred_masks.cpu().numpy(), target_numpy) + if U == 0: + this_iou = 0.0 + else: + this_iou = I*1.0/U + mean_IoU.append(this_iou) + #image_mean_iou.append(this_iou) + cum_I += I + cum_U += U + for n_eval_iou in range(len(eval_seg_iou_list)): + eval_seg_iou = eval_seg_iou_list[n_eval_iou] + seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) + seg_total += 1 + + + #print(output.shape) + #output_mask = output.argmax(1).data.numpy() + #output_mask = (output > 0.5).data.numpy() + + #vis_output_mask = torch.sigmoid(output[:, 1]).data.numpy() + #vis_output_mask = torch.sigmoid((output>0.5).float()).data.numpy() + #soft + #vis_output_mask = output.data.numpy() + #vis_output_mask = output_mask + + #print(output.shape, orig_shape) + gt_masks = torch.nn.functional.interpolate(target.unsqueeze(0).float(), (h, w)) + pred_masks = torch.nn.functional.interpolate(pred_masks, (h, w)) + #print(orig_mask.shape) + pred_masks = (pred_masks > 0.5).data.cpu().numpy() + #ntarget = target.data.cpu().numpy() + ##orig_mask = orig_mask.argmax(1).data.numpy() + + #print(orig_img[0].shape, orig_mask[0][0].shape, flush=True) + #print(orig_img.dtype, orig_mask.dtype) + predict_imgs = overlay_davis(orig_img, pred_masks[0][0].astype(np.uint8)) + gt_imgs = overlay_davis(orig_img, gt_masks[0][0].cpu().numpy().astype(np.uint8), colors=[[0, 0, 0], [0, 0, 255]]) + #print(orig_mask.shape, orig_img.shape) + #red_mask = np.zeros((orig_mask.shape[1], orig_mask.shape[2], orig_mask.shape[3], 3)).astype(np.uint8) + #print("???", red_mask.shape, orig_mask.shape) + #red_mask[:, :, :, 1] = orig_mask * 255 + #red_mask = cv2.bitwise_and(red_mask, red_mask, orig_mask.astype(np.uint8)) + + #temp = cv2.addWeighted(red_mask, 0.5, orig_img, 0.5, 0) + #print(orig_img.shape, temp.shape, orig_mask.shape, "WHAT?") + #new = orig_img * (1.0 - orig_mask[0][:,:,:,None]) + temp * orig_mask[0][:,:,:,None] + #print(new.shape, orig_mask.shape, temp.shape, "check") + ##print(vis_output_mask) + ##output_mask = output.argmax(1).data.numpy() + # + #print(raw_sentences[j]) + + # print(image.shape, target.shape, output_mask.shape) + + #mean = np.array([0.485, 0.456, 0.406]) + #std = np.array([0.229, 0.224, 0.225]) + #np_image = (((image[0].permute(1,2,0).cpu().numpy() * std) + mean) * 255).astype(np.uint8)[:,:,::-1] + #np_target = (target * 255).transpose(1,2,0).astype(np.uint8) + ##print(output_mask) + #np_output_mask = (vis_output_mask*255).transpose(1,2,0).repeat(3, axis=2).astype(np.uint8) + + font = cv2.FONT_HERSHEY_DUPLEX + fontScale = 1.0 + fontColor = (255,0,0) + thickness = 1 + lineType = 2 + + wrapped_text = textwrap.wrap(' '.join(raw_sentences[j]), width=35) + for k, line in enumerate(wrapped_text): + bottomLeftCornerOfText = (10,h-60 + k*30) + gt_imgs = cv2.putText(gt_imgs, line, + bottomLeftCornerOfText, + font, + fontScale, + fontColor, + thickness, + lineType) + + + #temp = j + 2 + #split = temp // 3 + #row = temp % 3 + vis[0:h, 0:w, :] = gt_imgs + vis[0:h, w:2*w, :] = predict_imgs + + + + #cv2.imwrite("vis/elifan_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], j), new[0].astype(np.uint8)) + cv2.imwrite("vis/{:s}/{:s}_{:d}_{:d}_{:.2f}.jpg".format(args.vis_dir, this_img[0].split('.')[0], idx, j, this_iou), vis[:, :, ::-1].astype(np.uint8)) + + #print('---------------') + #cv2.imshow("vis", vis) + #cv2.waitKey(0) + + #image_mean_iou = np.mean(np.array(image_mean_iou)) + #print(image_mean_iou) + #if image_mean_iou < 0.5: + #cv2.imwrite("vis/elian_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], idx), vis) + + + #del image, target, sentences, attentions, output, output_mask + #if bert_model is not None: + # del last_hidden_states, embedding + + mean_IoU = np.array(mean_IoU) + mIoU = np.mean(mean_IoU) + print('Final results:') + print('Mean IoU is %.2f\n' % (mIoU*100.)) + results_str = '' + for n_eval_iou in range(len(eval_seg_iou_list)): + results_str += ' precision@%s = %.2f\n' % \ + (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) + results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) + print(results_str) + +#def evaluate(model, data_loader, device): +# model.eval() +# metric_logger = utils.MetricLogger(delimiter=" ") +# +# # evaluation variables +# cum_I, cum_U = 0, 0 +# eval_seg_iou_list = [.5, .6, .7, .8, .9] +# seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) +# seg_total = 0 +# mean_IoU = [] +# header = 'Test:' +# +# with torch.no_grad(): +# for data in metric_logger.log_every(data_loader, 100, header): +# image, target, sentences, attentions = data +# image, target, sentences, attentions = image.to(device), target.to(device), \ +# sentences.to(device), attentions.to(device) +# sentences = sentences.squeeze(1) +# attentions = attentions.squeeze(1) +# target = target.cpu().data.numpy() +# for j in range(sentences.size(-1)): +# #if bert_model is not None: +# # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] +# # embedding = last_hidden_states.permute(0, 2, 1) +# # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1)) +# #else: +# output = model(image, sentences[:, :, j], attentions[:, :, j]) +# mask_cls_results = output["pred_logits"] +# mask_pred_results = output["pred_masks"] +# +# target_shape = target.shape[-2:] +# mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True) +# +# pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results) +# output = pred_masks[0] +# +# output = output.cpu() +# #print(output.shape) +# #output_mask = output.argmax(1).data.numpy() +# output_mask = (output > 0.5).data.numpy() +# I, U = computeIoU(output_mask, target) +# if U == 0: +# this_iou = 0.0 +# else: +# this_iou = I*1.0/U +# mean_IoU.append(this_iou) +# cum_I += I +# cum_U += U +# for n_eval_iou in range(len(eval_seg_iou_list)): +# eval_seg_iou = eval_seg_iou_list[n_eval_iou] +# seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) +# seg_total += 1 +# +# #del image, target, sentences, attentions, output, output_mask +# #if bert_model is not None: +# # del last_hidden_states, embedding +# +# mean_IoU = np.array(mean_IoU) +# mIoU = np.mean(mean_IoU) +# print('Final results:') +# print('Mean IoU is %.2f\n' % (mIoU*100.)) +# results_str = '' +# for n_eval_iou in range(len(eval_seg_iou_list)): +# results_str += ' precision@%s = %.2f\n' % \ +# (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) +# results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) +# print(results_str) + + +def get_transform(args): + transforms = [T.Resize(args.img_size, args.img_size), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ] + + return T.Compose(transforms) + + +def computeIoU(pred_seg, gd_seg): + I = np.sum(np.logical_and(pred_seg, gd_seg)) + U = np.sum(np.logical_or(pred_seg, gd_seg)) + + return I, U + +class WrapperModel(nn.Module): + def __init__(self, image_model, language_model, classifier, args) : + super(WrapperModel, self).__init__() + self.image_model = image_model + self.language_model = language_model + self.classifier = classifier + self.lang_proj = nn.Linear(768,256) + + config = Dict({ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + #"max_position_embeddings": 16+20, + "model_type": "bert", + "num_attention_heads": 8, + "num_hidden_layers": 8, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "use_cache": True, + "vocab_size": 30522 + }) + self.mlm_transformer = BertEncoder(config) + + self.lang_proj = nn.Linear(768,256) + self.mlm_vis_proj = nn.Conv2d(1024,512,1) + self.mlm_lang_proj = nn.Linear(768,512) + #print(vis_proj) + self.mlm_head = BertLMPredictionHead(config) + + assert args.img_size % 4 == 0 + num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2 + print(num_img_tokens) + self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512) + self.mlm_modal_embeds = nn.Embedding(3, 512) + + self.mlm_mask_embed = nn.Embedding(1, 512) + self.mlm_pos_mlp = nn.Sequential( + nn.Linear(2, 512), + nn.LayerNorm(512), + nn.Linear(512,512), + nn.GELU() + ) + + def _get_binary_mask(self, target): + # 返回每类的binary mask + y, x = target.size() + target_onehot = torch.zeros(self.num_classes + 1, y, x) + target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1) + return target_onehot[1:] + + def semantic_inference(self, mask_cls, mask_pred): + mask_cls = F.softmax(mask_cls, dim=1)[...,1:] + mask_pred = mask_pred.sigmoid() + semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) + return semseg + + def forward(self, image, sentences, attentions): + input_shape = image.shape[-2:] + l_mask = attentions.unsqueeze(dim=-1) + + i0, Wh, Ww = self.image_model.forward_stem(image) + l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions) + + i1 = self.image_model.forward_stage1(i0, Wh, Ww) + l1 = self.language_model.forward_stage1(l0, extended_attention_mask) + i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask) + l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) + i1 = i1_temp + + i2 = self.image_model.forward_stage2(i1, Wh, Ww) + l2 = self.language_model.forward_stage2(l1, extended_attention_mask) + i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask) + l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) + i2 = i2_temp + + i3 = self.image_model.forward_stage3(i2, Wh, Ww) + l3 = self.language_model.forward_stage3(l2, extended_attention_mask) + i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask) + l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) + i3 = i3_temp + + i4 = self.image_model.forward_stage4(i3, Wh, Ww) + l4 = self.language_model.forward_stage4(l3, extended_attention_mask) + i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask) + l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) + i4 = i4_temp + + #i1_residual, i2_residual, i3_residual, i4_residual = features + #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual) + #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) + outputs = {} + outputs['s1'] = i1_residual + outputs['s2'] = i2_residual + outputs['s3'] = i3_residual + outputs['s4'] = i4_residual + + predictions, mask_predictions = self.classifier(outputs) + return predictions + +def main(args): +#def main(local_rank, args): + + #device = torch.device(args.device) + device = 'cuda' + dataset_test, _ = get_dataset(args.split, get_transform(args=args), args) + test_sampler = torch.utils.data.SequentialSampler(dataset_test) + data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, + sampler=test_sampler, num_workers=args.workers) + print(args.model) + single_model = multimodal_segmentation_ppm.__dict__[args.model](pretrained='',args=args) + #single_model = MultiModalFocal(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], focal_windows=[9,9,9,9], drop_path_rate=0.3) + #single_model.init_weights('./focalnet_base_lrf.pth') + checkpoint = torch.load(args.resume, map_location='cpu') + #single_model.load_state_dict(checkpoint['model']) + #model = single_model.to(device) + + if args.model != 'lavt_one': + model_class = MultiModalBert + #single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=128) + single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=single_model.backbone.embed_dim) + # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines + if args.ddp_trained_weights: + single_bert_model.pooler = None + #single_bert_model.load_state_dict(checkpoint['bert_model']) + #bert_model = single_bert_model.to(device) + else: + bert_model = None + + #model = WrapperModel(single_model.backbone, single_bert_model, single_model.classifier) + #model.load_state_dict(checkpoint['model']) + #model.to(device) + input_shape = dict() + input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) + input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) + input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) + input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) + + + + cfg = Dict() + cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 + cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 + cfg.MODEL.MASK_FORMER.NHEADS = 8 + cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4 + cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 + cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 + cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] + + cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 + cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 + cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1 + cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 + cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10 + cfg.MODEL.MASK_FORMER.PRE_NORM = False + + + maskformer_head = MaskFormerHead(cfg, input_shape) + #maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head) + #maskformer_head.cuda() + #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False) + #single_head = maskformer_head.module + #print(single_head) + + model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head, args) + model.load_state_dict(checkpoint['model']) + model.to(device) + #model.cuda() + #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) + #single_model = model.module + #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) + #single_model = model.module + evaluate(model, data_loader_test, device=device, args=args) + + +if __name__ == "__main__": + from args import get_parser + parser = get_parser() + args = parser.parse_args() + print('Image size: {}'.format(str(args.img_size))) + print(args) + os.makedirs('vis/' + args.vis_dir, exist_ok=True) + main(args) + #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count()) diff --git a/elia/visualize_final.py b/elia/visualize_final.py new file mode 100644 index 0000000000000000000000000000000000000000..0947b95c496b33953301ac4f18a9533076daf27e --- /dev/null +++ b/elia/visualize_final.py @@ -0,0 +1,506 @@ + +import datetime +import os +import time + +import torch +import torch.utils.data +from torch import nn + +from bert.multimodal_bert import MultiModalBert +import torchvision + +from lib import multimodal_segmentation_ppm +import transforms as T +import utils + +import numpy as np +from PIL import Image +import torch.nn.functional as F + +from modeling.MaskFormerModel import MaskFormerHead +from addict import Dict +from bert.modeling_bert import BertLMPredictionHead, BertEncoder +import cv2 +import textwrap + +def get_dataset(image_set, transform, args): + from data.dataset_refer_bert_vis import ReferDataset + ds = ReferDataset(args, + split=image_set, + image_transforms=transform, + target_transforms=None, + eval_mode=True + ) + num_classes = 2 + return ds, num_classes + + +def overlay_davis(image, mask, colors=[[0, 0, 0], [0, 255, 0]], cscale=1, alpha=0.4): + from scipy.ndimage.morphology import binary_dilation + + colors = np.reshape(colors, (-1, 3)) + colors = np.atleast_2d(colors) * cscale + + im_overlay = image.copy() + object_ids = np.unique(mask) + + for object_id in object_ids[1:]: + # Overlay color on binary mask + foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) + binary_mask = mask == object_id + + # Compose image + im_overlay[binary_mask] = foreground[binary_mask] + + # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask + countours = binary_dilation(binary_mask) ^ binary_mask + # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask + im_overlay[countours, :] = 0 + + return im_overlay.astype(image.dtype) + +def evaluate(model, data_loader, device): + model.eval() + metric_logger = utils.MetricLogger(delimiter=" ") + + # evaluation variables + cum_I, cum_U = 0, 0 + eval_seg_iou_list = [.5, .6, .7, .8, .9] + seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) + seg_total = 0 + mean_IoU = [] + header = 'Test:' + + with torch.no_grad(): + number = 0 + idx = 0 + for data in metric_logger.log_every(data_loader, 100, header): + number +=1 + idx += 1 + print(number) + image, target, sentences, attentions, raw_sentences, this_img, orig_img = data + image, target, sentences, attentions = image.to(device), target.to(device), \ + sentences.to(device), attentions.to(device) + #if number <= 40: + # continue + + sentences = sentences.squeeze(1) + attentions = attentions.squeeze(1) + target = target.cpu().data.numpy() + + orig_shape = orig_img.shape + orig_img = orig_img.numpy()[:, :, :, ::-1] + print(orig_img.shape, "??") + + vis = np.zeros((480*2, 480*3,3)).astype(np.uint8) + + image_mean_iou = [] + + for j in range(sentences.size(-1)): + #if bert_model is not None: + # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] + # embedding = last_hidden_states.permute(0, 2, 1) + # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1)) + #else: + output = model(image, sentences[:, :, j], attentions[:, :, j]) + mask_cls_results = output["pred_logits"] + mask_pred_results = output["pred_masks"] + + target_shape = target.shape[-2:] + mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True) + + pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results) + output = pred_masks[0] + + output = output.cpu() + + + + #print(output.shape) + #output_mask = output.argmax(1).data.numpy() + output_mask = (output > 0.5).data.numpy() + + #vis_output_mask = torch.sigmoid(output[:, 1]).data.numpy() + #vis_output_mask = torch.sigmoid((output>0.5).float()).data.numpy() + #soft + #vis_output_mask = output.data.numpy() + #vis_output_mask = output_mask + + #print(output.shape, orig_shape) + orig_mask = torch.nn.functional.interpolate(pred_masks, (orig_shape[1], orig_shape[2])) + #print(orig_mask.shape) + orig_mask = (orig_mask > 0.5).data.cpu().numpy() + ##orig_mask = orig_mask.argmax(1).data.numpy() + + print(orig_img[0].shape, orig_mask[0][0].shape, flush=True) + print(orig_img.dtype, orig_mask.dtype) + new = overlay_davis(orig_img[0], orig_mask[0][0].astype(np.uint8)) + #print(orig_mask.shape, orig_img.shape) + #red_mask = np.zeros((orig_mask.shape[1], orig_mask.shape[2], orig_mask.shape[3], 3)).astype(np.uint8) + #print("???", red_mask.shape, orig_mask.shape) + #red_mask[:, :, :, 1] = orig_mask * 255 + #red_mask = cv2.bitwise_and(red_mask, red_mask, orig_mask.astype(np.uint8)) + + #temp = cv2.addWeighted(red_mask, 0.5, orig_img, 0.5, 0) + #print(orig_img.shape, temp.shape, orig_mask.shape, "WHAT?") + #new = orig_img * (1.0 - orig_mask[0][:,:,:,None]) + temp * orig_mask[0][:,:,:,None] + #print(new.shape, orig_mask.shape, temp.shape, "check") + ##print(vis_output_mask) + ##output_mask = output.argmax(1).data.numpy() + # + #print(raw_sentences[j]) + + # print(image.shape, target.shape, output_mask.shape) + + #mean = np.array([0.485, 0.456, 0.406]) + #std = np.array([0.229, 0.224, 0.225]) + #np_image = (((image[0].permute(1,2,0).cpu().numpy() * std) + mean) * 255).astype(np.uint8)[:,:,::-1] + #np_target = (target * 255).transpose(1,2,0).astype(np.uint8) + ##print(output_mask) + #np_output_mask = (vis_output_mask*255).transpose(1,2,0).repeat(3, axis=2).astype(np.uint8) + + #font = cv2.FONT_HERSHEY_SIMPLEX + #fontScale = 0.75 + #fontColor = (0,0,255) + #thickness = 1 + #lineType = 2 + + #wrapped_text = textwrap.wrap(' '.join(raw_sentences[j]), width=35) + #for k, line in enumerate(wrapped_text): + # bottomLeftCornerOfText = (10,420+k*20) + # np_output_mask = cv2.putText(np_output_mask, line, + # bottomLeftCornerOfText, + # font, + # fontScale, + # fontColor, + # thickness, + # lineType) + + # + #temp = j + 2 + #split = temp // 3 + #row = temp % 3 + #vis[0:480, 0:480, :] = np_image + #vis[0:480, 480:960, :] = np_target.repeat(3, axis=2) + #vis[split*480:(split+1)*480:, row * 480:(row+1)*480, :] = np_output_mask + + + + I, U = computeIoU(output_mask, target) + if U == 0: + this_iou = 0.0 + else: + this_iou = I*1.0/U + mean_IoU.append(this_iou) + image_mean_iou.append(this_iou) + cum_I += I + cum_U += U + for n_eval_iou in range(len(eval_seg_iou_list)): + eval_seg_iou = eval_seg_iou_list[n_eval_iou] + seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) + seg_total += 1 + #cv2.imwrite("vis/elifan_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], j), new[0].astype(np.uint8)) + cv2.imwrite("vis/elia_refcoco+_green/{:s}_{:d}_{:d}_{:.2f}.jpg".format(this_img[0].split('.')[0], idx, j, this_iou), new.astype(np.uint8)) + + print('---------------') + #cv2.imshow("vis", vis) + #cv2.waitKey(0) + + image_mean_iou = np.mean(np.array(image_mean_iou)) + print(image_mean_iou) + #if image_mean_iou < 0.5: + #cv2.imwrite("vis/elian_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], idx), vis) + + + #del image, target, sentences, attentions, output, output_mask + #if bert_model is not None: + # del last_hidden_states, embedding + + mean_IoU = np.array(mean_IoU) + mIoU = np.mean(mean_IoU) + print('Final results:') + print('Mean IoU is %.2f\n' % (mIoU*100.)) + results_str = '' + for n_eval_iou in range(len(eval_seg_iou_list)): + results_str += ' precision@%s = %.2f\n' % \ + (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) + results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) + print(results_str) + +#def evaluate(model, data_loader, device): +# model.eval() +# metric_logger = utils.MetricLogger(delimiter=" ") +# +# # evaluation variables +# cum_I, cum_U = 0, 0 +# eval_seg_iou_list = [.5, .6, .7, .8, .9] +# seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) +# seg_total = 0 +# mean_IoU = [] +# header = 'Test:' +# +# with torch.no_grad(): +# for data in metric_logger.log_every(data_loader, 100, header): +# image, target, sentences, attentions = data +# image, target, sentences, attentions = image.to(device), target.to(device), \ +# sentences.to(device), attentions.to(device) +# sentences = sentences.squeeze(1) +# attentions = attentions.squeeze(1) +# target = target.cpu().data.numpy() +# for j in range(sentences.size(-1)): +# #if bert_model is not None: +# # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] +# # embedding = last_hidden_states.permute(0, 2, 1) +# # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1)) +# #else: +# output = model(image, sentences[:, :, j], attentions[:, :, j]) +# mask_cls_results = output["pred_logits"] +# mask_pred_results = output["pred_masks"] +# +# target_shape = target.shape[-2:] +# mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True) +# +# pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results) +# output = pred_masks[0] +# +# output = output.cpu() +# #print(output.shape) +# #output_mask = output.argmax(1).data.numpy() +# output_mask = (output > 0.5).data.numpy() +# I, U = computeIoU(output_mask, target) +# if U == 0: +# this_iou = 0.0 +# else: +# this_iou = I*1.0/U +# mean_IoU.append(this_iou) +# cum_I += I +# cum_U += U +# for n_eval_iou in range(len(eval_seg_iou_list)): +# eval_seg_iou = eval_seg_iou_list[n_eval_iou] +# seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) +# seg_total += 1 +# +# #del image, target, sentences, attentions, output, output_mask +# #if bert_model is not None: +# # del last_hidden_states, embedding +# +# mean_IoU = np.array(mean_IoU) +# mIoU = np.mean(mean_IoU) +# print('Final results:') +# print('Mean IoU is %.2f\n' % (mIoU*100.)) +# results_str = '' +# for n_eval_iou in range(len(eval_seg_iou_list)): +# results_str += ' precision@%s = %.2f\n' % \ +# (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) +# results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) +# print(results_str) + + +def get_transform(args): + transforms = [T.Resize(args.img_size, args.img_size), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ] + + return T.Compose(transforms) + + +def computeIoU(pred_seg, gd_seg): + I = np.sum(np.logical_and(pred_seg, gd_seg)) + U = np.sum(np.logical_or(pred_seg, gd_seg)) + + return I, U + +class WrapperModel(nn.Module): + def __init__(self, image_model, language_model, classifier, args) : + super(WrapperModel, self).__init__() + self.image_model = image_model + self.language_model = language_model + self.classifier = classifier + self.lang_proj = nn.Linear(768,256) + + config = Dict({ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + #"max_position_embeddings": 16+20, + "model_type": "bert", + "num_attention_heads": 8, + "num_hidden_layers": 8, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "use_cache": True, + "vocab_size": 30522 + }) + self.mlm_transformer = BertEncoder(config) + + self.lang_proj = nn.Linear(768,256) + self.mlm_vis_proj = nn.Conv2d(1024,512,1) + self.mlm_lang_proj = nn.Linear(768,512) + #print(vis_proj) + self.mlm_head = BertLMPredictionHead(config) + + assert args.img_size % 4 == 0 + num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2 + print(num_img_tokens) + self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512) + self.mlm_modal_embeds = nn.Embedding(3, 512) + + self.mlm_mask_embed = nn.Embedding(1, 512) + self.mlm_pos_mlp = nn.Sequential( + nn.Linear(2, 512), + nn.LayerNorm(512), + nn.Linear(512,512), + nn.GELU() + ) + + def _get_binary_mask(self, target): + # 返回每类的binary mask + y, x = target.size() + target_onehot = torch.zeros(self.num_classes + 1, y, x) + target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1) + return target_onehot[1:] + + def semantic_inference(self, mask_cls, mask_pred): + mask_cls = F.softmax(mask_cls, dim=1)[...,1:] + mask_pred = mask_pred.sigmoid() + semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) + return semseg + + def forward(self, image, sentences, attentions): + input_shape = image.shape[-2:] + l_mask = attentions.unsqueeze(dim=-1) + + i0, Wh, Ww = self.image_model.forward_stem(image) + l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions) + + i1 = self.image_model.forward_stage1(i0, Wh, Ww) + l1 = self.language_model.forward_stage1(l0, extended_attention_mask) + i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask) + l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) + i1 = i1_temp + + i2 = self.image_model.forward_stage2(i1, Wh, Ww) + l2 = self.language_model.forward_stage2(l1, extended_attention_mask) + i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask) + l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) + i2 = i2_temp + + i3 = self.image_model.forward_stage3(i2, Wh, Ww) + l3 = self.language_model.forward_stage3(l2, extended_attention_mask) + i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask) + l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) + i3 = i3_temp + + i4 = self.image_model.forward_stage4(i3, Wh, Ww) + l4 = self.language_model.forward_stage4(l3, extended_attention_mask) + i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask) + l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) + i4 = i4_temp + + #i1_residual, i2_residual, i3_residual, i4_residual = features + #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual) + #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) + outputs = {} + outputs['s1'] = i1_residual + outputs['s2'] = i2_residual + outputs['s3'] = i3_residual + outputs['s4'] = i4_residual + + predictions = self.classifier(outputs) + return predictions + +def main(args): +#def main(local_rank, args): + + #device = torch.device(args.device) + device = 'cuda' + dataset_test, _ = get_dataset(args.split, get_transform(args=args), args) + test_sampler = torch.utils.data.SequentialSampler(dataset_test) + data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, + sampler=test_sampler, num_workers=args.workers) + print(args.model) + single_model = multimodal_segmentation_ppm.__dict__[args.model](pretrained='',args=args) + #single_model = MultiModalFocal(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], focal_windows=[9,9,9,9], drop_path_rate=0.3) + #single_model.init_weights('./focalnet_base_lrf.pth') + checkpoint = torch.load(args.resume, map_location='cpu') + #single_model.load_state_dict(checkpoint['model']) + #model = single_model.to(device) + + if args.model != 'lavt_one': + model_class = MultiModalBert + #single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=128) + single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=single_model.backbone.embed_dim) + # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines + if args.ddp_trained_weights: + single_bert_model.pooler = None + #single_bert_model.load_state_dict(checkpoint['bert_model']) + #bert_model = single_bert_model.to(device) + else: + bert_model = None + + #model = WrapperModel(single_model.backbone, single_bert_model, single_model.classifier) + #model.load_state_dict(checkpoint['model']) + #model.to(device) + input_shape = dict() + input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) + input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) + input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) + input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) + + + + cfg = Dict() + cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 + cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 + cfg.MODEL.MASK_FORMER.NHEADS = 8 + cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4 + cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 + cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 + cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] + + cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 + cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 + cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1 + cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 + cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10 + cfg.MODEL.MASK_FORMER.PRE_NORM = False + + + maskformer_head = MaskFormerHead(cfg, input_shape) + #maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head) + #maskformer_head.cuda() + #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False) + #single_head = maskformer_head.module + #print(single_head) + + model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head, args) + model.load_state_dict(checkpoint['model']) + model.to(device) + #model.cuda() + #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) + #single_model = model.module + #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) + #single_model = model.module + evaluate(model, data_loader_test, device=device) + + +if __name__ == "__main__": + from args import get_parser + parser = get_parser() + args = parser.parse_args() + print('Image size: {}'.format(str(args.img_size))) + print(args) + main(args) + #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())