diff --git a/elia/.gitattributes b/elia/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..c7d9f3332a950355d5a77d85000f05e6f45435ea
--- /dev/null
+++ b/elia/.gitattributes
@@ -0,0 +1,34 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/elia/LICENSE b/elia/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f288702d2fa16d3cdf0035b15a9fcbc552cd88e7
--- /dev/null
+++ b/elia/LICENSE
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
diff --git a/elia/README.md b/elia/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..869de8afb5569dca5c730028fe17228f54198bef
--- /dev/null
+++ b/elia/README.md
@@ -0,0 +1,222 @@
+# LAVT: Language-Aware Vision Transformer for Referring Image Segmentation
+Welcome to the official repository for the method presented in
+"LAVT: Language-Aware Vision Transformer for Referring Image Segmentation."
+
+
+![Pipeline Image](pipeline.jpg)
+
+Code in this repository is written using [PyTorch](https://pytorch.org/) and is organized in the following way (assuming the working directory is the root directory of this repository):
+* `./lib` contains files implementing the main network.
+* Inside `./lib`, `_utils.py` defines the highest-level model, which incorporates the backbone network
+defined in `backbone.py` and the simple mask decoder defined in `mask_predictor.py`.
+`segmentation.py` provides the model interface and initialization functions.
+* `./bert` contains files migrated from [Hugging Face Transformers v3.0.2](https://huggingface.co/transformers/v3.0.2/quicktour.html),
+which implement the BERT language model.
+We used Transformers v3.0.2 during development but it had a bug that would appear when using `DistributedDataParallel`.
+Therefore we maintain a copy of the relevant source files in this repository.
+This way, the bug is fixed and code in this repository is self-contained.
+* `./train.py` is invoked to train the model.
+* `./test.py` is invoked to run inference on the evaluation subsets after training.
+* `./refer` contains data pre-processing code and is also where data should be placed, including the images and all annotations.
+It is cloned from [refer](https://github.com/lichengunc/refer).
+* `./data/dataset_refer_bert.py` is where the dataset class is defined.
+* `./utils.py` defines functions that track training statistics and setup
+functions for `DistributedDataParallel`.
+
+
+## Updates
+**June 21st, 2022**. Uploaded the training logs and trained
+model weights of lavt_one.
+
+**June 9th, 2022**.
+Added a more efficient implementation of LAVT.
+* To train this new model, specify `--model` as `lavt_one`
+(and `lavt` is still valid for specifying the old model).
+The rest of the configuration stays unchanged.
+* The difference between this version and the previous one
+is that the language model has been moved inside the overall model,
+so that `DistributedDataParallel` needs to be applied only once.
+Applying it twice (on the standalone language model and the main branch)
+as done in the old implementation led to low GPU utility,
+which prevented scaling up training speed with more GPUs.
+We recommend training this model on 8 GPUs
+(and same as before with batch size 32).
+
+## Setting Up
+### Preliminaries
+The code has been verified to work with PyTorch v1.7.1 and Python 3.7.
+1. Clone this repository.
+2. Change directory to root of this repository.
+### Package Dependencies
+1. Create a new Conda environment with Python 3.7 then activate it:
+```shell
+conda create -n lavt python==3.7
+conda activate lavt
+```
+
+2. Install PyTorch v1.7.1 with a CUDA version that works on your cluster/machine (CUDA 10.2 is used in this example):
+```shell
+conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch
+```
+
+3. Install the packages in `requirements.txt` via `pip`:
+```shell
+pip install -r requirements.txt
+```
+
+### Datasets
+1. Follow instructions in the `./refer` directory to set up subdirectories
+and download annotations.
+This directory is a git clone (minus two data files that we do not need)
+from the [refer](https://github.com/lichengunc/refer) public API.
+
+2. Download images from [COCO](https://cocodataset.org/#download).
+Please use the first downloading link *2014 Train images [83K/13GB]*, and extract
+the downloaded `train_2014.zip` file to `./refer/data/images/mscoco/images`.
+
+### The Initialization Weights for Training
+1. Create the `./pretrained_weights` directory where we will be storing the weights.
+```shell
+mkdir ./pretrained_weights
+```
+2. Download [pre-trained classification weights of
+the Swin Transformer](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth),
+and put the `pth` file in `./pretrained_weights`.
+These weights are needed for training to initialize the model.
+
+### Trained Weights of LAVT for Testing
+1. Create the `./checkpoints` directory where we will be storing the weights.
+```shell
+mkdir ./checkpoints
+```
+2. Download LAVT model weights (which are stored on Google Drive) using links below and put them in `./checkpoints`.
+
+| [RefCOCO](https://drive.google.com/file/d/13D-OeEOijV8KTC3BkFP-gOJymc6DLwVT/view?usp=sharing) | [RefCOCO+](https://drive.google.com/file/d/1B8Q44ZWsc8Pva2xD_M-KFh7-LgzeH2-2/view?usp=sharing) | [G-Ref (UMD)](https://drive.google.com/file/d/1BjUnPVpALurkGl7RXXvQiAHhA-gQYKvK/view?usp=sharing) | [G-Ref (Google)](https://drive.google.com/file/d/1weiw5UjbPfo3tCBPfB8tu6xFXCUG16yS/view?usp=sharing) |
+|---|---|---|---|
+
+3. Model weights and training logs of the new lavt_one implementation are below.
+
+| RefCOCO | RefCOCO+ | G-Ref (UMD) | G-Ref (Google) |
+|:-----:|:-----:|:-----:|:-----:|
+|[log](https://drive.google.com/file/d/1YIojIHqe3bxxsWOltifa2U9jH67hPHLM/view?usp=sharing) | [weights](https://drive.google.com/file/d/1xFMEXr6AGU97Ypj1yr8oo00uObbeIQvJ/view?usp=sharing)|[log](https://drive.google.com/file/d/1Z34T4gEnWlvcSUQya7txOuM0zdLK7MRT/view?usp=sharing) | [weights](https://drive.google.com/file/d/1HS8ZnGaiPJr-OmoUn4-4LVnVtD_zHY6w/view?usp=sharing)|[log](https://drive.google.com/file/d/14VAgahngOV8NA6noLZCqDoqaUrlW14v8/view?usp=sharing) | [weights](https://drive.google.com/file/d/14g8NzgZn6HzC6tP_bsQuWmh5LnOcovsE/view?usp=sharing)|[log](https://drive.google.com/file/d/1JBXfmlwemWSvs92Rky0TlHcVuuLpt4Da/view?usp=sharing) | [weights](https://drive.google.com/file/d/1IJeahFVLgKxu_BVmWacZs3oUzgTCeWcz/view?usp=sharing)|
+
+* The Prec@K, overall IoU and mean IoU numbers in the training logs will differ
+from the final results obtained by running `test.py`,
+because only one out of multiple annotated expressions is
+randomly selected and evaluated for each object during training.
+But these numbers give a good idea about the test performance.
+The two should be fairly close.
+
+
+## Training
+We use `DistributedDataParallel` from PyTorch.
+The released `lavt` weights were trained using 4 x 32G V100 cards (max mem on each card was about 26G).
+The released `lavt_one` weights were trained using 8 x 32G V100 cards (max mem on each card was about 13G).
+Using more cards was to accelerate training.
+To run on 4 GPUs (with IDs 0, 1, 2, and 3) on a single node:
+```shell
+mkdir ./models
+
+mkdir ./models/refcoco
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcoco --model_id refcoco --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/refcoco/output
+
+mkdir ./models/refcoco+
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcoco+ --model_id refcoco+ --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/refcoco+/output
+
+mkdir ./models/gref_umd
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcocog --splitBy umd --model_id gref_umd --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/gref_umd/output
+
+mkdir ./models/gref_google
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --model lavt --dataset refcocog --splitBy google --model_id gref_google --batch-size 8 --lr 0.00005 --wd 1e-2 --swin_type base --pretrained_swin_weights ./pretrained_weights/swin_base_patch4_window12_384_22k.pth --epochs 40 --img_size 480 2>&1 | tee ./models/gref_google/output
+```
+* *--model* is a pre-defined model name. Options include `lavt` and `lavt_one`. See [Updates](#updates).
+* *--dataset* is the dataset name. One can choose from `refcoco`, `refcoco+`, and `refcocog`.
+* *--splitBy* needs to be specified if and only if the dataset is G-Ref (which is also called RefCOCOg).
+`umd` identifies the UMD partition and `google` identifies the Google partition.
+* *--model_id* is the model name one should define oneself (*e.g.*, customize it to contain training/model configurations, dataset information, experiment IDs, *etc*.).
+It is used in two ways: Training log will be saved as `./models/[args.model_id]/output` and the best checkpoint will be saved as `./checkpoints/model_best_[args.model_id].pth`.
+* *--swin_type* specifies the version of the Swin Transformer.
+One can choose from `tiny`, `small`, `base`, and `large`. The default is `base`.
+* *--pretrained_swin_weights* specifies the path to pre-trained Swin Transformer weights used for model initialization.
+* Note that currently we need to manually create the `./models/[args.model_id]` directory via `mkdir` before running `train.py`.
+This is because we use `tee` to redirect `stdout` and `stderr` to `./models/[args.model_id]/output` for logging.
+This is a nuisance and should be resolved in the future, *i.e.*, using a proper logger or a bash script for initiating training.
+
+## Testing
+For RefCOCO/RefCOCO+, run one of
+```shell
+python test.py --model lavt --swin_type base --dataset refcoco --split val --resume ./checkpoints/refcoco.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
+python test.py --model lavt --swin_type base --dataset refcoco+ --split val --resume ./checkpoints/refcoco+.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
+```
+* *--split* is the subset to evaluate, and one can choose from `val`, `testA`, and `testB`.
+* *--resume* is the path to the weights of a trained model.
+
+For G-Ref (UMD)/G-Ref (Google), run one of
+```shell
+python test.py --model lavt --swin_type base --dataset refcocog --splitBy umd --split val --resume ./checkpoints/gref_umd.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
+python test.py --model lavt --swin_type base --dataset refcocog --splitBy google --split val --resume ./checkpoints/gref_google.pth --workers 4 --ddp_trained_weights --window12 --img_size 480
+```
+* *--splitBy* specifies the partition to evaluate.
+One can choose from `umd` or `google`.
+* *--split* is the subset (according to the specified partition) to evaluate, and one can choose from `val` and `test` for the UMD partition, and only `val` for the Google partition..
+* *--resume* is the path to the weights of a trained model.
+
+## Results
+The complete test results of the released LAVT models are summarized as follows:
+
+| Dataset | P@0.5 | P@0.6 | P@0.7 | P@0.8 | P@0.9 | Overall IoU | Mean IoU |
+|:---------------:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----------:|:--------:|
+| RefCOCO val | 84.46 | 80.90 | 75.28 | 64.71 | 34.30 | 72.73 | 74.46 |
+| RefCOCO test A | 88.07 | 85.17 | 79.90 | 68.52 | 35.69 | 75.82 | 76.89 |
+| RefCOCO test B | 79.12 | 74.94 | 69.17 | 59.37 | 34.45 | 68.79 | 70.94 |
+| RefCOCO+ val | 74.44 | 70.91 | 65.58 | 56.34 | 30.23 | 62.14 | 65.81 |
+| RefCOCO+ test A | 80.68 | 77.96 | 72.90 | 62.21 | 32.36 | 68.38 | 70.97 |
+| RefCOCO+ test B | 65.66 | 61.85 | 55.94 | 47.56 | 27.24 | 55.10 | 59.23 |
+| G-Ref val (UMD) | 70.81 | 65.28 | 58.60 | 47.49 | 22.73 | 61.24 | 63.34 |
+| G-Ref test (UMD)| 71.54 | 66.38 | 59.00 | 48.21 | 23.10 | 62.09 | 63.62 |
+|G-Ref val (Goog.)| 71.16 | 67.21 | 61.76 | 51.98 | 27.30 | 60.50 | 63.66 |
+
+We have validated LAVT on RefCOCO with multiple runs.
+The overall IoU on the val set generally lies in the range of 72.73±0.5%.
+
+
+## Demo: Try LAVT on Your Own Image-text Pairs!
+One can run inference on a custom image-text pair
+and visualize the result by running the script `./demo_inference.py`.
+Choose your photos and expessions and have fun.
+
+
+## Citing LAVT
+```
+@inproceedings{yang2022lavt,
+ title={LAVT: Language-Aware Vision Transformer for Referring Image Segmentation},
+ author={Yang, Zhao and Wang, Jiaqi and Tang, Yansong and Chen, Kai and Zhao, Hengshuang and Torr, Philip HS},
+ booktitle={CVPR},
+ year={2022}
+}
+```
+
+
+## Contributing
+We appreciate all contributions.
+It helps the project if you could
+- report issues you are facing,
+- give a :+1: on issues reported by others that are relevant to you,
+- answer issues reported by others for which you have found solutions,
+- and implement helpful new features or improve the code otherwise with pull requests.
+
+## Acknowledgements
+Code in this repository is built upon several public repositories.
+Specifically,
+* data pre-processing leverages the [refer](https://github.com/lichengunc/refer) repository,
+* the backbone model is implemented based on code from [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation),
+* the training and testing pipelines are adapted from [RefVOS](https://github.com/miriambellver/refvos),
+* and implementation of the BERT model (files in the bert directory) is from [Hugging Face Transformers v3.0.2](https://github.com/huggingface/transformers/tree/v3.0.2)
+(we migrated over the relevant code to fix a bug and simplify the installation process).
+
+Some of these repositories in turn adapt code from [OpenMMLab](https://github.com/open-mmlab) and [TorchVision](https://github.com/pytorch/vision).
+We'd like to thank the authors/organizations of these repositories for open sourcing their projects.
+
+
+## License
+GNU GPLv3
diff --git a/elia/__pycache__/args.cpython-37.pyc b/elia/__pycache__/args.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f640c5b4ee64eda39a04bf8049ed80dc140ef6e
Binary files /dev/null and b/elia/__pycache__/args.cpython-37.pyc differ
diff --git a/elia/__pycache__/args.cpython-38.pyc b/elia/__pycache__/args.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e42ca8813cec4f1a54d78aebc2ea603adf3defe3
Binary files /dev/null and b/elia/__pycache__/args.cpython-38.pyc differ
diff --git a/elia/__pycache__/transforms.cpython-37.pyc b/elia/__pycache__/transforms.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..23f8d1523aee9c09d0276e0b628bd712c7bad041
Binary files /dev/null and b/elia/__pycache__/transforms.cpython-37.pyc differ
diff --git a/elia/__pycache__/transforms.cpython-38.pyc b/elia/__pycache__/transforms.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..815e3982a1d1d71927cc081f1e04de5d988183bc
Binary files /dev/null and b/elia/__pycache__/transforms.cpython-38.pyc differ
diff --git a/elia/__pycache__/utils.cpython-37.pyc b/elia/__pycache__/utils.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a04ac1d1a6fdbcf6d8b1f00ee78c4698efc273b
Binary files /dev/null and b/elia/__pycache__/utils.cpython-37.pyc differ
diff --git a/elia/__pycache__/utils.cpython-38.pyc b/elia/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..59c425f22032f4435c681238cd31a83bd8fb1061
Binary files /dev/null and b/elia/__pycache__/utils.cpython-38.pyc differ
diff --git a/elia/app.py b/elia/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..03b11a42bb61ff9a4b2dc51fd6d8e4fe106de893
--- /dev/null
+++ b/elia/app.py
@@ -0,0 +1,310 @@
+import gradio as gr
+
+image_path = './image001.png'
+sentence = 'spoon on the dish'
+weights = './checkpoints/model_best_refcoco_0508.pth'
+device = 'cpu'
+
+# pre-process the input image
+from PIL import Image
+import torchvision.transforms as T
+import numpy as np
+import datetime
+import os
+import time
+
+import torch
+import torch.utils.data
+from torch import nn
+
+from bert.multimodal_bert import MultiModalBert
+import torchvision
+
+from lib import multimodal_segmentation_ppm
+#import transforms as T
+import utils
+
+import numpy as np
+from PIL import Image
+import torch.nn.functional as F
+
+from modeling.MaskFormerModel import MaskFormerHead
+from addict import Dict
+#from bert.modeling_bert import BertLMPredictionHead, BertEncoder
+import cv2
+import textwrap
+
+class WrapperModel(nn.Module):
+ def __init__(self, image_model, language_model, classifier) :
+ super(WrapperModel, self).__init__()
+ self.image_model = image_model
+ self.language_model = language_model
+ self.classifier = classifier
+
+ config = Dict({
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "gradient_checkpointing": False,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 512,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ #"max_position_embeddings": 16+20,
+ "model_type": "bert",
+ "num_attention_heads": 8,
+ "num_hidden_layers": 8,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "transformers_version": "4.6.0.dev0",
+ "type_vocab_size": 2,
+ "use_cache": True,
+ "vocab_size": 30522
+ })
+
+
+
+ def _get_binary_mask(self, target):
+ # 返回每类的binary mask
+ y, x = target.size()
+ target_onehot = torch.zeros(self.num_classes + 1, y, x)
+ target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
+ return target_onehot[1:]
+
+ def semantic_inference(self, mask_cls, mask_pred):
+ mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
+ mask_pred = mask_pred.sigmoid()
+ semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
+ return semseg
+
+ def forward(self, image, sentences, attentions):
+ print(image.sum(), sentences.sum(), attentions.sum())
+ input_shape = image.shape[-2:]
+ l_mask = attentions.unsqueeze(dim=-1)
+
+ i0, Wh, Ww = self.image_model.forward_stem(image)
+ l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)
+
+ i1 = self.image_model.forward_stage1(i0, Wh, Ww)
+ l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
+ i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
+ l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask)
+ i1 = i1_temp
+
+ i2 = self.image_model.forward_stage2(i1, Wh, Ww)
+ l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
+ i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
+ l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask)
+ i2 = i2_temp
+
+ i3 = self.image_model.forward_stage3(i2, Wh, Ww)
+ l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
+ i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
+ l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask)
+ i3 = i3_temp
+
+ i4 = self.image_model.forward_stage4(i3, Wh, Ww)
+ l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
+ i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
+ l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask)
+ i4 = i4_temp
+
+ #i1_residual, i2_residual, i3_residual, i4_residual = features
+ #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
+ #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
+ outputs = {}
+ outputs['s1'] = i1_residual
+ outputs['s2'] = i2_residual
+ outputs['s3'] = i3_residual
+ outputs['s4'] = i4_residual
+
+ predictions = self.classifier(outputs)
+ return predictions
+
+#img = Image.open(image_path).convert("RGB")
+
+# pre-process the raw sentence
+from bert.tokenization_bert import BertTokenizer
+import torch
+
+# initialize model and load weights
+#from bert.modeling_bert import BertModel
+#from lib import segmentation
+
+# construct a mini args class; like from a config file
+
+
+class args:
+ swin_type = 'base'
+ window12 = True
+ mha = ''
+ fusion_drop = 0.0
+
+
+#single_model = segmentation.__dict__['lavt'](pretrained='', args=args)
+single_model = multimodal_segmentation_ppm.__dict__['lavt'](pretrained='',args=args)
+single_model.to(device)
+model_class = MultiModalBert
+single_bert_model = model_class.from_pretrained('bert-base-uncased', embed_dim=single_model.backbone.embed_dim)
+single_bert_model.pooler = None
+
+input_shape = dict()
+input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
+input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
+input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
+input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
+
+
+
+cfg = Dict()
+cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
+cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
+cfg.MODEL.MASK_FORMER.NHEADS = 8
+cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
+cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
+cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
+cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
+
+cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
+cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
+cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
+cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
+cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
+cfg.MODEL.MASK_FORMER.PRE_NORM = False
+
+
+maskformer_head = MaskFormerHead(cfg, input_shape)
+
+
+model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head)
+
+
+
+checkpoint = torch.load(weights, map_location='cpu')
+
+model.load_state_dict(checkpoint['model'], strict=False)
+model.to(device)
+model.eval()
+#single_bert_model.load_state_dict(checkpoint['bert_model'])
+#single_model.load_state_dict(checkpoint['model'])
+#model = single_model.to(device)
+#bert_model = single_bert_model.to(device)
+
+
+# inference
+#import torch.nn.functional as F
+#last_hidden_states = bert_model(padded_sent_toks, attention_mask=attention_mask)[0]
+#embedding = last_hidden_states.permute(0, 2, 1)
+#output = model(img, embedding, l_mask=attention_mask.unsqueeze(-1))
+#output = output.argmax(1, keepdim=True) # (1, 1, 480, 480)
+#output = F.interpolate(output.float(), (original_h, original_w)) # 'nearest'; resize to the original image size
+#output = output.squeeze() # (orig_h, orig_w)
+#output = output.cpu().data.numpy() # (orig_h, orig_w)
+
+#output = pred_masks[0]
+
+#output = output.cpu()
+
+
+
+#print(output.shape)
+#output_mask = output.argmax(1).data.numpy()
+#output = (output > 0.5).data.cpu().numpy()
+
+
+# show/save results
+def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 0, 0]], cscale=1, alpha=0.4):
+ from scipy.ndimage.morphology import binary_dilation
+
+ colors = np.reshape(colors, (-1, 3))
+ colors = np.atleast_2d(colors) * cscale
+
+ im_overlay = image.copy()
+ object_ids = np.unique(mask)
+
+ for object_id in object_ids[1:]:
+ # Overlay color on binary mask
+ foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id])
+ binary_mask = mask == object_id
+
+ # Compose image
+ im_overlay[binary_mask] = foreground[binary_mask]
+
+ # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask
+ countours = binary_dilation(binary_mask) ^ binary_mask
+ # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask
+ im_overlay[countours, :] = 0
+
+ return im_overlay.astype(image.dtype)
+
+
+def run_model(img, sentence):
+
+#img = Image.open(image_path).convert("RGB")
+ img = Image.fromarray(img)
+ img = img.convert("RGB")
+ #print(img.shape)
+ img_ndarray = np.array(img) # (orig_h, orig_w, 3); for visualization
+ original_w, original_h = img.size # PIL .size returns width first and height second
+
+ image_transforms = T.Compose(
+ [
+ T.Resize((480, 480)),
+ T.ToTensor(),
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ]
+ )
+
+ img = image_transforms(img).unsqueeze(0) # (1, 3, 480, 480)
+ img = img.to(device) # for inference (input)
+
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ sentence_tokenized = tokenizer.encode(text=sentence, add_special_tokens=True)
+ sentence_tokenized = sentence_tokenized[:20] # if the sentence is longer than 20, then this truncates it to 20 words
+ # pad the tokenized sentence
+ padded_sent_toks = [0] * 20
+ padded_sent_toks[:len(sentence_tokenized)] = sentence_tokenized
+ # create a sentence token mask: 1 for real words; 0 for padded tokens
+ attention_mask = [0] * 20
+ attention_mask[:len(sentence_tokenized)] = [1]*len(sentence_tokenized)
+ # convert lists to tensors
+ padded_sent_toks = torch.tensor(padded_sent_toks).unsqueeze(0) # (1, 20)
+ attention_mask = torch.tensor(attention_mask).unsqueeze(0) # (1, 20)
+ padded_sent_toks = padded_sent_toks.to(device) # for inference (input)
+ attention_mask = attention_mask.to(device) # for inference (input)
+
+ output = model(img, padded_sent_toks, attention_mask)[0]
+ #print(output[0].keys())
+ #print(output[1].shape)
+ mask_cls_results = output["pred_logits"]
+ mask_pred_results = output["pred_masks"]
+
+ target_shape = img_ndarray.shape[:2]
+ #print(target_shape, mask_pred_results.shape)
+ mask_pred_results = F.interpolate(mask_pred_results, size=(480,480), mode='bilinear', align_corners=True)
+
+ pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)
+
+ output = torch.nn.functional.interpolate(pred_masks, target_shape)
+ output = (output > 0.5).data.cpu().numpy()
+
+ output = output.astype(np.uint8) # (orig_h, orig_w), np.uint8
+ # Overlay the mask on the image
+ print(img_ndarray.shape, output.shape)
+ visualization = overlay_davis(img_ndarray, output[0][0]) # red
+ visualization = Image.fromarray(visualization)
+ # show the visualization
+ #visualization.show()
+ # Save the visualization
+ #visualization.save('./demo/spoon_on_the_dish.jpg')
+ return visualization
+
+
+
+
+demo = gr.Interface(run_model, inputs=[gr.Image(), "text"], outputs=["image"])
+demo.launch()
+
diff --git a/elia/args.py b/elia/args.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6f8480b92afc4c61e3c45a79b3d332d9c53226a
--- /dev/null
+++ b/elia/args.py
@@ -0,0 +1,74 @@
+import argparse
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description='LAVT training and testing')
+ parser.add_argument('--amsgrad', action='store_true',
+ help='if true, set amsgrad to True in an Adam or AdamW optimizer.')
+ parser.add_argument('-b', '--batch-size', default=8, type=int)
+ parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer')
+ parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights')
+ parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog')
+ parser.add_argument('--ddp_trained_weights', action='store_true',
+ help='Only needs specified when testing,'
+ 'whether the weights to be loaded are from a DDP-trained model')
+ parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine
+ parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
+ parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs')
+ parser.add_argument('--img_size', default=480, type=int, help='input image size')
+ parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel')
+ parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate')
+ parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,'
+ 'where a, b, c, and d refer to the numbers of heads in stage-1,'
+ 'stage-2, stage-3, and stage-4 PWAMs')
+ parser.add_argument('--model', default='lavt', help='model: lavt, lavt_one')
+ parser.add_argument('--model_id', default='lavt', help='name to identify the model')
+ parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights')
+ parser.add_argument('--pin_mem', action='store_true',
+ help='If true, pin memory when using the data loader.')
+ parser.add_argument('--pretrained_swin_weights', default='',
+ help='path to pre-trained Swin backbone weights')
+ parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
+ parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory')
+ parser.add_argument('--resume', default='auto', help='resume from checkpoint')
+ parser.add_argument('--split', default='test', help='only used when testing')
+ parser.add_argument('--splitBy', default='unc', help='change to umd or google when the dataset is G-Ref (RefCOCOg)')
+ parser.add_argument('--swin_type', default='base',
+ help='tiny, small, base, or large variants of the Swin Transformer')
+ parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay',
+ dest='weight_decay')
+ parser.add_argument('--window12', action='store_true',
+ help='only needs specified when testing,'
+ 'when training, window size is inferred from pre-trained weights file name'
+ '(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.')
+ parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers')
+ parser.add_argument('--seed', default=0, type=int)
+ parser.add_argument('--max_ckpt', default=2, type=int)
+ parser.add_argument('--num_object_queries', default=1, type=int)
+ parser.add_argument('--no_object_weight', default=0.0, type=float)
+ parser.add_argument('--class_weight', default=2.0, type=float)
+ parser.add_argument('--dice_weight', default=2.0, type=float)
+ parser.add_argument('--mask_weight', default=2.0, type=float)
+ parser.add_argument('--train_num_points', default=12544, type=int)
+ parser.add_argument('--dim_feedforward', default=2048, type=int)
+ parser.add_argument('--dec_layers', default=10, type=int)
+ parser.add_argument('--transformer_enc_layers', default=4, type=int)
+
+ parser.add_argument('--plic_pos_weight', default=0.5, type=float)
+ parser.add_argument('--plic_neg_weight', default=0.5, type=float)
+ parser.add_argument('--plic_lang_weight', default=0.5, type=float)
+ parser.add_argument('--plic_pos_alpha', default=0.0, type=float)
+ parser.add_argument('--plic_neg_alpha', default=0.0, type=float)
+ parser.add_argument('--plic_lang_alpha', default=0.0, type=float)
+ parser.add_argument('--plic_pos_temp', default=0.2, type=float)
+ parser.add_argument('--plic_neg_temp', default=0.2, type=float)
+ parser.add_argument('--plic_lang_temp', default=0.2, type=float)
+ parser.add_argument('--smlm_weight', default=1.0, type=float)
+ parser.add_argument('--vis_dir', default='./vis_dir')
+
+ return parser
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args_dict = parser.parse_args()
diff --git a/elia/bert/__pycache__/activations.cpython-37.pyc b/elia/bert/__pycache__/activations.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf79bfc969d2dfa652d29321c45aec02bc8ed301
Binary files /dev/null and b/elia/bert/__pycache__/activations.cpython-37.pyc differ
diff --git a/elia/bert/__pycache__/activations.cpython-38.pyc b/elia/bert/__pycache__/activations.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d3804e94dab5c36f8cdda958065b4e32dc4aabd
Binary files /dev/null and b/elia/bert/__pycache__/activations.cpython-38.pyc differ
diff --git a/elia/bert/__pycache__/configuration_bert.cpython-37.pyc b/elia/bert/__pycache__/configuration_bert.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8084771ac9831c916fa953b7b10017435ea14d0
Binary files /dev/null and b/elia/bert/__pycache__/configuration_bert.cpython-37.pyc differ
diff --git a/elia/bert/__pycache__/configuration_bert.cpython-38.pyc b/elia/bert/__pycache__/configuration_bert.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9eb473e317fe8fb5c667a3dbe69bace5c29e8986
Binary files /dev/null and b/elia/bert/__pycache__/configuration_bert.cpython-38.pyc differ
diff --git a/elia/bert/__pycache__/configuration_utils.cpython-37.pyc b/elia/bert/__pycache__/configuration_utils.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea208dae8700e64e59cc156989eabbf2e528dc80
Binary files /dev/null and b/elia/bert/__pycache__/configuration_utils.cpython-37.pyc differ
diff --git a/elia/bert/__pycache__/configuration_utils.cpython-38.pyc b/elia/bert/__pycache__/configuration_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eefb7f8062db85c760031b0999814cf84c2f992b
Binary files /dev/null and b/elia/bert/__pycache__/configuration_utils.cpython-38.pyc differ
diff --git a/elia/bert/__pycache__/file_utils.cpython-37.pyc b/elia/bert/__pycache__/file_utils.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48ad2774a20d2634cf7e3541436cc4a40aca9ca5
Binary files /dev/null and b/elia/bert/__pycache__/file_utils.cpython-37.pyc differ
diff --git a/elia/bert/__pycache__/file_utils.cpython-38.pyc b/elia/bert/__pycache__/file_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29ede2127f22edebb5409a5bfcbec143227b7d01
Binary files /dev/null and b/elia/bert/__pycache__/file_utils.cpython-38.pyc differ
diff --git a/elia/bert/__pycache__/generation_utils.cpython-37.pyc b/elia/bert/__pycache__/generation_utils.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d79d8c2dd616ce91f49c643e29b40e084012bc28
Binary files /dev/null and b/elia/bert/__pycache__/generation_utils.cpython-37.pyc differ
diff --git a/elia/bert/__pycache__/generation_utils.cpython-38.pyc b/elia/bert/__pycache__/generation_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c36afeb132be1ff8b5cca362f0911bf211e1b1e3
Binary files /dev/null and b/elia/bert/__pycache__/generation_utils.cpython-38.pyc differ
diff --git a/elia/bert/__pycache__/modeling_bert.cpython-37.pyc b/elia/bert/__pycache__/modeling_bert.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4114e8086ba2f24b8786db255b4d7c0391b372c
Binary files /dev/null and b/elia/bert/__pycache__/modeling_bert.cpython-37.pyc differ
diff --git a/elia/bert/__pycache__/modeling_bert.cpython-38.pyc b/elia/bert/__pycache__/modeling_bert.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e4f837799032e0940315dd0506de477acb1acf1
Binary files /dev/null and b/elia/bert/__pycache__/modeling_bert.cpython-38.pyc differ
diff --git a/elia/bert/__pycache__/modeling_utils.cpython-37.pyc b/elia/bert/__pycache__/modeling_utils.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c77d5fa351f765bffe538cc8fe462ec1f119a704
Binary files /dev/null and b/elia/bert/__pycache__/modeling_utils.cpython-37.pyc differ
diff --git a/elia/bert/__pycache__/modeling_utils.cpython-38.pyc b/elia/bert/__pycache__/modeling_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52806021ea6d36cc4fad71c04b02165b2e14f4e7
Binary files /dev/null and b/elia/bert/__pycache__/modeling_utils.cpython-38.pyc differ
diff --git a/elia/bert/__pycache__/multimodal_bert.cpython-37.pyc b/elia/bert/__pycache__/multimodal_bert.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d286f1ac28e65d2c220c49e2cc53f8283ef8e87
Binary files /dev/null and b/elia/bert/__pycache__/multimodal_bert.cpython-37.pyc differ
diff --git a/elia/bert/__pycache__/multimodal_bert.cpython-38.pyc b/elia/bert/__pycache__/multimodal_bert.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0919175640605c3d91aa56da62a3a916bca08648
Binary files /dev/null and b/elia/bert/__pycache__/multimodal_bert.cpython-38.pyc differ
diff --git a/elia/bert/__pycache__/tokenization_bert.cpython-37.pyc b/elia/bert/__pycache__/tokenization_bert.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..047e5c66121aba964c5d9ecd7c5d3079213f1055
Binary files /dev/null and b/elia/bert/__pycache__/tokenization_bert.cpython-37.pyc differ
diff --git a/elia/bert/__pycache__/tokenization_bert.cpython-38.pyc b/elia/bert/__pycache__/tokenization_bert.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46e86bdfb72fd99e3a12e8304d1a87da8874af31
Binary files /dev/null and b/elia/bert/__pycache__/tokenization_bert.cpython-38.pyc differ
diff --git a/elia/bert/__pycache__/tokenization_utils.cpython-37.pyc b/elia/bert/__pycache__/tokenization_utils.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f93a01807dbfcf5f158ea2d18ec7b1a3cf156ff
Binary files /dev/null and b/elia/bert/__pycache__/tokenization_utils.cpython-37.pyc differ
diff --git a/elia/bert/__pycache__/tokenization_utils.cpython-38.pyc b/elia/bert/__pycache__/tokenization_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0c7b8e10c8c808b0d8466184f2107f6efa0261c
Binary files /dev/null and b/elia/bert/__pycache__/tokenization_utils.cpython-38.pyc differ
diff --git a/elia/bert/__pycache__/tokenization_utils_base.cpython-37.pyc b/elia/bert/__pycache__/tokenization_utils_base.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e583920001390bd94d4c8cbd447cc601e15d9846
Binary files /dev/null and b/elia/bert/__pycache__/tokenization_utils_base.cpython-37.pyc differ
diff --git a/elia/bert/__pycache__/tokenization_utils_base.cpython-38.pyc b/elia/bert/__pycache__/tokenization_utils_base.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9d46a9416d56f3d4ed2ae4473434002b42eb68c
Binary files /dev/null and b/elia/bert/__pycache__/tokenization_utils_base.cpython-38.pyc differ
diff --git a/elia/bert/activations.py b/elia/bert/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a1206ee285ce3f0484d129711a2d684700a20a1
--- /dev/null
+++ b/elia/bert/activations.py
@@ -0,0 +1,56 @@
+import logging
+import math
+
+import torch
+import torch.nn.functional as F
+
+
+logger = logging.getLogger(__name__)
+
+
+def swish(x):
+ return x * torch.sigmoid(x)
+
+
+def _gelu_python(x):
+ """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
+ This is now written in C in torch.nn.functional
+ Also see https://arxiv.org/abs/1606.08415
+ """
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
+
+
+def gelu_new(x):
+ """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
+ Also see https://arxiv.org/abs/1606.08415
+ """
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+
+
+if torch.__version__ < "1.4.0":
+ gelu = _gelu_python
+else:
+ gelu = F.gelu
+
+
+def gelu_fast(x):
+ return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
+
+
+ACT2FN = {
+ "relu": F.relu,
+ "swish": swish,
+ "gelu": gelu,
+ "tanh": torch.tanh,
+ "gelu_new": gelu_new,
+ "gelu_fast": gelu_fast,
+}
+
+
+def get_activation(activation_string):
+ if activation_string in ACT2FN:
+ return ACT2FN[activation_string]
+ else:
+ raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
diff --git a/elia/bert/configuration_bert.py b/elia/bert/configuration_bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e815837bc4dbc5fc8eec7ee37547b5d41519af5
--- /dev/null
+++ b/elia/bert/configuration_bert.py
@@ -0,0 +1,143 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" BERT model configuration """
+
+
+import logging
+
+from .configuration_utils import PretrainedConfig
+
+
+logger = logging.getLogger(__name__)
+
+BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
+ "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
+ "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
+ "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
+ "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
+ "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
+ "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
+ "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
+ "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
+ "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
+ "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
+ "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
+ "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
+ "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
+ "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
+ "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
+ "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
+ # See all BERT models at https://huggingface.co/models?filter=bert
+}
+
+
+class BertConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a :class:`~transformers.BertModel`.
+ It is used to instantiate an BERT model according to the specified arguments, defining the model
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
+ the BERT `bert-base-uncased `__ architecture.
+
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
+ to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
+ for more information.
+
+
+ Args:
+ vocab_size (:obj:`int`, optional, defaults to 30522):
+ Vocabulary size of the BERT model. Defines the different tokens that
+ can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
+ hidden_size (:obj:`int`, optional, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (:obj:`int`, optional, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (:obj:`int`, optional, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (:obj:`int`, optional, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
+ The non-linear activation function (function or string) in the encoder and pooler.
+ If string, "gelu", "relu", "swish" and "gelu_new" are supported.
+ hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (:obj:`int`, optional, defaults to 512):
+ The maximum sequence length that this model might ever be used with.
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (:obj:`int`, optional, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
+ initializer_range (:obj:`float`, optional, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ gradient_checkpointing (:obj:`bool`, optional, defaults to False):
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
+
+ Example::
+
+ >>> from transformers import BertModel, BertConfig
+
+ >>> # Initializing a BERT bert-base-uncased style configuration
+ >>> configuration = BertConfig()
+
+ >>> # Initializing a model from the bert-base-uncased style configuration
+ >>> model = BertModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ """
+ model_type = "bert"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ gradient_checkpointing=False,
+ **kwargs
+ ):
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.gradient_checkpointing = gradient_checkpointing
diff --git a/elia/bert/configuration_utils.py b/elia/bert/configuration_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9929ee4c19dab9a88bb51e0281d220a8456c2fce
--- /dev/null
+++ b/elia/bert/configuration_utils.py
@@ -0,0 +1,408 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Configuration base class and utilities."""
+
+
+import copy
+import json
+import logging
+import os
+from typing import Dict, Tuple
+
+from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
+
+
+logger = logging.getLogger(__name__)
+
+
+class PretrainedConfig(object):
+ r""" Base class for all configuration classes.
+ Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
+
+ Note:
+ A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
+ It only affects the model's configuration.
+
+ Class attributes (overridden by derived classes):
+ - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`.
+
+ Args:
+ finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`):
+ Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
+ num_labels (:obj:`int`, `optional`, defaults to `2`):
+ Number of classes to use when the model is a classification model (sequences/tokens)
+ output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Should the model returns all hidden-states.
+ output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Should the model returns all attentions.
+ torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Is the model used with Torchscript (for PyTorch models).
+ """
+ model_type: str = ""
+
+ def __init__(self, **kwargs):
+ # Attributes with defaults
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
+ self.output_attentions = kwargs.pop("output_attentions", False)
+ self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
+ self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
+ self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
+ self.pruned_heads = kwargs.pop("pruned_heads", {})
+
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
+ self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
+ self.is_decoder = kwargs.pop("is_decoder", False)
+
+ # Parameters for sequence generation
+ self.max_length = kwargs.pop("max_length", 20)
+ self.min_length = kwargs.pop("min_length", 0)
+ self.do_sample = kwargs.pop("do_sample", False)
+ self.early_stopping = kwargs.pop("early_stopping", False)
+ self.num_beams = kwargs.pop("num_beams", 1)
+ self.temperature = kwargs.pop("temperature", 1.0)
+ self.top_k = kwargs.pop("top_k", 50)
+ self.top_p = kwargs.pop("top_p", 1.0)
+ self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
+ self.length_penalty = kwargs.pop("length_penalty", 1.0)
+ self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
+ self.bad_words_ids = kwargs.pop("bad_words_ids", None)
+ self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
+
+ # Fine-tuning task arguments
+ self.architectures = kwargs.pop("architectures", None)
+ self.finetuning_task = kwargs.pop("finetuning_task", None)
+ self.id2label = kwargs.pop("id2label", None)
+ self.label2id = kwargs.pop("label2id", None)
+ if self.id2label is not None:
+ kwargs.pop("num_labels", None)
+ self.id2label = dict((int(key), value) for key, value in self.id2label.items())
+ # Keys are always strings in JSON so convert ids to int here.
+ else:
+ self.num_labels = kwargs.pop("num_labels", 2)
+
+ # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
+ self.prefix = kwargs.pop("prefix", None)
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
+ self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
+
+ # task specific arguments
+ self.task_specific_params = kwargs.pop("task_specific_params", None)
+
+ # TPU arguments
+ self.xla_device = kwargs.pop("xla_device", None)
+
+ # Additional attributes without default values
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error("Can't set {} with value {} for {}".format(key, value, self))
+ raise err
+
+ @property
+ def num_labels(self):
+ return len(self.id2label)
+
+ @num_labels.setter
+ def num_labels(self, num_labels):
+ self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
+ self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
+
+ def save_pretrained(self, save_directory):
+ """
+ Save a configuration object to the directory `save_directory`, so that it
+ can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
+
+ Args:
+ save_directory (:obj:`string`):
+ Directory where the configuration JSON file will be saved.
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory))
+ os.makedirs(save_directory, exist_ok=True)
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
+
+ self.to_json_file(output_config_file, use_diff=True)
+ logger.info("Configuration saved in {}".format(output_config_file))
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
+ r"""
+
+ Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
+
+ Args:
+ pretrained_model_name_or_path (:obj:`string`):
+ either:
+ - a string with the `shortcut name` of a pre-trained model configuration to load from cache or
+ download, e.g.: ``bert-base-uncased``.
+ - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to
+ our S3, e.g.: ``dbmdz/bert-base-german-cased``.
+ - a path to a `directory` containing a configuration file saved using the
+ :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
+ - a path or url to a saved configuration JSON `file`, e.g.:
+ ``./my_model_directory/configuration.json``.
+ cache_dir (:obj:`string`, `optional`):
+ Path to a directory in which a downloaded pre-trained model
+ configuration should be cached if the standard cache should not be used.
+ kwargs (:obj:`Dict[str, any]`, `optional`):
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is
+ controlled by the `return_unused_kwargs` keyword parameter.
+ force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exist.
+ resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
+ proxies (:obj:`Dict`, `optional`):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.:
+ :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.`
+ The proxies are used on each request.
+ return_unused_kwargs: (`optional`) bool:
+ If False, then this function returns just the final configuration object.
+ If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part
+ of kwargs which has not been used to update `config` and is otherwise ignored.
+
+ Returns:
+ :class:`PretrainedConfig`: An instance of a configuration object
+
+ Examples::
+
+ # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
+ # derived class: BertConfig
+ config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
+ config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
+ config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
+ config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
+ assert config.output_attention == True
+ config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
+ foo=False, return_unused_kwargs=True)
+ assert config.output_attention == True
+ assert unused_kwargs == {'foo': False}
+
+ """
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+ return cls.from_dict(config_dict, **kwargs)
+
+ @classmethod
+ def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
+ for instantiating a Config using `from_dict`.
+
+ Parameters:
+ pretrained_model_name_or_path (:obj:`string`):
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+
+ Returns:
+ :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.
+
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+
+ if os.path.isdir(pretrained_model_name_or_path):
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
+ config_file = pretrained_model_name_or_path
+ else:
+ config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
+
+ try:
+ # Load from URL or cache if already cached
+ resolved_config_file = cached_path(
+ config_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ )
+ # Load config dict
+ if resolved_config_file is None:
+ raise EnvironmentError
+ config_dict = cls._dict_from_json_file(resolved_config_file)
+
+ except EnvironmentError:
+ msg = (
+ f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
+ )
+ raise EnvironmentError(msg)
+
+ except json.JSONDecodeError:
+ msg = (
+ "Couldn't reach server at '{}' to download configuration file or "
+ "configuration file is not a valid JSON file. "
+ "Please check network or file content here: {}.".format(config_file, resolved_config_file)
+ )
+ raise EnvironmentError(msg)
+
+ if resolved_config_file == config_file:
+ logger.info("loading configuration file {}".format(config_file))
+ else:
+ logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
+
+ return config_dict, kwargs
+
+ @classmethod
+ def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig":
+ """
+ Constructs a `Config` from a Python dictionary of parameters.
+
+ Args:
+ config_dict (:obj:`Dict[str, any]`):
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved
+ from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict`
+ method.
+ kwargs (:obj:`Dict[str, any]`):
+ Additional parameters from which to initialize the configuration object.
+
+ Returns:
+ :class:`PretrainedConfig`: An instance of a configuration object
+ """
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+ config = cls(**config_dict)
+
+ if hasattr(config, "pruned_heads"):
+ config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
+
+ # Update config with kwargs if needed
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(config, key):
+ setattr(config, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+
+ logger.info("Model config %s", str(config))
+ if return_unused_kwargs:
+ return config, kwargs
+ else:
+ return config
+
+ @classmethod
+ def from_json_file(cls, json_file: str) -> "PretrainedConfig":
+ """
+ Constructs a `Config` from the path to a json file of parameters.
+
+ Args:
+ json_file (:obj:`string`):
+ Path to the JSON file containing the parameters.
+
+ Returns:
+ :class:`PretrainedConfig`: An instance of a configuration object
+
+ """
+ config_dict = cls._dict_from_json_file(json_file)
+ return cls(**config_dict)
+
+ @classmethod
+ def _dict_from_json_file(cls, json_file: str):
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ return json.loads(text)
+
+ def __eq__(self, other):
+ return self.__dict__ == other.__dict__
+
+ def __repr__(self):
+ return "{} {}".format(self.__class__.__name__, self.to_json_string())
+
+ def to_diff_dict(self):
+ """
+ Removes all attributes from config which correspond to the default
+ config attributes for better readability and serializes to a Python
+ dictionary.
+
+ Returns:
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ config_dict = self.to_dict()
+
+ # get the default config dict
+ default_config_dict = PretrainedConfig().to_dict()
+
+ serializable_config_dict = {}
+
+ # only serialize values that differ from the default config
+ for key, value in config_dict.items():
+ if key not in default_config_dict or value != default_config_dict[key]:
+ serializable_config_dict[key] = value
+
+ return serializable_config_dict
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary.
+
+ Returns:
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ if hasattr(self.__class__, "model_type"):
+ output["model_type"] = self.__class__.model_type
+ return output
+
+ def to_json_string(self, use_diff=True):
+ """
+ Serializes this instance to a JSON string.
+
+ Args:
+ use_diff (:obj:`bool`):
+ If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.
+
+ Returns:
+ :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
+ """
+ if use_diff is True:
+ config_dict = self.to_diff_dict()
+ else:
+ config_dict = self.to_dict()
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path, use_diff=True):
+ """
+ Save this instance to a json file.
+
+ Args:
+ json_file_path (:obj:`string`):
+ Path to the JSON file in which this configuration instance's parameters will be saved.
+ use_diff (:obj:`bool`):
+ If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string(use_diff=use_diff))
+
+ def update(self, config_dict: Dict):
+ """
+ Updates attributes of this class
+ with attributes from `config_dict`.
+
+ Args:
+ :obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class.
+ """
+ for key, value in config_dict.items():
+ setattr(self, key, value)
diff --git a/elia/bert/file_utils.py b/elia/bert/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..81b76b7fefd186d540fda1014dd69724049a4483
--- /dev/null
+++ b/elia/bert/file_utils.py
@@ -0,0 +1,808 @@
+"""
+Utilities for working with the local dataset cache.
+This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
+Copyright by the AllenNLP authors.
+"""
+
+import fnmatch
+import json
+import logging
+import os
+import shutil
+import sys
+import tarfile
+import tempfile
+from contextlib import contextmanager
+from functools import partial, wraps
+from hashlib import sha256
+from pathlib import Path
+from typing import Dict, Optional, Union
+from urllib.parse import urlparse
+from zipfile import ZipFile, is_zipfile
+
+import requests
+from filelock import FileLock
+from tqdm.auto import tqdm
+
+#from . import __version__
+__version__ = "3.0.2"
+
+logger = logging.getLogger(__name__) # pylint: disable=invalid-name
+
+try:
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
+ if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
+ import torch
+
+ _torch_available = True # pylint: disable=invalid-name
+ logger.info("PyTorch version {} available.".format(torch.__version__))
+ else:
+ logger.info("Disabling PyTorch because USE_TF is set")
+ _torch_available = False
+except ImportError:
+ _torch_available = False # pylint: disable=invalid-name
+
+try:
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
+
+ if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
+ import tensorflow as tf
+
+ assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
+ _tf_available = True # pylint: disable=invalid-name
+ logger.info("TensorFlow version {} available.".format(tf.__version__))
+ else:
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
+ _tf_available = False
+except (ImportError, AssertionError):
+ _tf_available = False # pylint: disable=invalid-name
+
+
+try:
+ from torch.hub import _get_torch_home
+
+ torch_cache_home = _get_torch_home()
+except ImportError:
+ torch_cache_home = os.path.expanduser(
+ os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
+ )
+
+
+try:
+ import torch_xla.core.xla_model as xm # noqa: F401
+
+ if _torch_available:
+ _torch_tpu_available = True # pylint: disable=
+ else:
+ _torch_tpu_available = False
+except ImportError:
+ _torch_tpu_available = False
+
+
+try:
+ import psutil # noqa: F401
+
+ _psutil_available = True
+
+except ImportError:
+ _psutil_available = False
+
+
+try:
+ import py3nvml # noqa: F401
+
+ _py3nvml_available = True
+
+except ImportError:
+ _py3nvml_available = False
+
+
+try:
+ from apex import amp # noqa: F401
+
+ _has_apex = True
+except ImportError:
+ _has_apex = False
+
+default_cache_path = os.path.join(torch_cache_home, "transformers")
+
+
+PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
+PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
+TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
+
+WEIGHTS_NAME = "pytorch_model.bin"
+TF2_WEIGHTS_NAME = "tf_model.h5"
+TF_WEIGHTS_NAME = "model.ckpt"
+CONFIG_NAME = "config.json"
+MODEL_CARD_NAME = "modelcard.json"
+
+
+MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]]
+DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
+DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
+
+S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
+CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
+
+
+def is_torch_available():
+ return _torch_available
+
+
+def is_tf_available():
+ return _tf_available
+
+
+def is_torch_tpu_available():
+ return _torch_tpu_available
+
+
+def is_psutil_available():
+ return _psutil_available
+
+
+def is_py3nvml_available():
+ return _py3nvml_available
+
+
+def is_apex_available():
+ return _has_apex
+
+
+def add_start_docstrings(*docstr):
+ def docstring_decorator(fn):
+ fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
+ return fn
+
+ return docstring_decorator
+
+
+def add_start_docstrings_to_callable(*docstr):
+ def docstring_decorator(fn):
+ class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0])
+ intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name)
+ note = r"""
+
+ .. note::
+ Although the recipe for forward pass needs to be defined within
+ this function, one should call the :class:`Module` instance afterwards
+ instead of this since the former takes care of running the
+ pre and post processing steps while the latter silently ignores them.
+ """
+ fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
+ return fn
+
+ return docstring_decorator
+
+
+def add_end_docstrings(*docstr):
+ def docstring_decorator(fn):
+ fn.__doc__ = fn.__doc__ + "".join(docstr)
+ return fn
+
+ return docstring_decorator
+
+
+PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
+ Example::
+
+ >>> from transformers import {tokenizer_class}, {model_class}
+ >>> import torch
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1
+
+ >>> outputs = model(**inputs, labels=labels)
+ >>> loss, scores = outputs[:2]
+"""
+
+PT_QUESTION_ANSWERING_SAMPLE = r"""
+ Example::
+
+ >>> from transformers import {tokenizer_class}, {model_class}
+ >>> import torch
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> start_positions = torch.tensor([1])
+ >>> end_positions = torch.tensor([3])
+
+ >>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
+ >>> loss, start_scores, end_scores = outputs[:3]
+"""
+
+PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
+ Example::
+
+ >>> from transformers import {tokenizer_class}, {model_class}
+ >>> import torch
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
+ >>> outputs = model(**inputs, labels=labels)
+ >>> loss, logits = outputs[:2]
+"""
+
+PT_MASKED_LM_SAMPLE = r"""
+ Example::
+
+ >>> from transformers import {tokenizer_class}, {model_class}
+ >>> import torch
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
+
+ >>> outputs = model(input_ids, labels=input_ids)
+ >>> loss, prediction_scores = outputs[:2]
+"""
+
+PT_BASE_MODEL_SAMPLE = r"""
+ Example::
+
+ >>> from transformers import {tokenizer_class}, {model_class}
+ >>> import torch
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
+"""
+
+PT_MULTIPLE_CHOICE_SAMPLE = r"""
+ Example::
+
+ >>> from transformers import {tokenizer_class}, {model_class}
+ >>> import torch
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> choice0 = "It is eaten with a fork and a knife."
+ >>> choice1 = "It is eaten while held in the hand."
+ >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
+
+ >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)
+ >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1
+
+ >>> # the linear classifier still needs to be trained
+ >>> loss, logits = outputs[:2]
+"""
+
+PT_CAUSAL_LM_SAMPLE = r"""
+ Example::
+
+ >>> import torch
+ >>> from transformers import {tokenizer_class}, {model_class}
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs, labels=inputs["input_ids"])
+ >>> loss, logits = outputs[:2]
+"""
+
+TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
+ Example::
+
+ >>> from transformers import {tokenizer_class}, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
+ >>> input_ids = inputs["input_ids"]
+ >>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
+
+ >>> outputs = model(inputs)
+ >>> loss, scores = outputs[:2]
+"""
+
+TF_QUESTION_ANSWERING_SAMPLE = r"""
+ Example::
+
+ >>> from transformers import {tokenizer_class}, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
+ >>> input_dict = tokenizer(question, text, return_tensors='tf')
+ >>> start_scores, end_scores = model(input_dict)
+
+ >>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
+ >>> answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
+"""
+
+TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
+ Example::
+
+ >>> from transformers import {tokenizer_class}, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
+ >>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
+
+ >>> outputs = model(inputs)
+ >>> loss, logits = outputs[:2]
+"""
+
+TF_MASKED_LM_SAMPLE = r"""
+ Example::
+ >>> from transformers import {tokenizer_class}, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
+
+ >>> outputs = model(input_ids)
+ >>> prediction_scores = outputs[0]
+"""
+
+TF_BASE_MODEL_SAMPLE = r"""
+ Example::
+
+ >>> from transformers import {tokenizer_class}, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
+ >>> outputs = model(inputs)
+
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
+"""
+
+TF_MULTIPLE_CHOICE_SAMPLE = r"""
+ Example::
+
+ >>> from transformers import {tokenizer_class}, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> choice0 = "It is eaten with a fork and a knife."
+ >>> choice1 = "It is eaten while held in the hand."
+
+ >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True)
+ >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
+ >>> outputs = model(inputs) # batch size is 1
+
+ >>> # the linear classifier still needs to be trained
+ >>> logits = outputs[0]
+"""
+
+TF_CAUSAL_LM_SAMPLE = r"""
+ Example::
+
+ >>> from transformers import {tokenizer_class}, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
+ >>> outputs = model(inputs)
+ >>> logits = outputs[0]
+"""
+
+
+def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
+ def docstring_decorator(fn):
+ model_class = fn.__qualname__.split(".")[0]
+ is_tf_class = model_class[:2] == "TF"
+
+ if "SequenceClassification" in model_class:
+ code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE
+ elif "QuestionAnswering" in model_class:
+ code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE
+ elif "TokenClassification" in model_class:
+ code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE
+ elif "MultipleChoice" in model_class:
+ code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE
+ elif "MaskedLM" in model_class:
+ code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
+ elif "LMHead" in model_class:
+ code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
+ elif "Model" in model_class:
+ code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
+ else:
+ raise ValueError(f"Docstring can't be built for model {model_class}")
+
+ built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
+ fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc
+ return fn
+
+ return docstring_decorator
+
+
+def is_remote_url(url_or_filename):
+ parsed = urlparse(url_or_filename)
+ return parsed.scheme in ("http", "https")
+
+
+def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
+ """
+ Resolve a model identifier, and a file name, to a HF-hosted url
+ on either S3 or Cloudfront (a Content Delivery Network, or CDN).
+
+ Cloudfront is replicated over the globe so downloads are way faster
+ for the end user (and it also lowers our bandwidth costs). However, it
+ is more aggressively cached by default, so may not always reflect the
+ latest changes to the underlying file (default TTL is 24 hours).
+
+ In terms of client-side caching from this library, even though
+ Cloudfront relays the ETags from S3, using one or the other
+ (or switching from one to the other) will affect caching: cached files
+ are not shared between the two because the cached file's name contains
+ a hash of the url.
+ """
+ endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
+ legacy_format = "/" not in model_id
+ if legacy_format:
+ return f"{endpoint}/{model_id}-{filename}"
+ else:
+ return f"{endpoint}/{model_id}/{filename}"
+
+
+def url_to_filename(url, etag=None):
+ """
+ Convert `url` into a hashed filename in a repeatable way.
+ If `etag` is specified, append its hash to the url's, delimited
+ by a period.
+ If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
+ so that TF 2.0 can identify it as a HDF5 file
+ (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
+ """
+ url_bytes = url.encode("utf-8")
+ url_hash = sha256(url_bytes)
+ filename = url_hash.hexdigest()
+
+ if etag:
+ etag_bytes = etag.encode("utf-8")
+ etag_hash = sha256(etag_bytes)
+ filename += "." + etag_hash.hexdigest()
+
+ if url.endswith(".h5"):
+ filename += ".h5"
+
+ return filename
+
+
+def filename_to_url(filename, cache_dir=None):
+ """
+ Return the url and etag (which may be ``None``) stored for `filename`.
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
+ """
+ if cache_dir is None:
+ cache_dir = TRANSFORMERS_CACHE
+ if isinstance(cache_dir, Path):
+ cache_dir = str(cache_dir)
+
+ cache_path = os.path.join(cache_dir, filename)
+ if not os.path.exists(cache_path):
+ raise EnvironmentError("file {} not found".format(cache_path))
+
+ meta_path = cache_path + ".json"
+ if not os.path.exists(meta_path):
+ raise EnvironmentError("file {} not found".format(meta_path))
+
+ with open(meta_path, encoding="utf-8") as meta_file:
+ metadata = json.load(meta_file)
+ url = metadata["url"]
+ etag = metadata["etag"]
+
+ return url, etag
+
+
+def cached_path(
+ url_or_filename,
+ cache_dir=None,
+ force_download=False,
+ proxies=None,
+ resume_download=False,
+ user_agent: Union[Dict, str, None] = None,
+ extract_compressed_file=False,
+ force_extract=False,
+ local_files_only=False,
+) -> Optional[str]:
+ """
+ Given something that might be a URL (or might be a local path),
+ determine which. If it's a URL, download the file and cache it, and
+ return the path to the cached file. If it's already a local path,
+ make sure the file exists and then return the path.
+ Args:
+ cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
+ force_download: if True, re-dowload the file even if it's already cached in the cache dir.
+ resume_download: if True, resume the download if incompletly recieved file is found.
+ user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
+ extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
+ file in a folder along the archive.
+ force_extract: if True when extract_compressed_file is True and the archive was already extracted,
+ re-extract the archive and overide the folder where it was extracted.
+
+ Return:
+ None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
+ Local path (string) otherwise
+ """
+ if cache_dir is None:
+ cache_dir = TRANSFORMERS_CACHE
+ if isinstance(url_or_filename, Path):
+ url_or_filename = str(url_or_filename)
+ if isinstance(cache_dir, Path):
+ cache_dir = str(cache_dir)
+
+ if is_remote_url(url_or_filename):
+ # URL, so get it from the cache (downloading if necessary)
+ output_path = get_from_cache(
+ url_or_filename,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ user_agent=user_agent,
+ local_files_only=local_files_only,
+ )
+ elif os.path.exists(url_or_filename):
+ # File, and it exists.
+ output_path = url_or_filename
+ elif urlparse(url_or_filename).scheme == "":
+ # File, but it doesn't exist.
+ raise EnvironmentError("file {} not found".format(url_or_filename))
+ else:
+ # Something unknown
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
+
+ if extract_compressed_file:
+ if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
+ return output_path
+
+ # Path where we extract compressed archives
+ # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
+ output_dir, output_file = os.path.split(output_path)
+ output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
+ output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
+
+ if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
+ return output_path_extracted
+
+ # Prevent parallel extractions
+ lock_path = output_path + ".lock"
+ with FileLock(lock_path):
+ shutil.rmtree(output_path_extracted, ignore_errors=True)
+ os.makedirs(output_path_extracted)
+ if is_zipfile(output_path):
+ with ZipFile(output_path, "r") as zip_file:
+ zip_file.extractall(output_path_extracted)
+ zip_file.close()
+ elif tarfile.is_tarfile(output_path):
+ tar_file = tarfile.open(output_path)
+ tar_file.extractall(output_path_extracted)
+ tar_file.close()
+ else:
+ raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
+
+ return output_path_extracted
+
+ return output_path
+
+
+def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
+ ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
+ if is_torch_available():
+ ua += "; torch/{}".format(torch.__version__)
+ if is_tf_available():
+ ua += "; tensorflow/{}".format(tf.__version__)
+ if isinstance(user_agent, dict):
+ ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
+ elif isinstance(user_agent, str):
+ ua += "; " + user_agent
+ headers = {"user-agent": ua}
+ if resume_size > 0:
+ headers["Range"] = "bytes=%d-" % (resume_size,)
+ response = requests.get(url, stream=True, proxies=proxies, headers=headers)
+ if response.status_code == 416: # Range not satisfiable
+ return
+ content_length = response.headers.get("Content-Length")
+ total = resume_size + int(content_length) if content_length is not None else None
+ progress = tqdm(
+ unit="B",
+ unit_scale=True,
+ total=total,
+ initial=resume_size,
+ desc="Downloading",
+ disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
+ )
+ for chunk in response.iter_content(chunk_size=1024):
+ if chunk: # filter out keep-alive new chunks
+ progress.update(len(chunk))
+ temp_file.write(chunk)
+ progress.close()
+
+
+def get_from_cache(
+ url,
+ cache_dir=None,
+ force_download=False,
+ proxies=None,
+ etag_timeout=10,
+ resume_download=False,
+ user_agent: Union[Dict, str, None] = None,
+ local_files_only=False,
+) -> Optional[str]:
+ """
+ Given a URL, look for the corresponding file in the local cache.
+ If it's not there, download it. Then return the path to the cached file.
+
+ Return:
+ None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
+ Local path (string) otherwise
+ """
+ if cache_dir is None:
+ cache_dir = TRANSFORMERS_CACHE
+ if isinstance(cache_dir, Path):
+ cache_dir = str(cache_dir)
+
+ os.makedirs(cache_dir, exist_ok=True)
+
+ etag = None
+ if not local_files_only:
+ try:
+ response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
+ if response.status_code == 200:
+ etag = response.headers.get("ETag")
+ except (EnvironmentError, requests.exceptions.Timeout):
+ # etag is already None
+ pass
+
+ filename = url_to_filename(url, etag)
+
+ # get cache path to put the file
+ cache_path = os.path.join(cache_dir, filename)
+
+ # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
+ # try to get the last downloaded one
+ if etag is None:
+ if os.path.exists(cache_path):
+ return cache_path
+ else:
+ matching_files = [
+ file
+ for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
+ if not file.endswith(".json") and not file.endswith(".lock")
+ ]
+ if len(matching_files) > 0:
+ return os.path.join(cache_dir, matching_files[-1])
+ else:
+ # If files cannot be found and local_files_only=True,
+ # the models might've been found if local_files_only=False
+ # Notify the user about that
+ if local_files_only:
+ raise ValueError(
+ "Cannot find the requested files in the cached path and outgoing traffic has been"
+ " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
+ " to False."
+ )
+ return None
+
+ # From now on, etag is not None.
+ if os.path.exists(cache_path) and not force_download:
+ return cache_path
+
+ # Prevent parallel downloads of the same file with a lock.
+ lock_path = cache_path + ".lock"
+ with FileLock(lock_path):
+
+ # If the download just completed while the lock was activated.
+ if os.path.exists(cache_path) and not force_download:
+ # Even if returning early like here, the lock will be released.
+ return cache_path
+
+ if resume_download:
+ incomplete_path = cache_path + ".incomplete"
+
+ @contextmanager
+ def _resumable_file_manager():
+ with open(incomplete_path, "a+b") as f:
+ yield f
+
+ temp_file_manager = _resumable_file_manager
+ if os.path.exists(incomplete_path):
+ resume_size = os.stat(incomplete_path).st_size
+ else:
+ resume_size = 0
+ else:
+ temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
+ resume_size = 0
+
+ # Download to temporary file, then copy to cache dir once finished.
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
+ with temp_file_manager() as temp_file:
+ logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
+
+ http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
+
+ logger.info("storing %s in cache at %s", url, cache_path)
+ os.replace(temp_file.name, cache_path)
+
+ logger.info("creating metadata file for %s", cache_path)
+ meta = {"url": url, "etag": etag}
+ meta_path = cache_path + ".json"
+ with open(meta_path, "w") as meta_file:
+ json.dump(meta, meta_file)
+
+ return cache_path
+
+
+class cached_property(property):
+ """
+ Descriptor that mimics @property but caches output in member variable.
+
+ From tensorflow_datasets
+
+ Built-in in functools from Python 3.8.
+ """
+
+ def __get__(self, obj, objtype=None):
+ # See docs.python.org/3/howto/descriptor.html#properties
+ if obj is None:
+ return self
+ if self.fget is None:
+ raise AttributeError("unreadable attribute")
+ attr = "__cached_" + self.fget.__name__
+ cached = getattr(obj, attr, None)
+ if cached is None:
+ cached = self.fget(obj)
+ setattr(obj, attr, cached)
+ return cached
+
+
+def torch_required(func):
+ # Chose a different decorator name than in tests so it's clear they are not the same.
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if is_torch_available():
+ return func(*args, **kwargs)
+ else:
+ raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
+
+ return wrapper
+
+
+def tf_required(func):
+ # Chose a different decorator name than in tests so it's clear they are not the same.
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if is_tf_available():
+ return func(*args, **kwargs)
+ else:
+ raise ImportError(f"Method `{func.__name__}` requires TF.")
+
+ return wrapper
diff --git a/elia/bert/generation_utils.py b/elia/bert/generation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c49e15abf7bc822940112207f10e18e2e0230cc
--- /dev/null
+++ b/elia/bert/generation_utils.py
@@ -0,0 +1,993 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Iterable, Optional, Tuple
+
+import torch
+from torch import Tensor
+from torch.nn import functional as F
+
+
+logger = logging.getLogger(__name__)
+
+
+class GenerationMixin:
+ """
+ A class contraining all of the functions supporting generation, to be used as a mixin in PreTrainedModel.
+ """
+
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
+ return {"input_ids": input_ids}
+
+ def adjust_logits_during_generation(self, logits, **kwargs):
+ return logits
+
+ def _use_cache(self, outputs, use_cache):
+ """During generation, decide whether to pass the `past` variable to the next forward pass."""
+ if len(outputs) <= 1 or use_cache is False:
+ return False
+ if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
+ return False
+ return True
+
+ def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
+ """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
+ for i in range(batch_size * num_beams):
+ for previous_token in set(prev_output_tokens[i].tolist()):
+ # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
+ if lprobs[i, previous_token] < 0:
+ lprobs[i, previous_token] *= repetition_penalty
+ else:
+ lprobs[i, previous_token] /= repetition_penalty
+
+ def postprocess_next_token_scores(
+ self,
+ scores,
+ input_ids,
+ no_repeat_ngram_size,
+ bad_words_ids,
+ cur_len,
+ min_length,
+ max_length,
+ eos_token_id,
+ repetition_penalty,
+ batch_size,
+ num_beams,
+ ):
+ # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
+ if repetition_penalty != 1.0:
+ self.enforce_repetition_penalty_(
+ scores, batch_size, num_beams, input_ids, repetition_penalty,
+ )
+
+ # set eos token prob to zero if min_length is not reached
+ if eos_token_id is not None and cur_len < min_length:
+ scores[:, eos_token_id] = -float("inf")
+
+ if no_repeat_ngram_size > 0:
+ # calculate a list of banned tokens to prevent repetitively generating the same ngrams
+ num_batch_hypotheses = batch_size * num_beams
+ # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
+ banned_batch_tokens = calc_banned_ngram_tokens(
+ input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
+ )
+ for i, banned_tokens in enumerate(banned_batch_tokens):
+ scores[i, banned_tokens] = -float("inf")
+
+ if bad_words_ids is not None:
+ # calculate a list of banned tokens according to bad words
+ banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
+
+ for i, banned_tokens in enumerate(banned_tokens):
+ scores[i, banned_tokens] = -float("inf")
+
+ return scores
+
+ @torch.no_grad()
+ def generate(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ max_length: Optional[int] = None,
+ min_length: Optional[int] = None,
+ do_sample: Optional[bool] = None,
+ early_stopping: Optional[bool] = None,
+ num_beams: Optional[int] = None,
+ temperature: Optional[float] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ repetition_penalty: Optional[float] = None,
+ bad_words_ids: Optional[Iterable[int]] = None,
+ bos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[int] = None,
+ length_penalty: Optional[float] = None,
+ no_repeat_ngram_size: Optional[int] = None,
+ num_return_sequences: Optional[int] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_start_token_id: Optional[int] = None,
+ use_cache: Optional[bool] = None,
+ **model_specific_kwargs
+ ) -> torch.LongTensor:
+ r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
+
+ Adapted in part from `Facebook's XLM beam search code`_.
+
+ .. _`Facebook's XLM beam search code`:
+ https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
+
+
+ Parameters:
+
+ input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
+ The sequence used as a prompt for the generation. If `None` the method initializes
+ it as an empty `torch.LongTensor` of shape `(1,)`.
+
+ max_length: (`optional`) int
+ The max length of the sequence to be generated. Between `min_length` and infinity. Default to 20.
+
+ min_length: (`optional`) int
+ The min length of the sequence to be generated. Between 0 and infinity. Default to 0.
+
+ do_sample: (`optional`) bool
+ If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
+
+ early_stopping: (`optional`) bool
+ if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
+
+ num_beams: (`optional`) int
+ Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
+
+ temperature: (`optional`) float
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
+
+ top_k: (`optional`) int
+ The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
+
+ top_p: (`optional`) float
+ The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
+
+ repetition_penalty: (`optional`) float
+ The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
+
+ pad_token_id: (`optional`) int
+ Padding token. Default to specicic model pad_token_id or None if it does not exist.
+
+ bos_token_id: (`optional`) int
+ BOS token. Defaults to `bos_token_id` as defined in the models config.
+
+ eos_token_id: (`optional`) int
+ EOS token. Defaults to `eos_token_id` as defined in the models config.
+
+ length_penalty: (`optional`) float
+ Exponential penalty to the length. Default to 1.
+
+ no_repeat_ngram_size: (`optional`) int
+ If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
+ bad_words_ids: (`optional`) list of lists of int
+ `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
+
+ num_return_sequences: (`optional`) int
+ The number of independently computed returned sequences for each element in the batch. Default to 1.
+
+ attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
+ Mask to avoid performing attention on padding token indices.
+ Mask values selected in ``[0, 1]``:
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
+ Defaults to `None`.
+
+ `What are attention masks? <../glossary.html#attention-mask>`__
+
+ decoder_start_token_id=None: (`optional`) int
+ If an encoder-decoder model starts decoding with a different token than BOS.
+ Defaults to `None` and is changed to `BOS` later.
+
+ use_cache: (`optional`) bool
+ If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
+
+ model_specific_kwargs: (`optional`) dict
+ Additional model specific kwargs will be forwarded to the `forward` function of the model.
+
+ Return:
+
+ output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
+ sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
+
+ Examples::
+
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
+ outputs = model.generate(max_length=40) # do greedy decoding
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
+
+ tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
+ model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
+ input_context = 'The dog'
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
+ outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
+ for i in range(3): # 3 output sequences were generated
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
+
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
+ input_context = 'The dog'
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
+ outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
+ for i in range(3): # 3 output sequences were generated
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
+
+ tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
+ model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
+ input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
+ outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
+
+ tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
+ model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
+ input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
+ bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
+ outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
+ """
+
+ # We cannot generate if the model does not have a LM head
+ if self.get_output_embeddings() is None:
+ raise AttributeError(
+ "You tried to generate sequences with a model that does not have a LM Head."
+ "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
+ )
+
+ max_length = max_length if max_length is not None else self.config.max_length
+ min_length = min_length if min_length is not None else self.config.min_length
+ do_sample = do_sample if do_sample is not None else self.config.do_sample
+ early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
+ temperature = temperature if temperature is not None else self.config.temperature
+ top_k = top_k if top_k is not None else self.config.top_k
+ top_p = top_p if top_p is not None else self.config.top_p
+ repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
+ bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
+ length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
+ no_repeat_ngram_size = (
+ no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
+ )
+ bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
+ num_return_sequences = (
+ num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
+ )
+ decoder_start_token_id = (
+ decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
+ )
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0] # overriden by the input batch_size
+ else:
+ batch_size = 1
+
+ assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
+ assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
+ assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
+ assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
+ assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
+ assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
+ assert temperature > 0, "`temperature` should be strictly positive."
+ assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
+ assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
+ assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
+ assert input_ids is not None or (
+ isinstance(bos_token_id, int) and bos_token_id >= 0
+ ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
+ assert pad_token_id is None or (
+ isinstance(pad_token_id, int) and (pad_token_id >= 0)
+ ), "`pad_token_id` should be a positive integer."
+ assert (eos_token_id is None) or (
+ isinstance(eos_token_id, int) and (eos_token_id >= 0)
+ ), "`eos_token_id` should be a positive integer."
+ assert length_penalty > 0, "`length_penalty` should be strictly positive."
+ assert (
+ isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
+ ), "`no_repeat_ngram_size` should be a positive integer."
+ assert (
+ isinstance(num_return_sequences, int) and num_return_sequences > 0
+ ), "`num_return_sequences` should be a strictly positive integer."
+ assert (
+ bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
+ ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
+
+ if input_ids is None:
+ assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
+ "you should either supply a context to complete as `input_ids` input "
+ "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
+ )
+ input_ids = torch.full(
+ (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
+ )
+ else:
+ assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
+
+ # not allow to duplicate outputs when greedy decoding
+ if do_sample is False:
+ if num_beams == 1:
+ # no_beam_search greedy generation conditions
+ assert (
+ num_return_sequences == 1
+ ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
+
+ else:
+ # beam_search greedy generation conditions
+ assert (
+ num_beams >= num_return_sequences
+ ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
+
+ # create attention mask if necessary
+ # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
+ if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
+ attention_mask = input_ids.ne(pad_token_id).long()
+ elif attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+
+ # set pad_token_id to eos_token_id if not set. Important that this is done after
+ # attention_mask is created
+ if pad_token_id is None and eos_token_id is not None:
+ logger.warning(
+ "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
+ )
+ pad_token_id = eos_token_id
+
+ # current position and vocab size
+ if hasattr(self.config, "vocab_size"):
+ vocab_size = self.config.vocab_size
+ elif (
+ self.config.is_encoder_decoder
+ and hasattr(self.config, "decoder")
+ and hasattr(self.config.decoder, "vocab_size")
+ ):
+ vocab_size = self.config.decoder.vocab_size
+
+ # set effective batch size and effective batch multiplier according to do_sample
+ if do_sample:
+ effective_batch_size = batch_size * num_return_sequences
+ effective_batch_mult = num_return_sequences
+ else:
+ effective_batch_size = batch_size
+ effective_batch_mult = 1
+
+ if self.config.is_encoder_decoder:
+ if decoder_start_token_id is None:
+ decoder_start_token_id = bos_token_id
+
+ assert (
+ decoder_start_token_id is not None
+ ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
+ assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
+ assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
+
+ # get encoder and store encoder outputs
+ encoder = self.get_encoder()
+
+ encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
+
+ # Expand input ids if num_beams > 1 or num_return_sequences > 1
+ if num_return_sequences > 1 or num_beams > 1:
+ input_ids_len = input_ids.shape[-1]
+ input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
+ attention_mask = attention_mask.unsqueeze(1).expand(
+ batch_size, effective_batch_mult * num_beams, input_ids_len
+ )
+
+ input_ids = input_ids.contiguous().view(
+ effective_batch_size * num_beams, input_ids_len
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
+ attention_mask = attention_mask.contiguous().view(
+ effective_batch_size * num_beams, input_ids_len
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
+
+ if self.config.is_encoder_decoder:
+ # create empty decoder_input_ids
+ input_ids = torch.full(
+ (effective_batch_size * num_beams, 1),
+ decoder_start_token_id,
+ dtype=torch.long,
+ device=next(self.parameters()).device,
+ )
+ cur_len = 1
+
+ assert (
+ batch_size == encoder_outputs[0].shape[0]
+ ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
+
+ # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
+ expanded_batch_idxs = (
+ torch.arange(batch_size)
+ .view(-1, 1)
+ .repeat(1, num_beams * effective_batch_mult)
+ .view(-1)
+ .to(input_ids.device)
+ )
+ # expand encoder_outputs
+ encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
+
+ else:
+ encoder_outputs = None
+ cur_len = input_ids.shape[-1]
+
+ assert (
+ cur_len < max_length
+ ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
+
+ if num_beams > 1:
+ output = self._generate_beam_search(
+ input_ids,
+ cur_len=cur_len,
+ max_length=max_length,
+ min_length=min_length,
+ do_sample=do_sample,
+ early_stopping=early_stopping,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ no_repeat_ngram_size=no_repeat_ngram_size,
+ bad_words_ids=bad_words_ids,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ batch_size=effective_batch_size,
+ num_return_sequences=num_return_sequences,
+ length_penalty=length_penalty,
+ num_beams=num_beams,
+ vocab_size=vocab_size,
+ encoder_outputs=encoder_outputs,
+ attention_mask=attention_mask,
+ use_cache=use_cache,
+ model_specific_kwargs=model_specific_kwargs,
+ )
+ else:
+ output = self._generate_no_beam_search(
+ input_ids,
+ cur_len=cur_len,
+ max_length=max_length,
+ min_length=min_length,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ no_repeat_ngram_size=no_repeat_ngram_size,
+ bad_words_ids=bad_words_ids,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ batch_size=effective_batch_size,
+ encoder_outputs=encoder_outputs,
+ attention_mask=attention_mask,
+ use_cache=use_cache,
+ model_specific_kwargs=model_specific_kwargs,
+ )
+
+ return output
+
+ def _generate_no_beam_search(
+ self,
+ input_ids,
+ cur_len,
+ max_length,
+ min_length,
+ do_sample,
+ temperature,
+ top_k,
+ top_p,
+ repetition_penalty,
+ no_repeat_ngram_size,
+ bad_words_ids,
+ pad_token_id,
+ eos_token_id,
+ batch_size,
+ encoder_outputs,
+ attention_mask,
+ use_cache,
+ model_specific_kwargs,
+ ):
+ """ Generate sequences for each example without beam search (num_beams == 1).
+ All returned sequence are generated independantly.
+ """
+ # length of generated sentences / unfinished sentences
+ unfinished_sents = input_ids.new(batch_size).fill_(1)
+ sent_lengths = input_ids.new(batch_size).fill_(max_length)
+
+ past = (encoder_outputs, None) if encoder_outputs is not None else None
+
+ while cur_len < max_length:
+ model_inputs = self.prepare_inputs_for_generation(
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
+ )
+
+ outputs = self(**model_inputs)
+ next_token_logits = outputs[0][:, -1, :]
+
+ scores = self.postprocess_next_token_scores(
+ scores=next_token_logits,
+ input_ids=input_ids,
+ no_repeat_ngram_size=no_repeat_ngram_size,
+ bad_words_ids=bad_words_ids,
+ cur_len=cur_len,
+ min_length=min_length,
+ max_length=max_length,
+ eos_token_id=eos_token_id,
+ repetition_penalty=repetition_penalty,
+ batch_size=batch_size,
+ num_beams=1,
+ )
+
+ # if model has past, then set the past variable to speed up decoding
+ if self._use_cache(outputs, use_cache):
+ past = outputs[1]
+
+ if do_sample:
+ # Temperature (higher temperature => more likely to sample low probability tokens)
+ if temperature != 1.0:
+ scores = scores / temperature
+ # Top-p/top-k filtering
+ next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
+ # Sample
+ probs = F.softmax(next_token_logscores, dim=-1)
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ # Greedy decoding
+ next_token = torch.argmax(next_token_logits, dim=-1)
+
+ # update generations and finished sentences
+ if eos_token_id is not None:
+ # pad finished sentences if eos_token_id exist
+ tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
+ else:
+ tokens_to_add = next_token
+
+ # add token and increase length by one
+ input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
+ cur_len = cur_len + 1
+
+ if eos_token_id is not None:
+ eos_in_sents = tokens_to_add == eos_token_id
+ # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
+ is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
+ sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
+ # unfinished_sents is set to zero if eos in sentence
+ unfinished_sents.mul_((~eos_in_sents).long())
+
+ # stop when there is a in each sentence, or if we exceed the maximul length
+ if unfinished_sents.max() == 0:
+ break
+
+ # extend attention_mask for new generated input if only decoder
+ if self.config.is_encoder_decoder is False:
+ attention_mask = torch.cat(
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+ )
+
+ return input_ids
+
+ def _generate_beam_search(
+ self,
+ input_ids,
+ cur_len,
+ max_length,
+ min_length,
+ do_sample,
+ early_stopping,
+ temperature,
+ top_k,
+ top_p,
+ repetition_penalty,
+ no_repeat_ngram_size,
+ bad_words_ids,
+ pad_token_id,
+ eos_token_id,
+ batch_size,
+ num_return_sequences,
+ length_penalty,
+ num_beams,
+ vocab_size,
+ encoder_outputs,
+ attention_mask,
+ use_cache,
+ model_specific_kwargs,
+ ):
+ """ Generate sequences for each example with beam search.
+ """
+
+ # generated hypotheses
+ generated_hyps = [
+ BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
+ for _ in range(batch_size)
+ ]
+
+ # scores for each sentence in the beam
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
+
+ # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
+ if do_sample is False:
+ beam_scores[:, 1:] = -1e9
+ beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
+
+ # cache compute states
+ past = (encoder_outputs, None) if encoder_outputs is not None else None
+
+ # done sentences
+ done = [False for _ in range(batch_size)]
+
+ while cur_len < max_length:
+ model_inputs = self.prepare_inputs_for_generation(
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
+ )
+ outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
+ next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
+
+ # if model has past, then set the past variable to speed up decoding
+ if self._use_cache(outputs, use_cache):
+ past = outputs[1]
+ if self.config.is_encoder_decoder and do_sample is False:
+ # TODO (PVP) still a bit hacky here - there might be a better solution
+ next_token_logits = self.adjust_logits_during_generation(
+ next_token_logits, cur_len=cur_len, max_length=max_length
+ )
+
+ scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
+
+ scores = self.postprocess_next_token_scores(
+ scores=scores,
+ input_ids=input_ids,
+ no_repeat_ngram_size=no_repeat_ngram_size,
+ bad_words_ids=bad_words_ids,
+ cur_len=cur_len,
+ min_length=min_length,
+ max_length=max_length,
+ eos_token_id=eos_token_id,
+ repetition_penalty=repetition_penalty,
+ batch_size=batch_size,
+ num_beams=num_beams,
+ )
+
+ assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
+ scores.shape, (batch_size * num_beams, vocab_size)
+ )
+
+ if do_sample:
+ _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
+ # Temperature
+ if temperature != 1.0:
+ _scores = _scores / temperature
+ # Top-p/top-k filtering
+ _scores = top_k_top_p_filtering(
+ _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
+ ) # (batch_size * num_beams, vocab_size)
+ # re-organize to group the beam together to sample from all beam_idxs
+ _scores = _scores.contiguous().view(
+ batch_size, num_beams * vocab_size
+ ) # (batch_size, num_beams * vocab_size)
+
+ # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
+ probs = F.softmax(_scores, dim=-1)
+ next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
+ # Compute next scores
+ next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
+ # sort the sampled vector to make sure that the first num_beams samples are the best
+ next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
+ next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
+
+ else:
+ next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
+
+ # re-organize to group the beam together (we are keeping top hypothesis accross beams)
+ next_scores = next_scores.view(
+ batch_size, num_beams * vocab_size
+ ) # (batch_size, num_beams * vocab_size)
+
+ next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
+
+ assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
+
+ # next batch beam content
+ next_batch_beam = []
+
+ # for each sentence
+ for batch_idx in range(batch_size):
+
+ # if we are done with this sentence, add a pad token
+ if done[batch_idx]:
+ assert (
+ len(generated_hyps[batch_idx]) >= num_beams
+ ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
+ assert (
+ eos_token_id is not None and pad_token_id is not None
+ ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
+ next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
+ continue
+
+ # next sentence beam content, this will get added to next_batch_beam
+ next_sent_beam = []
+
+ # next tokens for this sentence
+ for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
+ zip(next_tokens[batch_idx], next_scores[batch_idx])
+ ):
+ # get beam and token IDs
+ beam_id = beam_token_id // vocab_size
+ token_id = beam_token_id % vocab_size
+
+ effective_beam_id = batch_idx * num_beams + beam_id
+ # add to generated hypotheses if end of sentence
+ if (eos_token_id is not None) and (token_id.item() == eos_token_id):
+ # if beam_token does not belong to top num_beams tokens, it should not be added
+ is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
+ if is_beam_token_worse_than_top_num_beams:
+ continue
+ generated_hyps[batch_idx].add(
+ input_ids[effective_beam_id].clone(), beam_token_score.item(),
+ )
+ else:
+ # add next predicted token since it is not eos_token
+ next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
+
+ # once the beam for next step is full, don't add more tokens to it.
+ if len(next_sent_beam) == num_beams:
+ break
+
+ # Check if we are done so that we can save a pad step if all(done)
+ done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
+ next_scores[batch_idx].max().item(), cur_len
+ )
+
+ # update next beam content
+ assert len(next_sent_beam) == num_beams, "Beam should always be full"
+ next_batch_beam.extend(next_sent_beam)
+ assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
+
+ # stop when we are done with each sentence
+ if all(done):
+ break
+
+ # sanity check / prepare next batch
+ assert len(next_batch_beam) == batch_size * num_beams
+ beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
+ beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
+ beam_idx = input_ids.new([x[2] for x in next_batch_beam])
+
+ # re-order batch and update current length
+ input_ids = input_ids[beam_idx, :]
+ input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
+ cur_len = cur_len + 1
+
+ # re-order internal states
+ if past is not None:
+ past = self._reorder_cache(past, beam_idx)
+
+ # extend attention_mask for new generated input if only decoder
+ if self.config.is_encoder_decoder is False:
+ attention_mask = torch.cat(
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+ )
+
+ # finalize all open beam hypotheses and add to generated hypotheses
+ for batch_idx in range(batch_size):
+ if done[batch_idx]:
+ continue
+
+ # test that beam scores match previously calculated scores if not eos and batch_idx not done
+ if eos_token_id is not None and all(
+ (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
+ ):
+ assert torch.all(
+ next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
+ ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
+ next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
+ )
+
+ # need to add best num_beams hypotheses to generated hyps
+ for beam_id in range(num_beams):
+ effective_beam_id = batch_idx * num_beams + beam_id
+ final_score = beam_scores[effective_beam_id].item()
+ final_tokens = input_ids[effective_beam_id]
+ generated_hyps[batch_idx].add(final_tokens, final_score)
+
+ # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
+ output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
+ output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
+
+ # select the best hypotheses
+ sent_lengths = input_ids.new(output_batch_size)
+ best = []
+
+ # retrieve best hypotheses
+ for i, hypotheses in enumerate(generated_hyps):
+ sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
+ for j in range(output_num_return_sequences_per_batch):
+ effective_batch_idx = output_num_return_sequences_per_batch * i + j
+ best_hyp = sorted_hyps.pop()[1]
+ sent_lengths[effective_batch_idx] = len(best_hyp)
+ best.append(best_hyp)
+
+ # shorter batches are padded
+ if sent_lengths.min().item() != sent_lengths.max().item():
+ assert pad_token_id is not None, "`Pad_token_id` has to be defined"
+ sent_max_len = min(sent_lengths.max().item() + 1, max_length)
+ decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
+
+ # fill with hypothesis and eos_token_id if necessary
+ for i, hypo in enumerate(best):
+ decoded[i, : sent_lengths[i]] = hypo
+ if sent_lengths[i] < max_length:
+ decoded[i, sent_lengths[i]] = eos_token_id
+ else:
+ # none of the hypotheses have an eos_token
+ assert (len(hypo) == max_length for hypo in best)
+ decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
+
+ return decoded
+
+ @staticmethod
+ def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
+ return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
+
+
+def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
+ """Copied from fairseq for no_repeat_ngram in beam_search"""
+ if cur_len + 1 < no_repeat_ngram_size:
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
+ return [[] for _ in range(num_hypos)]
+ generated_ngrams = [{} for _ in range(num_hypos)]
+ for idx in range(num_hypos):
+ gen_tokens = prev_input_ids[idx].tolist()
+ generated_ngram = generated_ngrams[idx]
+ for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
+ prev_ngram_tuple = tuple(ngram[:-1])
+ generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
+
+ def _get_generated_ngrams(hypo_idx):
+ # Before decoding the next token, prevent decoding of ngrams that have already appeared
+ start_idx = cur_len + 1 - no_repeat_ngram_size
+ ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
+ return generated_ngrams[hypo_idx].get(ngram_idx, [])
+
+ banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
+ return banned_tokens
+
+
+def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
+ banned_tokens = []
+
+ def _tokens_match(prev_tokens, tokens):
+ if len(tokens) == 0:
+ # if bad word tokens is just one token always ban it
+ return True
+ if len(tokens) > len(prev_input_ids):
+ # if bad word tokens are longer then prev input_ids they can't be equal
+ return False
+
+ if prev_tokens[-len(tokens) :] == tokens:
+ # if tokens match
+ return True
+ else:
+ return False
+
+ for prev_input_ids_slice in prev_input_ids:
+ banned_tokens_slice = []
+
+ for banned_token_seq in bad_words_ids:
+ assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
+ bad_words_ids
+ )
+
+ if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
+ # if tokens do not match continue
+ continue
+
+ banned_tokens_slice.append(banned_token_seq[-1])
+
+ banned_tokens.append(banned_tokens_slice)
+
+ return banned_tokens
+
+
+def top_k_top_p_filtering(
+ logits: Tensor,
+ top_k: int = 0,
+ top_p: float = 1.0,
+ filter_value: float = -float("Inf"),
+ min_tokens_to_keep: int = 1,
+) -> Tensor:
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
+ Args:
+ logits: logits distribution shape (batch size, vocabulary size)
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
+ """
+ if top_k > 0:
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = filter_value
+
+ if top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ if min_tokens_to_keep > 1:
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
+ # Shift the indices to the right to keep also the first token above the threshold
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+
+ # scatter sorted tensors to original indexing
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
+ logits[indices_to_remove] = filter_value
+ return logits
+
+
+class BeamHypotheses(object):
+ def __init__(self, num_beams, max_length, length_penalty, early_stopping):
+ """
+ Initialize n-best list of hypotheses.
+ """
+ self.max_length = max_length - 1 # ignoring bos_token
+ self.length_penalty = length_penalty
+ self.early_stopping = early_stopping
+ self.num_beams = num_beams
+ self.beams = []
+ self.worst_score = 1e9
+
+ def __len__(self):
+ """
+ Number of hypotheses in the list.
+ """
+ return len(self.beams)
+
+ def add(self, hyp, sum_logprobs):
+ """
+ Add a new hypothesis to the list.
+ """
+ score = sum_logprobs / len(hyp) ** self.length_penalty
+ if len(self) < self.num_beams or score > self.worst_score:
+ self.beams.append((score, hyp))
+ if len(self) > self.num_beams:
+ sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
+ del self.beams[sorted_scores[0][1]]
+ self.worst_score = sorted_scores[1][0]
+ else:
+ self.worst_score = min(score, self.worst_score)
+
+ def is_done(self, best_sum_logprobs, cur_len):
+ """
+ If there are enough hypotheses and that none of the hypotheses being generated
+ can become better than the worst one in the heap, then we are done with this sentence.
+ """
+
+ if len(self) < self.num_beams:
+ return False
+ elif self.early_stopping:
+ return True
+ else:
+ cur_score = best_sum_logprobs / cur_len ** self.length_penalty
+ ret = self.worst_score >= cur_score
+ return ret
diff --git a/elia/bert/modeling_bert.py b/elia/bert/modeling_bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..e796878aa2e6e39d6e65b0941396bcedca396a46
--- /dev/null
+++ b/elia/bert/modeling_bert.py
@@ -0,0 +1,1569 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BERT model. """
+
+
+import logging
+import math
+import os
+import warnings
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss, MSELoss
+
+from .activations import gelu, gelu_new, swish
+from .configuration_bert import BertConfig
+from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
+from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
+
+
+logger = logging.getLogger(__name__)
+
+_TOKENIZER_FOR_DOC = "BertTokenizer"
+
+BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "bert-base-uncased",
+ "bert-large-uncased",
+ "bert-base-cased",
+ "bert-large-cased",
+ "bert-base-multilingual-uncased",
+ "bert-base-multilingual-cased",
+ "bert-base-chinese",
+ "bert-base-german-cased",
+ "bert-large-uncased-whole-word-masking",
+ "bert-large-cased-whole-word-masking",
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
+ "bert-large-cased-whole-word-masking-finetuned-squad",
+ "bert-base-cased-finetuned-mrpc",
+ "bert-base-german-dbmdz-cased",
+ "bert-base-german-dbmdz-uncased",
+ "cl-tohoku/bert-base-japanese",
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
+ "cl-tohoku/bert-base-japanese-char",
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
+ "TurkuNLP/bert-base-finnish-cased-v1",
+ "TurkuNLP/bert-base-finnish-uncased-v1",
+ "wietsedv/bert-base-dutch-cased",
+ # See all BERT models at https://huggingface.co/models?filter=bert
+]
+
+
+def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
+ """ Load tf checkpoints in a pytorch model.
+ """
+ try:
+ import re
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ):
+ logger.info("Skipping {}".format("/".join(name)))
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info("Skipping {}".format("/".join(name)))
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ assert pointer.shape == array.shape
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info("Initialize PyTorch weight {}".format(name))
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+def mish(x):
+ return x * torch.tanh(nn.functional.softplus(x))
+
+
+ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
+
+
+BertLayerNorm = torch.nn.LayerNorm
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ if position_ids is None:
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
+ position_ids = position_ids.unsqueeze(0).expand(input_shape)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ position_embeddings = self.position_embeddings(position_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + position_embeddings + token_type_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ if encoder_hidden_states is not None:
+ mixed_key_layer = self.key(encoder_hidden_states)
+ mixed_value_layer = self.value(encoder_hidden_states)
+ attention_mask = encoder_attention_mask
+ else:
+ mixed_key_layer = self.key(hidden_states)
+ mixed_value_layer = self.value(hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+ key_layer = self.transpose_for_scores(mixed_key_layer)
+ value_layer = self.transpose_for_scores(mixed_value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = BertSelfAttention(config)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = BertAttention(config)
+ self.is_decoder = config.is_decoder
+ if self.is_decoder:
+ self.crossattention = BertAttention(config)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions=False,
+ ):
+ self_attention_outputs = self.attention(
+ hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ if self.is_decoder and encoder_hidden_states is not None:
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
+
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ outputs = (layer_output,) + outputs
+ return outputs
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ ):
+ all_hidden_states = ()
+ all_attentions = ()
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if getattr(self.config, "gradient_checkpointing", False):
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = (hidden_states,)
+ if output_hidden_states:
+ outputs = outputs + (all_hidden_states,)
+ if output_attentions:
+ outputs = outputs + (all_attentions,)
+ return outputs # last-layer hidden state, (all hidden states), (all attentions)
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertOnlyNSPHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, pooled_output):
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return seq_relationship_score
+
+
+class BertPreTrainingHeads(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, sequence_output, pooled_output):
+ prediction_scores = self.predictions(sequence_output)
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return prediction_scores, seq_relationship_score
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """ An abstract class to handle weights initialization and
+ a simple interface for downloading and loading pretrained models.
+ """
+
+ config_class = BertConfig
+ load_tf_weights = load_tf_weights_in_bert
+ base_model_prefix = "bert"
+
+ def _init_weights(self, module):
+ """ Initialize the weights """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, BertLayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+BERT_START_DOCSTRING = r"""
+ This model is a PyTorch `torch.nn.Module `_ sub-class.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
+ usage and behavior.
+
+ Parameters:
+ config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
+ Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
+"""
+
+BERT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using :class:`transformers.BertTokenizer`.
+ See :func:`transformers.PreTrainedTokenizer.encode` and
+ :func:`transformers.PreTrainedTokenizer.__call__` for details.
+
+ `What are input IDs? <../glossary.html#input-ids>`__
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
+ Mask to avoid performing attention on padding token indices.
+ Mask values selected in ``[0, 1]``:
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
+
+ `What are attention masks? <../glossary.html#attention-mask>`__
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
+ Segment token indices to indicate first and second portions of the inputs.
+ Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
+ corresponds to a `sentence B` token
+
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Selected in the range ``[0, config.max_position_embeddings - 1]``.
+
+ `What are position IDs? <../glossary.html#position-ids>`_
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
+ Mask to nullify selected heads of the self-attention modules.
+ Mask values selected in ``[0, 1]``:
+ :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ if the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
+ is used in the cross-attention if the model is configured as a decoder.
+ Mask values selected in ``[0, 1]``:
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
+ output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
+ If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
+"""
+
+
+@add_start_docstrings(
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
+ BERT_START_DOCSTRING,
+)
+class BertModel(BertPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well
+ as a decoder, in which case a layer of cross-attention is added between
+ the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
+ Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the
+ :obj:`is_decoder` argument of the configuration set to :obj:`True`; an
+ :obj:`encoder_hidden_states` is expected as an input to the forward pass.
+
+ .. _`Attention is all you need`:
+ https://arxiv.org/abs/1706.03762
+
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+ self.encoder = BertEncoder(config)
+ self.pooler = BertPooler(config)
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """ Prunes heads of the model.
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
+ See base class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ ):
+ r"""
+ Return:
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
+ last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token)
+ further processed by a Linear layer and a Tanh activation function. The Linear
+ layer weights are trained from the next sentence prediction (classification)
+ objective during pre-training.
+
+ This output is usually *not* a good summary
+ of the semantic content of the input, you're often better with averaging or pooling
+ the sequence of hidden-states for the whole input sequence.
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+
+ # If a 2D ou 3D attention mask is provided for the cross-attention
+ # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[
+ 1:
+ ] # add hidden_states and attentions if they are here
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
+
+
+@add_start_docstrings(
+ """Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
+ a `next sentence prediction (classification)` head. """,
+ BERT_START_DOCSTRING,
+)
+class BertForPreTraining(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config)
+ self.cls = BertPreTrainingHeads(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ next_sentence_label=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ **kwargs
+ ):
+ r"""
+ labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
+ Labels for computing the masked language modeling loss.
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
+ in ``[0, ..., config.vocab_size]``
+ next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
+ Indices should be in ``[0, 1]``.
+ ``0`` indicates sequence B is a continuation of sequence A,
+ ``1`` indicates sequence B is a random sequence.
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
+ Used to hide legacy arguments that have been deprecated.
+
+ Returns:
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
+ loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False
+ continuation before SoftMax).
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+
+
+ Examples::
+
+ >>> from transformers import BertTokenizer, BertForPreTraining
+ >>> import torch
+
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_scores, seq_relationship_scores = outputs[:2]
+
+ """
+ if "masked_lm_labels" in kwargs:
+ warnings.warn(
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
+ DeprecationWarning,
+ )
+ labels = kwargs.pop("masked_lm_labels")
+ assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ sequence_output, pooled_output = outputs[:2]
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
+
+ outputs = (prediction_scores, seq_relationship_score,) + outputs[
+ 2:
+ ] # add hidden states and attention if they are here
+
+ if labels is not None and next_sentence_label is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
+ total_loss = masked_lm_loss + next_sentence_loss
+ outputs = (total_loss,) + outputs
+
+ return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
+
+
+@add_start_docstrings(
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
+)
+class BertLMHeadModel(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`."
+
+ self.bert = BertModel(config)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ **kwargs
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
+ Labels for computing the left-to-right language modeling loss (next word prediction).
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
+ in ``[0, ..., config.vocab_size]``
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
+ Used to hide legacy arguments that have been deprecated.
+
+ Returns:
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
+ ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
+ Next token prediction loss.
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+
+ Example::
+
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> config.is_decoder = True
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
+ """
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
+
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ outputs = (ltr_lm_loss,) + outputs
+
+ return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
+
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
+class BertForMaskedLM(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ assert (
+ not config.is_decoder
+ ), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
+
+ self.bert = BertModel(config)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ **kwargs
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
+ Labels for computing the masked language modeling loss.
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
+ in ``[0, ..., config.vocab_size]``
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
+ Used to hide legacy arguments that have been deprecated.
+
+ Returns:
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
+ masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
+ Masked language modeling loss.
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+ if "masked_lm_labels" in kwargs:
+ warnings.warn(
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
+ DeprecationWarning,
+ )
+ labels = kwargs.pop("masked_lm_labels")
+ assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
+ assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
+
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ outputs = (masked_lm_loss,) + outputs
+
+ return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
+
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ effective_batch_size = input_shape[0]
+
+ # add a dummy token
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
+ dummy_token = torch.full(
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
+ )
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+@add_start_docstrings(
+ """Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
+)
+class BertForNextSentencePrediction(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config)
+ self.cls = BertOnlyNSPHead(config)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ next_sentence_label=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ ):
+ r"""
+ next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
+ Indices should be in ``[0, 1]``.
+ ``0`` indicates sequence B is a continuation of sequence A,
+ ``1`` indicates sequence B is a random sequence.
+
+ Returns:
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
+ Next sequence prediction (classification) loss.
+ seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+
+ Examples::
+
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
+ >>> import torch
+
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
+
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
+
+ >>> loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
+ """
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ pooled_output = outputs[1]
+
+ seq_relationship_score = self.cls(pooled_output)
+
+ outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
+ if next_sentence_label is not None:
+ loss_fct = CrossEntropyLoss()
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
+ outputs = (next_sentence_loss,) + outputs
+
+ return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
+
+
+@add_start_docstrings(
+ """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
+ the pooled output) e.g. for GLUE tasks. """,
+ BERT_START_DOCSTRING,
+)
+class BertForSequenceClassification(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = BertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
+ Labels for computing the sequence classification/regression loss.
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
+ If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
+
+ if labels is not None:
+ if self.num_labels == 1:
+ # We are doing regression
+ loss_fct = MSELoss()
+ loss = loss_fct(logits.view(-1), labels.view(-1))
+ else:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ outputs = (loss,) + outputs
+
+ return outputs # (loss), logits, (hidden_states), (attentions)
+
+
+@add_start_docstrings(
+ """Bert Model with a multiple choice classification head on top (a linear layer on top of
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
+ BERT_START_DOCSTRING,
+)
+class BertForMultipleChoice(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
+ Labels for computing the multiple choice classification loss.
+ Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
+ of the input tensors. (see `input_ids` above)
+
+ Returns:
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
+ loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
+ Classification loss.
+ classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
+ `num_choices` is the second dimension of the input tensors. (see `input_ids` above).
+
+ Classification scores (before SoftMax).
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
+
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+ outputs = (loss,) + outputs
+
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
+
+
+@add_start_docstrings(
+ """Bert Model with a token classification head on top (a linear layer on top of
+ the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
+ BERT_START_DOCSTRING,
+)
+class BertForTokenClassification(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = BertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
+ Labels for computing the token classification loss.
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
+
+ Returns:
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
+ Classification loss.
+ scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
+ Classification scores (before SoftMax).
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ outputs = (loss,) + outputs
+
+ return outputs # (loss), scores, (hidden_states), (attentions)
+
+
+@add_start_docstrings(
+ """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
+ BERT_START_DOCSTRING,
+)
+class BertForQuestionAnswering(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = BertModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ start_positions=None,
+ end_positions=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ ):
+ r"""
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`).
+ Position outside of the sequence are not taken into account for computing the loss.
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`).
+ Position outside of the sequence are not taken into account for computing the loss.
+
+ Returns:
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
+ Span-start scores (before SoftMax).
+ end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
+ Span-end scores (before SoftMax).
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1)
+ end_logits = end_logits.squeeze(-1)
+
+ outputs = (start_logits, end_logits,) + outputs[2:]
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions.clamp_(0, ignored_index)
+ end_positions.clamp_(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+ outputs = (total_loss,) + outputs
+
+ return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
diff --git a/elia/bert/modeling_utils.py b/elia/bert/modeling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5855e204247404cbfa6fed64fc930ca61d13780
--- /dev/null
+++ b/elia/bert/modeling_utils.py
@@ -0,0 +1,1268 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import logging
+import os
+from typing import Callable, Dict, List, Optional, Tuple
+
+import torch
+from torch import Tensor, device, dtype, nn
+from torch.nn import CrossEntropyLoss
+from torch.nn import functional as F
+
+from .activations import get_activation
+from .configuration_utils import PretrainedConfig
+from .file_utils import (
+ DUMMY_INPUTS,
+ TF2_WEIGHTS_NAME,
+ TF_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ cached_path,
+ hf_bucket_url,
+ is_remote_url,
+)
+from .generation_utils import GenerationMixin
+
+
+logger = logging.getLogger(__name__)
+
+
+try:
+ from torch.nn import Identity
+except ImportError:
+ # Older PyTorch compatibility
+ class Identity(nn.Module):
+ r"""A placeholder identity operator that is argument-insensitive.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ def forward(self, input):
+ return input
+
+
+def find_pruneable_heads_and_indices(
+ heads: List, n_heads: int, head_size: int, already_pruned_heads: set
+) -> Tuple[set, "torch.LongTensor"]:
+ mask = torch.ones(n_heads, head_size)
+ heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
+ for head in heads:
+ # Compute how many pruned heads are before the head and move the index accordingly
+ head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
+ mask[head] = 0
+ mask = mask.view(-1).contiguous().eq(1)
+ index: torch.LongTensor = torch.arange(len(mask))[mask].long()
+ return heads, index
+
+
+class ModuleUtilsMixin:
+ """
+ A few utilities for torch.nn.Modules, to be used as a mixin.
+ """
+
+ def num_parameters(self, only_trainable: bool = False) -> int:
+ """
+ Get number of (optionally, trainable) parameters in the module.
+ """
+ params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
+ return sum(p.numel() for p in params)
+
+ @staticmethod
+ def _hook_rss_memory_pre_forward(module, *args, **kwargs):
+ try:
+ import psutil
+ except (ImportError):
+ raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
+
+ process = psutil.Process(os.getpid())
+ mem = process.memory_info()
+ module.mem_rss_pre_forward = mem.rss
+ return None
+
+ @staticmethod
+ def _hook_rss_memory_post_forward(module, *args, **kwargs):
+ try:
+ import psutil
+ except (ImportError):
+ raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
+
+ process = psutil.Process(os.getpid())
+ mem = process.memory_info()
+ module.mem_rss_post_forward = mem.rss
+ mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
+ module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
+ return None
+
+ def add_memory_hooks(self):
+ """ Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
+ Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
+ """
+ for module in self.modules():
+ module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
+ module.register_forward_hook(self._hook_rss_memory_post_forward)
+ self.reset_memory_hooks_state()
+
+ def reset_memory_hooks_state(self):
+ for module in self.modules():
+ module.mem_rss_diff = 0
+ module.mem_rss_post_forward = 0
+ module.mem_rss_pre_forward = 0
+
+ @property
+ def device(self) -> device:
+ """
+ Get torch.device from module, assuming that the whole module has one device.
+ """
+ try:
+ return next(self.parameters()).device
+ except StopIteration:
+ # For nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = self._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].device
+
+ @property
+ def dtype(self) -> dtype:
+ """
+ Get torch.dtype from module, assuming that the whole module has one dtype.
+ """
+ try:
+ return next(self.parameters()).dtype
+ except StopIteration:
+ # For nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = self._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].dtype
+
+ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
+ """type: torch.Tensor -> torch.Tensor"""
+ if encoder_attention_mask.dim() == 3:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+ if encoder_attention_mask.dim() == 2:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
+ # /transformer/transformer_layers.py#L270
+ # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
+ # encoder_extended_attention_mask.transpose(-1, -2))
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+
+ if self.dtype == torch.float16:
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
+ elif self.dtype == torch.float32:
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
+ else:
+ raise ValueError(
+ "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
+ self.dtype
+ )
+ )
+
+ return encoder_extended_attention_mask
+
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor:
+ """Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
+
+ Arguments:
+ attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to
+ input_shape: tuple, shape of input_ids
+ device: torch.Device, usually self.device
+
+ Returns:
+ torch.Tensor with dtype of attention_mask.dtype
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder:
+ batch_size, seq_length = input_shape
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
+ """
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ attention_probs has shape bsz x n_heads x N x N
+ Arguments:
+ head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads]
+ num_hidden_layers: int
+ Returns:
+ Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ or list with [None] for each layer
+ """
+ if head_mask is not None:
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
+ if is_attention_chunked is True:
+ head_mask = head_mask.unsqueeze(-1)
+ else:
+ head_mask = [None] * num_hidden_layers
+
+ return head_mask
+
+ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
+ """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
+ if head_mask.dim() == 1:
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
+ elif head_mask.dim() == 2:
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
+ assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
+ head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
+ return head_mask
+
+
+class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
+ r""" Base class for all models.
+
+ :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
+ as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
+
+ Class attributes (overridden by derived classes):
+ - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
+ - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
+
+ - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
+ - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
+ - ``path``: a path (string) to the TensorFlow checkpoint.
+
+ - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
+ """
+ config_class = None
+ base_model_prefix = ""
+
+ @property
+ def dummy_inputs(self):
+ """ Dummy inputs to do a forward pass in the network.
+
+ Returns:
+ torch.Tensor with dummy inputs
+ """
+ return {"input_ids": torch.tensor(DUMMY_INPUTS)}
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__()
+ if not isinstance(config, PretrainedConfig):
+ raise ValueError(
+ "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
+ "To create a model from a pretrained model use "
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
+ self.__class__.__name__, self.__class__.__name__
+ )
+ )
+ # Save config in model
+ self.config = config
+
+ @property
+ def base_model(self):
+ return getattr(self, self.base_model_prefix, self)
+
+ def get_input_embeddings(self):
+ """
+ Returns the model's input embeddings.
+
+ Returns:
+ :obj:`nn.Module`:
+ A torch module mapping vocabulary to hidden states.
+ """
+ base_model = getattr(self, self.base_model_prefix, self)
+ if base_model is not self:
+ return base_model.get_input_embeddings()
+ else:
+ raise NotImplementedError
+
+ def set_input_embeddings(self, value: nn.Module):
+ """
+ Set model's input embeddings
+
+ Args:
+ value (:obj:`nn.Module`):
+ A module mapping vocabulary to hidden states.
+ """
+ base_model = getattr(self, self.base_model_prefix, self)
+ if base_model is not self:
+ base_model.set_input_embeddings(value)
+ else:
+ raise NotImplementedError
+
+ def get_output_embeddings(self):
+ """
+ Returns the model's output embeddings.
+
+ Returns:
+ :obj:`nn.Module`:
+ A torch module mapping hidden states to vocabulary.
+ """
+ return None # Overwrite for models with output embeddings
+
+ def tie_weights(self):
+ """
+ Tie the weights between the input embeddings and the output embeddings.
+ If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
+ the weights instead.
+ """
+ output_embeddings = self.get_output_embeddings()
+ if output_embeddings is not None:
+ self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
+
+ def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
+ """ Tie or clone module weights depending of whether we are using TorchScript or not
+ """
+ if self.config.torchscript:
+ output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
+ else:
+ output_embeddings.weight = input_embeddings.weight
+
+ if getattr(output_embeddings, "bias", None) is not None:
+ output_embeddings.bias.data = torch.nn.functional.pad(
+ output_embeddings.bias.data,
+ (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],),
+ "constant",
+ 0,
+ )
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
+ output_embeddings.out_features = input_embeddings.num_embeddings
+
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
+ """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
+ Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
+
+ Arguments:
+
+ new_num_tokens: (`optional`) int:
+ New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
+ If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
+
+ Return: ``torch.nn.Embeddings``
+ Pointer to the input tokens Embeddings Module of the model
+ """
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
+ model_embeds = base_model._resize_token_embeddings(new_num_tokens)
+ if new_num_tokens is None:
+ return model_embeds
+
+ # Update base model and current model config
+ self.config.vocab_size = new_num_tokens
+ base_model.vocab_size = new_num_tokens
+
+ # Tie weights again if needed
+ self.tie_weights()
+
+ return model_embeds
+
+ def _resize_token_embeddings(self, new_num_tokens):
+ old_embeddings = self.get_input_embeddings()
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
+ self.set_input_embeddings(new_embeddings)
+ return self.get_input_embeddings()
+
+ def _get_resized_embeddings(
+ self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
+ ) -> torch.nn.Embedding:
+ """ Build a resized Embedding Module from a provided token Embedding Module.
+ Increasing the size will add newly initialized vectors at the end
+ Reducing the size will remove vectors from the end
+
+ Args:
+ old_embeddings: ``torch.nn.Embedding``
+ Old embeddings to be resized.
+ new_num_tokens: (`optional`) int
+ New number of tokens in the embedding matrix.
+ Increasing the size will add newly initialized vectors at the end
+ Reducing the size will remove vectors from the end
+ If not provided or None: return the provided token Embedding Module.
+ Return: ``torch.nn.Embedding``
+ Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
+ """
+ if new_num_tokens is None:
+ return old_embeddings
+
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
+ if old_num_tokens == new_num_tokens:
+ return old_embeddings
+
+ # Build new embeddings
+ new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
+ new_embeddings.to(old_embeddings.weight.device)
+
+ # initialize all new embeddings (in particular added tokens)
+ self._init_weights(new_embeddings)
+
+ # Copy token embeddings from the previous weights
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
+ new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
+
+ return new_embeddings
+
+ def init_weights(self):
+ """ Initialize and prunes weights if needed. """
+ # Initialize weights
+ self.apply(self._init_weights)
+
+ # Prune heads if needed
+ if self.config.pruned_heads:
+ self.prune_heads(self.config.pruned_heads)
+
+ # Tie weights if needed
+ self.tie_weights()
+
+ def prune_heads(self, heads_to_prune: Dict):
+ """ Prunes heads of the base model.
+
+ Arguments:
+
+ heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
+ E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
+ """
+ # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
+ for layer, heads in heads_to_prune.items():
+ union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
+ self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
+
+ self.base_model._prune_heads(heads_to_prune)
+
+ def save_pretrained(self, save_directory):
+ """ Save a model and its configuration file to a directory, so that it
+ can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
+
+ Arguments:
+ save_directory: directory to which to save.
+ """
+ if os.path.isfile(save_directory):
+ logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
+ return
+ os.makedirs(save_directory, exist_ok=True)
+
+ # Only save the model itself if we are using distributed training
+ model_to_save = self.module if hasattr(self, "module") else self
+
+ # Attach architecture to the config
+ model_to_save.config.architectures = [model_to_save.__class__.__name__]
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
+
+ if getattr(self.config, "xla_device", False):
+ import torch_xla.core.xla_model as xm
+
+ if xm.is_master_ordinal():
+ # Save configuration file
+ model_to_save.config.save_pretrained(save_directory)
+ # xm.save takes care of saving only from master
+ xm.save(model_to_save.state_dict(), output_model_file)
+ else:
+ model_to_save.config.save_pretrained(save_directory)
+ torch.save(model_to_save.state_dict(), output_model_file)
+
+ logger.info("Model weights saved in {}".format(output_model_file))
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+ r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
+
+ The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
+ To train the model, you should first set it back in training mode with ``model.train()``
+
+ The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
+ It is up to you to train those weights with a downstream fine-tuning task.
+
+ The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path: either:
+ - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
+ - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
+ - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
+ - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+ - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
+
+ model_args: (`optional`) Sequence of positional arguments:
+ All remaning positional arguments will be passed to the underlying model's ``__init__`` method
+
+ config: (`optional`) one of:
+ - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
+ - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
+
+ Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
+ - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
+ - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
+ - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
+
+ state_dict: (`optional`) dict:
+ an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
+ This option can be used if you want to create a model from a pretrained configuration but load your own weights.
+ In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
+
+ cache_dir: (`optional`) string:
+ Path to a directory in which a downloaded pre-trained model
+ configuration should be cached if the standard cache should not be used.
+
+ force_download: (`optional`) boolean, default False:
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
+
+ resume_download: (`optional`) boolean, default False:
+ Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
+
+ proxies: (`optional`) dict, default None:
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
+ The proxies are used on each request.
+
+ output_loading_info: (`optional`) boolean:
+ Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
+
+ kwargs: (`optional`) Remaining dictionary of keyword arguments:
+ Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
+
+ - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
+ - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
+
+ Examples::
+
+ # For example purposes. Not runnable.
+ model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
+ model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
+ model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
+ assert model.config.output_attention == True
+ # Loading from a TF checkpoint file instead of a PyTorch model (slower)
+ config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
+ model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
+
+ """
+ config = kwargs.pop("config", None)
+ state_dict = kwargs.pop("state_dict", None)
+ cache_dir = kwargs.pop("cache_dir", None)
+ from_tf = kwargs.pop("from_tf", False)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ output_loading_info = kwargs.pop("output_loading_info", False)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_cdn = kwargs.pop("use_cdn", True)
+
+ # Load config if we don't provide a configuration
+ if not isinstance(config, PretrainedConfig):
+ config_path = config if config is not None else pretrained_model_name_or_path
+ config, model_kwargs = cls.config_class.from_pretrained(
+ config_path,
+ *model_args,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ **kwargs,
+ )
+ else:
+ model_kwargs = kwargs
+
+ # Load model
+ if pretrained_model_name_or_path is not None:
+ if os.path.isdir(pretrained_model_name_or_path):
+ if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
+ # Load from a TF 1.0 checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
+ elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
+ # Load from a TF 2.0 checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
+ # Load from a PyTorch checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
+ else:
+ raise EnvironmentError(
+ "Error no file named {} found in directory {} or `from_tf` set to False".format(
+ [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
+ pretrained_model_name_or_path,
+ )
+ )
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
+ archive_file = pretrained_model_name_or_path
+ elif os.path.isfile(pretrained_model_name_or_path + ".index"):
+ assert (
+ from_tf
+ ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
+ pretrained_model_name_or_path + ".index"
+ )
+ archive_file = pretrained_model_name_or_path + ".index"
+ else:
+ archive_file = hf_bucket_url(
+ pretrained_model_name_or_path,
+ filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
+ use_cdn=use_cdn,
+ )
+
+ try:
+ # Load from URL or cache if already cached
+ resolved_archive_file = cached_path(
+ archive_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ )
+ if resolved_archive_file is None:
+ raise EnvironmentError
+ except EnvironmentError:
+ msg = (
+ f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
+ )
+ raise EnvironmentError(msg)
+
+ if resolved_archive_file == archive_file:
+ logger.info("loading weights file {}".format(archive_file))
+ else:
+ logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
+ else:
+ resolved_archive_file = None
+
+ # Instantiate model.
+ model = cls(config, *model_args, **model_kwargs)
+
+ if state_dict is None and not from_tf:
+ try:
+ state_dict = torch.load(resolved_archive_file, map_location="cpu")
+ except Exception:
+ raise OSError(
+ "Unable to load weights from pytorch checkpoint file. "
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
+ )
+
+ missing_keys = []
+ unexpected_keys = []
+ error_msgs = []
+
+ if from_tf:
+ if resolved_archive_file.endswith(".index"):
+ # Load from a TensorFlow 1.X checkpoint - provided by original authors
+ model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
+ else:
+ # Load from our TensorFlow 2.0 checkpoints
+ try:
+ from transformers import load_tf2_checkpoint_in_pytorch_model
+
+ model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
+ "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ else:
+ # Convert old format to new format if needed from a PyTorch state_dict
+ old_keys = []
+ new_keys = []
+ for key in state_dict.keys():
+ new_key = None
+ if "gamma" in key:
+ new_key = key.replace("gamma", "weight")
+ if "beta" in key:
+ new_key = key.replace("beta", "bias")
+ if new_key:
+ old_keys.append(key)
+ new_keys.append(new_key)
+ for old_key, new_key in zip(old_keys, new_keys):
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ # copy state_dict so _load_from_state_dict can modify it
+ metadata = getattr(state_dict, "_metadata", None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ ##############################################################################################
+ # Print out state_dict's contents: keys
+ '''
+ for key, _ in state_dict.items():
+ print(key)
+ '''
+ ##############################################################################################
+
+
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
+ # so we need to apply the function recursively.
+ def load(module: nn.Module, prefix=""):
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+ module._load_from_state_dict(
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
+ )
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ start_prefix = ""
+ model_to_load = model
+ has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
+ if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
+ start_prefix = cls.base_model_prefix + "."
+ if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
+ model_to_load = getattr(model, cls.base_model_prefix)
+
+ load(model_to_load, prefix=start_prefix)
+
+ if model.__class__.__name__ != model_to_load.__class__.__name__:
+ base_model_state_dict = model_to_load.state_dict().keys()
+ head_model_state_dict_without_base_prefix = [
+ key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
+ ]
+
+ missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
+ f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
+ f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
+ f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
+ f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
+ f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
+ )
+ else:
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
+ f"and are newly initialized: {missing_keys}\n"
+ f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ else:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
+ f"If your task is similar to the task the model of the ckeckpoint was trained on, "
+ f"you can already use {model.__class__.__name__} for predictions without further training."
+ )
+ if len(error_msgs) > 0:
+ raise RuntimeError(
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
+ model.__class__.__name__, "\n\t".join(error_msgs)
+ )
+ )
+ model.tie_weights() # make sure token embedding weights are still tied if needed
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.eval()
+
+ if output_loading_info:
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "error_msgs": error_msgs,
+ }
+ return model, loading_info
+
+ if hasattr(config, "xla_device") and config.xla_device:
+ import torch_xla.core.xla_model as xm
+
+ model = xm.send_cpu_data_to_device(model, xm.xla_device())
+ model.to(xm.xla_device())
+
+ return model
+
+
+class Conv1D(nn.Module):
+ def __init__(self, nf, nx):
+ """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
+ Basically works like a Linear layer but the weights are transposed
+ """
+ super().__init__()
+ self.nf = nf
+ w = torch.empty(nx, nf)
+ nn.init.normal_(w, std=0.02)
+ self.weight = nn.Parameter(w)
+ self.bias = nn.Parameter(torch.zeros(nf))
+
+ def forward(self, x):
+ size_out = x.size()[:-1] + (self.nf,)
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
+ x = x.view(*size_out)
+ return x
+
+
+class PoolerStartLogits(nn.Module):
+ """ Compute SQuAD start_logits from sequence hidden states. """
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, 1)
+
+ def forward(self, hidden_states, p_mask=None):
+ """ Args:
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
+ invalid position mask such as query and special symbols (PAD, SEP, CLS)
+ 1.0 means token should be masked.
+ """
+ x = self.dense(hidden_states).squeeze(-1)
+
+ if p_mask is not None:
+ if next(self.parameters()).dtype == torch.float16:
+ x = x * (1 - p_mask) - 65500 * p_mask
+ else:
+ x = x * (1 - p_mask) - 1e30 * p_mask
+
+ return x
+
+
+class PoolerEndLogits(nn.Module):
+ """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
+ self.activation = nn.Tanh()
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dense_1 = nn.Linear(config.hidden_size, 1)
+
+ def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
+ """ Args:
+ One of ``start_states``, ``start_positions`` should be not None.
+ If both are set, ``start_positions`` overrides ``start_states``.
+
+ **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
+ hidden states of the first tokens for the labeled span.
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
+ position of the first token for the labeled span:
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
+ 1.0 means token should be masked.
+ """
+ assert (
+ start_states is not None or start_positions is not None
+ ), "One of start_states, start_positions should be not None"
+ if start_positions is not None:
+ slen, hsz = hidden_states.shape[-2:]
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
+ start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
+ start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
+
+ x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
+ x = self.activation(x)
+ x = self.LayerNorm(x)
+ x = self.dense_1(x).squeeze(-1)
+
+ if p_mask is not None:
+ if next(self.parameters()).dtype == torch.float16:
+ x = x * (1 - p_mask) - 65500 * p_mask
+ else:
+ x = x * (1 - p_mask) - 1e30 * p_mask
+
+ return x
+
+
+class PoolerAnswerClass(nn.Module):
+ """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
+ self.activation = nn.Tanh()
+ self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
+
+ def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
+ """
+ Args:
+ One of ``start_states``, ``start_positions`` should be not None.
+ If both are set, ``start_positions`` overrides ``start_states``.
+
+ **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
+ hidden states of the first tokens for the labeled span.
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
+ position of the first token for the labeled span.
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
+ position of the CLS token. If None, take the last token.
+
+ note(Original repo):
+ no dependency on end_feature so that we can obtain one single `cls_logits`
+ for each sample
+ """
+ hsz = hidden_states.shape[-1]
+ assert (
+ start_states is not None or start_positions is not None
+ ), "One of start_states, start_positions should be not None"
+ if start_positions is not None:
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
+ start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
+
+ if cls_index is not None:
+ cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
+ cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
+ else:
+ cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
+
+ x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
+ x = self.activation(x)
+ x = self.dense_1(x).squeeze(-1)
+
+ return x
+
+
+class SQuADHead(nn.Module):
+ r""" A SQuAD head inspired by XLNet.
+
+ Parameters:
+ config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
+
+ Inputs:
+ **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
+ hidden states of sequence tokens
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
+ position of the first token for the labeled span.
+ **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
+ position of the last token for the labeled span.
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
+ position of the CLS token. If None, take the last token.
+ **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
+ Whether the question has a possible answer in the paragraph or not.
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
+ 1.0 means token should be masked.
+
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
+ **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
+ Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
+ **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
+ Log probabilities for the top config.start_n_top start token possibilities (beam-search).
+ **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
+ Indices for the top config.start_n_top start token possibilities (beam-search).
+ **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
+ Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
+ **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
+ Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
+ **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
+ ``torch.FloatTensor`` of shape ``(batch_size,)``
+ Log probabilities for the ``is_impossible`` label of the answers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.start_n_top = config.start_n_top
+ self.end_n_top = config.end_n_top
+
+ self.start_logits = PoolerStartLogits(config)
+ self.end_logits = PoolerEndLogits(config)
+ self.answer_class = PoolerAnswerClass(config)
+
+ def forward(
+ self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
+ ):
+ outputs = ()
+
+ start_logits = self.start_logits(hidden_states, p_mask=p_mask)
+
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, let's remove the dimension added by batch splitting
+ for x in (start_positions, end_positions, cls_index, is_impossible):
+ if x is not None and x.dim() > 1:
+ x.squeeze_(-1)
+
+ # during training, compute the end logits based on the ground truth of the start position
+ end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
+
+ loss_fct = CrossEntropyLoss()
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if cls_index is not None and is_impossible is not None:
+ # Predict answerability from the representation of CLS and START
+ cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
+ loss_fct_cls = nn.BCEWithLogitsLoss()
+ cls_loss = loss_fct_cls(cls_logits, is_impossible)
+
+ # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
+ total_loss += cls_loss * 0.5
+
+ outputs = (total_loss,) + outputs
+
+ else:
+ # during inference, compute the end logits based on beam search
+ bsz, slen, hsz = hidden_states.size()
+ start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
+
+ start_top_log_probs, start_top_index = torch.topk(
+ start_log_probs, self.start_n_top, dim=-1
+ ) # shape (bsz, start_n_top)
+ start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
+ start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
+ start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
+
+ hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
+ start_states
+ ) # shape (bsz, slen, start_n_top, hsz)
+ p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
+ end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
+ end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
+
+ end_top_log_probs, end_top_index = torch.topk(
+ end_log_probs, self.end_n_top, dim=1
+ ) # shape (bsz, end_n_top, start_n_top)
+ end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
+ end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
+
+ start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
+ cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
+
+ outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs
+
+ # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
+ # or (if labels are provided) (total_loss,)
+ return outputs
+
+
+class SequenceSummary(nn.Module):
+ r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
+ Args of the config class:
+ summary_type:
+ - 'last' => [default] take the last token hidden state (like XLNet)
+ - 'first' => take the first token hidden state (like Bert)
+ - 'mean' => take the mean of all tokens hidden states
+ - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
+ - 'attn' => Not implemented now, use multi-head attention
+ summary_use_proj: Add a projection after the vector extraction
+ summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
+ summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
+ summary_first_dropout: Add a dropout before the projection and activation
+ summary_last_dropout: Add a dropout after the projection and activation
+ """
+
+ def __init__(self, config: PretrainedConfig):
+ super().__init__()
+
+ self.summary_type = getattr(config, "summary_type", "last")
+ if self.summary_type == "attn":
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
+ raise NotImplementedError
+
+ self.summary = Identity()
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
+ num_classes = config.num_labels
+ else:
+ num_classes = config.hidden_size
+ self.summary = nn.Linear(config.hidden_size, num_classes)
+
+ activation_string = getattr(config, "summary_activation", None)
+ self.activation: Callable = (get_activation(activation_string) if activation_string else Identity())
+
+ self.first_dropout = Identity()
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
+
+ self.last_dropout = Identity()
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
+
+ def forward(self, hidden_states, cls_index=None):
+ """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
+ cls_index: [optional] position of the classification token if summary_type == 'cls_index',
+ shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
+ if summary_type == 'cls_index' and cls_index is None:
+ we take the last token of the sequence as classification token
+ """
+ if self.summary_type == "last":
+ output = hidden_states[:, -1]
+ elif self.summary_type == "first":
+ output = hidden_states[:, 0]
+ elif self.summary_type == "mean":
+ output = hidden_states.mean(dim=1)
+ elif self.summary_type == "cls_index":
+ if cls_index is None:
+ cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,)
+ else:
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
+ elif self.summary_type == "attn":
+ raise NotImplementedError
+
+ output = self.first_dropout(output)
+ output = self.summary(output)
+ output = self.activation(output)
+ output = self.last_dropout(output)
+
+ return output
+
+
+def prune_linear_layer(layer, index, dim=0):
+ """ Prune a linear layer (a model parameters) to keep only entries in index.
+ Return the pruned layer as a new layer with requires_grad=True.
+ Used to remove heads.
+ """
+ index = index.to(layer.weight.device)
+ W = layer.weight.index_select(dim, index).clone().detach()
+ if layer.bias is not None:
+ if dim == 1:
+ b = layer.bias.clone().detach()
+ else:
+ b = layer.bias[index].clone().detach()
+ new_size = list(layer.weight.size())
+ new_size[dim] = len(index)
+ new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
+ new_layer.weight.requires_grad = False
+ new_layer.weight.copy_(W.contiguous())
+ new_layer.weight.requires_grad = True
+ if layer.bias is not None:
+ new_layer.bias.requires_grad = False
+ new_layer.bias.copy_(b.contiguous())
+ new_layer.bias.requires_grad = True
+ return new_layer
+
+
+def prune_conv1d_layer(layer, index, dim=1):
+ """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
+ A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
+ Return the pruned layer as a new layer with requires_grad=True.
+ Used to remove heads.
+ """
+ index = index.to(layer.weight.device)
+ W = layer.weight.index_select(dim, index).clone().detach()
+ if dim == 0:
+ b = layer.bias.clone().detach()
+ else:
+ b = layer.bias[index].clone().detach()
+ new_size = list(layer.weight.size())
+ new_size[dim] = len(index)
+ new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
+ new_layer.weight.requires_grad = False
+ new_layer.weight.copy_(W.contiguous())
+ new_layer.weight.requires_grad = True
+ new_layer.bias.requires_grad = False
+ new_layer.bias.copy_(b.contiguous())
+ new_layer.bias.requires_grad = True
+ return new_layer
+
+
+def prune_layer(layer, index, dim=None):
+ """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
+ Return the pruned layer as a new layer with requires_grad=True.
+ Used to remove heads.
+ """
+ if isinstance(layer, nn.Linear):
+ return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
+ elif isinstance(layer, Conv1D):
+ return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
+ else:
+ raise ValueError("Can't prune layer of class {}".format(layer.__class__))
+
+
+def apply_chunking_to_forward(
+ chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
+) -> torch.Tensor:
+ """
+ This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`.
+ It then applies a layer `forward_fn` to each chunk independently to save memory.
+ If the `forward_fn` is independent across the `chunk_dim` this function will yield the
+ same result as not applying it.
+
+ Args:
+ chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size`
+ chunk_dim: int - the dimension over which the input_tensors should be chunked
+ forward_fn: fn - the forward fn of the model
+ input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked
+ Returns:
+ a Tensor with the same shape the foward_fn would have given if applied
+
+
+ Examples::
+
+ # rename the usual forward() fn to forward_chunk()
+ def forward_chunk(self, hidden_states):
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+ # implement a chunked forward function
+ def forward(self, hidden_states):
+ return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
+ """
+
+ assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
+ tensor_shape = input_tensors[0].shape
+ assert all(
+ input_tensor.shape == tensor_shape for input_tensor in input_tensors
+ ), "All input tenors have to be of the same shape"
+
+ # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability
+ num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
+ assert num_args_in_forward_chunk_fn == len(
+ input_tensors
+ ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(
+ num_args_in_forward_chunk_fn, len(input_tensors)
+ )
+
+ if chunk_size > 0:
+ assert (
+ input_tensors[0].shape[chunk_dim] % chunk_size == 0
+ ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
+ input_tensors[0].shape[chunk_dim], chunk_size
+ )
+
+ num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
+
+ # chunk input tensor into tuples
+ input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
+ # apply forward fn to every tuple
+ output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
+ # concatenate output at same dimension
+ return torch.cat(output_chunks, dim=chunk_dim)
+
+ return forward_fn(*input_tensors)
diff --git a/elia/bert/multimodal_bert.py b/elia/bert/multimodal_bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..b01cdbefe1197b45e25bdd754f847fd7efdd7c01
--- /dev/null
+++ b/elia/bert/multimodal_bert.py
@@ -0,0 +1,277 @@
+
+from .modeling_bert import BertModel
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+
+class MultiModalBert(BertModel):
+ def __init__(self, config, embed_dim, pwam_idx=[3,6,9,12], num_heads_fusion=[1,1,1,1], fusion_drop=0.0):
+ super().__init__(config)
+ self.pwam_idx = pwam_idx
+ self.num_heads_fusion = num_heads_fusion
+ self.fusion_drop = fusion_drop
+
+ pwam_dims=[embed_dim * 2** i for i in range(len(pwam_idx))]
+ #print(pwam_dims)
+ self.pwams = nn.ModuleList()
+ self.res_gates = nn.ModuleList()
+ self.norms = nn.ModuleList()
+ for i in range(0, len(pwam_idx)):
+ dim = pwam_dims[i]
+ fusion = PWAM(768, # both the visual input and for combining, num of channels
+ dim, # v_in
+ 768, # l_in
+ 768, # key
+ 768, # value
+ num_heads=num_heads_fusion[i],
+ dropout=fusion_drop)
+ self.pwams.append(fusion)
+
+ res_gate = nn.Sequential(
+ nn.Linear(768, 768, bias=False),
+ nn.ReLU(),
+ nn.Linear(768, 768, bias=False),
+ nn.Tanh()
+ )
+ nn.init.zeros_(res_gate[0].weight)
+ nn.init.zeros_(res_gate[2].weight)
+ self.res_gates.append(res_gate)
+
+ self.norms.append(nn.LayerNorm(768))
+
+ def forward_stem(self, input_ids, attention_mask):
+ input_shape = input_ids.size()
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
+
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, input_ids.device)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids, token_type_ids=token_type_ids
+ )
+ #print(embedding_output.shape, extended_attention_mask.shape, "?>>>")
+ return embedding_output, extended_attention_mask
+
+ def forward_stage1(self, hidden_states, attention_mask):
+ for i in range(0, self.pwam_idx[0]):
+ layer_module = self.encoder.layer[i]
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ )
+ hidden_states = layer_outputs[0]
+
+ return layer_outputs[0]
+
+ def forward_stage2(self, hidden_states, attention_mask):
+ for i in range(self.pwam_idx[0], self.pwam_idx[1]):
+ layer_module = self.encoder.layer[i]
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ )
+ hidden_states = layer_outputs[0]
+
+ return layer_outputs[0]
+
+ def forward_stage3(self, hidden_states, attention_mask):
+ for i in range(self.pwam_idx[1], self.pwam_idx[2]):
+ layer_module = self.encoder.layer[i]
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ )
+ hidden_states = layer_outputs[0]
+
+ return layer_outputs[0]
+
+ def forward_stage4(self, hidden_states, attention_mask):
+ for i in range(self.pwam_idx[2], self.pwam_idx[3]):
+ layer_module = self.encoder.layer[i]
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ )
+ hidden_states = layer_outputs[0]
+
+ return layer_outputs[0]
+
+ def forward_pwam1(self, x, l, l_mask):
+ l_residual = self.pwams[0](x, l, l_mask)
+ l = l + (self.res_gates[0](l_residual) * l_residual)
+ return self.norms[0](l_residual), l
+
+ def forward_pwam2(self, x, l, l_mask):
+ l_residual = self.pwams[1](x, l, l_mask)
+ l = l + (self.res_gates[1](l_residual) * l_residual)
+ return self.norms[1](l_residual), l
+
+ def forward_pwam3(self, x, l, l_mask):
+ l_residual = self.pwams[2](x, l, l_mask)
+ l = l + (self.res_gates[2](l_residual) * l_residual)
+ return self.norms[2](l_residual), l
+
+ def forward_pwam4(self, x, l, l_mask):
+ l_residual = self.pwams[3](x, l, l_mask)
+ l = l + (self.res_gates[3](l_residual) * l_residual)
+ return self.norms[3](l_residual), l
+
+class PWAM(nn.Module):
+ def __init__(self, dim, v_in_channels, l_in_channels, key_channels, value_channels, num_heads=0, dropout=0.0):
+ super(PWAM, self).__init__()
+ # input x shape: (B, H*W, dim)
+ #self.vis_project = nn.Sequential(nn.Conv1d(dim, dim, 1, 1), # the init function sets bias to 0 if bias is True
+ # nn.GELU(),
+ # nn.Dropout(dropout)
+ # )
+ #self.vis_project = nn.Sequential(nn.Conv1d(dim, dim, 1, 1), # the init function sets bias to 0 if bias is True
+ self.vis_project = nn.Sequential(nn.Linear(dim, dim), # the init function sets bias to 0 if bias is True
+ nn.GELU(),
+ nn.Dropout(dropout)
+ )
+
+ self.image_lang_att = SpatialImageLanguageAttention(v_in_channels, # v_in
+ l_in_channels, # l_in
+ key_channels, # key
+ value_channels, # value
+ out_channels=value_channels, # out
+ num_heads=num_heads)
+
+ self.project_mm = nn.Sequential(nn.Conv1d(value_channels, value_channels, 1, 1),
+ nn.GELU(),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, l, l_mask):
+ # input x shape: (B, H*W, dim)
+ #print("???", x.shape, l.shape, l_mask.shape)
+ #print(self.vis_project)
+ #vis = self.vis_project(x.permute(0, 2, 1)) # (B, dim, H*W)
+ vis = self.vis_project(l) # (B, dim, H*W)
+
+ lang = self.image_lang_att(x, l, l_mask) # (B, H*W, dim)
+
+ lang = lang.permute(0, 2, 1) # (B, dim, H*W)
+
+ #print("vis", vis.shape, "lang", lang.shape)
+ mm = torch.mul(vis.permute(0,2,1), lang)
+ #print(mm.shape)
+ mm = self.project_mm(mm) # (B, dim, H*W)
+
+ mm = mm.permute(0, 2, 1) # (B, H*W, dim)
+
+ return mm
+
+ #self.fusion = PWAM(dim, # both the visual input and for combining, num of channels
+ # dim, # v_in
+ # 768, # l_in
+ # dim, # key
+ # dim, # value
+ # num_heads=num_heads_fusion,
+ # dropout=fusion_drop)
+
+class SpatialImageLanguageAttention(nn.Module):
+ def __init__(self, v_in_channels, l_in_channels, key_channels, value_channels, out_channels=None, num_heads=1):
+ super(SpatialImageLanguageAttention, self).__init__()
+ # x shape: (B, H*W, v_in_channels)
+ # l input shape: (B, l_in_channels, N_l)
+ # l_mask shape: (B, N_l, 1)
+ self.v_in_channels = v_in_channels
+ self.l_in_channels = l_in_channels
+ self.out_channels = out_channels
+ self.key_channels = key_channels
+ self.value_channels = value_channels
+ self.num_heads = num_heads
+ if out_channels is None:
+ self.out_channels = self.value_channels
+
+ # Keys: language features: (B, l_in_channels, #words)
+ # avoid any form of spatial normalization because a sentence contains many padding 0s
+ self.f_query = nn.Sequential(
+ nn.Conv1d(self.l_in_channels, self.key_channels, kernel_size=1, stride=1),
+ )
+
+ # Queries: visual features: (B, H*W, v_in_channels)
+ self.f_key = nn.Sequential(
+ nn.Conv1d(self.v_in_channels, self.key_channels, kernel_size=1, stride=1),
+ nn.InstanceNorm1d(self.key_channels),
+ )
+
+ # Values: language features: (B, l_in_channels, #words)
+ #self.f_value = nn.Sequential(
+ # nn.Conv1d(self.l_in_channels, self.value_channels, kernel_size=1, stride=1),
+ #)
+ self.f_value = nn.Sequential(
+ nn.Conv1d(self.v_in_channels, self.key_channels, kernel_size=1, stride=1),
+ nn.InstanceNorm1d(self.key_channels),
+ )
+
+ # Out projection
+ self.W = nn.Sequential(
+ nn.Conv1d(self.value_channels, self.out_channels, kernel_size=1, stride=1),
+ nn.InstanceNorm1d(self.out_channels),
+ )
+
+ def forward(self, x, l, l_mask):
+ #print('input shape', x.shape, l.shape, l_mask.shape)
+ l_mask = l_mask.squeeze(1)
+ # x shape: (B, H*W, v_in_channels)
+ # l input shape: (B, l_in_channels, N_l)
+ # l_mask shape: (B, N_l, 1)
+ B, HW = x.size(0), x.size(1)
+ x = x.permute(0, 2, 1) # (B, key_channels, H*W)
+ l = l.permute(0,2,1)
+ #l_mask = l_mask.permute(0, 2, 1) # (B, N_l, 1) -> (B, 1, N_l)
+ l_mask = l_mask # (B, N_l, 1) -> (B, 1, N_l)
+
+ #query = self.f_query(x) # (B, key_channels, H*W) if Conv1D
+ #query = query.permute(0, 2, 1) # (B, H*W, key_channels)
+ #key = self.f_key(l) # (B, key_channels, N_l)
+ #value = self.f_value(l) # (B, self.value_channels, N_l)
+ #key = key * l_mask # (B, key_channels, N_l)
+ #value = value * l_mask # (B, self.value_channels, N_l)
+
+ #print(l.shape, self.f_query)
+ query = self.f_query(l) # (B, key_channels, H*W) if Conv1D
+ query = query * l_mask # (B, key_channels, N_l)
+ query = query.permute(0, 2, 1) # (B, N_l, key_channels)
+
+ key = self.f_key(x) # (B, key_channels, H*W) if Conv1D
+ value = self.f_value(x) # (B, key_channels, H*W) if Conv1D
+
+ n_l = query.size(1)
+ #print(query.shape, key.shape, value.shape)
+
+ #query = query.reshape(B, HW, self.num_heads, self.key_channels//self.num_heads).permute(0, 2, 1, 3)
+ # (b, num_heads, H*W, self.key_channels//self.num_heads)
+ #key = key.reshape(B, self.num_heads, self.key_channels//self.num_heads, n_l)
+ # (b, num_heads, self.key_channels//self.num_heads, n_l)
+ #value = value.reshape(B, self.num_heads, self.value_channels//self.num_heads, n_l)
+ # # (b, num_heads, self.value_channels//self.num_heads, n_l)
+ key = key.reshape(B, self.num_heads, self.key_channels//self.num_heads, HW)
+ value = value.reshape(B, self.num_heads, self.key_channels//self.num_heads, HW)
+ # (b, num_heads, H*W, self.key_channels//self.num_heads)
+ #query = query.reshape(B, self.num_heads, self.key_channels//self.num_heads, n_l)
+ query = query.reshape(B, n_l, self.num_heads, self.key_channels//self.num_heads).permute(0, 2, 1, 3)
+ # (b, num_heads, self.key_channels//self.num_heads, n_l)
+ #value = value.reshape(B, self.num_heads, self.value_channels//self.num_heads, n_l)
+ #print('after reshape', query.shape, key.shape, value.shape)
+
+ l_mask = l_mask.unsqueeze(-1) # (b, 1, 1, n_l)
+
+ #sim_map = torch.matmul(query, key) # (B, self.num_heads, H*W, N_l)
+ sim_map = torch.matmul(query, key) # (B, self.num_heads, N_l, H*W)
+ sim_map = (self.key_channels ** -.5) * sim_map # scaled dot product
+
+ sim_map = sim_map + (1e4*l_mask - 1e4) # assign a very small number to padding positions
+ sim_map = F.softmax(sim_map, dim=-1) # (B, num_heads, h*w, N_l)
+ out = torch.matmul(sim_map, value.permute(0, 1, 3, 2)) # (B, num_heads, H*W, self.value_channels//num_heads)
+ #print('out', out.shape)
+ #out = out.permute(0, 2, 1, 3).contiguous().reshape(B, HW, self.value_channels) # (B, H*W, value_channels)
+ out = out.permute(0, 2, 1, 3).contiguous().reshape(B, n_l, self.value_channels) # (B, H*W, value_channels)
+ out = out.permute(0, 2, 1) # (B, value_channels, HW)
+ out = self.W(out) # (B, value_channels, HW)
+ out = out.permute(0, 2, 1) # (B, HW, value_channels)
+
+ return out
diff --git a/elia/bert/tokenization_bert.py b/elia/bert/tokenization_bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..972e1733163522359750dddedf6dea885085b2ca
--- /dev/null
+++ b/elia/bert/tokenization_bert.py
@@ -0,0 +1,545 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes."""
+
+
+import collections
+import logging
+import os
+import unicodedata
+from typing import List, Optional
+
+from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
+
+
+logger = logging.getLogger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
+ "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
+ "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
+ "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
+ "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
+ "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
+ "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
+ "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
+ "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
+ "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
+ "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
+ "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
+ "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
+ "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
+ }
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "bert-base-uncased": 512,
+ "bert-large-uncased": 512,
+ "bert-base-cased": 512,
+ "bert-large-cased": 512,
+ "bert-base-multilingual-uncased": 512,
+ "bert-base-multilingual-cased": 512,
+ "bert-base-chinese": 512,
+ "bert-base-german-cased": 512,
+ "bert-large-uncased-whole-word-masking": 512,
+ "bert-large-cased-whole-word-masking": 512,
+ "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
+ "bert-large-cased-whole-word-masking-finetuned-squad": 512,
+ "bert-base-cased-finetuned-mrpc": 512,
+ "bert-base-german-dbmdz-cased": 512,
+ "bert-base-german-dbmdz-uncased": 512,
+ "TurkuNLP/bert-base-finnish-cased-v1": 512,
+ "TurkuNLP/bert-base-finnish-uncased-v1": 512,
+ "wietsedv/bert-base-dutch-cased": 512,
+}
+
+PRETRAINED_INIT_CONFIGURATION = {
+ "bert-base-uncased": {"do_lower_case": True},
+ "bert-large-uncased": {"do_lower_case": True},
+ "bert-base-cased": {"do_lower_case": False},
+ "bert-large-cased": {"do_lower_case": False},
+ "bert-base-multilingual-uncased": {"do_lower_case": True},
+ "bert-base-multilingual-cased": {"do_lower_case": False},
+ "bert-base-chinese": {"do_lower_case": False},
+ "bert-base-german-cased": {"do_lower_case": False},
+ "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
+ "bert-large-cased-whole-word-masking": {"do_lower_case": False},
+ "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
+ "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
+ "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
+ "bert-base-german-dbmdz-cased": {"do_lower_case": False},
+ "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
+ "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
+ "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
+ "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
+}
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class BertTokenizer(PreTrainedTokenizer):
+ r"""
+ Constructs a BERT tokenizer. Based on WordPiece.
+
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
+ should refer to the superclass for more information regarding methods.
+
+ Args:
+ vocab_file (:obj:`string`):
+ File containing the vocabulary.
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether to lowercase the input when tokenizing.
+ do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether to do basic tokenization before WordPiece.
+ never_split (:obj:`Iterable`, `optional`, defaults to :obj:`None`):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ :obj:`do_basic_tokenize=True`
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
+ for sequence classification or for a text and a question for question answering.
+ It is also used as the last token of a sequence built with special tokens.
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
+ The classifier token which is used when doing sequence classification (classification of the whole
+ sequence instead of per-token classification). It is the first token of the sequence when built with
+ special tokens.
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether to tokenize Chinese characters.
+ This should likely be deactivated for Japanese:
+ see: https://github.com/huggingface/transformers/issues/328
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=True,
+ do_basic_tokenize=True,
+ never_split=None,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ **kwargs
+ ):
+ super().__init__(
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ **kwargs,
+ )
+
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
+ )
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(
+ do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars
+ )
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
+
+ @property
+ def vocab_size(self):
+ return len(self.vocab)
+
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ def _tokenize(self, text):
+ split_tokens = []
+ if self.do_basic_tokenize:
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
+
+ # If the token is part of the never_split set
+ if token in self.basic_tokenizer.never_split:
+ split_tokens.append(token)
+ else:
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
+ else:
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ def _convert_token_to_id(self, token):
+ """ Converts a token (str) in an id using the vocab. """
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """ Converts a sequence of tokens (string) in a single string. """
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
+ by concatenating and adding special tokens.
+ A BERT sequence has the following format:
+
+ - single sequence: ``[CLS] X [SEP]``
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs to which the special tokens will be added
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer ``prepare_for_model`` method.
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of ids.
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Set to True if the token list is already formatted with special tokens for the model
+
+ Returns:
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formated with special tokens for the model."
+ )
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
+ A BERT sequence pair mask has the following format:
+
+ ::
+
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of ids.
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
+ sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ def save_vocabulary(self, vocab_path):
+ """
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
+
+ Args:
+ vocab_path (:obj:`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ :obj:`Tuple(str)`: Paths to the files saved.
+ """
+ index = 0
+ if os.path.isdir(vocab_path):
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
+ else:
+ vocab_file = vocab_path
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
+
+
+class BasicTokenizer(object):
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
+
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
+ """ Constructs a BasicTokenizer.
+
+ Args:
+ **do_lower_case**: Whether to lower case the input.
+ **never_split**: (`optional`) list of str
+ Kept for backward compatibility purposes.
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
+ List of token not to split.
+ **tokenize_chinese_chars**: (`optional`) boolean (default True)
+ Whether to tokenize Chinese characters.
+ This should likely be deactivated for Japanese:
+ see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
+ """
+ if never_split is None:
+ never_split = []
+ self.do_lower_case = do_lower_case
+ self.never_split = set(never_split)
+ self.tokenize_chinese_chars = tokenize_chinese_chars
+
+ def tokenize(self, text, never_split=None):
+ """ Basic Tokenization of a piece of text.
+ Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
+
+ Args:
+ **never_split**: (`optional`) list of str
+ Kept for backward compatibility purposes.
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
+ List of token not to split.
+ """
+ # union() returns a new set by concatenating the two sets.
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ if self.tokenize_chinese_chars:
+ text = self._tokenize_chinese_chars(text)
+ orig_tokens = whitespace_tokenize(text)
+ split_tokens = []
+ for token in orig_tokens:
+ if self.do_lower_case and token not in never_split:
+ token = token.lower()
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text, never_split=None):
+ """Splits punctuation on a piece of text."""
+ if never_split is not None and text in never_split:
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (
+ (cp >= 0x4E00 and cp <= 0x9FFF)
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
+ or (cp >= 0xF900 and cp <= 0xFAFF)
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
+ ): #
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+class WordpieceTokenizer(object):
+ """Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """Tokenizes a piece of text into its word pieces.
+
+ This uses a greedy longest-match-first algorithm to perform tokenization
+ using the given vocabulary.
+
+ For example:
+ input = "unaffable"
+ output = ["un", "##aff", "##able"]
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through `BasicTokenizer`.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
+
diff --git a/elia/bert/tokenization_utils.py b/elia/bert/tokenization_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d878210f407d8fb10226d9e4435c761a1f7483fc
--- /dev/null
+++ b/elia/bert/tokenization_utils.py
@@ -0,0 +1,723 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Tokenization classes for python tokenizers.
+ For fast tokenizers (provided by HuggingFace's tokenizers library) see tokenization_utils_fast.py
+"""
+
+import itertools
+import logging
+import re
+import unicodedata
+from typing import Dict, List, Optional, Tuple, Union
+
+from .file_utils import add_end_docstrings
+from .tokenization_utils_base import (
+ ENCODE_KWARGS_DOCSTRING,
+ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
+ AddedToken,
+ BatchEncoding,
+ EncodedInput,
+ EncodedInputPair,
+ PaddingStrategy,
+ PreTokenizedInput,
+ PreTokenizedInputPair,
+ PreTrainedTokenizerBase,
+ TensorType,
+ TextInput,
+ TextInputPair,
+ TruncationStrategy,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+def _is_whitespace(char):
+ """Checks whether `chars` is a whitespace character."""
+ # \t, \n, and \r are technically contorl characters but we treat them
+ # as whitespace since they are generally considered as such.
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
+ return True
+ cat = unicodedata.category(char)
+ if cat == "Zs":
+ return True
+ return False
+
+
+def _is_control(char):
+ """Checks whether `chars` is a control character."""
+ # These are technically control characters but we count them as whitespace
+ # characters.
+ if char == "\t" or char == "\n" or char == "\r":
+ return False
+ cat = unicodedata.category(char)
+ if cat.startswith("C"):
+ return True
+ return False
+
+
+def _is_punctuation(char):
+ """Checks whether `chars` is a punctuation character."""
+ cp = ord(char)
+ # We treat all non-letter/number ASCII as punctuation.
+ # Characters such as "^", "$", and "`" are not in the Unicode
+ # Punctuation class but we treat them as punctuation anyways, for
+ # consistency.
+ if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
+ return True
+ cat = unicodedata.category(char)
+ if cat.startswith("P"):
+ return True
+ return False
+
+
+def _is_end_of_word(text):
+ """Checks whether the last character in text is one of a punctuation, control or whitespace character."""
+ last_char = text[-1]
+ return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
+
+
+def _is_start_of_word(text):
+ """Checks whether the first character in text is one of a punctuation, control or whitespace character."""
+ first_char = text[0]
+ return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
+
+
+class PreTrainedTokenizer(PreTrainedTokenizerBase):
+ """ Base class for all slow tokenizers.
+
+ Handle all the shared methods for tokenization and special tokens as well as methods
+ downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
+
+ This class also contain the added tokens in a unified way on top of all tokenizers so we don't
+ have to handle the specific vocabulary augmentation methods of the various underlying
+ dictionary structures (BPE, sentencepiece...).
+
+ Class attributes (overridden by derived classes):
+
+ - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file
+ required by the model, and as associated values, the filename for saving the associated file (string).
+ - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys
+ being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the
+ `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the
+ associated pretrained vocabulary file.
+ - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained
+ models, and as associated values, the maximum length of the sequence inputs of this model, or None if the
+ model has no maximum input size.
+ - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the
+ pretrained models, and as associated values, a dictionnary of specific arguments to pass to the
+ ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the
+ ``from_pretrained()`` method.
+
+ Args:
+ - ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model.
+ When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated
+ model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`).
+ no associated max_length can be found in ``max_model_input_sizes``.
+ - ``padding_side``: (`Optional`) string: the side on which the model should have padding applied.
+ Should be selected between ['right', 'left']
+ - ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the
+ model ("token_type_ids", "attention_mask"...).
+ - ``bos_token``: (`Optional`) string: a beginning of sentence token.
+ Will be associated to ``self.bos_token`` and ``self.bos_token_id``
+ - ``eos_token``: (`Optional`) string: an end of sentence token.
+ Will be associated to ``self.eos_token`` and ``self.eos_token_id``
+ - ``unk_token``: (`Optional`) string: an unknown token.
+ Will be associated to ``self.unk_token`` and ``self.unk_token_id``
+ - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence).
+ Will be associated to ``self.sep_token`` and ``self.sep_token_id``
+ - ``pad_token``: (`Optional`) string: a padding token.
+ Will be associated to ``self.pad_token`` and ``self.pad_token_id``
+ - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence
+ leveraging self-attention along the full depth of the model).
+ Will be associated to ``self.cls_token`` and ``self.cls_token_id``
+ - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language
+ modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
+ - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens.
+ Adding all special tokens here ensure they won't be split by the tokenization process.
+ Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
+
+
+ .. automethod:: __call__
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ # Added tokens - We store this for both slow and fast tokenizers
+ # until the serialization of Fast tokenizers is updated
+ self.added_tokens_encoder: Dict[str, int] = {}
+ self.added_tokens_decoder: Dict[int, str] = {}
+ self.unique_no_split_tokens: List[str] = []
+
+ @property
+ def is_fast(self) -> bool:
+ return False
+
+ @property
+ def vocab_size(self) -> int:
+ """ Size of the base vocabulary (without the added tokens) """
+ raise NotImplementedError
+
+ def get_vocab(self):
+ """ Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
+ raise NotImplementedError()
+
+ def get_added_vocab(self) -> Dict[str, int]:
+ return self.added_tokens_encoder
+
+ def __len__(self):
+ """ Size of the full vocabulary with the added tokens """
+ return self.vocab_size + len(self.added_tokens_encoder)
+
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens=False) -> int:
+ """
+ Add a list of new tokens to the tokenizer class. If the new tokens are not in the
+ vocabulary, they are added to it with indices starting from length of the current vocabulary.
+
+ Args:
+ new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not
+ already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
+
+ Returns:
+ Number of tokens added to the vocabulary.
+
+ Examples::
+
+ # Let's see how to increase the vocabulary of Bert model and tokenizer
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ model = BertModel.from_pretrained('bert-base-uncased')
+
+ num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
+ print('We have added', num_added_toks, 'tokens')
+ model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
+ """
+ new_tokens = [str(tok) for tok in new_tokens]
+
+ tokens_to_add = []
+ for token in new_tokens:
+ assert isinstance(token, str)
+ if not special_tokens and self.init_kwargs.get("do_lower_case", False):
+ token = token.lower()
+ if (
+ token != self.unk_token
+ and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
+ and token not in tokens_to_add
+ ):
+ tokens_to_add.append(token)
+ if self.verbose:
+ logger.info("Adding %s to the vocabulary", token)
+
+ added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
+ added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
+ self.added_tokens_encoder.update(added_tok_encoder)
+ self.added_tokens_decoder.update(added_tok_decoder)
+
+ # Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
+ if special_tokens:
+ self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(new_tokens)))
+ else:
+ # Or on the newly added tokens
+ self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
+
+ return len(tokens_to_add)
+
+ def num_special_tokens_to_add(self, pair=False):
+ """
+ Returns the number of added tokens when encoding a sequence with special tokens.
+
+ Note:
+ This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
+ inside your training loop.
+
+ Args:
+ pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
+ number of added tokens in the case of a single sequence if set to False.
+
+ Returns:
+ Number of tokens added to sequences
+ """
+ token_ids_0 = []
+ token_ids_1 = []
+ return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
+
+ def tokenize(self, text: TextInput, **kwargs):
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
+ Split in words for word-based vocabulary or sub-words for sub-word-based
+ vocabularies (BPE/SentencePieces/WordPieces).
+
+ Take care of added tokens.
+
+ Args:
+ text (:obj:`string`): The sequence to be encoded.
+ **kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method.
+ """
+ # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
+ all_special_tokens_extended = dict(
+ (str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
+ )
+
+ text, kwargs = self.prepare_for_tokenization(text, **kwargs)
+
+ if kwargs:
+ logger.warning(f"Keyword arguments {kwargs} not recognized.")
+
+ # TODO: should this be in the base class?
+ if self.init_kwargs.get("do_lower_case", False):
+ # convert non-special tokens to lowercase
+ escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
+ pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
+ text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
+
+ def split_on_token(tok, text):
+ result = []
+ tok_extended = all_special_tokens_extended.get(tok, None)
+ split_text = text.split(tok)
+ full_word = ""
+ for i, sub_text in enumerate(split_text):
+ # AddedToken can control whitespace stripping around them.
+ # We use them for GPT2 and Roberta to have different behavior depending on the special token
+ # Cf. https://github.com/huggingface/transformers/pull/2778
+ # and https://github.com/huggingface/transformers/issues/3788
+ if isinstance(tok_extended, AddedToken):
+ if tok_extended.single_word:
+ # Try to avoid splitting on token
+ if (
+ i < len(split_text) - 1
+ and not _is_end_of_word(sub_text)
+ and not _is_start_of_word(split_text[i + 1])
+ ):
+ # Don't extract the special token
+ full_word += sub_text + tok
+ elif full_word:
+ full_word += sub_text
+ result += [full_word]
+ full_word = ""
+ continue
+ # Strip white spaces on the right
+ if tok_extended.rstrip and i > 0:
+ # A bit counter-intuitive but we strip the left of the string
+ # since tok_extended.rstrip means the special token is eating all white spaces on its right
+ sub_text = sub_text.lstrip()
+ # Strip white spaces on the left
+ if tok_extended.lstrip and i < len(split_text) - 1:
+ sub_text = sub_text.rstrip() # Opposite here
+ else:
+ # We strip left and right by default
+ if i < len(split_text) - 1:
+ sub_text = sub_text.rstrip()
+ if i > 0:
+ sub_text = sub_text.lstrip()
+
+ if i == 0 and not sub_text:
+ result += [tok]
+ elif i == len(split_text) - 1:
+ if sub_text:
+ result += [sub_text]
+ else:
+ pass
+ else:
+ if sub_text:
+ result += [sub_text]
+ result += [tok]
+ return result
+
+ def split_on_tokens(tok_list, text):
+ if not text.strip():
+ return []
+ if not tok_list:
+ return self._tokenize(text)
+
+ tokenized_text = []
+ text_list = [text]
+ for tok in tok_list:
+ tokenized_text = []
+ for sub_text in text_list:
+ if sub_text not in self.unique_no_split_tokens:
+ tokenized_text += split_on_token(tok, sub_text)
+ else:
+ tokenized_text += [sub_text]
+ text_list = tokenized_text
+
+ return list(
+ itertools.chain.from_iterable(
+ (
+ self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
+ for token in tokenized_text
+ )
+ )
+ )
+
+ no_split_token = self.unique_no_split_tokens
+ tokenized_text = split_on_tokens(no_split_token, text)
+ return tokenized_text
+
+ def _tokenize(self, text, **kwargs):
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
+ Split in words for word-based vocabulary or sub-words for sub-word-based
+ vocabularies (BPE/SentencePieces/WordPieces).
+
+ Do NOT take care of added tokens.
+ """
+ raise NotImplementedError
+
+ def convert_tokens_to_ids(self, tokens):
+ """ Converts a token string (or a sequence of tokens) in a single integer id
+ (or a sequence of ids), using the vocabulary.
+ """
+ if tokens is None:
+ return None
+
+ if isinstance(tokens, str):
+ return self._convert_token_to_id_with_added_voc(tokens)
+
+ ids = []
+ for token in tokens:
+ ids.append(self._convert_token_to_id_with_added_voc(token))
+ return ids
+
+ def _convert_token_to_id_with_added_voc(self, token):
+ if token is None:
+ return None
+
+ if token in self.added_tokens_encoder:
+ return self.added_tokens_encoder[token]
+ return self._convert_token_to_id(token)
+
+ def _convert_token_to_id(self, token):
+ raise NotImplementedError
+
+ def _encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput, EncodedInput],
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_pretokenized: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ def get_input_ids(text):
+ if isinstance(text, str):
+ tokens = self.tokenize(text, **kwargs)
+ return self.convert_tokens_to_ids(tokens)
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
+ if is_pretokenized:
+ tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
+ return self.convert_tokens_to_ids(tokens)
+ else:
+ return self.convert_tokens_to_ids(text)
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
+ return text
+ else:
+ if is_pretokenized:
+ raise ValueError(
+ f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_pretokenized=True`."
+ )
+ else:
+ raise ValueError(
+ f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
+ )
+
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers."
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast."
+ "More information on available tokenizers at "
+ "https://github.com/huggingface/transformers/pull/2674"
+ )
+
+ first_ids = get_input_ids(text)
+ second_ids = get_input_ids(text_pair) if text_pair is not None else None
+
+ return self.prepare_for_model(
+ first_ids,
+ pair_ids=second_ids,
+ add_special_tokens=add_special_tokens,
+ padding=padding_strategy.value,
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ prepend_batch_axis=True,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ verbose=verbose,
+ )
+
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ List[TextInput],
+ List[TextInputPair],
+ List[PreTokenizedInput],
+ List[PreTokenizedInputPair],
+ List[EncodedInput],
+ List[EncodedInputPair],
+ ],
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_pretokenized: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ def get_input_ids(text):
+ if isinstance(text, str):
+ tokens = self.tokenize(text, **kwargs)
+ return self.convert_tokens_to_ids(tokens)
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
+ if is_pretokenized:
+ tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
+ return self.convert_tokens_to_ids(tokens)
+ else:
+ return self.convert_tokens_to_ids(text)
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
+ return text
+ else:
+ raise ValueError(
+ "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
+ )
+
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers."
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast."
+ )
+
+ input_ids = []
+ for ids_or_pair_ids in batch_text_or_text_pairs:
+ if not isinstance(ids_or_pair_ids, (list, tuple)):
+ ids, pair_ids = ids_or_pair_ids, None
+ elif is_pretokenized and not isinstance(ids_or_pair_ids[0], (list, tuple)):
+ ids, pair_ids = ids_or_pair_ids, None
+ else:
+ ids, pair_ids = ids_or_pair_ids
+
+ first_ids = get_input_ids(ids)
+ second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
+ input_ids.append((first_ids, second_ids))
+
+ batch_outputs = self._batch_prepare_for_model(
+ input_ids,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=return_tensors,
+ verbose=verbose,
+ )
+
+ return BatchEncoding(batch_outputs)
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def _batch_prepare_for_model(
+ self,
+ batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[str] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ ) -> BatchEncoding:
+ """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
+ It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
+ manages a moving window (with user defined stride) for overflowing tokens
+
+ Args:
+ batch_ids_pairs: list of tokenized input ids or input ids pairs
+ """
+
+ batch_outputs = {}
+ for first_ids, second_ids in batch_ids_pairs:
+ outputs = self.prepare_for_model(
+ first_ids,
+ second_ids,
+ add_special_tokens=add_special_tokens,
+ padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=None, # we pad in batch afterward
+ return_attention_mask=False, # we pad in batch afterward
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=None, # We convert the whole batch to tensors at the end
+ prepend_batch_axis=False,
+ verbose=verbose,
+ )
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+
+ batch_outputs = self.pad(
+ batch_outputs,
+ padding=padding_strategy.value,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+ return batch_outputs
+
+ def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict):
+ """ Performs any necessary transformations before tokenization.
+
+ This method should pop the arguments from kwargs and return kwargs as well.
+ We test kwargs at the end of the encoding process to be sure all the arguments have been used.
+ """
+ return (text, kwargs)
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer ``prepare_for_model`` method.
+
+ Args:
+ token_ids_0: list of ids (must not contain special tokens)
+ token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
+ for sequence pairs
+ already_has_special_tokens: (default False) Set to True if the token list is already formated with
+ special tokens for the model
+
+ Returns:
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
+
+ def convert_ids_to_tokens(
+ self, ids: Union[int, List[int]], skip_special_tokens: bool = False
+ ) -> Union[str, List[str]]:
+ """ Converts a single index or a sequence of indices (integers) in a token "
+ (resp.) a sequence of tokens (str), using the vocabulary and added tokens.
+
+ Args:
+ skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
+ """
+ if isinstance(ids, int):
+ if ids in self.added_tokens_decoder:
+ return self.added_tokens_decoder[ids]
+ else:
+ return self._convert_id_to_token(ids)
+ tokens = []
+ for index in ids:
+ index = int(index)
+ if skip_special_tokens and index in self.all_special_ids:
+ continue
+ if index in self.added_tokens_decoder:
+ tokens.append(self.added_tokens_decoder[index])
+ else:
+ tokens.append(self._convert_id_to_token(index))
+ return tokens
+
+ def _convert_id_to_token(self, index: int) -> str:
+ raise NotImplementedError
+
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
+ """ Converts a sequence of tokens (string) in a single string.
+ The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
+ but we often want to remove sub-word tokenization artifacts at the same time.
+ """
+ return " ".join(self.convert_ids_to_tokens(tokens))
+
+ def decode(
+ self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
+ ) -> str:
+ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
+
+ # To avoid mixing byte-level and unicode for byte-level BPT
+ # we need to build string separatly for added tokens and byte-level tokens
+ # cf. https://github.com/huggingface/transformers/issues/1133
+ sub_texts = []
+ current_sub_text = []
+ for token in filtered_tokens:
+ if skip_special_tokens and token in self.all_special_ids:
+ continue
+ if token in self.added_tokens_encoder:
+ if current_sub_text:
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
+ current_sub_text = []
+ sub_texts.append(token)
+ else:
+ current_sub_text.append(token)
+ if current_sub_text:
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
+ text = " ".join(sub_texts)
+
+ if clean_up_tokenization_spaces:
+ clean_text = self.clean_up_tokenization(text)
+ return clean_text
+ else:
+ return text
+
+ def save_vocabulary(self, save_directory) -> Tuple[str]:
+ """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
+ and special token mappings.
+
+ Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full
+ Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained`
+ class method.
+ """
+ raise NotImplementedError
diff --git a/elia/bert/tokenization_utils_base.py b/elia/bert/tokenization_utils_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a1219a4d3473ae510f0da905ea09a76019a7996
--- /dev/null
+++ b/elia/bert/tokenization_utils_base.py
@@ -0,0 +1,2317 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Base classes common to both the slow and the fast tokenization classes:
+ PreTrainedTokenizerBase (host all the user fronting encoding methodes)
+ Special token mixing (host the special tokens logic) and
+ BatchEncoding (wrap the dictionnary of output with special method for the Fast tokenizers)
+"""
+
+import copy
+import json
+import logging
+import os
+import warnings
+from collections import UserDict
+from enum import Enum
+from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
+
+import numpy as np
+from tokenizers import AddedToken
+from tokenizers import Encoding as EncodingFast
+
+from .file_utils import (
+ add_end_docstrings,
+ cached_path,
+ hf_bucket_url,
+ is_remote_url,
+ is_tf_available,
+ is_torch_available,
+ torch_required,
+)
+
+
+if is_tf_available():
+ import tensorflow as tf
+if is_torch_available():
+ import torch
+
+
+logger = logging.getLogger(__name__)
+
+VERY_LARGE_INTEGER = int(1e30) # This is used to set the max input length for a model with infinite size input
+LARGE_INTEGER = int(1e20) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
+
+# Define type aliases and NamedTuples
+TextInput = str
+PreTokenizedInput = List[str]
+EncodedInput = List[int]
+TextInputPair = Tuple[str, str]
+PreTokenizedInputPair = Tuple[List[str], List[str]]
+EncodedInputPair = Tuple[List[int], List[int]]
+
+
+# Slow tokenizers used to be saved in three separated files
+SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
+ADDED_TOKENS_FILE = "added_tokens.json"
+TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
+
+# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
+FULL_TOKENIZER_FILE = "tokenizer.json"
+
+
+class ExplicitEnum(Enum):
+ """ Enum with more explicit error message for missing values.
+ """
+
+ @classmethod
+ def _missing_(cls, value):
+ raise ValueError(
+ "%r is not a valid %s, please select one of %s"
+ % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
+ )
+
+
+class TruncationStrategy(ExplicitEnum):
+ ONLY_FIRST = "only_first"
+ ONLY_SECOND = "only_second"
+ LONGEST_FIRST = "longest_first"
+ DO_NOT_TRUNCATE = "do_not_truncate"
+
+
+class PaddingStrategy(ExplicitEnum):
+ LONGEST = "longest"
+ MAX_LENGTH = "max_length"
+ DO_NOT_PAD = "do_not_pad"
+
+
+class TensorType(ExplicitEnum):
+ PYTORCH = "pt"
+ TENSORFLOW = "tf"
+ NUMPY = "np"
+
+
+class CharSpan(NamedTuple):
+ """ Character span in the original string
+
+ Args:
+ start: index of the first character in the original string
+ end: index of the character following the last character in the original string
+ """
+
+ start: int
+ end: int
+
+
+class TokenSpan(NamedTuple):
+ """ Token span in an encoded string (list of tokens)
+
+ Args:
+ start: index of the first token in the span
+ end: index of the token following the last token in the span
+ """
+
+ start: int
+ end: int
+
+
+class BatchEncoding(UserDict):
+ """ BatchEncoding hold the output of the encode and batch_encode methods (tokens, attention_masks, etc).
+ This class is derived from a python Dictionary and can be used as a dictionnary.
+ In addition, this class expose utility methods to map from word/char space to token space.
+
+ Args:
+ data (:obj:`dict`): Dictionary of lists/arrays returned by the encode/batch_encode methods ('input_ids', 'attention_mask'...)
+ encoding (:obj:`EncodingFast`, :obj:`list(EncodingFast)`, `optional`, defaults to :obj:`None`):
+ If the tokenizer is a fast tokenizer which outputs additional informations like mapping from word/char space to token space
+ the `EncodingFast` instance or list of instance (for batches) hold these informations.
+ tensor_type (:obj:`Union[None, str, TensorType]`, `optional`, defaults to :obj:`None`):
+ You can give a tensor_type here to convert the lists of integers in PyTorch/TF/Numpy Tensors at initialization
+ prepend_batch_axis (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Set to True to add a batch axis when converting in Tensors (see :obj:`tensor_type` above)
+ """
+
+ def __init__(
+ self,
+ data: Optional[Dict[str, Any]] = None,
+ encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None,
+ tensor_type: Union[None, str, TensorType] = None,
+ prepend_batch_axis: bool = False,
+ ):
+ super().__init__(data)
+
+ if isinstance(encoding, EncodingFast):
+ encoding = [encoding]
+
+ self._encodings = encoding
+
+ self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
+
+ @property
+ def is_fast(self):
+ """
+ Indicate if this BatchEncoding was generated from the result of a PreTrainedTokenizerFast
+ Returns: True if generated from subclasses of PreTrainedTokenizerFast, else otherwise
+ """
+ return self._encodings is not None
+
+ def __getitem__(self, item: Union[int, str]) -> EncodingFast:
+ """ If the key is a string, get the value of the dict associated to `key` ('input_ids', 'attention_mask'...)
+ If the key is an integer, get the EncodingFast for batch item with index `key`
+ """
+ if isinstance(item, str):
+ return self.data[item]
+ elif self._encodings is not None:
+ return self._encodings[item]
+ else:
+ raise KeyError(
+ "Indexing with integers (to access backend Encoding for a given batch index) "
+ "is not available when using Python based tokenizers"
+ )
+
+ def __getattr__(self, item: str):
+ try:
+ return self.data[item]
+ except KeyError:
+ raise AttributeError
+
+ def __getstate__(self):
+ return {"data": self.data, "encodings": self._encodings}
+
+ def __setstate__(self, state):
+ if "data" in state:
+ self.data = state["data"]
+
+ if "encodings" in state:
+ self._encodings = state["encodings"]
+
+ def keys(self):
+ return self.data.keys()
+
+ def values(self):
+ return self.data.values()
+
+ def items(self):
+ return self.data.items()
+
+ # After this point:
+ # Extended properties and methods only available for fast (Rust-based) tokenizers
+ # provided by HuggingFace tokenizers library.
+
+ @property
+ def encodings(self) -> Optional[List[EncodingFast]]:
+ """
+ Return the list all encoding from the tokenization process
+
+ Returns: List[EncodingFast] or None if input was tokenized through Python (i.e. not fast) tokenizer
+ """
+ return self._encodings
+
+ def tokens(self, batch_index: int = 0) -> List[str]:
+ if not self._encodings:
+ raise ValueError("tokens() is not available when using Python based tokenizers")
+ return self._encodings[batch_index].tokens
+
+ def words(self, batch_index: int = 0) -> List[Optional[int]]:
+ if not self._encodings:
+ raise ValueError("words() is not available when using Python based tokenizers")
+ return self._encodings[batch_index].words
+
+ def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:
+ """
+ Get the index of the word corresponding (i.e. comprising) to an encoded token
+ in a sequence of the batch.
+
+ Can be called as:
+
+ - ``self.token_to_word(token_index)`` if batch size is 1
+ - ``self.token_to_word(batch_index, token_index)`` if batch size is greater than 1
+
+ This method is particularly suited when the input sequences are provided as
+ pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
+ to easily associate encoded tokens with provided tokenized words.
+
+ Args:
+ batch_or_token_index (:obj:`int`):
+ Index of the sequence in the batch. If the batch only comprise one sequence,
+ this can be the index of the token in the sequence
+ token_index (:obj:`int`, `optional`):
+ If a batch index is provided in `batch_or_token_index`, this can be the index
+ of the token in the sequence.
+
+ Returns:
+ :obj:`int`:
+ index of the word in the input sequence.
+
+ """
+
+ if not self._encodings:
+ raise ValueError("token_to_word() is not available when using Python based tokenizers")
+ if token_index is not None:
+ batch_index = batch_or_token_index
+ else:
+ batch_index = 0
+ token_index = batch_or_token_index
+ if batch_index < 0:
+ batch_index = self._batch_size + batch_index
+ if token_index < 0:
+ token_index = self._seq_len + token_index
+ return self._encodings[batch_index].token_to_word(token_index)
+
+ def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> TokenSpan:
+ """
+ Get the encoded token span corresponding to a word in the sequence of the batch.
+
+ Token spans are returned as a TokenSpan NamedTuple with:
+
+ - start: index of the first token
+ - end: index of the token following the last token
+
+ Can be called as:
+
+ - ``self.word_to_tokens(word_index)`` if batch size is 1
+ - ``self.word_to_tokens(batch_index, word_index)`` if batch size is greater or equal to 1
+
+ This method is particularly suited when the input sequences are provided as
+ pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
+ to easily associate encoded tokens with provided tokenized words.
+
+ Args:
+ batch_or_word_index (:obj:`int`):
+ Index of the sequence in the batch. If the batch only comprises one sequence,
+ this can be the index of the word in the sequence
+ word_index (:obj:`int`, `optional`):
+ If a batch index is provided in `batch_or_token_index`, this can be the index
+ of the word in the sequence.
+
+ Returns:
+ :obj:`TokenSpan`:
+ Span of tokens in the encoded sequence.
+
+ :obj:`TokenSpan` are NamedTuple with:
+
+ - start: index of the first token
+ - end: index of the token following the last token
+ """
+
+ if not self._encodings:
+ raise ValueError("word_to_tokens() is not available when using Python based tokenizers")
+ if word_index is not None:
+ batch_index = batch_or_word_index
+ else:
+ batch_index = 0
+ word_index = batch_or_word_index
+ if batch_index < 0:
+ batch_index = self._batch_size + batch_index
+ if word_index < 0:
+ word_index = self._seq_len + word_index
+ return TokenSpan(*(self._encodings[batch_index].word_to_tokens(word_index)))
+
+ def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan:
+ """
+ Get the character span corresponding to an encoded token in a sequence of the batch.
+
+ Character spans are returned as a CharSpan NamedTuple with:
+
+ - start: index of the first character in the original string associated to the token
+ - end: index of the character following the last character in the original string associated to the token
+
+ Can be called as:
+
+ - ``self.token_to_chars(token_index)`` if batch size is 1
+ - ``self.token_to_chars(batch_index, token_index)`` if batch size is greater or equal to 1
+
+ Args:
+ batch_or_token_index (:obj:`int`):
+ Index of the sequence in the batch. If the batch only comprise one sequence,
+ this can be the index of the token in the sequence
+ token_index (:obj:`int`, `optional`):
+ If a batch index is provided in `batch_or_token_index`, this can be the index
+ of the token or tokens in the sequence.
+
+ Returns:
+ :obj:`CharSpan`:
+ Span of characters in the original string.
+
+ :obj:`CharSpan` are NamedTuple with:
+
+ - start: index of the first character in the original string
+ - end: index of the character following the last character in the original string
+ """
+
+ if not self._encodings:
+ raise ValueError("token_to_chars() is not available when using Python based tokenizers")
+ if token_index is not None:
+ batch_index = batch_or_token_index
+ else:
+ batch_index = 0
+ token_index = batch_or_token_index
+ return CharSpan(*(self._encodings[batch_index].token_to_chars(token_index)))
+
+ def char_to_token(self, batch_or_char_index: int, char_index: Optional[int] = None) -> int:
+ """
+ Get the index of the token in the encoded output comprising a character
+ in the original string for a sequence of the batch.
+
+ Can be called as:
+
+ - ``self.char_to_token(char_index)`` if batch size is 1
+ - ``self.char_to_token(batch_index, char_index)`` if batch size is greater or equal to 1
+
+ This method is particularly suited when the input sequences are provided as
+ pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
+ to easily associate encoded tokens with provided tokenized words.
+
+ Args:
+ batch_or_char_index (:obj:`int`):
+ Index of the sequence in the batch. If the batch only comprise one sequence,
+ this can be the index of the word in the sequence
+ char_index (:obj:`int`, `optional`):
+ If a batch index is provided in `batch_or_token_index`, this can be the index
+ of the word in the sequence.
+
+
+ Returns:
+ :obj:`int`: Index of the token.
+ """
+
+ if not self._encodings:
+ raise ValueError("char_to_token() is not available when using Python based tokenizers")
+ if char_index is not None:
+ batch_index = batch_or_char_index
+ else:
+ batch_index = 0
+ char_index = batch_or_char_index
+ return self._encodings[batch_index].char_to_token(char_index)
+
+ def word_to_chars(self, batch_or_word_index: int, word_index: Optional[int] = None) -> CharSpan:
+ """
+ Get the character span in the original string corresponding to given word in a sequence
+ of the batch.
+
+ Character spans are returned as a CharSpan NamedTuple with:
+
+ - start: index of the first character in the original string
+ - end: index of the character following the last character in the original string
+
+ Can be called as:
+
+ - ``self.word_to_chars(word_index)`` if batch size is 1
+ - ``self.word_to_chars(batch_index, word_index)`` if batch size is greater or equal to 1
+
+ Args:
+ batch_or_word_index (:obj:`int`):
+ Index of the sequence in the batch. If the batch only comprise one sequence,
+ this can be the index of the word in the sequence
+ word_index (:obj:`int`, `optional`):
+ If a batch index is provided in `batch_or_token_index`, this can be the index
+ of the word in the sequence.
+
+ Returns:
+ :obj:`CharSpan` or :obj:`List[CharSpan]`:
+ Span(s) of the associated character or characters in the string.
+ CharSpan are NamedTuple with:
+
+ - start: index of the first character associated to the token in the original string
+ - end: index of the character following the last character associated to the token in the original string
+ """
+
+ if not self._encodings:
+ raise ValueError("word_to_chars() is not available when using Python based tokenizers")
+ if word_index is not None:
+ batch_index = batch_or_word_index
+ else:
+ batch_index = 0
+ word_index = batch_or_word_index
+ return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index)))
+
+ def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None) -> int:
+ """
+ Get the word in the original string corresponding to a character in the original string of
+ a sequence of the batch.
+
+ Can be called as:
+
+ - ``self.char_to_word(char_index)`` if batch size is 1
+ - ``self.char_to_word(batch_index, char_index)`` if batch size is greater than 1
+
+ This method is particularly suited when the input sequences are provided as
+ pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
+ to easily associate encoded tokens with provided tokenized words.
+
+ Args:
+ batch_or_char_index (:obj:`int`):
+ Index of the sequence in the batch. If the batch only comprise one sequence,
+ this can be the index of the character in the orginal string.
+ char_index (:obj:`int`, `optional`):
+ If a batch index is provided in `batch_or_token_index`, this can be the index
+ of the character in the orginal string.
+
+
+ Returns:
+ :obj:`int` or :obj:`List[int]`:
+ Index or indices of the associated encoded token(s).
+ """
+
+ if not self._encodings:
+ raise ValueError("char_to_word() is not available when using Python based tokenizers")
+ if char_index is not None:
+ batch_index = batch_or_char_index
+ else:
+ batch_index = 0
+ char_index = batch_or_char_index
+ return self._encodings[batch_index].char_to_word(char_index)
+
+ def convert_to_tensors(self, tensor_type: Union[None, str, TensorType], prepend_batch_axis: bool = False):
+ if tensor_type is None:
+ return self
+
+ # Convert to TensorType
+ if not isinstance(tensor_type, TensorType):
+ tensor_type = TensorType(tensor_type)
+
+ # Get a function reference for the correct framework
+ if tensor_type == TensorType.TENSORFLOW and is_tf_available():
+ as_tensor = tf.constant
+ elif tensor_type == TensorType.PYTORCH and is_torch_available():
+ as_tensor = torch.tensor
+ elif tensor_type == TensorType.NUMPY:
+ as_tensor = np.asarray
+ else:
+ raise ImportError(
+ "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
+ tensor_type
+ )
+ )
+
+ # Do the tensor conversion in batch
+ for key, value in self.items():
+ try:
+ if prepend_batch_axis:
+ value = [value]
+
+ tensor = as_tensor(value)
+
+ # at-least2d
+ if tensor.ndim > 2:
+ tensor = tensor.squeeze(0)
+ elif tensor.ndim < 2:
+ tensor = tensor[None, :]
+
+ self[key] = tensor
+ except: # noqa E722
+ raise ValueError(
+ "Unable to create tensor, you should probably activate truncation and/or padding "
+ "with 'padding=True' 'truncation=True' to have batched tensors with the same length."
+ )
+
+ return self
+
+ @torch_required
+ def to(self, device: str):
+ """Send all values to device by calling v.to(device)"""
+ self.data = {k: v.to(device) for k, v in self.data.items()}
+ return self
+
+
+# class AddedToken(UserString):
+# """ AddedToken represents a token to be added to a Tokenizer
+
+# An AddedToken can have special options defining the way it should behave.
+
+# Args:
+# content: str:
+# The content of the token
+
+# single_word: bool
+# Whether this token should only match against single word. If True,
+# this token will never match inside of a word.
+
+# lstrip: bool
+# Whether this token should strip all potential whitespaces on the left side.
+# If True, this token will greedily match any whitespace on the left and then strip
+# them out.
+
+# rstrip: bool
+# Whether this token should strip all potential whitespaces on the right side.
+# If True, this token will greedily match any whitespace on the right and then strip
+# them out.
+# """
+
+# def __init__(
+# self, data: str, single_word: bool = False, lstrip: bool = False, rstrip: bool = False,
+# ):
+# super().__init__(data)
+
+# self._single_word = single_word
+# self._lstrip = lstrip
+# self._rstrip = rstrip
+
+# def lower(self):
+# return AddedToken(self.data.lower(), self._single_word, self._lstrip, self._rstrip)
+
+
+class SpecialTokensMixin:
+ """ SpecialTokensMixin is derived by ``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` and
+ handles specific behaviors related to special tokens. In particular, this class hold the
+ attributes which can be used to directly access to these special tokens in a
+ model-independant manner and allow to set and update the special tokens.
+ """
+
+ SPECIAL_TOKENS_ATTRIBUTES = [
+ "bos_token",
+ "eos_token",
+ "unk_token",
+ "sep_token",
+ "pad_token",
+ "cls_token",
+ "mask_token",
+ "additional_special_tokens",
+ ]
+
+ def __init__(self, verbose=True, **kwargs):
+ self._bos_token = None
+ self._eos_token = None
+ self._unk_token = None
+ self._sep_token = None
+ self._pad_token = None
+ self._cls_token = None
+ self._mask_token = None
+ self._pad_token_type_id = 0
+ self._additional_special_tokens = []
+ self.verbose = verbose
+
+ # We directly set the hidden value to allow initialization with special tokens
+ # which are not yet in the vocabulary. Necesssary for serialization/de-serialization
+ # TODO clean this up at some point (probably by sitching to fast tokenizers)
+ for key, value in kwargs.items():
+ if key in self.SPECIAL_TOKENS_ATTRIBUTES:
+ if key == "additional_special_tokens":
+ assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
+ setattr(self, key, value)
+ elif isinstance(value, (str, AddedToken)):
+ setattr(self, key, value)
+ else:
+ raise TypeError(
+ "special token {} has to be either str or AddedToken but got: {}".format(key, type(value))
+ )
+
+ def sanitize_special_tokens(self) -> int:
+ """ Make sure that all the special tokens attributes of the tokenizer (tokenizer.mask_token, tokenizer.cls_token, ...)
+ are in the vocabulary. Add the missing ones to the vocabulary if needed.
+
+ Return:
+ Number of tokens added in the vocaulary during the operation.
+ """
+ return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
+
+ def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
+ """
+ Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
+ to class attributes. If special tokens are NOT in the vocabulary, they are added
+ to it (indexed starting from the last index of the current vocabulary).
+
+ Using `add_special_tokens` will ensure your special tokens can be used in several ways:
+
+ - special tokens are carefully handled by the tokenizer (they are never split)
+ - you can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This makes it easy to develop model-agnostic training and fine-tuning scripts.
+
+ When possible, special tokens are already registered for provided pretrained models (ex: BertTokenizer cls_token is already registered to be '[CLS]' and XLM's one is also registered to be '')
+
+ Args:
+ special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes:
+ [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,
+ ``additional_special_tokens``].
+
+ Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
+
+ Returns:
+ Number of tokens added to the vocabulary.
+
+ Examples::
+
+ # Let's see how to add a new classification token to GPT-2
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ model = GPT2Model.from_pretrained('gpt2')
+
+ special_tokens_dict = {'cls_token': ''}
+
+ num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
+ print('We have added', num_added_toks, 'tokens')
+ model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
+
+ assert tokenizer.cls_token == ''
+ """
+ if not special_tokens_dict:
+ return 0
+
+ added_tokens = 0
+ for key, value in special_tokens_dict.items():
+ assert key in self.SPECIAL_TOKENS_ATTRIBUTES
+
+ if self.verbose:
+ logger.info("Assigning %s to the %s key of the tokenizer", value, key)
+ setattr(self, key, value)
+
+ if key == "additional_special_tokens":
+ assert isinstance(value, (list, tuple)) and all(
+ isinstance(t, (str, AddedToken)) for t in value
+ ), f"Tokens {value} for key {key} should all be str or AddedToken instances"
+ added_tokens += self.add_tokens(value, special_tokens=True)
+ else:
+ assert isinstance(
+ value, (str, AddedToken)
+ ), f"Token {value} for key {key} should be a str or an AddedToken instance"
+ added_tokens += self.add_tokens([value], special_tokens=True)
+
+ return added_tokens
+
+ def add_tokens(self, new_tokens: Union[str, AddedToken, List[str], List[AddedToken]], special_tokens=False) -> int:
+ """
+ Add a list of new tokens to the tokenizer class. If the new tokens are not in the
+ vocabulary, they are added to it with indices starting from length of the current vocabulary.
+
+ Args:
+ new_tokens: string or list of string or :class:`~transformers.AddedToken`. Each string is a token to add.
+ Tokens are only added if they are not already in the vocabulary. AddedToken wrap a string token to
+ let you personnalize it's behavior (Whether this token should only match against single word, whether
+ this token should strip all potential whitespaces on the left side, Whether this token should strip
+ all potential whitespaces on the right side...).
+ special_token: can be used to specify if the token is a special token. This mostly change the normalization
+ behavior (special tokens like CLS or [MASK] are usually not lower-cased for instance)
+
+ See details for :class:`~transformers.AddedToken` in HuggingFace tokenizers library.
+
+ Returns:
+ Number of tokens added to the vocabulary.
+
+ Examples::
+
+ # Let's see how to increase the vocabulary of Bert model and tokenizer
+ tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
+ model = BertModel.from_pretrained('bert-base-uncased')
+
+ num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
+ print('We have added', num_added_toks, 'tokens')
+ model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
+ """
+ if not new_tokens:
+ return 0
+
+ if not isinstance(new_tokens, (list, tuple)):
+ new_tokens = [new_tokens]
+
+ return self._add_tokens(new_tokens, special_tokens=special_tokens)
+
+ @property
+ def bos_token(self):
+ """ Beginning of sentence token (string). Log an error if used while not having been set. """
+ if self._bos_token is None and self.verbose:
+ logger.error("Using bos_token, but it is not set yet.")
+ return None
+ return str(self._bos_token)
+
+ @property
+ def eos_token(self):
+ """ End of sentence token (string). Log an error if used while not having been set. """
+ if self._eos_token is None and self.verbose:
+ logger.error("Using eos_token, but it is not set yet.")
+ return None
+ return str(self._eos_token)
+
+ @property
+ def unk_token(self):
+ """ Unknown token (string). Log an error if used while not having been set. """
+ if self._unk_token is None and self.verbose:
+ logger.error("Using unk_token, but it is not set yet.")
+ return None
+ return str(self._unk_token)
+
+ @property
+ def sep_token(self):
+ """ Separation token (string). E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
+ if self._sep_token is None and self.verbose:
+ logger.error("Using sep_token, but it is not set yet.")
+ return None
+ return str(self._sep_token)
+
+ @property
+ def pad_token(self):
+ """ Padding token (string). Log an error if used while not having been set. """
+ if self._pad_token is None and self.verbose:
+ logger.error("Using pad_token, but it is not set yet.")
+ return None
+ return str(self._pad_token)
+
+ @property
+ def cls_token(self):
+ """ Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
+ if self._cls_token is None and self.verbose:
+ logger.error("Using cls_token, but it is not set yet.")
+ return None
+ return str(self._cls_token)
+
+ @property
+ def mask_token(self):
+ """ Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
+ if self._mask_token is None and self.verbose:
+ logger.error("Using mask_token, but it is not set yet.")
+ return None
+ return str(self._mask_token)
+
+ @property
+ def additional_special_tokens(self):
+ """ All the additional special tokens you may want to use (list of strings). Log an error if used while not having been set. """
+ if self._additional_special_tokens is None and self.verbose:
+ logger.error("Using additional_special_tokens, but it is not set yet.")
+ return None
+ return [str(tok) for tok in self._additional_special_tokens]
+
+ @bos_token.setter
+ def bos_token(self, value):
+ self._bos_token = value
+
+ @eos_token.setter
+ def eos_token(self, value):
+ self._eos_token = value
+
+ @unk_token.setter
+ def unk_token(self, value):
+ self._unk_token = value
+
+ @sep_token.setter
+ def sep_token(self, value):
+ self._sep_token = value
+
+ @pad_token.setter
+ def pad_token(self, value):
+ self._pad_token = value
+
+ @cls_token.setter
+ def cls_token(self, value):
+ self._cls_token = value
+
+ @mask_token.setter
+ def mask_token(self, value):
+ self._mask_token = value
+
+ @additional_special_tokens.setter
+ def additional_special_tokens(self, value):
+ self._additional_special_tokens = value
+
+ @property
+ def bos_token_id(self):
+ """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
+ if self._bos_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.bos_token)
+
+ @property
+ def eos_token_id(self):
+ """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
+ if self._eos_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.eos_token)
+
+ @property
+ def unk_token_id(self):
+ """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
+ if self._unk_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.unk_token)
+
+ @property
+ def sep_token_id(self):
+ """ Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
+ if self._sep_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.sep_token)
+
+ @property
+ def pad_token_id(self):
+ """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """
+ if self._pad_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.pad_token)
+
+ @property
+ def pad_token_type_id(self):
+ """ Id of the padding token type in the vocabulary."""
+ return self._pad_token_type_id
+
+ @property
+ def cls_token_id(self):
+ """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
+ if self._cls_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.cls_token)
+
+ @property
+ def mask_token_id(self):
+ """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
+ if self._mask_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.mask_token)
+
+ @property
+ def additional_special_tokens_ids(self):
+ """ Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
+ return self.convert_tokens_to_ids(self.additional_special_tokens)
+
+ @property
+ def special_tokens_map(self):
+ """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their
+ values ('', ''...)
+ Convert tokens of AddedToken type in string.
+ All returned tokens are strings
+ """
+ set_attr = {}
+ for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
+ attr_value = getattr(self, "_" + attr)
+ if attr_value:
+ set_attr[attr] = str(attr_value)
+ return set_attr
+
+ @property
+ def special_tokens_map_extended(self):
+ """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their
+ values ('', ''...)
+ Keep the tokens as AddedToken if they are of this type.
+
+ AddedToken can be used to control more finely how special tokens are tokenized.
+ """
+ set_attr = {}
+ for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
+ attr_value = getattr(self, "_" + attr)
+ if attr_value:
+ set_attr[attr] = attr_value
+ return set_attr
+
+ @property
+ def all_special_tokens(self):
+ """ List all the special tokens ('', ''...) mapped to class attributes
+ Convert tokens of AddedToken type in string.
+ All returned tokens are strings
+ (cls_token, unk_token...).
+ """
+ all_toks = [str(s) for s in self.all_special_tokens_extended]
+ return all_toks
+
+ @property
+ def all_special_tokens_extended(self):
+ """ List all the special tokens ('', ''...) mapped to class attributes
+ Keep the tokens as AddedToken if they are of this type.
+
+ AddedToken can be used to control more finely how special tokens are tokenized.
+ """
+ all_toks = []
+ set_attr = self.special_tokens_map_extended
+ for attr_value in set_attr.values():
+ all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])
+ all_toks = list(set(all_toks))
+ return all_toks
+
+ @property
+ def all_special_ids(self):
+ """ List the vocabulary indices of the special tokens ('', ''...) mapped to
+ class attributes (cls_token, unk_token...).
+ """
+ all_toks = self.all_special_tokens
+ all_ids = self.convert_tokens_to_ids(all_toks)
+ return all_ids
+
+
+ENCODE_KWARGS_DOCSTRING = r"""
+ add_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ If set to ``True``, the sequences will be encoded with the special tokens relative
+ to their model.
+ `padding` (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`):
+ Activate and control padding. Accepts the following values:
+
+ * `True` or `'longest'`: pad to the longest sequence in the batch (or no padding if only a single sequence if provided),
+ * `'max_length'`: pad to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`)
+ * `False` or `'do_not_pad'` (default): No padding (i.e. can output batch with sequences of uneven lengths)
+ `truncation` (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`):
+ Activate and control truncation. Accepts the following values:
+
+ * `True` or `'longest_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided,
+ * `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided,
+ * `'only_second'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided,
+ * `False` or `'do_not_truncate'` (default): No truncation (i.e. can output batch with sequences length greater than the model max admissible input size)
+ `max_length` (:obj:`Union[int, None]`, `optional`, defaults to :obj:`None`):
+ Control the length for padding/truncation. Accepts the following values
+
+ * `None` (default): This will use the predefined model max length if required by one of the truncation/padding parameters. If the model has no specific max input length (e.g. XLNet) truncation/padding to max length is deactivated.
+ * `any integer value` (e.g. `42`): Use this specific maximum length value if required by one of the truncation/padding parameters.
+ stride (:obj:`int`, `optional`, defaults to ``0``):
+ If set to a number along with max_length, the overflowing tokens returned when `return_overflowing_tokens=True`
+ will contain some tokens from the end of the truncated sequence returned to provide some overlap between truncated and overflow ing sequences.
+ The value of this argument defines the number of overlapping tokens.
+ is_pretokenized (:obj:`bool`, defaults to :obj:`False`):
+ Set to True to indicate the input is already tokenized
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ >= 7.5 (Volta).
+ return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
+ Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`,
+ PyTorch :obj:`torch.Tensor` or Numpy :oj: `np.ndarray` instead of a list of python integers.
+"""
+
+ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
+ return_token_type_ids (:obj:`bool`, `optional`, defaults to :obj:`None`):
+ Whether to return token type IDs. If left to the default, will return the token type IDs according
+ to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
+
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
+ return_attention_mask (:obj:`bool`, `optional`, defaults to :obj:`none`):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
+
+ `What are attention masks? <../glossary.html#attention-mask>`__
+ return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Set to True to return overflowing token sequences (default False).
+ return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Set to True to return special tokens mask information (default False).
+ return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Set to True to return (char_start, char_end) for each token (default False).
+ If using Python's tokenizer, this method will raise NotImplementedError.
+ This one is only available on fast tokenizers inheriting from PreTrainedTokenizerFast.
+ **kwargs: passed to the `self.tokenize()` method
+
+ Return:
+ A Dictionary of shape::
+
+ {
+ input_ids: list[int],
+ token_type_ids: list[int] if return_token_type_ids is True (default)
+ attention_mask: list[int] if return_attention_mask is True (default)
+ overflowing_tokens: list[int] if the tokenizer is a slow tokenize, else a List[List[int]] if a ``max_length`` is specified and ``return_overflowing_tokens=True``
+ special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True``
+ and return_special_tokens_mask is True
+ }
+
+ With the fields:
+
+ - ``input_ids``: list of token ids to be fed to a model
+ - ``token_type_ids``: list of token type ids to be fed to a model
+ - ``attention_mask``: list of indices specifying which tokens should be attended to by the model
+ - ``overflowing_tokens``: list of overflowing tokens sequences if a max length is specified and ``return_overflowing_tokens=True``.
+ - ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
+ tokens and 1 specifying sequence tokens.
+"""
+
+
+class PreTrainedTokenizerBase(SpecialTokensMixin):
+ """ Base class for slow and fast tokenizers.
+
+ Handle shared (mostly boiler plate) methods for slow and fast tokenizers.
+ """
+
+ vocab_files_names: Dict[str, str] = {}
+ pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {}
+ pretrained_init_configuration: Dict[str, Dict[str, Any]] = {}
+ max_model_input_sizes: Dict[str, int] = {}
+ model_input_names: List[str] = ["token_type_ids", "attention_mask"]
+
+ padding_side: str = "right"
+
+ def __init__(self, **kwargs):
+ # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
+ self.init_inputs = ()
+ self.init_kwargs = kwargs
+
+ # For backward compatibility we fallback to set model_max_length from max_len if provided
+ model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
+ self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER
+
+ # Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed.
+ self.padding_side = kwargs.pop("padding_side", self.padding_side)
+ assert self.padding_side in [
+ "right",
+ "left",
+ ], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
+ self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
+
+ super().__init__(**kwargs)
+
+ @property
+ def max_len(self) -> int:
+ """ Kept here for backward compatibility.
+ Now renamed to `model_max_length` to avoid ambiguity.
+ """
+ return self.model_max_length
+
+ @property
+ def max_len_single_sentence(self) -> int:
+ return self.model_max_length - self.num_special_tokens_to_add(pair=False)
+
+ @property
+ def max_len_sentences_pair(self) -> int:
+ return self.model_max_length - self.num_special_tokens_to_add(pair=True)
+
+ @max_len_single_sentence.setter
+ def max_len_single_sentence(self, value) -> int:
+ """ For backward compatibility, allow to try to setup 'max_len_single_sentence' """
+ if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose:
+ logger.warning(
+ "Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up."
+ )
+ else:
+ raise ValueError(
+ "Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up."
+ )
+
+ @max_len_sentences_pair.setter
+ def max_len_sentences_pair(self, value) -> int:
+ """ For backward compatibility, allow to try to setup 'max_len_sentences_pair' """
+ if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose:
+ logger.warning(
+ "Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up."
+ )
+ else:
+ raise ValueError(
+ "Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up."
+ )
+
+ @classmethod
+ def from_pretrained(cls, *inputs, **kwargs):
+ r"""
+ Instantiate a :class:`~transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer.
+
+ Args:
+ pretrained_model_name_or_path: either:
+
+ - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
+ - a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
+ - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
+ - (not applicable to all derived classes, deprecated) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
+
+ cache_dir: (`optional`) string:
+ Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
+
+ force_download: (`optional`) boolean, default False:
+ Force to (re-)download the vocabulary files and override the cached versions if they exists.
+
+ resume_download: (`optional`) boolean, default False:
+ Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
+
+ proxies: (`optional`) dict, default None:
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
+ The proxies are used on each request.
+
+ inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
+
+ kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details.
+
+ Examples::
+
+ # We can't instantiate directly the base class `PreTrainedTokenizer` so let's show our examples on a derived class: BertTokenizer
+
+ # Download vocabulary from S3 and cache.
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+
+ # Download vocabulary from S3 (user-uploaded) and cache.
+ tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-german-cased')
+
+ # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
+ tokenizer = BertTokenizer.from_pretrained('./test/saved_model/')
+
+ # If the tokenizer uses a single vocabulary file, you can point directly to this file
+ tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt')
+
+ # You can link tokens to special vocabulary when instantiating
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='')
+ # You should be sure '' is in the vocabulary when doing that.
+ # Otherwise use tokenizer.add_special_tokens({'unk_token': ''}) instead)
+ assert tokenizer.unk_token == ''
+
+ """
+ return cls._from_pretrained(*inputs, **kwargs)
+
+ @classmethod
+ def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+
+ s3_models = list(cls.max_model_input_sizes.keys())
+ vocab_files = {}
+ init_configuration = {}
+ if pretrained_model_name_or_path in s3_models:
+ # Get the vocabulary from AWS S3 bucket
+ for file_id, map_list in cls.pretrained_vocab_files_map.items():
+ vocab_files[file_id] = map_list[pretrained_model_name_or_path]
+ if (
+ cls.pretrained_init_configuration
+ and pretrained_model_name_or_path in cls.pretrained_init_configuration
+ ):
+ init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path].copy()
+ else:
+ # Get the vocabulary from local files
+ logger.info(
+ "Model name '{}' not found in model shortcut name list ({}). "
+ "Assuming '{}' is a path, a model identifier, or url to a directory containing tokenizer files.".format(
+ pretrained_model_name_or_path, ", ".join(s3_models), pretrained_model_name_or_path
+ )
+ )
+
+ if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
+ if len(cls.vocab_files_names) > 1:
+ raise ValueError(
+ "Calling {}.from_pretrained() with the path to a single file or url is not supported."
+ "Use a model identifier or the path to a directory instead.".format(cls.__name__)
+ )
+ logger.warning(
+ "Calling {}.from_pretrained() with the path to a single file or url is deprecated".format(
+ cls.__name__
+ )
+ )
+ file_id = list(cls.vocab_files_names.keys())[0]
+ vocab_files[file_id] = pretrained_model_name_or_path
+ else:
+ # At this point pretrained_model_name_or_path is either a directory or a model identifier name
+ additional_files_names = {
+ "added_tokens_file": ADDED_TOKENS_FILE,
+ "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
+ "tokenizer_config_file": TOKENIZER_CONFIG_FILE,
+ "full_tokenizer_file": FULL_TOKENIZER_FILE,
+ }
+ # Look for the tokenizer files
+ for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
+ if os.path.isdir(pretrained_model_name_or_path):
+ full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
+ if not os.path.exists(full_file_name):
+ logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
+ full_file_name = None
+ else:
+ full_file_name = hf_bucket_url(
+ pretrained_model_name_or_path, filename=file_name, use_cdn=False
+ )
+
+ vocab_files[file_id] = full_file_name
+
+ # Get files from url, cache, or disk depending on the case
+ try:
+ resolved_vocab_files = {}
+ for file_id, file_path in vocab_files.items():
+ if file_path is None:
+ resolved_vocab_files[file_id] = None
+ else:
+ resolved_vocab_files[file_id] = cached_path(
+ file_path,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ )
+ except EnvironmentError:
+ if pretrained_model_name_or_path in s3_models:
+ msg = "Couldn't reach server at '{}' to download vocabulary files."
+ else:
+ msg = (
+ "Model name '{}' was not found in tokenizers model name list ({}). "
+ "We assumed '{}' was a path or url to a directory containing vocabulary files "
+ "named {}, but couldn't find such vocabulary files at this path or url.".format(
+ pretrained_model_name_or_path,
+ ", ".join(s3_models),
+ pretrained_model_name_or_path,
+ list(cls.vocab_files_names.values()),
+ )
+ )
+
+ raise EnvironmentError(msg)
+
+ if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
+ raise EnvironmentError(
+ "Model name '{}' was not found in tokenizers model name list ({}). "
+ "We assumed '{}' was a path, a model identifier, or url to a directory containing vocabulary files "
+ "named {} but couldn't find such vocabulary files at this path or url.".format(
+ pretrained_model_name_or_path,
+ ", ".join(s3_models),
+ pretrained_model_name_or_path,
+ list(cls.vocab_files_names.values()),
+ )
+ )
+
+ for file_id, file_path in vocab_files.items():
+ if file_path == resolved_vocab_files[file_id]:
+ logger.info("loading file {}".format(file_path))
+ else:
+ logger.info("loading file {} from cache at {}".format(file_path, resolved_vocab_files[file_id]))
+
+ # Prepare tokenizer initialization kwargs
+ # Did we saved some inputs and kwargs to reload ?
+ tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None)
+ if tokenizer_config_file is not None:
+ with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
+ init_kwargs = json.load(tokenizer_config_handle)
+ saved_init_inputs = init_kwargs.pop("init_inputs", ())
+ if not init_inputs:
+ init_inputs = saved_init_inputs
+ else:
+ init_kwargs = init_configuration
+
+ # Update with newly provided kwargs
+ init_kwargs.update(kwargs)
+
+ # Set max length if needed
+ if pretrained_model_name_or_path in cls.max_model_input_sizes:
+ # if we're using a pretrained model, ensure the tokenizer
+ # wont index sequences longer than the number of positional embeddings
+ model_max_length = cls.max_model_input_sizes[pretrained_model_name_or_path]
+ if model_max_length is not None and isinstance(model_max_length, (int, float)):
+ init_kwargs["model_max_length"] = min(init_kwargs.get("model_max_length", int(1e30)), model_max_length)
+
+ # Merge resolved_vocab_files arguments in init_kwargs.
+ added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
+ for args_name, file_path in resolved_vocab_files.items():
+ if args_name not in init_kwargs:
+ init_kwargs[args_name] = file_path
+
+ # Instantiate tokenizer.
+ try:
+ tokenizer = cls(*init_inputs, **init_kwargs)
+ except OSError:
+ raise OSError(
+ "Unable to load vocabulary from file. "
+ "Please check that the provided vocabulary is accessible and not corrupted."
+ )
+
+ # Save inputs and kwargs for saving and re-loading with ``save_pretrained``
+ tokenizer.init_inputs = init_inputs
+ tokenizer.init_kwargs = init_kwargs
+
+ # If there is a complementary special token map, load it
+ special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
+ if special_tokens_map_file is not None:
+ with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
+ special_tokens_map = json.load(special_tokens_map_handle)
+
+ for key, value in special_tokens_map.items():
+ if isinstance(value, dict):
+ value = AddedToken(**value)
+ setattr(tokenizer, key, value)
+
+ # Add supplementary tokens.
+ special_tokens = tokenizer.all_special_tokens
+ if added_tokens_file is not None:
+ with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
+ added_tok_encoder = json.load(added_tokens_handle)
+
+ # Sort added tokens by index
+ added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1]))
+
+ for token, index in added_tok_encoder_sorted:
+ assert index == len(tokenizer), (
+ f"Non-consecutive added token '{token}' found. "
+ f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary."
+ )
+ tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens))
+
+ # Check all our special tokens are registrered as "no split" token (we don't cut them) and are in the vocab
+ added_tokens = tokenizer.sanitize_special_tokens()
+ if added_tokens:
+ logger.warning(
+ "Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained."
+ )
+
+ return tokenizer
+
+ def save_pretrained(self, save_directory) -> Tuple[str]:
+ """ Save the tokenizer vocabulary files together with:
+ - added tokens,
+ - special-tokens-to-class-attributes-mapping,
+ - tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert).
+
+ Warning: This won't save modifications you may have applied to the tokenizer after the instantiation
+ (e.g. modifying tokenizer.do_lower_case after creation).
+
+ This method make sure the full tokenizer can then be re-loaded using the
+ :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method.
+ """
+ if os.path.isfile(save_directory):
+ logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
+ return
+ os.makedirs(save_directory, exist_ok=True)
+
+ special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
+ added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
+ tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE)
+
+ tokenizer_config = copy.deepcopy(self.init_kwargs)
+ if len(self.init_inputs) > 0:
+ tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
+ for file_id in self.vocab_files_names.keys():
+ tokenizer_config.pop(file_id, None)
+
+ with open(tokenizer_config_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(tokenizer_config, ensure_ascii=False))
+
+ with open(special_tokens_map_file, "w", encoding="utf-8") as f:
+ write_dict = {}
+ for key, value in self.special_tokens_map_extended.items():
+ if isinstance(value, AddedToken):
+ write_dict[key] = value.__getstate__()
+ else:
+ write_dict[key] = value
+ f.write(json.dumps(write_dict, ensure_ascii=False))
+
+ added_vocab = self.get_added_vocab()
+ if added_vocab:
+ with open(added_tokens_file, "w", encoding="utf-8") as f:
+ out_str = json.dumps(added_vocab, ensure_ascii=False)
+ f.write(out_str)
+
+ vocab_files = self.save_vocabulary(save_directory)
+
+ return vocab_files + (special_tokens_map_file, added_tokens_file)
+
+ @add_end_docstrings(
+ ENCODE_KWARGS_DOCSTRING,
+ """
+ **kwargs: passed to the `self.tokenize()` method.
+ """,
+ )
+ def encode(
+ self,
+ text: Union[TextInput, PreTokenizedInput, EncodedInput],
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str] = False,
+ truncation: Union[bool, str] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs
+ ):
+ """
+ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
+
+ Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
+
+ Args:
+ text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`):
+ The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
+ the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
+ method)
+ text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
+ string using the `tokenize` method) or a list of integers (tokenized string ids using the
+ `convert_tokens_to_ids` method)
+ """
+ encoded_inputs = self.encode_plus(
+ text,
+ text_pair=text_pair,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ return_tensors=return_tensors,
+ **kwargs,
+ )
+
+ return encoded_inputs["input_ids"]
+
+ def num_special_tokens_to_add(self, pair: bool = False) -> int:
+ raise NotImplementedError
+
+ def _get_padding_truncation_strategies(
+ self, padding=False, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
+ ):
+ """ Find the correct padding/truncation strategy with backward compatibility
+ for old arguments (truncation_strategy and pad_to_max_length) and behaviors.
+ """
+ old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
+ old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
+
+ # Backward compatibility for previous behavior, maybe we should deprecate it:
+ # If you only set max_length, it activates truncation for max_length
+ if max_length is not None and padding is False and truncation is False:
+ if verbose:
+ logger.warning(
+ "Truncation was not explicitely activated but `max_length` is provided a specific value, "
+ "please use `truncation=True` to explicitely truncate examples to max length. "
+ "Defaulting to 'longest_first' truncation strategy. "
+ "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
+ "more precisely by providing a specific strategy to `truncation`."
+ )
+ truncation = "longest_first"
+
+ # Get padding strategy
+ if padding is False and old_pad_to_max_length:
+ if verbose:
+ warnings.warn(
+ "The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
+ "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
+ "use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
+ "maximal input size of the model (e.g. 512 for Bert).",
+ DeprecationWarning,
+ )
+ if max_length is None:
+ padding_strategy = PaddingStrategy.LONGEST
+ else:
+ padding_strategy = PaddingStrategy.MAX_LENGTH
+ elif padding is not False:
+ if padding is True:
+ padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
+ elif not isinstance(padding, PaddingStrategy):
+ padding_strategy = PaddingStrategy(padding)
+ else:
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
+
+ # Get truncation strategy
+ if truncation is False and old_truncation_strategy != "do_not_truncate":
+ if verbose:
+ warnings.warn(
+ "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
+ "use `truncation=True` to truncate examples to a max length. You can give a specific "
+ "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
+ "maximal input size of the model (e.g. 512 for Bert). "
+ " If you have pairs of inputs, you can give a specific truncation strategy selected among "
+ "`truncation='only_first'` (will only truncate the first sentence in the pairs) "
+ "`truncation='only_second'` (will only truncate the second sentence in the pairs) "
+ "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
+ DeprecationWarning,
+ )
+ truncation_strategy = TruncationStrategy(old_truncation_strategy)
+ elif truncation is not False:
+ if truncation is True:
+ truncation_strategy = (
+ TruncationStrategy.LONGEST_FIRST
+ ) # Default to truncate the longest sequences in pairs of inputs
+ elif not isinstance(truncation, TruncationStrategy):
+ truncation_strategy = TruncationStrategy(truncation)
+ else:
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
+
+ # Set max length if needed
+ if max_length is None:
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
+ if self.model_max_length > LARGE_INTEGER:
+ if verbose:
+ logger.warning(
+ "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
+ "Default to no padding."
+ )
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
+ else:
+ max_length = self.model_max_length
+
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
+ if self.model_max_length > LARGE_INTEGER:
+ if verbose:
+ logger.warning(
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
+ "Default to no truncation."
+ )
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
+ else:
+ max_length = self.model_max_length
+
+ # Test if we have a padding token
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
+ raise ValueError(
+ "Asking to pad but the tokenizer does not have a padding token. "
+ "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
+ "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
+ )
+
+ # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
+ if (
+ truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
+ and padding_strategy != PaddingStrategy.DO_NOT_PAD
+ and pad_to_multiple_of is not None
+ and max_length is not None
+ and (max_length % pad_to_multiple_of != 0)
+ ):
+ raise ValueError(
+ f"Truncation and padding are both activated but "
+ f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
+ )
+
+ return padding_strategy, truncation_strategy, max_length, kwargs
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str] = False,
+ truncation: Union[bool, str] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_pretokenized: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Returns a dictionary containing the encoded sequence or sequence pair and additional information:
+ the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
+
+ Args:
+ text (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]``):
+ The sequence or batch of sequences to be encoded.
+ Each sequence can be a string or a list of strings (pre-tokenized string).
+ If the sequences are provided as list of strings (pretokenized), you must set `is_pretokenized=True`
+ (to lift the ambiguity with a batch of sequences)
+ text_pair (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]``):
+ The sequence or batch of sequences to be encoded.
+ Each sequence can be a string or a list of strings (pre-tokenized string).
+ If the sequences are provided as list of strings (pretokenized), you must set `is_pretokenized=True`
+ (to lift the ambiguity with a batch of sequences)
+ """
+ # Input type checking for clearer error
+ assert isinstance(text, str) or (
+ isinstance(text, (list, tuple))
+ and (
+ len(text) == 0
+ or (
+ isinstance(text[0], str)
+ or (isinstance(text[0], (list, tuple)) and (len(text[0]) == 0 or isinstance(text[0][0], str)))
+ )
+ )
+ ), (
+ "text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+
+ assert (
+ text_pair is None
+ or isinstance(text_pair, str)
+ or (
+ isinstance(text_pair, (list, tuple))
+ and (
+ len(text_pair) == 0
+ or (
+ isinstance(text_pair[0], str)
+ or (
+ isinstance(text_pair[0], (list, tuple))
+ and (len(text_pair[0]) == 0 or isinstance(text_pair[0][0], str))
+ )
+ )
+ )
+ )
+ ), (
+ "text_pair input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+
+ is_batched = bool(
+ (not is_pretokenized and isinstance(text, (list, tuple)))
+ or (is_pretokenized and isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)))
+ )
+
+ if is_batched:
+ batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
+ return self.batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ is_pretokenized=is_pretokenized,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+ else:
+ return self.encode_plus(
+ text=text,
+ text_pair=text_pair,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ is_pretokenized=is_pretokenized,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput, EncodedInput],
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str] = False,
+ truncation: Union[bool, str] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_pretokenized: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Returns a dictionary containing the encoded sequence or sequence pair and additional information:
+ the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
+
+ Args:
+ text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]` (the later only for not-fast tokenizers)):
+ The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
+ the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
+ method)
+ text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`, defaults to :obj:`None`):
+ Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
+ string using the `tokenize` method) or a list of integers (tokenized string ids using the
+ `convert_tokens_to_ids` method)
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._encode_plus(
+ text=text,
+ text_pair=text_pair,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ is_pretokenized=is_pretokenized,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ def _encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput, EncodedInput],
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_pretokenized: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ raise NotImplementedError
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ List[TextInput],
+ List[TextInputPair],
+ List[PreTokenizedInput],
+ List[PreTokenizedInputPair],
+ List[EncodedInput],
+ List[EncodedInputPair],
+ ],
+ add_special_tokens: bool = True,
+ padding: Union[bool, str] = False,
+ truncation: Union[bool, str] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_pretokenized: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Returns a dictionary containing the encoded sequence or sequence pair and additional information:
+ the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
+
+ Args:
+ batch_text_or_text_pairs (:obj:`List[str]`, :obj:`List[Tuple[str, str]]`,
+ :obj:`List[List[str]]`, :obj:`List[Tuple[List[str], List[str]]]`,
+ and for not-fast tokenizers, also:
+ :obj:`List[List[int]]`, :obj:`List[Tuple[List[int], List[int]]]`):
+ Batch of sequences or pair of sequences to be encoded.
+ This can be a list of string/string-sequences/int-sequences or a list of pair of
+ string/string-sequences/int-sequence (see details in encode_plus)
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ is_pretokenized=is_pretokenized,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ List[TextInput],
+ List[TextInputPair],
+ List[PreTokenizedInput],
+ List[PreTokenizedInputPair],
+ List[EncodedInput],
+ List[EncodedInputPair],
+ ],
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_pretokenized: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ raise NotImplementedError
+
+ def pad(
+ self,
+ encoded_inputs: Union[
+ BatchEncoding,
+ List[BatchEncoding],
+ Dict[str, EncodedInput],
+ Dict[str, List[EncodedInput]],
+ List[Dict[str, EncodedInput]],
+ ],
+ padding: Union[bool, str] = True,
+ max_length: Optional[int] = None,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ verbose: bool = True,
+ ) -> BatchEncoding:
+ """ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length in the batch.
+
+ Padding side (left/right) padding token ids are defined at the tokenizer level
+ (with ``self.padding_side``, ``self.pad_token_id`` and ``self.pad_token_type_id``)
+
+ Args:
+ encoded_inputs: Dictionary of tokenized inputs (`Dict[str, List[int]]`) or batch of tokenized inputs.
+ Batch of tokenized inputs can be given as dicts of lists or lists of dicts, both work so you can
+ use ``tokenizer.pad()`` during pre-processing as well as in a PyTorch Dataloader collate function.
+ (`Dict[str, List[List[int]]]` or `List[Dict[str, List[int]]]`).
+ padding: Boolean or specific strategy to use for padding.
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among:
+ - 'longest' (or `True`) Pad to the longest sequence in the batch
+ - 'max_length': Pad to the max length (default)
+ - 'do_not_pad' (or `False`): Do not pad
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ >= 7.5 (Volta).
+ return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
+ Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`,
+ PyTorch :obj:`torch.Tensor` or Numpy :oj: `np.ndarray` instead of a list of python integers.
+ verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Set to ``False`` to avoid printing infos and warnings.
+ """
+ # If we have a list of dicts, let's convert it in a dict of lists
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
+ encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
+
+ assert "input_ids" in encoded_inputs, (
+ "You should supply an encoding or a list of encodings to this method. "
+ "An encoding is the output of one the encoding methods of the tokenizer, i.e. "
+ "__call__/encode_plus/batch_encode_plus. "
+ )
+
+ if not encoded_inputs["input_ids"]:
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = []
+ return encoded_inputs
+
+ # Convert padding_strategy in PaddingStrategy
+ padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
+ padding=padding, max_length=max_length, verbose=verbose
+ )
+
+ if encoded_inputs["input_ids"] and not isinstance(encoded_inputs["input_ids"][0], (list, tuple)):
+ encoded_inputs = self._pad(
+ encoded_inputs,
+ max_length=max_length,
+ padding_strategy=padding_strategy,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+ return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
+
+ batch_size = len(encoded_inputs["input_ids"])
+ assert all(
+ len(v) == batch_size for v in encoded_inputs.values()
+ ), "Some items in the output dictionnary have a different batch size than others."
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = max(len(inputs) for inputs in encoded_inputs["input_ids"])
+ padding_strategy = PaddingStrategy.MAX_LENGTH
+
+ batch_outputs = {}
+ for i in range(batch_size):
+ inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
+ outputs = self._pad(
+ inputs,
+ max_length=max_length,
+ padding_strategy=padding_strategy,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+
+ return BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+ def create_token_type_ids_from_sequences(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List[int]:
+ if token_ids_1 is None:
+ return len(token_ids_0) * [0]
+ return [0] * len(token_ids_0) + [1] * len(token_ids_1)
+
+ def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
+ by concatenating and adding special tokens. This implementation does not add special tokens.
+ """
+ if token_ids_1 is None:
+ return token_ids_0
+ return token_ids_0 + token_ids_1
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def prepare_for_model(
+ self,
+ ids: List[int],
+ pair_ids: Optional[List[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str] = False,
+ truncation: Union[bool, str] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ prepend_batch_axis: bool = False,
+ **kwargs
+ ) -> BatchEncoding:
+ """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
+ It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
+ manages a moving window (with user defined stride) for overflowing tokens
+
+ Args:
+ ids: list of tokenized input ids. Can be obtained from a string by chaining the
+ `tokenize` and `convert_tokens_to_ids` methods.
+ pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
+ `tokenize` and `convert_tokens_to_ids` methods.
+ """
+
+ if "return_lengths" in kwargs:
+ if verbose:
+ warnings.warn(
+ "The PreTrainedTokenizerBase.prepare_for_model `return_lengths` parameter is deprecated. "
+ "Please use `return_length` instead.",
+ FutureWarning,
+ )
+ return_length = kwargs["return_lengths"]
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ pair = bool(pair_ids is not None)
+ len_ids = len(ids)
+ len_pair_ids = len(pair_ids) if pair else 0
+
+ # Load from model defaults
+ if return_token_type_ids is None:
+ return_token_type_ids = "token_type_ids" in self.model_input_names
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ encoded_inputs = {}
+
+ # Compute the total size of the returned encodings
+ total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
+
+ # Truncation: Handle max sequence length
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
+ ids, pair_ids, overflowing_tokens = self.truncate_sequences(
+ ids,
+ pair_ids=pair_ids,
+ num_tokens_to_remove=total_len - max_length,
+ truncation_strategy=truncation_strategy,
+ stride=stride,
+ )
+ if return_overflowing_tokens:
+ encoded_inputs["overflowing_tokens"] = overflowing_tokens
+ encoded_inputs["num_truncated_tokens"] = total_len - max_length
+
+ # Add special tokens
+ if add_special_tokens:
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
+ token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
+ else:
+ sequence = ids + pair_ids if pair else ids
+ token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
+
+ # Build output dictionnary
+ encoded_inputs["input_ids"] = sequence
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = token_type_ids
+ if return_special_tokens_mask:
+ if add_special_tokens:
+ encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
+ else:
+ encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
+
+ # Check lengths
+ if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose:
+ logger.warning(
+ "Token indices sequence length is longer than the specified maximum sequence length "
+ "for this model ({} > {}). Running this sequence through the model will result in "
+ "indexing errors".format(len(ids), self.model_max_length)
+ )
+
+ # Padding
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
+ encoded_inputs = self.pad(
+ encoded_inputs,
+ max_length=max_length,
+ padding=padding_strategy.value,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+
+ if return_length:
+ encoded_inputs["length"] = len(encoded_inputs["input_ids"])
+
+ batch_outputs = BatchEncoding(
+ encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
+ )
+
+ return batch_outputs
+
+ def truncate_sequences(
+ self,
+ ids: List[int],
+ pair_ids: Optional[List[int]] = None,
+ num_tokens_to_remove: int = 0,
+ truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
+ stride: int = 0,
+ ) -> Tuple[List[int], List[int], List[int]]:
+ """ Truncates a sequence pair in place to the maximum length.
+
+ Args:
+ ids: list of tokenized input ids. Can be obtained from a string by chaining the
+ `tokenize` and `convert_tokens_to_ids` methods.
+ pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
+ `tokenize` and `convert_tokens_to_ids` methods.
+ num_tokens_to_remove (:obj:`int`, `optional`, defaults to ``0``):
+ number of tokens to remove using the truncation strategy
+ truncation_strategy (:obj:`string`, `optional`, defaults to "longest_first"):
+ String selected in the following options:
+
+ - 'longest_first' (default): Iteratively reduce the inputs sequence until the input is under max_length
+ starting from the longest one at each token (when there is a pair of input sequences).
+ Overflowing tokens only contains overflow from the first sequence.
+ - 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove.
+ - 'only_second': Only truncate the second sequence
+ - 'do_not_truncate'
+ stride (:obj:`int`, `optional`, defaults to ``0``):
+ If set to a number along with max_length, the overflowing tokens returned will contain some tokens
+ from the main sequence returned. The value of this argument defines the number of additional tokens.
+ """
+ if num_tokens_to_remove <= 0:
+ return ids, pair_ids, []
+
+ if not isinstance(truncation_strategy, TruncationStrategy):
+ truncation_strategy = TruncationStrategy(truncation_strategy)
+
+ overflowing_tokens = []
+ if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
+ for _ in range(num_tokens_to_remove):
+ if pair_ids is None or len(ids) > len(pair_ids):
+ if not overflowing_tokens:
+ window_len = min(len(ids), stride + 1)
+ else:
+ window_len = 1
+ overflowing_tokens.extend(ids[-window_len:])
+ ids = ids[:-1]
+ else:
+ if not overflowing_tokens:
+ window_len = min(len(pair_ids), stride + 1)
+ else:
+ window_len = 1
+ overflowing_tokens.extend(pair_ids[-window_len:])
+ pair_ids = pair_ids[:-1]
+ elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
+ if len(ids) > num_tokens_to_remove:
+ window_len = min(len(ids), stride + num_tokens_to_remove)
+ overflowing_tokens = ids[-window_len:]
+ ids = ids[:-num_tokens_to_remove]
+ else:
+ logger.error(
+ f"We need to remove {num_tokens_to_remove} to truncate the input"
+ f"but the first sequence has a length {len(ids)}. "
+ f"Please select another truncation strategy than {truncation_strategy}, "
+ f"for instance 'longest_first' or 'only_second'."
+ )
+ elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
+ if len(pair_ids) > num_tokens_to_remove:
+ window_len = min(len(pair_ids), stride + num_tokens_to_remove)
+ overflowing_tokens = pair_ids[-window_len:]
+ pair_ids = pair_ids[:-num_tokens_to_remove]
+ else:
+ logger.error(
+ f"We need to remove {num_tokens_to_remove} to truncate the input"
+ f"but the second sequence has a length {len(pair_ids)}. "
+ f"Please select another truncation strategy than {truncation_strategy}, "
+ f"for instance 'longest_first' or 'only_first'."
+ )
+
+ return (ids, pair_ids, overflowing_tokens)
+
+ def _pad(
+ self,
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """ Pad encoded inputs (on left/right and up to predefined legnth or max length in the batch)
+
+ Args:
+ encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ >= 7.5 (Volta).
+ return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(encoded_inputs["input_ids"])
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = (
+ padding_strategy != PaddingStrategy.DO_NOT_PAD and len(encoded_inputs["input_ids"]) != max_length
+ )
+
+ if needs_to_be_padded:
+ difference = max_length - len(encoded_inputs["input_ids"])
+ if self.padding_side == "right":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
+ )
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
+ encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
+ elif self.padding_side == "left":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"])
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+ "token_type_ids"
+ ]
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+ encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
+ else:
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
+ else:
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
+
+ return encoded_inputs
+
+ def batch_decode(self, sequences: List[List[int]], **kwargs) -> List[str]:
+ return [self.decode(seq, **kwargs) for seq in sequences]
+
+ def decode(
+ self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
+ ) -> str:
+ """
+ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
+ with options to remove special tokens and clean up tokenization spaces.
+ Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
+
+ Args:
+ token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
+ skip_special_tokens: if set to True, will replace special tokens.
+ clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
+ """
+ raise NotImplementedError
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
+
+ Args:
+ token_ids_0: list of ids (must not contain special tokens)
+ token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
+ for sequence pairs
+ already_has_special_tokens: (default False) Set to True if the token list is already formated with
+ special tokens for the model
+
+ Returns:
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ assert already_has_special_tokens and token_ids_1 is None, (
+ "You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
+ "Please use a slow (full python) tokenizer to activate this argument."
+ "Or set `return_special_token_mask=True` when calling the encoding method "
+ "to get the special tokens mask in any tokenizer. "
+ )
+
+ all_special_ids = self.all_special_ids # cache the property
+
+ special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
+
+ return special_tokens_mask
+
+ @staticmethod
+ def clean_up_tokenization(out_string: str) -> str:
+ """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
+ """
+ out_string = (
+ out_string.replace(" .", ".")
+ .replace(" ?", "?")
+ .replace(" !", "!")
+ .replace(" ,", ",")
+ .replace(" ' ", "'")
+ .replace(" n't", "n't")
+ .replace(" 'm", "'m")
+ .replace(" 's", "'s")
+ .replace(" 've", "'ve")
+ .replace(" 're", "'re")
+ )
+ return out_string
diff --git a/elia/demo_inference.py b/elia/demo_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..630be118942e8ebc5a9bf6b13df650683f4e4ae8
--- /dev/null
+++ b/elia/demo_inference.py
@@ -0,0 +1,295 @@
+image_path = './image001.png'
+sentence = 'spoon on the dish'
+weights = '/cluster/nvme4/cyx/lavt/vis/model_best_refcoco_0508.pth'
+device = 'cpu'
+
+# pre-process the input image
+from PIL import Image
+import torchvision.transforms as T
+import numpy as np
+import datetime
+import os
+import time
+
+import torch
+import torch.utils.data
+from torch import nn
+
+from bert.multimodal_bert import MultiModalBert
+import torchvision
+
+from lib import multimodal_segmentation_ppm
+#import transforms as T
+import utils
+
+import numpy as np
+from PIL import Image
+import torch.nn.functional as F
+
+from modeling.MaskFormerModel import MaskFormerHead
+from addict import Dict
+#from bert.modeling_bert import BertLMPredictionHead, BertEncoder
+import cv2
+import textwrap
+
+class WrapperModel(nn.Module):
+ def __init__(self, image_model, language_model, classifier) :
+ super(WrapperModel, self).__init__()
+ self.image_model = image_model
+ self.language_model = language_model
+ self.classifier = classifier
+
+ config = Dict({
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "gradient_checkpointing": False,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 512,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ #"max_position_embeddings": 16+20,
+ "model_type": "bert",
+ "num_attention_heads": 8,
+ "num_hidden_layers": 8,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "transformers_version": "4.6.0.dev0",
+ "type_vocab_size": 2,
+ "use_cache": True,
+ "vocab_size": 30522
+ })
+
+
+
+ def _get_binary_mask(self, target):
+ # 返回每类的binary mask
+ y, x = target.size()
+ target_onehot = torch.zeros(self.num_classes + 1, y, x)
+ target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
+ return target_onehot[1:]
+
+ def semantic_inference(self, mask_cls, mask_pred):
+ mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
+ mask_pred = mask_pred.sigmoid()
+ semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
+ return semseg
+
+ def forward(self, image, sentences, attentions):
+ print(image.sum(), sentences.sum(), attentions.sum())
+ input_shape = image.shape[-2:]
+ l_mask = attentions.unsqueeze(dim=-1)
+
+ i0, Wh, Ww = self.image_model.forward_stem(image)
+ l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)
+
+ i1 = self.image_model.forward_stage1(i0, Wh, Ww)
+ l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
+ i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
+ l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask)
+ i1 = i1_temp
+
+ i2 = self.image_model.forward_stage2(i1, Wh, Ww)
+ l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
+ i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
+ l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask)
+ i2 = i2_temp
+
+ i3 = self.image_model.forward_stage3(i2, Wh, Ww)
+ l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
+ i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
+ l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask)
+ i3 = i3_temp
+
+ i4 = self.image_model.forward_stage4(i3, Wh, Ww)
+ l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
+ i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
+ l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask)
+ i4 = i4_temp
+
+ #i1_residual, i2_residual, i3_residual, i4_residual = features
+ #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
+ #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
+ outputs = {}
+ outputs['s1'] = i1_residual
+ outputs['s2'] = i2_residual
+ outputs['s3'] = i3_residual
+ outputs['s4'] = i4_residual
+
+ predictions = self.classifier(outputs)
+ return predictions
+
+#img = Image.open(image_path).convert("RGB")
+img = Image.open(image_path).convert("RGB")
+img_ndarray = np.array(img) # (orig_h, orig_w, 3); for visualization
+original_w, original_h = img.size # PIL .size returns width first and height second
+
+image_transforms = T.Compose(
+ [
+ T.Resize((480, 480)),
+ T.ToTensor(),
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ]
+)
+
+img = image_transforms(img).unsqueeze(0) # (1, 3, 480, 480)
+img = img.to(device) # for inference (input)
+
+# pre-process the raw sentence
+from bert.tokenization_bert import BertTokenizer
+import torch
+tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+sentence_tokenized = tokenizer.encode(text=sentence, add_special_tokens=True)
+sentence_tokenized = sentence_tokenized[:20] # if the sentence is longer than 20, then this truncates it to 20 words
+# pad the tokenized sentence
+padded_sent_toks = [0] * 20
+padded_sent_toks[:len(sentence_tokenized)] = sentence_tokenized
+# create a sentence token mask: 1 for real words; 0 for padded tokens
+attention_mask = [0] * 20
+attention_mask[:len(sentence_tokenized)] = [1]*len(sentence_tokenized)
+# convert lists to tensors
+padded_sent_toks = torch.tensor(padded_sent_toks).unsqueeze(0) # (1, 20)
+attention_mask = torch.tensor(attention_mask).unsqueeze(0) # (1, 20)
+padded_sent_toks = padded_sent_toks.to(device) # for inference (input)
+attention_mask = attention_mask.to(device) # for inference (input)
+
+# initialize model and load weights
+#from bert.modeling_bert import BertModel
+#from lib import segmentation
+
+# construct a mini args class; like from a config file
+
+
+class args:
+ swin_type = 'base'
+ window12 = True
+ mha = ''
+ fusion_drop = 0.0
+
+
+#single_model = segmentation.__dict__['lavt'](pretrained='', args=args)
+single_model = multimodal_segmentation_ppm.__dict__['lavt'](pretrained='',args=args)
+single_model.to(device)
+model_class = MultiModalBert
+single_bert_model = model_class.from_pretrained('bert-base-uncased', embed_dim=single_model.backbone.embed_dim)
+single_bert_model.pooler = None
+
+input_shape = dict()
+input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
+input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
+input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
+input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
+
+
+
+cfg = Dict()
+cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
+cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
+cfg.MODEL.MASK_FORMER.NHEADS = 8
+cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
+cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
+cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
+cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
+
+cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
+cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
+cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
+cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
+cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
+cfg.MODEL.MASK_FORMER.PRE_NORM = False
+
+
+maskformer_head = MaskFormerHead(cfg, input_shape)
+
+
+model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head)
+
+
+
+checkpoint = torch.load(weights, map_location='cpu')
+
+model.load_state_dict(checkpoint['model'], strict=False)
+model.to(device)
+model.eval()
+#single_bert_model.load_state_dict(checkpoint['bert_model'])
+#single_model.load_state_dict(checkpoint['model'])
+#model = single_model.to(device)
+#bert_model = single_bert_model.to(device)
+
+
+# inference
+#import torch.nn.functional as F
+#last_hidden_states = bert_model(padded_sent_toks, attention_mask=attention_mask)[0]
+#embedding = last_hidden_states.permute(0, 2, 1)
+#output = model(img, embedding, l_mask=attention_mask.unsqueeze(-1))
+#output = output.argmax(1, keepdim=True) # (1, 1, 480, 480)
+#output = F.interpolate(output.float(), (original_h, original_w)) # 'nearest'; resize to the original image size
+#output = output.squeeze() # (orig_h, orig_w)
+#output = output.cpu().data.numpy() # (orig_h, orig_w)
+
+output = model(img, padded_sent_toks, attention_mask)[0]
+#print(output[0].keys())
+#print(output[1].shape)
+mask_cls_results = output["pred_logits"]
+mask_pred_results = output["pred_masks"]
+
+target_shape = img_ndarray.shape[:2]
+#print(target_shape, mask_pred_results.shape)
+mask_pred_results = F.interpolate(mask_pred_results, size=(480,480), mode='bilinear', align_corners=True)
+
+pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)
+#output = pred_masks[0]
+
+#output = output.cpu()
+
+
+
+#print(output.shape)
+#output_mask = output.argmax(1).data.numpy()
+#output = (output > 0.5).data.cpu().numpy()
+output = torch.nn.functional.interpolate(pred_masks, target_shape)
+output = (output > 0.5).data.cpu().numpy()
+
+
+# show/save results
+def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 0, 0]], cscale=1, alpha=0.4):
+ from scipy.ndimage.morphology import binary_dilation
+
+ colors = np.reshape(colors, (-1, 3))
+ colors = np.atleast_2d(colors) * cscale
+
+ im_overlay = image.copy()
+ object_ids = np.unique(mask)
+
+ for object_id in object_ids[1:]:
+ # Overlay color on binary mask
+ foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id])
+ binary_mask = mask == object_id
+
+ # Compose image
+ im_overlay[binary_mask] = foreground[binary_mask]
+
+ # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask
+ countours = binary_dilation(binary_mask) ^ binary_mask
+ # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask
+ im_overlay[countours, :] = 0
+
+ return im_overlay.astype(image.dtype)
+
+
+output = output.astype(np.uint8) # (orig_h, orig_w), np.uint8
+# Overlay the mask on the image
+print(img_ndarray.shape, output.shape)
+visualization = overlay_davis(img_ndarray, output[0][0]) # red
+visualization = Image.fromarray(visualization)
+# show the visualization
+#visualization.show()
+# Save the visualization
+visualization.save('./demo/spoon_on_the_dish.jpg')
+
+
+
+
diff --git a/elia/requirements.txt b/elia/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9cb7e037ed9179ca75ea3a86497e68307401a645
--- /dev/null
+++ b/elia/requirements.txt
@@ -0,0 +1,14 @@
+requests
+filelock
+tqdm
+timm
+mmcv-full==1.3.12
+mmsegmentation==0.17.0
+ftfy
+regex
+scipy
+scikit-image
+pycocotools==2.0.2
+opencv-python==4.5.3.56
+tokenizers==0.8.1rc1
+h5py
\ No newline at end of file
diff --git a/elia/test_elia.py b/elia/test_elia.py
new file mode 100644
index 0000000000000000000000000000000000000000..00f79bdc289aaa54520c5036751a94b5bcd01c1b
--- /dev/null
+++ b/elia/test_elia.py
@@ -0,0 +1,312 @@
+
+import datetime
+import os
+import time
+
+import torch
+import torch.utils.data
+from torch import nn
+
+from bert.multimodal_bert import MultiModalBert
+import torchvision
+
+from lib import multimodal_segmentation_ppm
+import transforms as T
+import utils
+
+import numpy as np
+from PIL import Image
+import torch.nn.functional as F
+
+from modeling.MaskFormerModel import MaskFormerHead
+from addict import Dict
+from bert.modeling_bert import BertLMPredictionHead, BertEncoder
+
+def get_dataset(image_set, transform, args):
+ from data.dataset_refer_bert import ReferDataset
+ ds = ReferDataset(args,
+ split=image_set,
+ image_transforms=transform,
+ target_transforms=None,
+ eval_mode=True
+ )
+ num_classes = 2
+ return ds, num_classes
+
+
+def evaluate(model, data_loader, device):
+ model.eval()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+
+ # evaluation variables
+ cum_I, cum_U = 0, 0
+ eval_seg_iou_list = [.5, .6, .7, .8, .9]
+ seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
+ seg_total = 0
+ mean_IoU = []
+ header = 'Test:'
+
+ with torch.no_grad():
+ for data in metric_logger.log_every(data_loader, 100, header):
+ image, target, sentences, attentions = data
+ image, target, sentences, attentions = image.to(device), target.to(device), \
+ sentences.to(device), attentions.to(device)
+ sentences = sentences.squeeze(1)
+ attentions = attentions.squeeze(1)
+ target = target.cpu().data.numpy()
+ for j in range(sentences.size(-1)):
+ #if bert_model is not None:
+ # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0]
+ # embedding = last_hidden_states.permute(0, 2, 1)
+ # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1))
+ #else:
+ output = model(image, sentences[:, :, j], attentions[:, :, j])
+ mask_cls_results = output["pred_logits"]
+ mask_pred_results = output["pred_masks"]
+
+ target_shape = target.shape[-2:]
+ mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)
+
+ pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)
+ output = pred_masks[0]
+
+ output = output.cpu()
+ #print(output.shape)
+ #output_mask = output.argmax(1).data.numpy()
+ output_mask = (output > 0.5).data.numpy()
+ I, U = computeIoU(output_mask, target)
+ if U == 0:
+ this_iou = 0.0
+ else:
+ this_iou = I*1.0/U
+ mean_IoU.append(this_iou)
+ cum_I += I
+ cum_U += U
+ for n_eval_iou in range(len(eval_seg_iou_list)):
+ eval_seg_iou = eval_seg_iou_list[n_eval_iou]
+ seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
+ seg_total += 1
+
+ #del image, target, sentences, attentions, output, output_mask
+ #if bert_model is not None:
+ # del last_hidden_states, embedding
+
+ mean_IoU = np.array(mean_IoU)
+ mIoU = np.mean(mean_IoU)
+ print('Final results:')
+ print('Mean IoU is %.2f\n' % (mIoU*100.))
+ results_str = ''
+ for n_eval_iou in range(len(eval_seg_iou_list)):
+ results_str += ' precision@%s = %.2f\n' % \
+ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
+ results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
+ print(results_str)
+
+
+def get_transform(args):
+ transforms = [T.Resize(args.img_size, args.img_size),
+ T.ToTensor(),
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ]
+
+ return T.Compose(transforms)
+
+
+def computeIoU(pred_seg, gd_seg):
+ I = np.sum(np.logical_and(pred_seg, gd_seg))
+ U = np.sum(np.logical_or(pred_seg, gd_seg))
+
+ return I, U
+
+class WrapperModel(nn.Module):
+ def __init__(self, image_model, language_model, classifier, args) :
+ super(WrapperModel, self).__init__()
+ self.image_model = image_model
+ self.language_model = language_model
+ self.classifier = classifier
+ self.lang_proj = nn.Linear(768,256)
+
+ config = Dict({
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "gradient_checkpointing": False,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 512,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ #"max_position_embeddings": 16+20,
+ "model_type": "bert",
+ "num_attention_heads": 8,
+ "num_hidden_layers": 8,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "transformers_version": "4.6.0.dev0",
+ "type_vocab_size": 2,
+ "use_cache": True,
+ "vocab_size": 30522
+ })
+ self.mlm_transformer = BertEncoder(config)
+
+ self.lang_proj = nn.Linear(768,256)
+ self.mlm_vis_proj = nn.Conv2d(1024,512,1)
+ self.mlm_lang_proj = nn.Linear(768,512)
+ #print(vis_proj)
+ self.mlm_head = BertLMPredictionHead(config)
+
+ assert args.img_size % 4 == 0
+ num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2
+ print(num_img_tokens)
+ self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512)
+ self.mlm_modal_embeds = nn.Embedding(3, 512)
+
+ self.mlm_mask_embed = nn.Embedding(1, 512)
+ self.mlm_pos_mlp = nn.Sequential(
+ nn.Linear(2, 512),
+ nn.LayerNorm(512),
+ nn.Linear(512,512),
+ nn.GELU()
+ )
+
+ def _get_binary_mask(self, target):
+ # 返回每类的binary mask
+ y, x = target.size()
+ target_onehot = torch.zeros(self.num_classes + 1, y, x)
+ target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
+ return target_onehot[1:]
+
+ def semantic_inference(self, mask_cls, mask_pred):
+ mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
+ mask_pred = mask_pred.sigmoid()
+ semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
+ return semseg
+
+ def forward(self, image, sentences, attentions):
+ input_shape = image.shape[-2:]
+ l_mask = attentions.unsqueeze(dim=-1)
+
+ i0, Wh, Ww = self.image_model.forward_stem(image)
+ l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)
+
+ i1 = self.image_model.forward_stage1(i0, Wh, Ww)
+ l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
+ i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
+ l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask)
+ i1 = i1_temp
+
+ i2 = self.image_model.forward_stage2(i1, Wh, Ww)
+ l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
+ i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
+ l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask)
+ i2 = i2_temp
+
+ i3 = self.image_model.forward_stage3(i2, Wh, Ww)
+ l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
+ i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
+ l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask)
+ i3 = i3_temp
+
+ i4 = self.image_model.forward_stage4(i3, Wh, Ww)
+ l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
+ i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
+ l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask)
+ i4 = i4_temp
+
+ #i1_residual, i2_residual, i3_residual, i4_residual = features
+ #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
+ #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
+ outputs = {}
+ outputs['s1'] = i1_residual
+ outputs['s2'] = i2_residual
+ outputs['s3'] = i3_residual
+ outputs['s4'] = i4_residual
+
+ predictions, _ = self.classifier(outputs)
+ return predictions
+
+def main(args):
+#def main(local_rank, args):
+
+ #device = torch.device(args.device)
+ device = 'cuda'
+ dataset_test, _ = get_dataset(args.split, get_transform(args=args), args)
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+ data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1,
+ sampler=test_sampler, num_workers=args.workers)
+ print(args.model)
+ single_model = multimodal_segmentation_ppm.__dict__[args.model](pretrained='',args=args)
+ #single_model = MultiModalFocal(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], focal_windows=[9,9,9,9], drop_path_rate=0.3)
+ #single_model.init_weights('./focalnet_base_lrf.pth')
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ #single_model.load_state_dict(checkpoint['model'])
+ #model = single_model.to(device)
+
+ if args.model != 'lavt_one':
+ model_class = MultiModalBert
+ #single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=128)
+ single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=single_model.backbone.embed_dim)
+ # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines
+ if args.ddp_trained_weights:
+ single_bert_model.pooler = None
+ #single_bert_model.load_state_dict(checkpoint['bert_model'])
+ #bert_model = single_bert_model.to(device)
+ else:
+ bert_model = None
+
+ #model = WrapperModel(single_model.backbone, single_bert_model, single_model.classifier)
+ #model.load_state_dict(checkpoint['model'])
+ #model.to(device)
+ input_shape = dict()
+ input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
+ input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
+ input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
+ input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
+
+
+
+ cfg = Dict()
+ cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
+
+ cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
+
+
+ maskformer_head = MaskFormerHead(cfg, input_shape)
+ #maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head)
+ #maskformer_head.cuda()
+ #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False)
+ #single_head = maskformer_head.module
+ #print(single_head)
+
+ model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head, args)
+ model.load_state_dict(checkpoint['model'])
+ model.to(device)
+ #model.cuda()
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
+ #single_model = model.module
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
+ #single_model = model.module
+ evaluate(model, data_loader_test, device=device)
+
+
+if __name__ == "__main__":
+ from args import get_parser
+ parser = get_parser()
+ args = parser.parse_args()
+ print('Image size: {}'.format(str(args.img_size)))
+ print(args)
+ main(args)
+ #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())
diff --git a/elia/test_lavt.py b/elia/test_lavt.py
new file mode 100644
index 0000000000000000000000000000000000000000..e85c8ede2fdf46d6ade4b5d5fbc70d12628fa6d4
--- /dev/null
+++ b/elia/test_lavt.py
@@ -0,0 +1,139 @@
+import datetime
+import os
+import time
+
+import torch
+import torch.utils.data
+from torch import nn
+
+from bert.modeling_bert import BertModel
+import torchvision
+
+from lib import segmentation
+import transforms as T
+import utils
+
+import numpy as np
+from PIL import Image
+import torch.nn.functional as F
+
+
+def get_dataset(image_set, transform, args):
+ from data.dataset_refer_bert import ReferDataset
+ ds = ReferDataset(args,
+ split=image_set,
+ image_transforms=transform,
+ target_transforms=None,
+ eval_mode=True
+ )
+ num_classes = 2
+ return ds, num_classes
+
+
+def evaluate(model, data_loader, bert_model, device):
+ model.eval()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+
+ # evaluation variables
+ cum_I, cum_U = 0, 0
+ eval_seg_iou_list = [.5, .6, .7, .8, .9]
+ seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
+ seg_total = 0
+ mean_IoU = []
+ header = 'Test:'
+
+ with torch.no_grad():
+ for data in metric_logger.log_every(data_loader, 100, header):
+ image, target, sentences, attentions = data
+ image, target, sentences, attentions = image.to(device), target.to(device), \
+ sentences.to(device), attentions.to(device)
+ sentences = sentences.squeeze(1)
+ attentions = attentions.squeeze(1)
+ target = target.cpu().data.numpy()
+ for j in range(sentences.size(-1)):
+ if bert_model is not None:
+ last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0]
+ embedding = last_hidden_states.permute(0, 2, 1)
+ output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1))
+ else:
+ output = model(image, sentences[:, :, j], l_mask=attentions[:, :, j])
+
+ output = output.cpu()
+ output_mask = output.argmax(1).data.numpy()
+ I, U = computeIoU(output_mask, target)
+ if U == 0:
+ this_iou = 0.0
+ else:
+ this_iou = I*1.0/U
+ mean_IoU.append(this_iou)
+ cum_I += I
+ cum_U += U
+ for n_eval_iou in range(len(eval_seg_iou_list)):
+ eval_seg_iou = eval_seg_iou_list[n_eval_iou]
+ seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
+ seg_total += 1
+
+ del image, target, sentences, attentions, output, output_mask
+ if bert_model is not None:
+ del last_hidden_states, embedding
+
+ mean_IoU = np.array(mean_IoU)
+ mIoU = np.mean(mean_IoU)
+ print('Final results:')
+ print('Mean IoU is %.2f\n' % (mIoU*100.))
+ results_str = ''
+ for n_eval_iou in range(len(eval_seg_iou_list)):
+ results_str += ' precision@%s = %.2f\n' % \
+ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
+ results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
+ print(results_str)
+
+
+def get_transform(args):
+ transforms = [T.Resize(args.img_size, args.img_size),
+ T.ToTensor(),
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ]
+
+ return T.Compose(transforms)
+
+
+def computeIoU(pred_seg, gd_seg):
+ I = np.sum(np.logical_and(pred_seg, gd_seg))
+ U = np.sum(np.logical_or(pred_seg, gd_seg))
+
+ return I, U
+
+
+def main(args):
+ device = torch.device(args.device)
+ dataset_test, _ = get_dataset(args.split, get_transform(args=args), args)
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+ data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1,
+ sampler=test_sampler, num_workers=args.workers)
+ print(args.model)
+ single_model = segmentation.__dict__[args.model](pretrained='',args=args)
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ single_model.load_state_dict(checkpoint['model'])
+ model = single_model.to(device)
+
+ if args.model != 'lavt_one':
+ model_class = BertModel
+ single_bert_model = model_class.from_pretrained(args.ck_bert)
+ # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines
+ if args.ddp_trained_weights:
+ single_bert_model.pooler = None
+ single_bert_model.load_state_dict(checkpoint['bert_model'])
+ bert_model = single_bert_model.to(device)
+ else:
+ bert_model = None
+
+ evaluate(model, data_loader_test, bert_model, device=device)
+
+
+if __name__ == "__main__":
+ from args import get_parser
+ parser = get_parser()
+ args = parser.parse_args()
+ print('Image size: {}'.format(str(args.img_size)))
+ main(args)
diff --git a/elia/train_elia.py b/elia/train_elia.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ce163e6497f2acb88dd1bc3219cf31997ac5495
--- /dev/null
+++ b/elia/train_elia.py
@@ -0,0 +1,812 @@
+import datetime
+import os
+import time
+
+import torch
+import torch.utils.data
+from torch import nn
+
+from functools import reduce
+import operator
+from bert.multimodal_bert import MultiModalBert
+
+import torchvision
+from lib import multimodal_segmentation_ppm
+
+import transforms as T
+import utils
+import numpy as np
+
+import torch.nn.functional as F
+
+import gc
+from collections import OrderedDict
+
+import torch.backends.cudnn as cudnn
+
+#from ffrecord.torch import DataLoader,Dataset
+from modeling.MaskFormerModel import MaskFormerHead
+from addict import Dict
+
+from mask2former_utils.criterion import SetCriterion, Criterion
+from mask2former_utils.matcher import HungarianMatcher
+from bert.modeling_bert import BertLMPredictionHead, BertEncoder
+
+
+
+
+class WrapperModel(nn.Module):
+ def __init__(self, image_model, language_model, classifier, args) :
+ super(WrapperModel, self).__init__()
+ self.image_model = image_model
+ self.language_model = language_model
+ self.classifier = classifier
+
+ self.lang_proj = nn.Linear(768,256)
+
+ config = Dict({
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "gradient_checkpointing": False,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 512,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ #"max_position_embeddings": 16+20,
+ "model_type": "bert",
+ "num_attention_heads": 8,
+ "num_hidden_layers": 8,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "transformers_version": "4.6.0.dev0",
+ "type_vocab_size": 2,
+ "use_cache": True,
+ "vocab_size": 30522
+ })
+ self.mlm_transformer = BertEncoder(config)
+
+ self.lang_proj = nn.Linear(768,256)
+ self.mlm_vis_proj = nn.Conv2d(1024,512,1)
+ self.mlm_lang_proj = nn.Linear(768,512)
+ #print(vis_proj)
+ self.mlm_head = BertLMPredictionHead(config)
+
+ assert args.img_size % 4 == 0
+ num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2
+ print(num_img_tokens)
+ self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512)
+ self.mlm_modal_embeds = nn.Embedding(3, 512)
+
+ self.mlm_mask_embed = nn.Embedding(1, 512)
+ self.mlm_pos_mlp = nn.Sequential(
+ nn.Linear(2, 512),
+ nn.LayerNorm(512),
+ nn.Linear(512,512),
+ nn.GELU()
+ )
+
+ def _get_binary_mask(self, target):
+ # 返回每类的binary mask
+ y, x = target.size()
+ target_onehot = torch.zeros(self.num_classes + 1, y, x)
+ target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
+ return target_onehot[1:]
+
+ def semantic_inference(self, mask_cls, mask_pred):
+ mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
+ mask_pred = mask_pred.sigmoid()
+ semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
+ return semseg
+
+ def forward(self, image, sentences, attentions, mlm_targets, mlm_masks, position):
+ input_shape = image.shape[-2:]
+ l_mask = attentions.unsqueeze(dim=-1)
+
+ i0, Wh, Ww = self.image_model.forward_stem(image)
+ l0, extended_attention_mask = self.language_model.forward_stem(mlm_targets.squeeze(1), attentions)
+
+ i1 = self.image_model.forward_stage1(i0, Wh, Ww)
+ l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
+ i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
+ l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask)
+ i1 = i1_temp
+
+ i2 = self.image_model.forward_stage2(i1, Wh, Ww)
+ l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
+ i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
+ l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask)
+ i2 = i2_temp
+
+ i3 = self.image_model.forward_stage3(i2, Wh, Ww)
+ l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
+ i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
+ l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask)
+ i3 = i3_temp
+
+ i4 = self.image_model.forward_stage4(i3, Wh, Ww)
+ l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
+ i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
+ l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask)
+ i4 = i4_temp
+
+ #i1_residual, i2_residual, i3_residual, i4_residual = features
+ #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
+ #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
+ outputs = {}
+ outputs['s1'] = i1_residual
+ outputs['s2'] = i2_residual
+ outputs['s3'] = i3_residual
+ outputs['s4'] = i4_residual
+
+ predictions, mask_features = self.classifier(outputs)
+
+ #print(target_reshape.shape)
+ #tmp = np.argwhere(target_reshape[:, 0].detach().cpu().numpy()).reshape(-1, target_reshape.shape[2]*target_reshape[3], 3)
+ #centroid = tmp.mean(1)
+ #print(centroid)
+ #centroid_x, centroid_y = int(centroid[1]), int(centroid[0])
+ #last_hidden_states = brt_model(sentences, attention_mask=attentions)[0] # (6, 10, 768)
+ #embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy
+
+
+ l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)
+ l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
+ l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
+ l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
+ l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
+
+
+ mlp_embed = self.mlm_pos_mlp(position)
+ #print(centroid_x, centroid_y)
+
+ mlm_targets = torch.where(
+ mlm_masks > 0,
+ mlm_targets,
+ torch.ones_like(mlm_targets) * (-1)
+ )
+
+ #print(x_c4[target_reshape[:, [0]].bool()].shape)
+ vis_features = self.mlm_vis_proj(i4_residual).flatten(2).permute(0,2,1)
+ #print(l4.shape)
+ lang_features = self.mlm_lang_proj(l4)
+
+ #print(lang_features.shape, vis_features.shape, mlp_embed.shape)
+ mm_features = torch.cat([lang_features, vis_features, mlp_embed.unsqueeze(1)], dim=1)
+ #print(mm_features.shape)
+
+ #print(mlm_modal_embeds.weight.shape)
+ modal_embeds = torch.cat([self.mlm_modal_embeds.weight[0].unsqueeze(0).repeat(1, lang_features.shape[1], 1), self.mlm_modal_embeds.weight[1].unsqueeze(0).repeat(1, vis_features.shape[1], 1), self.mlm_modal_embeds.weight[2].unsqueeze(0).repeat(1,1,1)], dim=1)
+ #print(modal_embeds.shape)
+
+ #print(mlm_transformer)
+
+
+ #print(attentions.shape)
+ mixed_attention_mask = torch.cat([attentions.unsqueeze(-1), torch.ones(attentions.shape[0], vis_features.shape[1]+1, 1).to(attentions.device)], dim=1)
+ mixed_attention_mask = mixed_attention_mask.permute(0,2,1).unsqueeze(1)
+ mixed_attention_mask = (1-mixed_attention_mask)* -10000.0
+ head_mask = [None] * 8
+ #extended_attention_mask = get_extended_attention_mask(mixed_attention_mask, mm_features.shape, mm_features.device)
+ #print(mm_features.shape, mixed_attention_mask.shape, head_mask)
+ #print(mm_features.shape, self.mlm_pos_embeds.weight.shape, self.mlm_modal_embeds.weight.shape)
+ head_features = self.mlm_transformer(mm_features + self.mlm_pos_embeds.weight.unsqueeze(0) + modal_embeds, mixed_attention_mask, head_mask)[0]
+ #print(head_features.shape, attentions.shape)
+ head_features = head_features[:, :20][attentions.bool()]
+
+ #print(embedding.shape, mask_features.shape)
+ mlm_predictions = self.mlm_head(head_features)
+ mlm_predictions = mlm_predictions.reshape(-1, self.language_model.config.vocab_size)
+ mlm_targets = mlm_targets.squeeze(1)[attentions.bool()]
+ #mlm_loss = mlm_weight * nn.CrossEntropyLoss(ignore_index=-1)(mlm_predictions, mlm_targets)
+ #loss += mlm_loss
+ #mlm_loss_print=mlm_loss.item()
+
+ return predictions, mask_features, self.lang_proj((l4_residual * l_mask).sum(1)/l_mask.sum(1)), mlm_predictions, mlm_targets
+# IoU calculation for validation
+def IoU(pred, gt):
+ #pred = pred.argmax(1)
+ pred = (pred > 0.5)
+
+ intersection = torch.sum(torch.mul(pred, gt))
+ union = torch.sum(torch.add(pred, gt)) - intersection
+
+ if intersection == 0 or union == 0:
+ iou = 0
+ else:
+ iou = float(intersection) / float(union)
+
+ return iou, intersection, union
+
+def get_dataset(image_set, transform, args):
+ from data.dataset_refer_bert_mlm import ReferDataset
+ ds = ReferDataset(args,
+ split=image_set,
+ image_transforms=transform,
+ target_transforms=None
+ )
+ num_classes = 2
+
+ return ds, num_classes
+
+
+
+def get_transform(args):
+ transforms = [T.Resize(args.img_size, args.img_size),
+ T.ToTensor(),
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ]
+
+ return T.Compose(transforms)
+
+
+#def criterion(input, target):
+# weight = torch.FloatTensor([0.9, 1.1]).cuda()
+# return nn.functional.cross_entropy(input, target, weight=weight)
+
+
+def evaluate(model, data_loader):
+ model.eval()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Test:'
+ total_its = 0
+ acc_ious = 0
+
+ # evaluation variables
+ cum_I, cum_U = 0, 0
+ eval_seg_iou_list = [.5, .6, .7, .8, .9]
+ seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
+ seg_total = 0
+ mean_IoU = []
+
+ with torch.no_grad():
+ for data in metric_logger.log_every(data_loader, 100, header):
+ total_its += 1
+ #image, target, sentences, attentions = data
+ #image, target, sentences, attentions = image.cuda(non_blocking=True),\
+ # target.cuda(non_blocking=True),\
+ # sentences.cuda(non_blocking=True),\
+ # attentions.cuda(non_blocking=True)
+
+ image, target, sentences, attentions, mlm_targets, mlm_masks, position = data
+ image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\
+ target.cuda(non_blocking=True),\
+ sentences.cuda(non_blocking=True),\
+ attentions.cuda(non_blocking=True), \
+ mlm_targets.cuda(non_blocking=True), \
+ mlm_masks.cuda(non_blocking=True), \
+ position.cuda(non_blocking=True)
+
+ sentences = sentences.squeeze(1)
+ attentions = attentions.squeeze(1)
+ #print("sentences", sentences.shape)
+ #print("attentions", attentions.shape)
+
+
+ output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position)
+ mask_cls_results = output["pred_logits"]
+ mask_pred_results = output["pred_masks"]
+
+ target_shape = target.shape[-2:]
+ mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)
+
+ pred_masks = model.module.semantic_inference(mask_cls_results, mask_pred_results)
+ output = pred_masks[0]
+
+
+ iou, I, U = IoU(output, target)
+ acc_ious += iou
+ mean_IoU.append(iou)
+ cum_I += I
+ cum_U += U
+ for n_eval_iou in range(len(eval_seg_iou_list)):
+ eval_seg_iou = eval_seg_iou_list[n_eval_iou]
+ seg_correct[n_eval_iou] += (iou >= eval_seg_iou)
+ seg_total += 1
+ iou = acc_ious / total_its
+
+ mean_IoU = np.array(mean_IoU)
+ mIoU = np.mean(mean_IoU)
+ print('Final results:')
+ print('Mean IoU is %.2f\n' % (mIoU * 100.))
+ results_str = ''
+ for n_eval_iou in range(len(eval_seg_iou_list)):
+ results_str += ' precision@%s = %.2f\n' % \
+ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
+ results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
+ print(results_str)
+
+ return 100 * iou, 100 * cum_I / cum_U
+
+
+def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq,
+ iterations, args):
+ model.train()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
+ header = 'Epoch: [{}]'.format(epoch)
+ train_loss = 0
+ total_its = 0
+
+ for data in metric_logger.log_every(data_loader, print_freq, header):
+ total_its += 1
+ #image, target, sentences, attentions = data
+ #image, target, sentences, attentions = image.cuda(non_blocking=True),\
+ # target.cuda(non_blocking=True),\
+ # sentences.cuda(non_blocking=True),\
+ # attentions.cuda(non_blocking=True)
+ image, target, sentences, attentions, mlm_targets, mlm_masks, position = data
+ image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\
+ target.cuda(non_blocking=True),\
+ sentences.cuda(non_blocking=True),\
+ attentions.cuda(non_blocking=True), \
+ mlm_targets.cuda(non_blocking=True), \
+ mlm_masks.cuda(non_blocking=True), \
+ position.cuda(non_blocking=True)
+
+ sentences = sentences.squeeze(1)
+ attentions = attentions.squeeze(1)
+ #l_mask = attentions.unsqueeze(dim=-1)
+
+ output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position)
+ #print(avg_lang_feature.shape)
+ avg_lang_feature = torch.nn.functional.normalize(avg_lang_feature, dim=1)
+ #print("----")
+ #print(output.shape)
+ #print(mask_features.shape)
+ #print(avg_lang_feature.shape)
+ #print( mlm_predictions.shape)
+ #print(mlm_targets.shape)
+ #print("----")
+
+ target_shape = target.shape[-2:]
+ output['pred_masks'] = F.interpolate(output['pred_masks'], size=target_shape, mode='bilinear', align_corners=True)
+
+ if "aux_outputs" in output:
+ for i, aux_outputs in enumerate(output["aux_outputs"]):
+ output['aux_outputs'][i]['pred_masks'] = F.interpolate(output['aux_outputs'][i]['pred_masks'], size=target_shape, mode='bilinear', align_corners=True)
+
+ # pixel region
+ B, C, H, W = mask_features.shape
+
+ target_reshape = F.interpolate(target.unsqueeze(1).float(), size=mask_features.shape[-2:], mode='nearest').long()
+
+ target_reshape = target_reshape.repeat(1, mask_features.shape[1], 1, 1)
+ #print(avg_pos_feature.shape, avg_lang_feature.shape, avg_neg_feature.shape)
+
+ #cl_loss = 0.0
+ plic_lang_loss = 0.0
+ plic_pos_loss = 0.0
+ plic_neg_loss = 0.0
+ for i in range(B):
+ if ((target_reshape[[i]] == 0).sum() != 0 and (target_reshape[[i]] == 1).sum() != 0):
+
+ avg_pos_feature = (mask_features[[i]] * target_reshape[[i]]).sum(-1).sum(-1) / target_reshape[[i]].sum(-1).sum(-1)
+ avg_neg_feature = (mask_features[[i]] * (1.0-target_reshape[[i]])).sum(-1).sum(-1) / (1.0-target_reshape[[i]]).sum(-1).sum(-1)
+ avg_pos_feature = torch.nn.functional.normalize(avg_pos_feature, dim=1)
+ avg_neg_feature = torch.nn.functional.normalize(avg_neg_feature, dim=1)
+
+ #avg lang feature no normalize???
+
+
+
+ pos_features = mask_features[[i]][target_reshape[[i]]==1].view(1, C, -1)
+ neg_features = mask_features[[i]][target_reshape[[i]]==0].view(1, C, -1)
+ #inter_neg_features = mask_features[[B-i-1]][target_reshape[[B-i-1]]==1].view(1, C, -1)
+ #neg_features = torch.cat([intra_neg_features, inter_neg_features], dim=2)
+
+ pos_features = torch.nn.functional.normalize(pos_features, dim=1)
+ neg_features = torch.nn.functional.normalize(neg_features, dim=1)
+
+ #print(avg_lang_feature.shape, avg_lang_feature[[i]].shape, pos_features.shape)
+ lang_pos_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], pos_features)
+ lang_neg_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], neg_features)
+
+ lang_matrix = torch.cat([lang_pos_scores.unsqueeze(-1), lang_neg_scores.unsqueeze(1).repeat(1, lang_pos_scores.shape[1], 1)], dim=2)
+ lang_labels = torch.zeros(lang_matrix.shape[1], dtype=torch.long).cuda()
+ lang_labels = lang_labels.unsqueeze(0).repeat(lang_matrix.shape[0], 1)
+
+ lang_score = torch.softmax(lang_matrix, -1)
+ lang_score = 1.0 - lang_score[:, :, 0]
+
+ pos_pos_scores = torch.einsum("bq,bqn->bn", avg_pos_feature, pos_features)
+ pos_neg_scores = torch.einsum("bqn,bqm->bnm", pos_features, neg_features)
+
+ pos_matrix = torch.cat([pos_pos_scores.unsqueeze(-1), pos_neg_scores], dim=2)
+ pos_labels = torch.zeros(pos_matrix.shape[1], dtype=torch.long).cuda()
+ pos_labels = pos_labels.unsqueeze(0).repeat(pos_matrix.shape[0], 1)
+
+ pos_score = torch.softmax(pos_matrix, -1)
+ pos_score = 1.0 - pos_score[:, :, 0]
+ #pos_weight = pos_weight.view(-1, pos_weight.shape[-1])
+
+ #intra_neg_features = torch.nn.functional.normalize(intra_neg_features, dim=1)
+ neg_neg_scores = torch.einsum("bq,bqn->bn", avg_neg_feature, neg_features)
+ neg_pos_scores = torch.einsum("bqn,bqm->bnm", neg_features, pos_features)
+
+ neg_matrix = torch.cat([neg_neg_scores.unsqueeze(-1), neg_pos_scores], dim=2)
+ neg_labels = torch.zeros(neg_matrix.shape[1], dtype=torch.long).cuda()
+ neg_labels = neg_labels.unsqueeze(0).repeat(neg_matrix.shape[0], 1)
+
+ neg_score = torch.softmax(neg_matrix, -1)
+ neg_score = 1.0 - neg_score[:, :, 0]
+ #neg_weight = neg_weight.view(-1, neg_weight.shape[-1])
+
+ pos_loss = (torch.pow(pos_score, args.plic_pos_alpha) * torch.nn.functional.cross_entropy(pos_matrix.view(-1, pos_matrix.shape[-1])/args.plic_pos_temp, pos_labels.view(-1), reduction='none')).mean()
+ neg_loss = (torch.pow(neg_score, args.plic_neg_alpha) * torch.nn.functional.cross_entropy(neg_matrix.view(-1, neg_matrix.shape[-1])/args.plic_neg_temp, neg_labels.view(-1), reduction='none')).mean()
+
+ lang_loss = (torch.pow(lang_score, args.plic_lang_alpha) * torch.nn.functional.cross_entropy(lang_matrix.view(-1, lang_matrix.shape[-1])/args.plic_lang_temp, lang_labels.view(-1), reduction='none')).mean()
+
+ plic_pos_loss += pos_loss
+ plic_neg_loss += neg_loss
+ plic_lang_loss += lang_loss
+ #cl_loss += 0.5 * (torch.nn.functional.cross_entropy(pos_matrix.view(-1, pos_matrix.shape[-1])/cl_temp, pos_labels.view(-1))+torch.nn.functional.cross_entropy(neg_matrix.view(-1, neg_matrix.shape[-1])/cl_temp, neg_labels.view(-1)))
+ plic_pos_loss = (args.plic_pos_weight * plic_pos_loss) / B
+ plic_neg_loss = (args.plic_neg_weight * plic_neg_loss) / B
+ plic_lang_loss = (args.plic_lang_weight * plic_lang_loss) / B
+ plic_loss = plic_pos_loss + plic_neg_loss +plic_lang_loss
+
+
+ #print(output.device, target.device)
+ losses = criterion(output, target)
+ weight_dict = criterion.weight_dict
+
+ loss_ce = 0.0
+ loss_dice = 0.0
+ loss_mask = 0.0
+ for k in list(losses.keys()):
+ if k in weight_dict:
+ losses[k] *= criterion.weight_dict[k]
+ if '_ce' in k:
+ loss_ce += losses[k]
+ elif '_dice' in k:
+ loss_dice += losses[k]
+ else:
+ loss_mask += losses[k]
+ else:
+ # remove this loss if not specified in `weight_dict`
+ losses.pop(k)
+ #loss = 0.3 * loss_ce + 0.3 * loss_dice + 0.4 * loss_mask
+ smlm_loss = args.smlm_weight * nn.CrossEntropyLoss(ignore_index=-1)(mlm_predictions, mlm_targets)
+ loss = loss_ce + loss_dice + loss_mask + plic_loss + smlm_loss
+
+
+ #loss = criterion(output.squeeze(1), target.float())
+ optimizer.zero_grad() # set_to_none=True is only available in pytorch 1.6+
+ loss.backward()
+ optimizer.step()
+ lr_scheduler.step()
+
+ torch.cuda.synchronize()
+ train_loss += loss.item()
+ iterations += 1
+ #metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
+ metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"], loss_ce=loss_ce.item(), loss_dice=loss_dice.item(), loss_mask=loss_mask.item(), plic_loss=plic_loss.item(), plic_lang_loss=plic_lang_loss.item(), plic_pos_loss=plic_pos_loss.item(), plic_neg_loss=plic_neg_loss.item(), smlm_loss=smlm_loss.item())
+ #metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"], loss_ce=loss_ce.item(), loss_dice=loss_dice.item(), loss_mask=loss_mask.item(), cl_loss=cl_loss.item(), cl_lang_loss=cl_lang_loss_print, cl_pos_loss=cl_pos_loss_print, cl_neg_loss=cl_neg_loss_print)
+
+ #del image, target, sentences, attentions, loss, output, data
+ #if bert_model is not None:
+ # del last_hidden_states, embedding
+
+ #gc.collect()
+ #torch.cuda.empty_cache()
+ #del loss
+ #del cl_loss
+ #del cl_lang_loss
+ #del loss_ce
+ #del loss_dice
+ #del loss_mask
+ torch.cuda.synchronize()
+
+
+def main(args):
+#def main(local_rank, args):
+ #ip = os.environ['MASTER_IP']
+ #port = os.environ['MASTER_PORT']
+ #hosts = int(os.environ['WORLD_SIZE']) # 机器个数 1
+ #rank = int(os.environ['RANK']) # 当前机器编号
+ #gpus = torch.cuda.device_count() # 每台机器的GPU个数
+ #print(local_rank, rank, gpus) #3 0 8
+ #dist.init_process_group(backend='nccl', init_method=f'tcp://{ip}:{port}', world_size=hosts*gpus, rank=rank*gpus+local_rank)
+ #torch.cuda.set_device(local_rank)
+ #dist.barrier()
+
+ ##utils.init_distributed_mode(args)
+ #args.distributed=True
+ #args.gpu = local_rank
+ #print(args)
+ ##misc.init_distributed_mode(args)
+
+ #print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
+ #print("{}".format(args).replace(', ', ',\n'))
+
+ #device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ print('seed', seed)
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ #cudnn.benchmark = True
+
+ dataset, num_classes = get_dataset("train",
+ get_transform(args=args),
+ args=args)
+ dataset_test, _ = get_dataset("val",
+ get_transform(args=args),
+ args=args)
+
+ # batch sampler
+ print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.")
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ #num_tasks = hosts*gpus
+ #global_rank = rank*gpus+local_rank
+ train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank,
+ shuffle=True)
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+
+ # data loader
+ data_loader = torch.utils.data.DataLoader(
+ dataset, batch_size=args.batch_size,
+ sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True)
+
+ data_loader_test = torch.utils.data.DataLoader(
+ dataset_test, batch_size=1, sampler=test_sampler, pin_memory=True, num_workers=args.workers)
+
+ # model initialization
+ print(args.model)
+ model = multimodal_segmentation_ppm.__dict__[args.model](pretrained=args.pretrained_swin_weights,
+ args=args)
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ #model.cuda()
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True)
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=False)
+ #single_model = model.module
+
+ if args.model != 'lavt_one':
+ model_class = MultiModalBert
+ bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=model.backbone.embed_dim)
+ bert_model.pooler = None # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel
+ #bert_model.cuda()
+ bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model)
+ #bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[local_rank])
+ #single_bert_model = bert_model.module
+ else:
+ bert_model = None
+ single_bert_model = None
+
+ input_shape = dict()
+ input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
+ input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
+ input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
+ input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
+
+
+
+ cfg = Dict()
+ cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = args.transformer_enc_layers
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
+
+ cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = args.num_object_queries
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = args.dim_feedforward
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = args.dec_layers
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
+
+ cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
+ cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = args.no_object_weight
+ cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = args.class_weight
+ cfg.MODEL.MASK_FORMER.DICE_WEIGHT = args.dice_weight
+ cfg.MODEL.MASK_FORMER.MASK_WEIGHT = args.mask_weight
+
+ cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = args.train_num_points
+ cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0
+ cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75
+ print(cfg)
+
+ maskformer_head = MaskFormerHead(cfg, input_shape)
+ maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head)
+ #maskformer_head.cuda()
+ #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False)
+ #single_head = maskformer_head.module
+ #print(single_head)
+
+ model = WrapperModel(model.backbone, bert_model, maskformer_head, args)
+ model.cuda()
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
+ single_model = model.module
+
+ # mask2former loss
+ deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
+ no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
+
+ # loss weights
+ class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
+ dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
+ mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
+ # self.criterion = Criterion(self.num_classes)
+
+ # building criterion
+
+ matcher = HungarianMatcher(
+ cost_class=class_weight,
+ cost_mask=mask_weight,
+ cost_dice=dice_weight,
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
+ )
+
+ weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
+ if deep_supervision:
+ dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
+ aux_weight_dict = {}
+ for i in range(dec_layers - 1):
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
+ weight_dict.update(aux_weight_dict)
+
+ losses = ["labels", "masks"]
+ criterion = SetCriterion(
+ cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
+ matcher=matcher,
+ weight_dict=weight_dict,
+ eos_coef=no_object_weight,
+ losses=losses,
+ num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
+ oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
+ importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
+ device='cuda'
+ )
+
+ if args.resume == "auto":
+ last_ckpt = ""
+ for e in range(args.epochs):
+ ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth')
+ if os.path.exists(ckpt_path):
+ last_ckpt = ckpt_path
+ args.resume = last_ckpt
+
+ # resume training
+ if args.resume:
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ single_model.load_state_dict(checkpoint['model'])
+ #if args.model != 'lavt_one':
+ # single_bert_model.load_state_dict(checkpoint['bert_model'])
+
+ # parameters to optimize
+ backbone_no_decay = list()
+ backbone_decay = list()
+ for name, m in single_model.image_model.named_parameters():
+ if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name:
+ backbone_no_decay.append(m)
+ else:
+ backbone_decay.append(m)
+
+ params_to_optimize = [
+ {'params': backbone_no_decay, 'weight_decay': 0.0},
+ {'params': backbone_decay},
+ {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]},
+ # the following are the parameters of bert
+ {"params": reduce(operator.concat,
+ [[p for p in single_model.language_model.encoder.layer[i].parameters()
+ if p.requires_grad] for i in range(10)])},
+ {"params": single_model.language_model.pwams.parameters()},
+ {"params": single_model.language_model.res_gates.parameters()},
+ {"params": single_model.language_model.norms.parameters()},
+ {"params": single_model.lang_proj.parameters()},
+ #{"params": single_model.language_model.parameters()},
+ {'params': single_model.mlm_head.parameters()},
+ {'params': single_model.mlm_vis_proj.parameters()},
+ {'params': single_model.mlm_lang_proj.parameters()},
+ {'params': single_model.mlm_transformer.parameters()},
+ {'params': single_model.mlm_pos_embeds.parameters()},
+ {'params': single_model.mlm_modal_embeds.parameters()},
+ {'params': single_model.mlm_mask_embed.parameters()},
+ {'params': single_model.mlm_pos_mlp.parameters()},
+ #{'params': mlm_head.parameters(), 'weight_decay': 0.0},
+ #{'params': mlm_vis_proj.parameters(), 'weight_decay': 0.0},
+ #{'params': mlm_lang_proj.parameters(), 'weight_decay': 0.0},
+ #{'params': mlm_transformer.parameters(), 'weight_decay': 0.0},
+ #{'params': mlm_pos_embeds.parameters(), 'weight_decay': 0.0},
+ #{'params': mlm_modal_embeds.parameters(), 'weight_decay': 0.0},
+ #{'params': mlm_mask_embed.parameters(), 'weight_decay': 0.0},
+ #{'params': mlm_pos_mlp.parameters(), 'weight_decay': 0.0},
+ ]
+
+
+ # optimizer
+ optimizer = torch.optim.AdamW(params_to_optimize,
+ lr=args.lr,
+ weight_decay=args.weight_decay,
+ amsgrad=args.amsgrad
+ )
+
+ # learning rate scheduler
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
+ lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
+
+ # housekeeping
+ start_time = time.time()
+ iterations = 0
+ best_oIoU = -0.1
+
+ # resume training (optimizer, lr scheduler, and the epoch)
+ if args.resume:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ resume_epoch = checkpoint['epoch']
+ else:
+ resume_epoch = -999
+
+ # training loops
+ for epoch in range(max(0, resume_epoch+1), args.epochs):
+ data_loader.sampler.set_epoch(epoch)
+ train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq,
+ iterations, args)
+ iou, overallIoU = evaluate(model, data_loader_test)
+
+ print('Average object IoU {}'.format(iou))
+ print('Overall IoU {}'.format(overallIoU))
+
+
+ dict_to_save = {'model': single_model.state_dict(),
+ 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
+ 'lr_scheduler': lr_scheduler.state_dict()}
+
+ checkpoint_path = os.path.join(args.output_dir, 'checkpoint-{}.pth'.format(epoch))
+ utils.save_on_master(dict_to_save, str(checkpoint_path) + '_TEMP')
+ if utils.is_main_process():
+ os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path))
+
+ if utils.is_main_process():
+ ckpt_paths = []
+ for e in range(args.epochs):
+ ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth')
+ print(ckpt_path)
+ if os.path.exists(ckpt_path):
+ ckpt_paths.append(ckpt_path)
+ print(ckpt_paths)
+ for ckpt_path in ckpt_paths[:-args.max_ckpt]:
+ os.remove(ckpt_path)
+ print("remove {:s}".format(ckpt_path))
+
+
+ save_checkpoint = (best_oIoU < overallIoU)
+ if save_checkpoint:
+ print('Better epoch: {}\n'.format(epoch))
+ dict_to_save = {'model': single_model.state_dict(),
+ 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
+ 'lr_scheduler': lr_scheduler.state_dict()}
+
+ checkpoint_path = os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id))
+ utils.save_on_master(dict_to_save, checkpoint_path + '_TEMP')
+ if utils.is_main_process():
+ os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path))
+ best_oIoU = overallIoU
+ torch.cuda.empty_cache()
+
+ # summarize
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == "__main__":
+ from args import get_parser
+ parser = get_parser()
+ args = parser.parse_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+ # set up distributed learning
+ utils.init_distributed_mode(args)
+ print('Image size: {}'.format(str(args.img_size)))
+ main(args)
+ #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())
diff --git a/elia/train_lavt.py b/elia/train_lavt.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a6d78314f04b84ce618dcc9e33b1044a615fc1c
--- /dev/null
+++ b/elia/train_lavt.py
@@ -0,0 +1,444 @@
+import haienv
+haienv.set_env('lavt2')
+import torch.multiprocessing as mp
+import torch.distributed as dist
+
+import datetime
+import os
+import time
+
+import torch
+import torch.utils.data
+from torch import nn
+
+from functools import reduce
+import operator
+from bert.modeling_bert import BertModel
+
+import torchvision
+from lib import segmentation
+
+import transforms as T
+import utils
+import numpy as np
+
+import torch.nn.functional as F
+
+import gc
+from collections import OrderedDict
+
+import torch.backends.cudnn as cudnn
+
+from ffrecord.torch import DataLoader,Dataset
+def get_dataset(image_set, transform, args):
+ from data.dataset_refer_bert import ReferDataset
+ ds = ReferDataset(args,
+ split=image_set,
+ image_transforms=transform,
+ target_transforms=None
+ )
+ num_classes = 2
+
+ return ds, num_classes
+
+
+# IoU calculation for validation
+def IoU(pred, gt):
+ pred = pred.argmax(1)
+
+ intersection = torch.sum(torch.mul(pred, gt))
+ union = torch.sum(torch.add(pred, gt)) - intersection
+
+ if intersection == 0 or union == 0:
+ iou = 0
+ else:
+ iou = float(intersection) / float(union)
+
+ return iou, intersection, union
+
+
+def get_transform(args):
+ transforms = [T.Resize(args.img_size, args.img_size),
+ T.ToTensor(),
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ]
+
+ return T.Compose(transforms)
+
+
+def criterion(input, target):
+ weight = torch.FloatTensor([0.9, 1.1]).cuda()
+ return nn.functional.cross_entropy(input, target, weight=weight)
+
+
+def evaluate(model, data_loader, bert_model):
+ model.eval()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Test:'
+ total_its = 0
+ acc_ious = 0
+
+ # evaluation variables
+ cum_I, cum_U = 0, 0
+ eval_seg_iou_list = [.5, .6, .7, .8, .9]
+ seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
+ seg_total = 0
+ mean_IoU = []
+
+ with torch.no_grad():
+ for data in metric_logger.log_every(data_loader, 100, header):
+ total_its += 1
+ image, target, sentences, attentions = data
+ image, target, sentences, attentions = image.cuda(non_blocking=True),\
+ target.cuda(non_blocking=True),\
+ sentences.cuda(non_blocking=True),\
+ attentions.cuda(non_blocking=True)
+
+ sentences = sentences.squeeze(1)
+ attentions = attentions.squeeze(1)
+ #print("sentences", sentences.shape)
+ #print("attentions", attentions.shape)
+
+ if bert_model is not None:
+ last_hidden_states = bert_model(sentences, attention_mask=attentions)[0]
+ #print("last hidden states", last_hidden_states.shape)
+ embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy
+ attentions = attentions.unsqueeze(dim=-1) # (B, N_l, 1)
+ output = model(image, embedding, l_mask=attentions)
+ else:
+ output = model(image, sentences, l_mask=attentions)
+
+ iou, I, U = IoU(output, target)
+ acc_ious += iou
+ mean_IoU.append(iou)
+ cum_I += I
+ cum_U += U
+ for n_eval_iou in range(len(eval_seg_iou_list)):
+ eval_seg_iou = eval_seg_iou_list[n_eval_iou]
+ seg_correct[n_eval_iou] += (iou >= eval_seg_iou)
+ seg_total += 1
+ iou = acc_ious / total_its
+
+ mean_IoU = np.array(mean_IoU)
+ mIoU = np.mean(mean_IoU)
+ print('Final results:')
+ print('Mean IoU is %.2f\n' % (mIoU * 100.))
+ results_str = ''
+ for n_eval_iou in range(len(eval_seg_iou_list)):
+ results_str += ' precision@%s = %.2f\n' % \
+ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
+ results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
+ print(results_str)
+
+ return 100 * iou, 100 * cum_I / cum_U
+
+
+def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq,
+ iterations, bert_model):
+ model.train()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
+ header = 'Epoch: [{}]'.format(epoch)
+ train_loss = 0
+ total_its = 0
+
+ for data in metric_logger.log_every(data_loader, print_freq, header):
+ total_its += 1
+ image, target, sentences, attentions = data
+ image, target, sentences, attentions = image.cuda(non_blocking=True),\
+ target.cuda(non_blocking=True),\
+ sentences.cuda(non_blocking=True),\
+ attentions.cuda(non_blocking=True)
+
+ sentences = sentences.squeeze(1)
+ attentions = attentions.squeeze(1)
+ #print(sentences.shape, attentions.shape, target.shape)
+ #print(sentences)
+ #print('a', sentences.shape)
+ #print('b', attentions.shape)
+
+ if bert_model is not None:
+ last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] # (6, 10, 768)
+ #print('c', last_hidden_states.shape)
+
+ embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy
+ #print('e', embedding.shape)
+ attentions = attentions.unsqueeze(dim=-1) # (batch, N_l, 1)
+ #print('f', attentions.shape)
+ output = model(image, embedding, l_mask=attentions)
+ else:
+ output = model(image, sentences, l_mask=attentions)
+
+ loss = criterion(output, target)
+ optimizer.zero_grad() # set_to_none=True is only available in pytorch 1.6+
+ loss.backward()
+ optimizer.step()
+ lr_scheduler.step()
+
+ torch.cuda.synchronize()
+ train_loss += loss.item()
+ iterations += 1
+ metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
+
+ del image, target, sentences, attentions, loss, output, data
+ if bert_model is not None:
+ del last_hidden_states, embedding
+
+ #gc.collect()
+ #torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+
+#def main(args):
+def main(local_rank, args):
+ ip = os.environ['MASTER_IP']
+ port = os.environ['MASTER_PORT']
+ hosts = int(os.environ['WORLD_SIZE']) # 机器个数 1
+ rank = int(os.environ['RANK']) # 当前机器编号
+ gpus = torch.cuda.device_count() # 每台机器的GPU个数
+ print(local_rank, rank, gpus) #3 0 8
+ dist.init_process_group(backend='nccl', init_method=f'tcp://{ip}:{port}', world_size=hosts*gpus, rank=rank*gpus+local_rank)
+ torch.cuda.set_device(local_rank)
+ dist.barrier()
+
+ #utils.init_distributed_mode(args)
+ args.distributed=True
+ args.gpu = local_rank
+ print(args)
+ #misc.init_distributed_mode(args)
+
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
+ print("{}".format(args).replace(', ', ',\n'))
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ #cudnn.benchmark = True
+
+ dataset, num_classes = get_dataset("train",
+ get_transform(args=args),
+ args=args)
+ dataset_test, _ = get_dataset("val",
+ get_transform(args=args),
+ args=args)
+
+ # batch sampler
+ print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.")
+ #num_tasks = utils.get_world_size()
+ #global_rank = utils.get_rank()
+ num_tasks = hosts*gpus
+ global_rank = rank*gpus+local_rank
+ train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank,
+ shuffle=True)
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+
+ # data loader
+ data_loader = DataLoader(
+ dataset, batch_size=args.batch_size,
+ sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True)
+
+ data_loader_test = DataLoader(
+ dataset_test, batch_size=1, sampler=test_sampler, pin_memory=True, num_workers=args.workers)
+
+ # model initialization
+ print(args.model)
+ model = segmentation.__dict__[args.model](pretrained=args.pretrained_swin_weights,
+ args=args)
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model.cuda()
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=False)
+ single_model = model.module
+
+ if args.model != 'lavt_one':
+ model_class = BertModel
+ bert_model = model_class.from_pretrained(args.ck_bert)
+ bert_model.pooler = None # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel
+ bert_model.cuda()
+ bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model)
+ bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank])
+ single_bert_model = bert_model.module
+ else:
+ bert_model = None
+ single_bert_model = None
+
+ input_shape = dict()
+ input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
+ input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
+ input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
+ input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
+
+
+
+ cfg = Dict()
+ cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
+
+ cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
+
+
+ maskformer_head = MaskFormerHead(cfg, input_shape)
+ maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head)
+ maskformer_head.cuda()
+ maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False)
+ single_head = maskformer_head.module
+ print(single_head)
+
+
+ if args.resume == "auto":
+ last_ckpt = ""
+ for e in range(args.epochs):
+ ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth')
+ if os.path.exists(ckpt_path):
+ last_ckpt = ckpt_path
+ args.resume = last_ckpt
+
+ # resume training
+ if args.resume:
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ single_model.load_state_dict(checkpoint['model'])
+ single_head.load_state_dict(checkpoint['head_model'])
+ if args.model != 'lavt_one':
+ single_bert_model.load_state_dict(checkpoint['bert_model'])
+
+ # parameters to optimize
+ backbone_no_decay = list()
+ backbone_decay = list()
+ for name, m in single_model.backbone.named_parameters():
+ if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name:
+ backbone_no_decay.append(m)
+ else:
+ backbone_decay.append(m)
+
+ if args.model != 'lavt_one':
+ params_to_optimize = [
+ {'params': backbone_no_decay, 'weight_decay': 0.0},
+ {'params': backbone_decay},
+ {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]},
+ # the following are the parameters of bert
+ {"params": reduce(operator.concat,
+ [[p for p in single_bert_model.encoder.layer[i].parameters()
+ if p.requires_grad] for i in range(10)])},
+ {"params": single_head.parameters()}
+ ]
+ else:
+ params_to_optimize = [
+ {'params': backbone_no_decay, 'weight_decay': 0.0},
+ {'params': backbone_decay},
+ {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]},
+ # the following are the parameters of bert
+ {"params": reduce(operator.concat,
+ [[p for p in single_model.text_encoder.encoder.layer[i].parameters()
+ if p.requires_grad] for i in range(10)])},
+ ]
+
+ # optimizer
+ optimizer = torch.optim.AdamW(params_to_optimize,
+ lr=args.lr,
+ weight_decay=args.weight_decay,
+ amsgrad=args.amsgrad
+ )
+
+ # learning rate scheduler
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
+ lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
+
+ # housekeeping
+ start_time = time.time()
+ iterations = 0
+ best_oIoU = -0.1
+
+ # resume training (optimizer, lr scheduler, and the epoch)
+ if args.resume:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ resume_epoch = checkpoint['epoch']
+ else:
+ resume_epoch = -999
+
+ # training loops
+ for epoch in range(max(0, resume_epoch+1), args.epochs):
+ data_loader.sampler.set_epoch(epoch)
+ train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq,
+ iterations, bert_model, single_head)
+ iou, overallIoU = evaluate(model, data_loader_test, bert_model, single_head)
+
+ print('Average object IoU {}'.format(iou))
+ print('Overall IoU {}'.format(overallIoU))
+
+
+ if single_bert_model is not None:
+ dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(),
+ 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
+ 'lr_scheduler': lr_scheduler.state_dict(), 'head_model': single_head.state_dict()}
+ else:
+ dict_to_save = {'model': single_model.state_dict(),
+ 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
+ 'lr_scheduler': lr_scheduler.state_dict()}
+
+ checkpoint_path = os.path.join(args.output_dir, 'checkpoint-{}.pth'.format(epoch))
+ utils.save_on_master(dict_to_save, str(checkpoint_path) + '_TEMP')
+ if utils.is_main_process():
+ os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path))
+
+ if utils.is_main_process():
+ ckpt_paths = []
+ for e in range(args.epochs):
+ ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth')
+ print(ckpt_path)
+ if os.path.exists(ckpt_path):
+ ckpt_paths.append(ckpt_path)
+ print(ckpt_paths)
+ for ckpt_path in ckpt_paths[:-args.max_ckpt]:
+ os.remove(ckpt_path)
+ print("remove {:s}".format(ckpt_path))
+
+
+ save_checkpoint = (best_oIoU < overallIoU)
+ if save_checkpoint:
+ print('Better epoch: {}\n'.format(epoch))
+ if single_bert_model is not None:
+ dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(),
+ 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
+ 'lr_scheduler': lr_scheduler.state_dict()}
+ else:
+ dict_to_save = {'model': single_model.state_dict(),
+ 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args,
+ 'lr_scheduler': lr_scheduler.state_dict()}
+
+ checkpoint_path = os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id))
+ utils.save_on_master(dict_to_save, checkpoint_path + '_TEMP')
+ if utils.is_main_process():
+ os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path))
+ best_oIoU = overallIoU
+
+ # summarize
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == "__main__":
+ from args import get_parser
+ parser = get_parser()
+ args = parser.parse_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+ # set up distributed learning
+ #utils.init_distributed_mode(args)
+ print('Image size: {}'.format(str(args.img_size)))
+ #main(args)
+ mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())
diff --git a/elia/transforms.py b/elia/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d22889dbab930ebe2a41dd01b8067465343079f
--- /dev/null
+++ b/elia/transforms.py
@@ -0,0 +1,124 @@
+import numpy as np
+from PIL import Image
+import random
+
+import torch
+from torchvision import transforms as T
+from torchvision.transforms import functional as F
+
+
+def pad_if_smaller(img, size, fill=0):
+ min_size = min(img.size)
+ if min_size < size:
+ ow, oh = img.size
+ padh = size - oh if oh < size else 0
+ padw = size - ow if ow < size else 0
+ img = F.pad(img, (0, 0, padw, padh), fill=fill)
+ return img
+
+
+class Compose(object):
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, image, target):
+ for t in self.transforms:
+ image, target = t(image, target)
+ return image, target
+
+
+class Resize(object):
+ def __init__(self, h, w):
+ self.h = h
+ self.w = w
+
+ def __call__(self, image, target):
+ image = F.resize(image, (self.h, self.w))
+ # If size is a sequence like (h, w), the output size will be matched to this.
+ # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio
+ target = F.resize(target, (self.h, self.w), interpolation=Image.NEAREST)
+ return image, target
+
+
+class RandomResize(object):
+ def __init__(self, min_size, max_size=None):
+ self.min_size = min_size
+ if max_size is None:
+ max_size = min_size
+ self.max_size = max_size
+
+ def __call__(self, image, target):
+ size = random.randint(self.min_size, self.max_size) # Return a random integer N such that a <= N <= b. Alias for randrange(a, b+1)
+ image = F.resize(image, size)
+ # If size is a sequence like (h, w), the output size will be matched to this.
+ # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio
+ target = F.resize(target, size, interpolation=Image.NEAREST)
+ return image, target
+
+
+class RandomHorizontalFlip(object):
+ def __init__(self, flip_prob):
+ self.flip_prob = flip_prob
+
+ def __call__(self, image, target):
+ if random.random() < self.flip_prob:
+ image = F.hflip(image)
+ target = F.hflip(target)
+ return image, target
+
+
+class RandomCrop(object):
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, image, target):
+ image = pad_if_smaller(image, self.size)
+ target = pad_if_smaller(target, self.size, fill=255)
+ crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
+ image = F.crop(image, *crop_params)
+ target = F.crop(target, *crop_params)
+ return image, target
+
+
+class CenterCrop(object):
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, image, target):
+ image = F.center_crop(image, self.size)
+ target = F.center_crop(target, self.size)
+ return image, target
+
+
+class ToTensor(object):
+ def __call__(self, image, target):
+ image = F.to_tensor(image)
+ target = torch.as_tensor(np.asarray(target).copy(), dtype=torch.int64)
+ return image, target
+
+
+class RandomAffine(object):
+ def __init__(self, angle, translate, scale, shear, resample=0, fillcolor=None):
+ self.angle = angle
+ self.translate = translate
+ self.scale = scale
+ self.shear = shear
+ self.resample = resample
+ self.fillcolor = fillcolor
+
+ def __call__(self, image, target):
+ affine_params = T.RandomAffine.get_params(self.angle, self.translate, self.scale, self.shear, image.size)
+ image = F.affine(image, *affine_params)
+ target = F.affine(target, *affine_params)
+ return image, target
+
+
+class Normalize(object):
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, image, target):
+ image = F.normalize(image, mean=self.mean, std=self.std)
+ return image, target
+
diff --git a/elia/utils.py b/elia/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8db2c5a47227f7f3bcd7c4f219d97d7141724ccb
--- /dev/null
+++ b/elia/utils.py
@@ -0,0 +1,222 @@
+from __future__ import print_function
+from collections import defaultdict, deque
+import datetime
+import math
+import time
+import torch
+import torch.distributed as dist
+import torch.backends.cudnn as cudnn
+
+import errno
+import os
+
+import sys
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = self.delimiter.join([
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}',
+ 'max mem: {memory:.0f}'
+ ])
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ sys.stdout.flush()
+
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {}'.format(header, total_time_str))
+
+
+def mkdir(path):
+ try:
+ os.makedirs(path)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ rank = int(os.environ["RANK"])
+ world_size = int(os.environ['WORLD_SIZE'])
+ print(f"RANK and WORLD_SIZE in environment: {rank}/{world_size}")
+ else:
+ rank = -1
+ world_size = -1
+
+ torch.cuda.set_device(args.local_rank)
+ torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
+ torch.distributed.barrier()
+ setup_for_distributed(is_main_process())
+
+ if args.output_dir:
+ mkdir(args.output_dir)
+ if args.model_id:
+ mkdir(os.path.join('./models/', args.model_id))
diff --git a/elia/visualize.py b/elia/visualize.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e4ae7b657f1a12b6bf7102b17bf414b64d60fa8
--- /dev/null
+++ b/elia/visualize.py
@@ -0,0 +1,505 @@
+
+import datetime
+import os
+import time
+
+import torch
+import torch.utils.data
+from torch import nn
+
+from bert.multimodal_bert import MultiModalBert
+import torchvision
+
+from lib import multimodal_segmentation_ppm
+import transforms as T
+import utils
+
+import numpy as np
+from PIL import Image
+import torch.nn.functional as F
+
+from modeling.MaskFormerModel import MaskFormerHead
+from addict import Dict
+from bert.modeling_bert import BertLMPredictionHead, BertEncoder
+import cv2
+import textwrap
+
+def get_dataset(image_set, transform, args):
+ from data.dataset_refer_bert_vis import ReferDataset
+ ds = ReferDataset(args,
+ split=image_set,
+ image_transforms=transform,
+ target_transforms=None,
+ eval_mode=True
+ )
+ num_classes = 2
+ return ds, num_classes
+
+
+def overlay_davis(image, mask, colors=[[0, 0, 0], [0, 255, 0]], cscale=1, alpha=0.4):
+ from scipy.ndimage.morphology import binary_dilation
+
+ colors = np.reshape(colors, (-1, 3))
+ colors = np.atleast_2d(colors) * cscale
+
+ im_overlay = image.copy()
+ object_ids = np.unique(mask)
+
+ for object_id in object_ids[1:]:
+ # Overlay color on binary mask
+ foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id])
+ binary_mask = mask == object_id
+
+ # Compose image
+ im_overlay[binary_mask] = foreground[binary_mask]
+
+ # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask
+ countours = binary_dilation(binary_mask) ^ binary_mask
+ # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask
+ im_overlay[countours, :] = 0
+
+ return im_overlay.astype(image.dtype)
+
+def evaluate(model, data_loader, device, args):
+ model.eval()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+
+ # evaluation variables
+ cum_I, cum_U = 0, 0
+ eval_seg_iou_list = [.5, .6, .7, .8, .9]
+ seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
+ seg_total = 0
+ mean_IoU = []
+ header = 'Test:'
+
+ with torch.no_grad():
+ idx = 0
+ for data in metric_logger.log_every(data_loader, 100, header):
+ idx += 1
+ image, target, sentences, attentions, raw_sentences, this_img, orig_img = data
+ image, target, sentences, attentions = image.to(device), target.to(device), \
+ sentences.to(device), attentions.to(device)
+
+ sentences = sentences.squeeze(1)
+ attentions = attentions.squeeze(1)
+ #target = target.cpu().data.numpy()
+
+ b, h, w, c = orig_img.shape
+ #orig_img = orig_img.numpy()[:, :, :, ::-1]
+
+ orig_img =orig_img.data.cpu().numpy()[0, :, :, :].astype(np.uint8)
+ vis = np.zeros((h, w*2,3)).astype(np.uint8)
+
+ #image_mean_iou = []
+ target_numpy = target.cpu().numpy()
+
+ for j in range(sentences.size(-1)):
+ #if bert_model is not None:
+ # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0]
+ # embedding = last_hidden_states.permute(0, 2, 1)
+ # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1))
+ #else:
+ output = model(image, sentences[:, :, j], attentions[:, :, j])
+ mask_cls_results = output["pred_logits"]
+ mask_pred_results = output["pred_masks"]
+
+ target_shape = target.shape[-2:]
+ mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)
+
+ pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)
+ #output = pred_masks[0]
+
+ #output = output.cpu()
+
+ I, U = computeIoU(pred_masks.cpu().numpy(), target_numpy)
+ if U == 0:
+ this_iou = 0.0
+ else:
+ this_iou = I*1.0/U
+ mean_IoU.append(this_iou)
+ #image_mean_iou.append(this_iou)
+ cum_I += I
+ cum_U += U
+ for n_eval_iou in range(len(eval_seg_iou_list)):
+ eval_seg_iou = eval_seg_iou_list[n_eval_iou]
+ seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
+ seg_total += 1
+
+
+ #print(output.shape)
+ #output_mask = output.argmax(1).data.numpy()
+ #output_mask = (output > 0.5).data.numpy()
+
+ #vis_output_mask = torch.sigmoid(output[:, 1]).data.numpy()
+ #vis_output_mask = torch.sigmoid((output>0.5).float()).data.numpy()
+ #soft
+ #vis_output_mask = output.data.numpy()
+ #vis_output_mask = output_mask
+
+ #print(output.shape, orig_shape)
+ gt_masks = torch.nn.functional.interpolate(target.unsqueeze(0).float(), (h, w))
+ pred_masks = torch.nn.functional.interpolate(pred_masks, (h, w))
+ #print(orig_mask.shape)
+ pred_masks = (pred_masks > 0.5).data.cpu().numpy()
+ #ntarget = target.data.cpu().numpy()
+ ##orig_mask = orig_mask.argmax(1).data.numpy()
+
+ #print(orig_img[0].shape, orig_mask[0][0].shape, flush=True)
+ #print(orig_img.dtype, orig_mask.dtype)
+ predict_imgs = overlay_davis(orig_img, pred_masks[0][0].astype(np.uint8))
+ gt_imgs = overlay_davis(orig_img, gt_masks[0][0].cpu().numpy().astype(np.uint8), colors=[[0, 0, 0], [0, 0, 255]])
+ #print(orig_mask.shape, orig_img.shape)
+ #red_mask = np.zeros((orig_mask.shape[1], orig_mask.shape[2], orig_mask.shape[3], 3)).astype(np.uint8)
+ #print("???", red_mask.shape, orig_mask.shape)
+ #red_mask[:, :, :, 1] = orig_mask * 255
+ #red_mask = cv2.bitwise_and(red_mask, red_mask, orig_mask.astype(np.uint8))
+
+ #temp = cv2.addWeighted(red_mask, 0.5, orig_img, 0.5, 0)
+ #print(orig_img.shape, temp.shape, orig_mask.shape, "WHAT?")
+ #new = orig_img * (1.0 - orig_mask[0][:,:,:,None]) + temp * orig_mask[0][:,:,:,None]
+ #print(new.shape, orig_mask.shape, temp.shape, "check")
+ ##print(vis_output_mask)
+ ##output_mask = output.argmax(1).data.numpy()
+ #
+ #print(raw_sentences[j])
+
+ # print(image.shape, target.shape, output_mask.shape)
+
+ #mean = np.array([0.485, 0.456, 0.406])
+ #std = np.array([0.229, 0.224, 0.225])
+ #np_image = (((image[0].permute(1,2,0).cpu().numpy() * std) + mean) * 255).astype(np.uint8)[:,:,::-1]
+ #np_target = (target * 255).transpose(1,2,0).astype(np.uint8)
+ ##print(output_mask)
+ #np_output_mask = (vis_output_mask*255).transpose(1,2,0).repeat(3, axis=2).astype(np.uint8)
+
+ font = cv2.FONT_HERSHEY_DUPLEX
+ fontScale = 1.0
+ fontColor = (255,0,0)
+ thickness = 1
+ lineType = 2
+
+ wrapped_text = textwrap.wrap(' '.join(raw_sentences[j]), width=35)
+ for k, line in enumerate(wrapped_text):
+ bottomLeftCornerOfText = (10,h-60 + k*30)
+ gt_imgs = cv2.putText(gt_imgs, line,
+ bottomLeftCornerOfText,
+ font,
+ fontScale,
+ fontColor,
+ thickness,
+ lineType)
+
+
+ #temp = j + 2
+ #split = temp // 3
+ #row = temp % 3
+ vis[0:h, 0:w, :] = gt_imgs
+ vis[0:h, w:2*w, :] = predict_imgs
+
+
+
+ #cv2.imwrite("vis/elifan_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], j), new[0].astype(np.uint8))
+ cv2.imwrite("vis/{:s}/{:s}_{:d}_{:d}_{:.2f}.jpg".format(args.vis_dir, this_img[0].split('.')[0], idx, j, this_iou), vis[:, :, ::-1].astype(np.uint8))
+
+ #print('---------------')
+ #cv2.imshow("vis", vis)
+ #cv2.waitKey(0)
+
+ #image_mean_iou = np.mean(np.array(image_mean_iou))
+ #print(image_mean_iou)
+ #if image_mean_iou < 0.5:
+ #cv2.imwrite("vis/elian_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], idx), vis)
+
+
+ #del image, target, sentences, attentions, output, output_mask
+ #if bert_model is not None:
+ # del last_hidden_states, embedding
+
+ mean_IoU = np.array(mean_IoU)
+ mIoU = np.mean(mean_IoU)
+ print('Final results:')
+ print('Mean IoU is %.2f\n' % (mIoU*100.))
+ results_str = ''
+ for n_eval_iou in range(len(eval_seg_iou_list)):
+ results_str += ' precision@%s = %.2f\n' % \
+ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
+ results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
+ print(results_str)
+
+#def evaluate(model, data_loader, device):
+# model.eval()
+# metric_logger = utils.MetricLogger(delimiter=" ")
+#
+# # evaluation variables
+# cum_I, cum_U = 0, 0
+# eval_seg_iou_list = [.5, .6, .7, .8, .9]
+# seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
+# seg_total = 0
+# mean_IoU = []
+# header = 'Test:'
+#
+# with torch.no_grad():
+# for data in metric_logger.log_every(data_loader, 100, header):
+# image, target, sentences, attentions = data
+# image, target, sentences, attentions = image.to(device), target.to(device), \
+# sentences.to(device), attentions.to(device)
+# sentences = sentences.squeeze(1)
+# attentions = attentions.squeeze(1)
+# target = target.cpu().data.numpy()
+# for j in range(sentences.size(-1)):
+# #if bert_model is not None:
+# # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0]
+# # embedding = last_hidden_states.permute(0, 2, 1)
+# # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1))
+# #else:
+# output = model(image, sentences[:, :, j], attentions[:, :, j])
+# mask_cls_results = output["pred_logits"]
+# mask_pred_results = output["pred_masks"]
+#
+# target_shape = target.shape[-2:]
+# mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)
+#
+# pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)
+# output = pred_masks[0]
+#
+# output = output.cpu()
+# #print(output.shape)
+# #output_mask = output.argmax(1).data.numpy()
+# output_mask = (output > 0.5).data.numpy()
+# I, U = computeIoU(output_mask, target)
+# if U == 0:
+# this_iou = 0.0
+# else:
+# this_iou = I*1.0/U
+# mean_IoU.append(this_iou)
+# cum_I += I
+# cum_U += U
+# for n_eval_iou in range(len(eval_seg_iou_list)):
+# eval_seg_iou = eval_seg_iou_list[n_eval_iou]
+# seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
+# seg_total += 1
+#
+# #del image, target, sentences, attentions, output, output_mask
+# #if bert_model is not None:
+# # del last_hidden_states, embedding
+#
+# mean_IoU = np.array(mean_IoU)
+# mIoU = np.mean(mean_IoU)
+# print('Final results:')
+# print('Mean IoU is %.2f\n' % (mIoU*100.))
+# results_str = ''
+# for n_eval_iou in range(len(eval_seg_iou_list)):
+# results_str += ' precision@%s = %.2f\n' % \
+# (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
+# results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
+# print(results_str)
+
+
+def get_transform(args):
+ transforms = [T.Resize(args.img_size, args.img_size),
+ T.ToTensor(),
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ]
+
+ return T.Compose(transforms)
+
+
+def computeIoU(pred_seg, gd_seg):
+ I = np.sum(np.logical_and(pred_seg, gd_seg))
+ U = np.sum(np.logical_or(pred_seg, gd_seg))
+
+ return I, U
+
+class WrapperModel(nn.Module):
+ def __init__(self, image_model, language_model, classifier, args) :
+ super(WrapperModel, self).__init__()
+ self.image_model = image_model
+ self.language_model = language_model
+ self.classifier = classifier
+ self.lang_proj = nn.Linear(768,256)
+
+ config = Dict({
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "gradient_checkpointing": False,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 512,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ #"max_position_embeddings": 16+20,
+ "model_type": "bert",
+ "num_attention_heads": 8,
+ "num_hidden_layers": 8,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "transformers_version": "4.6.0.dev0",
+ "type_vocab_size": 2,
+ "use_cache": True,
+ "vocab_size": 30522
+ })
+ self.mlm_transformer = BertEncoder(config)
+
+ self.lang_proj = nn.Linear(768,256)
+ self.mlm_vis_proj = nn.Conv2d(1024,512,1)
+ self.mlm_lang_proj = nn.Linear(768,512)
+ #print(vis_proj)
+ self.mlm_head = BertLMPredictionHead(config)
+
+ assert args.img_size % 4 == 0
+ num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2
+ print(num_img_tokens)
+ self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512)
+ self.mlm_modal_embeds = nn.Embedding(3, 512)
+
+ self.mlm_mask_embed = nn.Embedding(1, 512)
+ self.mlm_pos_mlp = nn.Sequential(
+ nn.Linear(2, 512),
+ nn.LayerNorm(512),
+ nn.Linear(512,512),
+ nn.GELU()
+ )
+
+ def _get_binary_mask(self, target):
+ # 返回每类的binary mask
+ y, x = target.size()
+ target_onehot = torch.zeros(self.num_classes + 1, y, x)
+ target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
+ return target_onehot[1:]
+
+ def semantic_inference(self, mask_cls, mask_pred):
+ mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
+ mask_pred = mask_pred.sigmoid()
+ semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
+ return semseg
+
+ def forward(self, image, sentences, attentions):
+ input_shape = image.shape[-2:]
+ l_mask = attentions.unsqueeze(dim=-1)
+
+ i0, Wh, Ww = self.image_model.forward_stem(image)
+ l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)
+
+ i1 = self.image_model.forward_stage1(i0, Wh, Ww)
+ l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
+ i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
+ l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask)
+ i1 = i1_temp
+
+ i2 = self.image_model.forward_stage2(i1, Wh, Ww)
+ l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
+ i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
+ l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask)
+ i2 = i2_temp
+
+ i3 = self.image_model.forward_stage3(i2, Wh, Ww)
+ l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
+ i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
+ l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask)
+ i3 = i3_temp
+
+ i4 = self.image_model.forward_stage4(i3, Wh, Ww)
+ l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
+ i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
+ l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask)
+ i4 = i4_temp
+
+ #i1_residual, i2_residual, i3_residual, i4_residual = features
+ #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
+ #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
+ outputs = {}
+ outputs['s1'] = i1_residual
+ outputs['s2'] = i2_residual
+ outputs['s3'] = i3_residual
+ outputs['s4'] = i4_residual
+
+ predictions, mask_predictions = self.classifier(outputs)
+ return predictions
+
+def main(args):
+#def main(local_rank, args):
+
+ #device = torch.device(args.device)
+ device = 'cuda'
+ dataset_test, _ = get_dataset(args.split, get_transform(args=args), args)
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+ data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1,
+ sampler=test_sampler, num_workers=args.workers)
+ print(args.model)
+ single_model = multimodal_segmentation_ppm.__dict__[args.model](pretrained='',args=args)
+ #single_model = MultiModalFocal(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], focal_windows=[9,9,9,9], drop_path_rate=0.3)
+ #single_model.init_weights('./focalnet_base_lrf.pth')
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ #single_model.load_state_dict(checkpoint['model'])
+ #model = single_model.to(device)
+
+ if args.model != 'lavt_one':
+ model_class = MultiModalBert
+ #single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=128)
+ single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=single_model.backbone.embed_dim)
+ # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines
+ if args.ddp_trained_weights:
+ single_bert_model.pooler = None
+ #single_bert_model.load_state_dict(checkpoint['bert_model'])
+ #bert_model = single_bert_model.to(device)
+ else:
+ bert_model = None
+
+ #model = WrapperModel(single_model.backbone, single_bert_model, single_model.classifier)
+ #model.load_state_dict(checkpoint['model'])
+ #model.to(device)
+ input_shape = dict()
+ input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
+ input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
+ input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
+ input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
+
+
+
+ cfg = Dict()
+ cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
+
+ cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
+
+
+ maskformer_head = MaskFormerHead(cfg, input_shape)
+ #maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head)
+ #maskformer_head.cuda()
+ #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False)
+ #single_head = maskformer_head.module
+ #print(single_head)
+
+ model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head, args)
+ model.load_state_dict(checkpoint['model'])
+ model.to(device)
+ #model.cuda()
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
+ #single_model = model.module
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
+ #single_model = model.module
+ evaluate(model, data_loader_test, device=device, args=args)
+
+
+if __name__ == "__main__":
+ from args import get_parser
+ parser = get_parser()
+ args = parser.parse_args()
+ print('Image size: {}'.format(str(args.img_size)))
+ print(args)
+ os.makedirs('vis/' + args.vis_dir, exist_ok=True)
+ main(args)
+ #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())
diff --git a/elia/visualize_final.py b/elia/visualize_final.py
new file mode 100644
index 0000000000000000000000000000000000000000..0947b95c496b33953301ac4f18a9533076daf27e
--- /dev/null
+++ b/elia/visualize_final.py
@@ -0,0 +1,506 @@
+
+import datetime
+import os
+import time
+
+import torch
+import torch.utils.data
+from torch import nn
+
+from bert.multimodal_bert import MultiModalBert
+import torchvision
+
+from lib import multimodal_segmentation_ppm
+import transforms as T
+import utils
+
+import numpy as np
+from PIL import Image
+import torch.nn.functional as F
+
+from modeling.MaskFormerModel import MaskFormerHead
+from addict import Dict
+from bert.modeling_bert import BertLMPredictionHead, BertEncoder
+import cv2
+import textwrap
+
+def get_dataset(image_set, transform, args):
+ from data.dataset_refer_bert_vis import ReferDataset
+ ds = ReferDataset(args,
+ split=image_set,
+ image_transforms=transform,
+ target_transforms=None,
+ eval_mode=True
+ )
+ num_classes = 2
+ return ds, num_classes
+
+
+def overlay_davis(image, mask, colors=[[0, 0, 0], [0, 255, 0]], cscale=1, alpha=0.4):
+ from scipy.ndimage.morphology import binary_dilation
+
+ colors = np.reshape(colors, (-1, 3))
+ colors = np.atleast_2d(colors) * cscale
+
+ im_overlay = image.copy()
+ object_ids = np.unique(mask)
+
+ for object_id in object_ids[1:]:
+ # Overlay color on binary mask
+ foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id])
+ binary_mask = mask == object_id
+
+ # Compose image
+ im_overlay[binary_mask] = foreground[binary_mask]
+
+ # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask
+ countours = binary_dilation(binary_mask) ^ binary_mask
+ # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask
+ im_overlay[countours, :] = 0
+
+ return im_overlay.astype(image.dtype)
+
+def evaluate(model, data_loader, device):
+ model.eval()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+
+ # evaluation variables
+ cum_I, cum_U = 0, 0
+ eval_seg_iou_list = [.5, .6, .7, .8, .9]
+ seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
+ seg_total = 0
+ mean_IoU = []
+ header = 'Test:'
+
+ with torch.no_grad():
+ number = 0
+ idx = 0
+ for data in metric_logger.log_every(data_loader, 100, header):
+ number +=1
+ idx += 1
+ print(number)
+ image, target, sentences, attentions, raw_sentences, this_img, orig_img = data
+ image, target, sentences, attentions = image.to(device), target.to(device), \
+ sentences.to(device), attentions.to(device)
+ #if number <= 40:
+ # continue
+
+ sentences = sentences.squeeze(1)
+ attentions = attentions.squeeze(1)
+ target = target.cpu().data.numpy()
+
+ orig_shape = orig_img.shape
+ orig_img = orig_img.numpy()[:, :, :, ::-1]
+ print(orig_img.shape, "??")
+
+ vis = np.zeros((480*2, 480*3,3)).astype(np.uint8)
+
+ image_mean_iou = []
+
+ for j in range(sentences.size(-1)):
+ #if bert_model is not None:
+ # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0]
+ # embedding = last_hidden_states.permute(0, 2, 1)
+ # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1))
+ #else:
+ output = model(image, sentences[:, :, j], attentions[:, :, j])
+ mask_cls_results = output["pred_logits"]
+ mask_pred_results = output["pred_masks"]
+
+ target_shape = target.shape[-2:]
+ mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)
+
+ pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)
+ output = pred_masks[0]
+
+ output = output.cpu()
+
+
+
+ #print(output.shape)
+ #output_mask = output.argmax(1).data.numpy()
+ output_mask = (output > 0.5).data.numpy()
+
+ #vis_output_mask = torch.sigmoid(output[:, 1]).data.numpy()
+ #vis_output_mask = torch.sigmoid((output>0.5).float()).data.numpy()
+ #soft
+ #vis_output_mask = output.data.numpy()
+ #vis_output_mask = output_mask
+
+ #print(output.shape, orig_shape)
+ orig_mask = torch.nn.functional.interpolate(pred_masks, (orig_shape[1], orig_shape[2]))
+ #print(orig_mask.shape)
+ orig_mask = (orig_mask > 0.5).data.cpu().numpy()
+ ##orig_mask = orig_mask.argmax(1).data.numpy()
+
+ print(orig_img[0].shape, orig_mask[0][0].shape, flush=True)
+ print(orig_img.dtype, orig_mask.dtype)
+ new = overlay_davis(orig_img[0], orig_mask[0][0].astype(np.uint8))
+ #print(orig_mask.shape, orig_img.shape)
+ #red_mask = np.zeros((orig_mask.shape[1], orig_mask.shape[2], orig_mask.shape[3], 3)).astype(np.uint8)
+ #print("???", red_mask.shape, orig_mask.shape)
+ #red_mask[:, :, :, 1] = orig_mask * 255
+ #red_mask = cv2.bitwise_and(red_mask, red_mask, orig_mask.astype(np.uint8))
+
+ #temp = cv2.addWeighted(red_mask, 0.5, orig_img, 0.5, 0)
+ #print(orig_img.shape, temp.shape, orig_mask.shape, "WHAT?")
+ #new = orig_img * (1.0 - orig_mask[0][:,:,:,None]) + temp * orig_mask[0][:,:,:,None]
+ #print(new.shape, orig_mask.shape, temp.shape, "check")
+ ##print(vis_output_mask)
+ ##output_mask = output.argmax(1).data.numpy()
+ #
+ #print(raw_sentences[j])
+
+ # print(image.shape, target.shape, output_mask.shape)
+
+ #mean = np.array([0.485, 0.456, 0.406])
+ #std = np.array([0.229, 0.224, 0.225])
+ #np_image = (((image[0].permute(1,2,0).cpu().numpy() * std) + mean) * 255).astype(np.uint8)[:,:,::-1]
+ #np_target = (target * 255).transpose(1,2,0).astype(np.uint8)
+ ##print(output_mask)
+ #np_output_mask = (vis_output_mask*255).transpose(1,2,0).repeat(3, axis=2).astype(np.uint8)
+
+ #font = cv2.FONT_HERSHEY_SIMPLEX
+ #fontScale = 0.75
+ #fontColor = (0,0,255)
+ #thickness = 1
+ #lineType = 2
+
+ #wrapped_text = textwrap.wrap(' '.join(raw_sentences[j]), width=35)
+ #for k, line in enumerate(wrapped_text):
+ # bottomLeftCornerOfText = (10,420+k*20)
+ # np_output_mask = cv2.putText(np_output_mask, line,
+ # bottomLeftCornerOfText,
+ # font,
+ # fontScale,
+ # fontColor,
+ # thickness,
+ # lineType)
+
+ #
+ #temp = j + 2
+ #split = temp // 3
+ #row = temp % 3
+ #vis[0:480, 0:480, :] = np_image
+ #vis[0:480, 480:960, :] = np_target.repeat(3, axis=2)
+ #vis[split*480:(split+1)*480:, row * 480:(row+1)*480, :] = np_output_mask
+
+
+
+ I, U = computeIoU(output_mask, target)
+ if U == 0:
+ this_iou = 0.0
+ else:
+ this_iou = I*1.0/U
+ mean_IoU.append(this_iou)
+ image_mean_iou.append(this_iou)
+ cum_I += I
+ cum_U += U
+ for n_eval_iou in range(len(eval_seg_iou_list)):
+ eval_seg_iou = eval_seg_iou_list[n_eval_iou]
+ seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
+ seg_total += 1
+ #cv2.imwrite("vis/elifan_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], j), new[0].astype(np.uint8))
+ cv2.imwrite("vis/elia_refcoco+_green/{:s}_{:d}_{:d}_{:.2f}.jpg".format(this_img[0].split('.')[0], idx, j, this_iou), new.astype(np.uint8))
+
+ print('---------------')
+ #cv2.imshow("vis", vis)
+ #cv2.waitKey(0)
+
+ image_mean_iou = np.mean(np.array(image_mean_iou))
+ print(image_mean_iou)
+ #if image_mean_iou < 0.5:
+ #cv2.imwrite("vis/elian_refcoco/{:s}_{:d}.jpg".format(this_img[0].split('.')[0], idx), vis)
+
+
+ #del image, target, sentences, attentions, output, output_mask
+ #if bert_model is not None:
+ # del last_hidden_states, embedding
+
+ mean_IoU = np.array(mean_IoU)
+ mIoU = np.mean(mean_IoU)
+ print('Final results:')
+ print('Mean IoU is %.2f\n' % (mIoU*100.))
+ results_str = ''
+ for n_eval_iou in range(len(eval_seg_iou_list)):
+ results_str += ' precision@%s = %.2f\n' % \
+ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
+ results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
+ print(results_str)
+
+#def evaluate(model, data_loader, device):
+# model.eval()
+# metric_logger = utils.MetricLogger(delimiter=" ")
+#
+# # evaluation variables
+# cum_I, cum_U = 0, 0
+# eval_seg_iou_list = [.5, .6, .7, .8, .9]
+# seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
+# seg_total = 0
+# mean_IoU = []
+# header = 'Test:'
+#
+# with torch.no_grad():
+# for data in metric_logger.log_every(data_loader, 100, header):
+# image, target, sentences, attentions = data
+# image, target, sentences, attentions = image.to(device), target.to(device), \
+# sentences.to(device), attentions.to(device)
+# sentences = sentences.squeeze(1)
+# attentions = attentions.squeeze(1)
+# target = target.cpu().data.numpy()
+# for j in range(sentences.size(-1)):
+# #if bert_model is not None:
+# # last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0]
+# # embedding = last_hidden_states.permute(0, 2, 1)
+# # output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1))
+# #else:
+# output = model(image, sentences[:, :, j], attentions[:, :, j])
+# mask_cls_results = output["pred_logits"]
+# mask_pred_results = output["pred_masks"]
+#
+# target_shape = target.shape[-2:]
+# mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True)
+#
+# pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)
+# output = pred_masks[0]
+#
+# output = output.cpu()
+# #print(output.shape)
+# #output_mask = output.argmax(1).data.numpy()
+# output_mask = (output > 0.5).data.numpy()
+# I, U = computeIoU(output_mask, target)
+# if U == 0:
+# this_iou = 0.0
+# else:
+# this_iou = I*1.0/U
+# mean_IoU.append(this_iou)
+# cum_I += I
+# cum_U += U
+# for n_eval_iou in range(len(eval_seg_iou_list)):
+# eval_seg_iou = eval_seg_iou_list[n_eval_iou]
+# seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou)
+# seg_total += 1
+#
+# #del image, target, sentences, attentions, output, output_mask
+# #if bert_model is not None:
+# # del last_hidden_states, embedding
+#
+# mean_IoU = np.array(mean_IoU)
+# mIoU = np.mean(mean_IoU)
+# print('Final results:')
+# print('Mean IoU is %.2f\n' % (mIoU*100.))
+# results_str = ''
+# for n_eval_iou in range(len(eval_seg_iou_list)):
+# results_str += ' precision@%s = %.2f\n' % \
+# (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
+# results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
+# print(results_str)
+
+
+def get_transform(args):
+ transforms = [T.Resize(args.img_size, args.img_size),
+ T.ToTensor(),
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ]
+
+ return T.Compose(transforms)
+
+
+def computeIoU(pred_seg, gd_seg):
+ I = np.sum(np.logical_and(pred_seg, gd_seg))
+ U = np.sum(np.logical_or(pred_seg, gd_seg))
+
+ return I, U
+
+class WrapperModel(nn.Module):
+ def __init__(self, image_model, language_model, classifier, args) :
+ super(WrapperModel, self).__init__()
+ self.image_model = image_model
+ self.language_model = language_model
+ self.classifier = classifier
+ self.lang_proj = nn.Linear(768,256)
+
+ config = Dict({
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "gradient_checkpointing": False,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 512,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ #"max_position_embeddings": 16+20,
+ "model_type": "bert",
+ "num_attention_heads": 8,
+ "num_hidden_layers": 8,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "transformers_version": "4.6.0.dev0",
+ "type_vocab_size": 2,
+ "use_cache": True,
+ "vocab_size": 30522
+ })
+ self.mlm_transformer = BertEncoder(config)
+
+ self.lang_proj = nn.Linear(768,256)
+ self.mlm_vis_proj = nn.Conv2d(1024,512,1)
+ self.mlm_lang_proj = nn.Linear(768,512)
+ #print(vis_proj)
+ self.mlm_head = BertLMPredictionHead(config)
+
+ assert args.img_size % 4 == 0
+ num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2
+ print(num_img_tokens)
+ self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512)
+ self.mlm_modal_embeds = nn.Embedding(3, 512)
+
+ self.mlm_mask_embed = nn.Embedding(1, 512)
+ self.mlm_pos_mlp = nn.Sequential(
+ nn.Linear(2, 512),
+ nn.LayerNorm(512),
+ nn.Linear(512,512),
+ nn.GELU()
+ )
+
+ def _get_binary_mask(self, target):
+ # 返回每类的binary mask
+ y, x = target.size()
+ target_onehot = torch.zeros(self.num_classes + 1, y, x)
+ target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
+ return target_onehot[1:]
+
+ def semantic_inference(self, mask_cls, mask_pred):
+ mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
+ mask_pred = mask_pred.sigmoid()
+ semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
+ return semseg
+
+ def forward(self, image, sentences, attentions):
+ input_shape = image.shape[-2:]
+ l_mask = attentions.unsqueeze(dim=-1)
+
+ i0, Wh, Ww = self.image_model.forward_stem(image)
+ l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)
+
+ i1 = self.image_model.forward_stage1(i0, Wh, Ww)
+ l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
+ i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
+ l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask)
+ i1 = i1_temp
+
+ i2 = self.image_model.forward_stage2(i1, Wh, Ww)
+ l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
+ i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
+ l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask)
+ i2 = i2_temp
+
+ i3 = self.image_model.forward_stage3(i2, Wh, Ww)
+ l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
+ i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
+ l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask)
+ i3 = i3_temp
+
+ i4 = self.image_model.forward_stage4(i3, Wh, Ww)
+ l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
+ i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
+ l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask)
+ i4 = i4_temp
+
+ #i1_residual, i2_residual, i3_residual, i4_residual = features
+ #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
+ #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
+ outputs = {}
+ outputs['s1'] = i1_residual
+ outputs['s2'] = i2_residual
+ outputs['s3'] = i3_residual
+ outputs['s4'] = i4_residual
+
+ predictions = self.classifier(outputs)
+ return predictions
+
+def main(args):
+#def main(local_rank, args):
+
+ #device = torch.device(args.device)
+ device = 'cuda'
+ dataset_test, _ = get_dataset(args.split, get_transform(args=args), args)
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+ data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1,
+ sampler=test_sampler, num_workers=args.workers)
+ print(args.model)
+ single_model = multimodal_segmentation_ppm.__dict__[args.model](pretrained='',args=args)
+ #single_model = MultiModalFocal(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], focal_windows=[9,9,9,9], drop_path_rate=0.3)
+ #single_model.init_weights('./focalnet_base_lrf.pth')
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ #single_model.load_state_dict(checkpoint['model'])
+ #model = single_model.to(device)
+
+ if args.model != 'lavt_one':
+ model_class = MultiModalBert
+ #single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=128)
+ single_bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=single_model.backbone.embed_dim)
+ # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines
+ if args.ddp_trained_weights:
+ single_bert_model.pooler = None
+ #single_bert_model.load_state_dict(checkpoint['bert_model'])
+ #bert_model = single_bert_model.to(device)
+ else:
+ bert_model = None
+
+ #model = WrapperModel(single_model.backbone, single_bert_model, single_model.classifier)
+ #model.load_state_dict(checkpoint['model'])
+ #model.to(device)
+ input_shape = dict()
+ input_shape['s1'] = Dict({'channel': 128, 'stride': 4})
+ input_shape['s2'] = Dict({'channel': 256, 'stride': 8})
+ input_shape['s3'] = Dict({'channel': 512, 'stride': 16})
+ input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})
+
+
+
+ cfg = Dict()
+ cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.0
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]
+
+ cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
+
+
+ maskformer_head = MaskFormerHead(cfg, input_shape)
+ #maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head)
+ #maskformer_head.cuda()
+ #maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False)
+ #single_head = maskformer_head.module
+ #print(single_head)
+
+ model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head, args)
+ model.load_state_dict(checkpoint['model'])
+ model.to(device)
+ #model.cuda()
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
+ #single_model = model.module
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
+ #single_model = model.module
+ evaluate(model, data_loader_test, device=device)
+
+
+if __name__ == "__main__":
+ from args import get_parser
+ parser = get_parser()
+ args = parser.parse_args()
+ print('Image size: {}'.format(str(args.img_size)))
+ print(args)
+ main(args)
+ #mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count())